1 //===-- IPO/OpenMPOpt.cpp - Collection of OpenMP specific optimizations ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // OpenMP specific optimizations:
10 //
11 // - Deduplication of runtime calls, e.g., omp_get_thread_num.
12 // - Replacing globalized device memory with stack memory.
13 // - Replacing globalized device memory with shared memory.
14 // - Parallel region merging.
15 // - Transforming generic-mode device kernels to SPMD mode.
16 // - Specializing the state machine for generic-mode device kernels.
17 //
18 //===----------------------------------------------------------------------===//
19 
20 #include "llvm/Transforms/IPO/OpenMPOpt.h"
21 
22 #include "llvm/ADT/EnumeratedArray.h"
23 #include "llvm/ADT/PostOrderIterator.h"
24 #include "llvm/ADT/SetVector.h"
25 #include "llvm/ADT/Statistic.h"
26 #include "llvm/ADT/StringRef.h"
27 #include "llvm/Analysis/CallGraph.h"
28 #include "llvm/Analysis/CallGraphSCCPass.h"
29 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
30 #include "llvm/Analysis/ValueTracking.h"
31 #include "llvm/Frontend/OpenMP/OMPConstants.h"
32 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
33 #include "llvm/IR/Assumptions.h"
34 #include "llvm/IR/DiagnosticInfo.h"
35 #include "llvm/IR/GlobalValue.h"
36 #include "llvm/IR/Instruction.h"
37 #include "llvm/IR/IntrinsicInst.h"
38 #include "llvm/IR/IntrinsicsAMDGPU.h"
39 #include "llvm/IR/IntrinsicsNVPTX.h"
40 #include "llvm/InitializePasses.h"
41 #include "llvm/Support/CommandLine.h"
42 #include "llvm/Transforms/IPO.h"
43 #include "llvm/Transforms/IPO/Attributor.h"
44 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
45 #include "llvm/Transforms/Utils/CallGraphUpdater.h"
46 #include "llvm/Transforms/Utils/CodeExtractor.h"
47 
48 #include <algorithm>
49 
50 using namespace llvm;
51 using namespace omp;
52 
53 #define DEBUG_TYPE "openmp-opt"
54 
55 static cl::opt<bool> DisableOpenMPOptimizations(
56     "openmp-opt-disable", cl::ZeroOrMore,
57     cl::desc("Disable OpenMP specific optimizations."), cl::Hidden,
58     cl::init(false));
59 
60 static cl::opt<bool> EnableParallelRegionMerging(
61     "openmp-opt-enable-merging", cl::ZeroOrMore,
62     cl::desc("Enable the OpenMP region merging optimization."), cl::Hidden,
63     cl::init(false));
64 
65 static cl::opt<bool>
66     DisableInternalization("openmp-opt-disable-internalization", cl::ZeroOrMore,
67                            cl::desc("Disable function internalization."),
68                            cl::Hidden, cl::init(false));
69 
70 static cl::opt<bool> PrintICVValues("openmp-print-icv-values", cl::init(false),
71                                     cl::Hidden);
72 static cl::opt<bool> PrintOpenMPKernels("openmp-print-gpu-kernels",
73                                         cl::init(false), cl::Hidden);
74 
75 static cl::opt<bool> HideMemoryTransferLatency(
76     "openmp-hide-memory-transfer-latency",
77     cl::desc("[WIP] Tries to hide the latency of host to device memory"
78              " transfers"),
79     cl::Hidden, cl::init(false));
80 
81 static cl::opt<bool> DisableOpenMPOptDeglobalization(
82     "openmp-opt-disable-deglobalization", cl::ZeroOrMore,
83     cl::desc("Disable OpenMP optimizations involving deglobalization."),
84     cl::Hidden, cl::init(false));
85 
86 static cl::opt<bool> DisableOpenMPOptSPMDization(
87     "openmp-opt-disable-spmdization", cl::ZeroOrMore,
88     cl::desc("Disable OpenMP optimizations involving SPMD-ization."),
89     cl::Hidden, cl::init(false));
90 
91 static cl::opt<bool> DisableOpenMPOptFolding(
92     "openmp-opt-disable-folding", cl::ZeroOrMore,
93     cl::desc("Disable OpenMP optimizations involving folding."), cl::Hidden,
94     cl::init(false));
95 
96 static cl::opt<bool> DisableOpenMPOptStateMachineRewrite(
97     "openmp-opt-disable-state-machine-rewrite", cl::ZeroOrMore,
98     cl::desc("Disable OpenMP optimizations that replace the state machine."),
99     cl::Hidden, cl::init(false));
100 
101 static cl::opt<bool> PrintModuleAfterOptimizations(
102     "openmp-opt-print-module", cl::ZeroOrMore,
103     cl::desc("Print the current module after OpenMP optimizations."),
104     cl::Hidden, cl::init(false));
105 
106 static cl::opt<bool> AlwaysInlineDeviceFunctions(
107     "openmp-opt-inline-device", cl::ZeroOrMore,
108     cl::desc("Inline all applicible functions on the device."), cl::Hidden,
109     cl::init(false));
110 
111 static cl::opt<bool>
112     EnableVerboseRemarks("openmp-opt-verbose-remarks", cl::ZeroOrMore,
113                          cl::desc("Enables more verbose remarks."), cl::Hidden,
114                          cl::init(false));
115 
116 static cl::opt<unsigned>
117     SetFixpointIterations("openmp-opt-max-iterations", cl::Hidden,
118                           cl::desc("Maximal number of attributor iterations."),
119                           cl::init(256));
120 
121 STATISTIC(NumOpenMPRuntimeCallsDeduplicated,
122           "Number of OpenMP runtime calls deduplicated");
123 STATISTIC(NumOpenMPParallelRegionsDeleted,
124           "Number of OpenMP parallel regions deleted");
125 STATISTIC(NumOpenMPRuntimeFunctionsIdentified,
126           "Number of OpenMP runtime functions identified");
127 STATISTIC(NumOpenMPRuntimeFunctionUsesIdentified,
128           "Number of OpenMP runtime function uses identified");
129 STATISTIC(NumOpenMPTargetRegionKernels,
130           "Number of OpenMP target region entry points (=kernels) identified");
131 STATISTIC(NumOpenMPTargetRegionKernelsSPMD,
132           "Number of OpenMP target region entry points (=kernels) executed in "
133           "SPMD-mode instead of generic-mode");
134 STATISTIC(NumOpenMPTargetRegionKernelsWithoutStateMachine,
135           "Number of OpenMP target region entry points (=kernels) executed in "
136           "generic-mode without a state machines");
137 STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback,
138           "Number of OpenMP target region entry points (=kernels) executed in "
139           "generic-mode with customized state machines with fallback");
140 STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback,
141           "Number of OpenMP target region entry points (=kernels) executed in "
142           "generic-mode with customized state machines without fallback");
143 STATISTIC(
144     NumOpenMPParallelRegionsReplacedInGPUStateMachine,
145     "Number of OpenMP parallel regions replaced with ID in GPU state machines");
146 STATISTIC(NumOpenMPParallelRegionsMerged,
147           "Number of OpenMP parallel regions merged");
148 STATISTIC(NumBytesMovedToSharedMemory,
149           "Amount of memory pushed to shared memory");
150 
151 #if !defined(NDEBUG)
152 static constexpr auto TAG = "[" DEBUG_TYPE "]";
153 #endif
154 
155 namespace {
156 
157 struct AAHeapToShared;
158 
159 struct AAICVTracker;
160 
161 /// OpenMP specific information. For now, stores RFIs and ICVs also needed for
162 /// Attributor runs.
163 struct OMPInformationCache : public InformationCache {
164   OMPInformationCache(Module &M, AnalysisGetter &AG,
165                       BumpPtrAllocator &Allocator, SetVector<Function *> &CGSCC,
166                       KernelSet &Kernels)
167       : InformationCache(M, AG, Allocator, &CGSCC), OMPBuilder(M),
168         Kernels(Kernels) {
169 
170     OMPBuilder.initialize();
171     initializeRuntimeFunctions();
172     initializeInternalControlVars();
173   }
174 
175   /// Generic information that describes an internal control variable.
176   struct InternalControlVarInfo {
177     /// The kind, as described by InternalControlVar enum.
178     InternalControlVar Kind;
179 
180     /// The name of the ICV.
181     StringRef Name;
182 
183     /// Environment variable associated with this ICV.
184     StringRef EnvVarName;
185 
186     /// Initial value kind.
187     ICVInitValue InitKind;
188 
189     /// Initial value.
190     ConstantInt *InitValue;
191 
192     /// Setter RTL function associated with this ICV.
193     RuntimeFunction Setter;
194 
195     /// Getter RTL function associated with this ICV.
196     RuntimeFunction Getter;
197 
198     /// RTL Function corresponding to the override clause of this ICV
199     RuntimeFunction Clause;
200   };
201 
202   /// Generic information that describes a runtime function
203   struct RuntimeFunctionInfo {
204 
205     /// The kind, as described by the RuntimeFunction enum.
206     RuntimeFunction Kind;
207 
208     /// The name of the function.
209     StringRef Name;
210 
211     /// Flag to indicate a variadic function.
212     bool IsVarArg;
213 
214     /// The return type of the function.
215     Type *ReturnType;
216 
217     /// The argument types of the function.
218     SmallVector<Type *, 8> ArgumentTypes;
219 
220     /// The declaration if available.
221     Function *Declaration = nullptr;
222 
223     /// Uses of this runtime function per function containing the use.
224     using UseVector = SmallVector<Use *, 16>;
225 
226     /// Clear UsesMap for runtime function.
227     void clearUsesMap() { UsesMap.clear(); }
228 
229     /// Boolean conversion that is true if the runtime function was found.
230     operator bool() const { return Declaration; }
231 
232     /// Return the vector of uses in function \p F.
233     UseVector &getOrCreateUseVector(Function *F) {
234       std::shared_ptr<UseVector> &UV = UsesMap[F];
235       if (!UV)
236         UV = std::make_shared<UseVector>();
237       return *UV;
238     }
239 
240     /// Return the vector of uses in function \p F or `nullptr` if there are
241     /// none.
242     const UseVector *getUseVector(Function &F) const {
243       auto I = UsesMap.find(&F);
244       if (I != UsesMap.end())
245         return I->second.get();
246       return nullptr;
247     }
248 
249     /// Return how many functions contain uses of this runtime function.
250     size_t getNumFunctionsWithUses() const { return UsesMap.size(); }
251 
252     /// Return the number of arguments (or the minimal number for variadic
253     /// functions).
254     size_t getNumArgs() const { return ArgumentTypes.size(); }
255 
256     /// Run the callback \p CB on each use and forget the use if the result is
257     /// true. The callback will be fed the function in which the use was
258     /// encountered as second argument.
259     void foreachUse(SmallVectorImpl<Function *> &SCC,
260                     function_ref<bool(Use &, Function &)> CB) {
261       for (Function *F : SCC)
262         foreachUse(CB, F);
263     }
264 
265     /// Run the callback \p CB on each use within the function \p F and forget
266     /// the use if the result is true.
267     void foreachUse(function_ref<bool(Use &, Function &)> CB, Function *F) {
268       SmallVector<unsigned, 8> ToBeDeleted;
269       ToBeDeleted.clear();
270 
271       unsigned Idx = 0;
272       UseVector &UV = getOrCreateUseVector(F);
273 
274       for (Use *U : UV) {
275         if (CB(*U, *F))
276           ToBeDeleted.push_back(Idx);
277         ++Idx;
278       }
279 
280       // Remove the to-be-deleted indices in reverse order as prior
281       // modifications will not modify the smaller indices.
282       while (!ToBeDeleted.empty()) {
283         unsigned Idx = ToBeDeleted.pop_back_val();
284         UV[Idx] = UV.back();
285         UV.pop_back();
286       }
287     }
288 
289   private:
290     /// Map from functions to all uses of this runtime function contained in
291     /// them.
292     DenseMap<Function *, std::shared_ptr<UseVector>> UsesMap;
293 
294   public:
295     /// Iterators for the uses of this runtime function.
296     decltype(UsesMap)::iterator begin() { return UsesMap.begin(); }
297     decltype(UsesMap)::iterator end() { return UsesMap.end(); }
298   };
299 
300   /// An OpenMP-IR-Builder instance
301   OpenMPIRBuilder OMPBuilder;
302 
303   /// Map from runtime function kind to the runtime function description.
304   EnumeratedArray<RuntimeFunctionInfo, RuntimeFunction,
305                   RuntimeFunction::OMPRTL___last>
306       RFIs;
307 
308   /// Map from function declarations/definitions to their runtime enum type.
309   DenseMap<Function *, RuntimeFunction> RuntimeFunctionIDMap;
310 
311   /// Map from ICV kind to the ICV description.
312   EnumeratedArray<InternalControlVarInfo, InternalControlVar,
313                   InternalControlVar::ICV___last>
314       ICVs;
315 
316   /// Helper to initialize all internal control variable information for those
317   /// defined in OMPKinds.def.
318   void initializeInternalControlVars() {
319 #define ICV_RT_SET(_Name, RTL)                                                 \
320   {                                                                            \
321     auto &ICV = ICVs[_Name];                                                   \
322     ICV.Setter = RTL;                                                          \
323   }
324 #define ICV_RT_GET(Name, RTL)                                                  \
325   {                                                                            \
326     auto &ICV = ICVs[Name];                                                    \
327     ICV.Getter = RTL;                                                          \
328   }
329 #define ICV_DATA_ENV(Enum, _Name, _EnvVarName, Init)                           \
330   {                                                                            \
331     auto &ICV = ICVs[Enum];                                                    \
332     ICV.Name = _Name;                                                          \
333     ICV.Kind = Enum;                                                           \
334     ICV.InitKind = Init;                                                       \
335     ICV.EnvVarName = _EnvVarName;                                              \
336     switch (ICV.InitKind) {                                                    \
337     case ICV_IMPLEMENTATION_DEFINED:                                           \
338       ICV.InitValue = nullptr;                                                 \
339       break;                                                                   \
340     case ICV_ZERO:                                                             \
341       ICV.InitValue = ConstantInt::get(                                        \
342           Type::getInt32Ty(OMPBuilder.Int32->getContext()), 0);                \
343       break;                                                                   \
344     case ICV_FALSE:                                                            \
345       ICV.InitValue = ConstantInt::getFalse(OMPBuilder.Int1->getContext());    \
346       break;                                                                   \
347     case ICV_LAST:                                                             \
348       break;                                                                   \
349     }                                                                          \
350   }
351 #include "llvm/Frontend/OpenMP/OMPKinds.def"
352   }
353 
354   /// Returns true if the function declaration \p F matches the runtime
355   /// function types, that is, return type \p RTFRetType, and argument types
356   /// \p RTFArgTypes.
357   static bool declMatchesRTFTypes(Function *F, Type *RTFRetType,
358                                   SmallVector<Type *, 8> &RTFArgTypes) {
359     // TODO: We should output information to the user (under debug output
360     //       and via remarks).
361 
362     if (!F)
363       return false;
364     if (F->getReturnType() != RTFRetType)
365       return false;
366     if (F->arg_size() != RTFArgTypes.size())
367       return false;
368 
369     auto *RTFTyIt = RTFArgTypes.begin();
370     for (Argument &Arg : F->args()) {
371       if (Arg.getType() != *RTFTyIt)
372         return false;
373 
374       ++RTFTyIt;
375     }
376 
377     return true;
378   }
379 
380   // Helper to collect all uses of the declaration in the UsesMap.
381   unsigned collectUses(RuntimeFunctionInfo &RFI, bool CollectStats = true) {
382     unsigned NumUses = 0;
383     if (!RFI.Declaration)
384       return NumUses;
385     OMPBuilder.addAttributes(RFI.Kind, *RFI.Declaration);
386 
387     if (CollectStats) {
388       NumOpenMPRuntimeFunctionsIdentified += 1;
389       NumOpenMPRuntimeFunctionUsesIdentified += RFI.Declaration->getNumUses();
390     }
391 
392     // TODO: We directly convert uses into proper calls and unknown uses.
393     for (Use &U : RFI.Declaration->uses()) {
394       if (Instruction *UserI = dyn_cast<Instruction>(U.getUser())) {
395         if (ModuleSlice.count(UserI->getFunction())) {
396           RFI.getOrCreateUseVector(UserI->getFunction()).push_back(&U);
397           ++NumUses;
398         }
399       } else {
400         RFI.getOrCreateUseVector(nullptr).push_back(&U);
401         ++NumUses;
402       }
403     }
404     return NumUses;
405   }
406 
407   // Helper function to recollect uses of a runtime function.
408   void recollectUsesForFunction(RuntimeFunction RTF) {
409     auto &RFI = RFIs[RTF];
410     RFI.clearUsesMap();
411     collectUses(RFI, /*CollectStats*/ false);
412   }
413 
414   // Helper function to recollect uses of all runtime functions.
415   void recollectUses() {
416     for (int Idx = 0; Idx < RFIs.size(); ++Idx)
417       recollectUsesForFunction(static_cast<RuntimeFunction>(Idx));
418   }
419 
420   // Helper function to inherit the calling convention of the function callee.
421   void setCallingConvention(FunctionCallee Callee, CallInst *CI) {
422     if (Function *Fn = dyn_cast<Function>(Callee.getCallee()))
423       CI->setCallingConv(Fn->getCallingConv());
424   }
425 
426   /// Helper to initialize all runtime function information for those defined
427   /// in OpenMPKinds.def.
428   void initializeRuntimeFunctions() {
429     Module &M = *((*ModuleSlice.begin())->getParent());
430 
431     // Helper macros for handling __VA_ARGS__ in OMP_RTL
432 #define OMP_TYPE(VarName, ...)                                                 \
433   Type *VarName = OMPBuilder.VarName;                                          \
434   (void)VarName;
435 
436 #define OMP_ARRAY_TYPE(VarName, ...)                                           \
437   ArrayType *VarName##Ty = OMPBuilder.VarName##Ty;                             \
438   (void)VarName##Ty;                                                           \
439   PointerType *VarName##PtrTy = OMPBuilder.VarName##PtrTy;                     \
440   (void)VarName##PtrTy;
441 
442 #define OMP_FUNCTION_TYPE(VarName, ...)                                        \
443   FunctionType *VarName = OMPBuilder.VarName;                                  \
444   (void)VarName;                                                               \
445   PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr;                         \
446   (void)VarName##Ptr;
447 
448 #define OMP_STRUCT_TYPE(VarName, ...)                                          \
449   StructType *VarName = OMPBuilder.VarName;                                    \
450   (void)VarName;                                                               \
451   PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr;                         \
452   (void)VarName##Ptr;
453 
454 #define OMP_RTL(_Enum, _Name, _IsVarArg, _ReturnType, ...)                     \
455   {                                                                            \
456     SmallVector<Type *, 8> ArgsTypes({__VA_ARGS__});                           \
457     Function *F = M.getFunction(_Name);                                        \
458     RTLFunctions.insert(F);                                                    \
459     if (declMatchesRTFTypes(F, OMPBuilder._ReturnType, ArgsTypes)) {           \
460       RuntimeFunctionIDMap[F] = _Enum;                                         \
461       F->removeFnAttr(Attribute::NoInline);                                    \
462       auto &RFI = RFIs[_Enum];                                                 \
463       RFI.Kind = _Enum;                                                        \
464       RFI.Name = _Name;                                                        \
465       RFI.IsVarArg = _IsVarArg;                                                \
466       RFI.ReturnType = OMPBuilder._ReturnType;                                 \
467       RFI.ArgumentTypes = std::move(ArgsTypes);                                \
468       RFI.Declaration = F;                                                     \
469       unsigned NumUses = collectUses(RFI);                                     \
470       (void)NumUses;                                                           \
471       LLVM_DEBUG({                                                             \
472         dbgs() << TAG << RFI.Name << (RFI.Declaration ? "" : " not")           \
473                << " found\n";                                                  \
474         if (RFI.Declaration)                                                   \
475           dbgs() << TAG << "-> got " << NumUses << " uses in "                 \
476                  << RFI.getNumFunctionsWithUses()                              \
477                  << " different functions.\n";                                 \
478       });                                                                      \
479     }                                                                          \
480   }
481 #include "llvm/Frontend/OpenMP/OMPKinds.def"
482 
483     // TODO: We should attach the attributes defined in OMPKinds.def.
484   }
485 
486   /// Collection of known kernels (\see Kernel) in the module.
487   KernelSet &Kernels;
488 
489   /// Collection of known OpenMP runtime functions..
490   DenseSet<const Function *> RTLFunctions;
491 };
492 
493 template <typename Ty, bool InsertInvalidates = true>
494 struct BooleanStateWithSetVector : public BooleanState {
495   bool contains(const Ty &Elem) const { return Set.contains(Elem); }
496   bool insert(const Ty &Elem) {
497     if (InsertInvalidates)
498       BooleanState::indicatePessimisticFixpoint();
499     return Set.insert(Elem);
500   }
501 
502   const Ty &operator[](int Idx) const { return Set[Idx]; }
503   bool operator==(const BooleanStateWithSetVector &RHS) const {
504     return BooleanState::operator==(RHS) && Set == RHS.Set;
505   }
506   bool operator!=(const BooleanStateWithSetVector &RHS) const {
507     return !(*this == RHS);
508   }
509 
510   bool empty() const { return Set.empty(); }
511   size_t size() const { return Set.size(); }
512 
513   /// "Clamp" this state with \p RHS.
514   BooleanStateWithSetVector &operator^=(const BooleanStateWithSetVector &RHS) {
515     BooleanState::operator^=(RHS);
516     Set.insert(RHS.Set.begin(), RHS.Set.end());
517     return *this;
518   }
519 
520 private:
521   /// A set to keep track of elements.
522   SetVector<Ty> Set;
523 
524 public:
525   typename decltype(Set)::iterator begin() { return Set.begin(); }
526   typename decltype(Set)::iterator end() { return Set.end(); }
527   typename decltype(Set)::const_iterator begin() const { return Set.begin(); }
528   typename decltype(Set)::const_iterator end() const { return Set.end(); }
529 };
530 
531 template <typename Ty, bool InsertInvalidates = true>
532 using BooleanStateWithPtrSetVector =
533     BooleanStateWithSetVector<Ty *, InsertInvalidates>;
534 
535 struct KernelInfoState : AbstractState {
536   /// Flag to track if we reached a fixpoint.
537   bool IsAtFixpoint = false;
538 
539   /// The parallel regions (identified by the outlined parallel functions) that
540   /// can be reached from the associated function.
541   BooleanStateWithPtrSetVector<Function, /* InsertInvalidates */ false>
542       ReachedKnownParallelRegions;
543 
544   /// State to track what parallel region we might reach.
545   BooleanStateWithPtrSetVector<CallBase> ReachedUnknownParallelRegions;
546 
547   /// State to track if we are in SPMD-mode, assumed or know, and why we decided
548   /// we cannot be. If it is assumed, then RequiresFullRuntime should also be
549   /// false.
550   BooleanStateWithPtrSetVector<Instruction, false> SPMDCompatibilityTracker;
551 
552   /// The __kmpc_target_init call in this kernel, if any. If we find more than
553   /// one we abort as the kernel is malformed.
554   CallBase *KernelInitCB = nullptr;
555 
556   /// The __kmpc_target_deinit call in this kernel, if any. If we find more than
557   /// one we abort as the kernel is malformed.
558   CallBase *KernelDeinitCB = nullptr;
559 
560   /// Flag to indicate if the associated function is a kernel entry.
561   bool IsKernelEntry = false;
562 
563   /// State to track what kernel entries can reach the associated function.
564   BooleanStateWithPtrSetVector<Function, false> ReachingKernelEntries;
565 
566   /// State to indicate if we can track parallel level of the associated
567   /// function. We will give up tracking if we encounter unknown caller or the
568   /// caller is __kmpc_parallel_51.
569   BooleanStateWithSetVector<uint8_t> ParallelLevels;
570 
571   /// Abstract State interface
572   ///{
573 
574   KernelInfoState() {}
575   KernelInfoState(bool BestState) {
576     if (!BestState)
577       indicatePessimisticFixpoint();
578   }
579 
580   /// See AbstractState::isValidState(...)
581   bool isValidState() const override { return true; }
582 
583   /// See AbstractState::isAtFixpoint(...)
584   bool isAtFixpoint() const override { return IsAtFixpoint; }
585 
586   /// See AbstractState::indicatePessimisticFixpoint(...)
587   ChangeStatus indicatePessimisticFixpoint() override {
588     IsAtFixpoint = true;
589     ReachingKernelEntries.indicatePessimisticFixpoint();
590     SPMDCompatibilityTracker.indicatePessimisticFixpoint();
591     ReachedKnownParallelRegions.indicatePessimisticFixpoint();
592     ReachedUnknownParallelRegions.indicatePessimisticFixpoint();
593     return ChangeStatus::CHANGED;
594   }
595 
596   /// See AbstractState::indicateOptimisticFixpoint(...)
597   ChangeStatus indicateOptimisticFixpoint() override {
598     IsAtFixpoint = true;
599     ReachingKernelEntries.indicateOptimisticFixpoint();
600     SPMDCompatibilityTracker.indicateOptimisticFixpoint();
601     ReachedKnownParallelRegions.indicateOptimisticFixpoint();
602     ReachedUnknownParallelRegions.indicateOptimisticFixpoint();
603     return ChangeStatus::UNCHANGED;
604   }
605 
606   /// Return the assumed state
607   KernelInfoState &getAssumed() { return *this; }
608   const KernelInfoState &getAssumed() const { return *this; }
609 
610   bool operator==(const KernelInfoState &RHS) const {
611     if (SPMDCompatibilityTracker != RHS.SPMDCompatibilityTracker)
612       return false;
613     if (ReachedKnownParallelRegions != RHS.ReachedKnownParallelRegions)
614       return false;
615     if (ReachedUnknownParallelRegions != RHS.ReachedUnknownParallelRegions)
616       return false;
617     if (ReachingKernelEntries != RHS.ReachingKernelEntries)
618       return false;
619     return true;
620   }
621 
622   /// Returns true if this kernel contains any OpenMP parallel regions.
623   bool mayContainParallelRegion() {
624     return !ReachedKnownParallelRegions.empty() ||
625            !ReachedUnknownParallelRegions.empty();
626   }
627 
628   /// Return empty set as the best state of potential values.
629   static KernelInfoState getBestState() { return KernelInfoState(true); }
630 
631   static KernelInfoState getBestState(KernelInfoState &KIS) {
632     return getBestState();
633   }
634 
635   /// Return full set as the worst state of potential values.
636   static KernelInfoState getWorstState() { return KernelInfoState(false); }
637 
638   /// "Clamp" this state with \p KIS.
639   KernelInfoState operator^=(const KernelInfoState &KIS) {
640     // Do not merge two different _init and _deinit call sites.
641     if (KIS.KernelInitCB) {
642       if (KernelInitCB && KernelInitCB != KIS.KernelInitCB)
643         llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
644                          "assumptions.");
645       KernelInitCB = KIS.KernelInitCB;
646     }
647     if (KIS.KernelDeinitCB) {
648       if (KernelDeinitCB && KernelDeinitCB != KIS.KernelDeinitCB)
649         llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
650                          "assumptions.");
651       KernelDeinitCB = KIS.KernelDeinitCB;
652     }
653     SPMDCompatibilityTracker ^= KIS.SPMDCompatibilityTracker;
654     ReachedKnownParallelRegions ^= KIS.ReachedKnownParallelRegions;
655     ReachedUnknownParallelRegions ^= KIS.ReachedUnknownParallelRegions;
656     return *this;
657   }
658 
659   KernelInfoState operator&=(const KernelInfoState &KIS) {
660     return (*this ^= KIS);
661   }
662 
663   ///}
664 };
665 
666 /// Used to map the values physically (in the IR) stored in an offload
667 /// array, to a vector in memory.
668 struct OffloadArray {
669   /// Physical array (in the IR).
670   AllocaInst *Array = nullptr;
671   /// Mapped values.
672   SmallVector<Value *, 8> StoredValues;
673   /// Last stores made in the offload array.
674   SmallVector<StoreInst *, 8> LastAccesses;
675 
676   OffloadArray() = default;
677 
678   /// Initializes the OffloadArray with the values stored in \p Array before
679   /// instruction \p Before is reached. Returns false if the initialization
680   /// fails.
681   /// This MUST be used immediately after the construction of the object.
682   bool initialize(AllocaInst &Array, Instruction &Before) {
683     if (!Array.getAllocatedType()->isArrayTy())
684       return false;
685 
686     if (!getValues(Array, Before))
687       return false;
688 
689     this->Array = &Array;
690     return true;
691   }
692 
693   static const unsigned DeviceIDArgNum = 1;
694   static const unsigned BasePtrsArgNum = 3;
695   static const unsigned PtrsArgNum = 4;
696   static const unsigned SizesArgNum = 5;
697 
698 private:
699   /// Traverses the BasicBlock where \p Array is, collecting the stores made to
700   /// \p Array, leaving StoredValues with the values stored before the
701   /// instruction \p Before is reached.
702   bool getValues(AllocaInst &Array, Instruction &Before) {
703     // Initialize container.
704     const uint64_t NumValues = Array.getAllocatedType()->getArrayNumElements();
705     StoredValues.assign(NumValues, nullptr);
706     LastAccesses.assign(NumValues, nullptr);
707 
708     // TODO: This assumes the instruction \p Before is in the same
709     //  BasicBlock as Array. Make it general, for any control flow graph.
710     BasicBlock *BB = Array.getParent();
711     if (BB != Before.getParent())
712       return false;
713 
714     const DataLayout &DL = Array.getModule()->getDataLayout();
715     const unsigned int PointerSize = DL.getPointerSize();
716 
717     for (Instruction &I : *BB) {
718       if (&I == &Before)
719         break;
720 
721       if (!isa<StoreInst>(&I))
722         continue;
723 
724       auto *S = cast<StoreInst>(&I);
725       int64_t Offset = -1;
726       auto *Dst =
727           GetPointerBaseWithConstantOffset(S->getPointerOperand(), Offset, DL);
728       if (Dst == &Array) {
729         int64_t Idx = Offset / PointerSize;
730         StoredValues[Idx] = getUnderlyingObject(S->getValueOperand());
731         LastAccesses[Idx] = S;
732       }
733     }
734 
735     return isFilled();
736   }
737 
738   /// Returns true if all values in StoredValues and
739   /// LastAccesses are not nullptrs.
740   bool isFilled() {
741     const unsigned NumValues = StoredValues.size();
742     for (unsigned I = 0; I < NumValues; ++I) {
743       if (!StoredValues[I] || !LastAccesses[I])
744         return false;
745     }
746 
747     return true;
748   }
749 };
750 
751 struct OpenMPOpt {
752 
753   using OptimizationRemarkGetter =
754       function_ref<OptimizationRemarkEmitter &(Function *)>;
755 
756   OpenMPOpt(SmallVectorImpl<Function *> &SCC, CallGraphUpdater &CGUpdater,
757             OptimizationRemarkGetter OREGetter,
758             OMPInformationCache &OMPInfoCache, Attributor &A)
759       : M(*(*SCC.begin())->getParent()), SCC(SCC), CGUpdater(CGUpdater),
760         OREGetter(OREGetter), OMPInfoCache(OMPInfoCache), A(A) {}
761 
762   /// Check if any remarks are enabled for openmp-opt
763   bool remarksEnabled() {
764     auto &Ctx = M.getContext();
765     return Ctx.getDiagHandlerPtr()->isAnyRemarkEnabled(DEBUG_TYPE);
766   }
767 
768   /// Run all OpenMP optimizations on the underlying SCC/ModuleSlice.
769   bool run(bool IsModulePass) {
770     if (SCC.empty())
771       return false;
772 
773     bool Changed = false;
774 
775     LLVM_DEBUG(dbgs() << TAG << "Run on SCC with " << SCC.size()
776                       << " functions in a slice with "
777                       << OMPInfoCache.ModuleSlice.size() << " functions\n");
778 
779     if (IsModulePass) {
780       Changed |= runAttributor(IsModulePass);
781 
782       // Recollect uses, in case Attributor deleted any.
783       OMPInfoCache.recollectUses();
784 
785       // TODO: This should be folded into buildCustomStateMachine.
786       Changed |= rewriteDeviceCodeStateMachine();
787 
788       if (remarksEnabled())
789         analysisGlobalization();
790     } else {
791       if (PrintICVValues)
792         printICVs();
793       if (PrintOpenMPKernels)
794         printKernels();
795 
796       Changed |= runAttributor(IsModulePass);
797 
798       // Recollect uses, in case Attributor deleted any.
799       OMPInfoCache.recollectUses();
800 
801       Changed |= deleteParallelRegions();
802 
803       if (HideMemoryTransferLatency)
804         Changed |= hideMemTransfersLatency();
805       Changed |= deduplicateRuntimeCalls();
806       if (EnableParallelRegionMerging) {
807         if (mergeParallelRegions()) {
808           deduplicateRuntimeCalls();
809           Changed = true;
810         }
811       }
812     }
813 
814     return Changed;
815   }
816 
817   /// Print initial ICV values for testing.
818   /// FIXME: This should be done from the Attributor once it is added.
819   void printICVs() const {
820     InternalControlVar ICVs[] = {ICV_nthreads, ICV_active_levels, ICV_cancel,
821                                  ICV_proc_bind};
822 
823     for (Function *F : OMPInfoCache.ModuleSlice) {
824       for (auto ICV : ICVs) {
825         auto ICVInfo = OMPInfoCache.ICVs[ICV];
826         auto Remark = [&](OptimizationRemarkAnalysis ORA) {
827           return ORA << "OpenMP ICV " << ore::NV("OpenMPICV", ICVInfo.Name)
828                      << " Value: "
829                      << (ICVInfo.InitValue
830                              ? toString(ICVInfo.InitValue->getValue(), 10, true)
831                              : "IMPLEMENTATION_DEFINED");
832         };
833 
834         emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPICVTracker", Remark);
835       }
836     }
837   }
838 
839   /// Print OpenMP GPU kernels for testing.
840   void printKernels() const {
841     for (Function *F : SCC) {
842       if (!OMPInfoCache.Kernels.count(F))
843         continue;
844 
845       auto Remark = [&](OptimizationRemarkAnalysis ORA) {
846         return ORA << "OpenMP GPU kernel "
847                    << ore::NV("OpenMPGPUKernel", F->getName()) << "\n";
848       };
849 
850       emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPGPU", Remark);
851     }
852   }
853 
854   /// Return the call if \p U is a callee use in a regular call. If \p RFI is
855   /// given it has to be the callee or a nullptr is returned.
856   static CallInst *getCallIfRegularCall(
857       Use &U, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
858     CallInst *CI = dyn_cast<CallInst>(U.getUser());
859     if (CI && CI->isCallee(&U) && !CI->hasOperandBundles() &&
860         (!RFI ||
861          (RFI->Declaration && CI->getCalledFunction() == RFI->Declaration)))
862       return CI;
863     return nullptr;
864   }
865 
866   /// Return the call if \p V is a regular call. If \p RFI is given it has to be
867   /// the callee or a nullptr is returned.
868   static CallInst *getCallIfRegularCall(
869       Value &V, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
870     CallInst *CI = dyn_cast<CallInst>(&V);
871     if (CI && !CI->hasOperandBundles() &&
872         (!RFI ||
873          (RFI->Declaration && CI->getCalledFunction() == RFI->Declaration)))
874       return CI;
875     return nullptr;
876   }
877 
878 private:
879   /// Merge parallel regions when it is safe.
880   bool mergeParallelRegions() {
881     const unsigned CallbackCalleeOperand = 2;
882     const unsigned CallbackFirstArgOperand = 3;
883     using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
884 
885     // Check if there are any __kmpc_fork_call calls to merge.
886     OMPInformationCache::RuntimeFunctionInfo &RFI =
887         OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
888 
889     if (!RFI.Declaration)
890       return false;
891 
892     // Unmergable calls that prevent merging a parallel region.
893     OMPInformationCache::RuntimeFunctionInfo UnmergableCallsInfo[] = {
894         OMPInfoCache.RFIs[OMPRTL___kmpc_push_proc_bind],
895         OMPInfoCache.RFIs[OMPRTL___kmpc_push_num_threads],
896     };
897 
898     bool Changed = false;
899     LoopInfo *LI = nullptr;
900     DominatorTree *DT = nullptr;
901 
902     SmallDenseMap<BasicBlock *, SmallPtrSet<Instruction *, 4>> BB2PRMap;
903 
904     BasicBlock *StartBB = nullptr, *EndBB = nullptr;
905     auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
906                          BasicBlock &ContinuationIP) {
907       BasicBlock *CGStartBB = CodeGenIP.getBlock();
908       BasicBlock *CGEndBB =
909           SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
910       assert(StartBB != nullptr && "StartBB should not be null");
911       CGStartBB->getTerminator()->setSuccessor(0, StartBB);
912       assert(EndBB != nullptr && "EndBB should not be null");
913       EndBB->getTerminator()->setSuccessor(0, CGEndBB);
914     };
915 
916     auto PrivCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, Value &,
917                       Value &Inner, Value *&ReplacementValue) -> InsertPointTy {
918       ReplacementValue = &Inner;
919       return CodeGenIP;
920     };
921 
922     auto FiniCB = [&](InsertPointTy CodeGenIP) {};
923 
924     /// Create a sequential execution region within a merged parallel region,
925     /// encapsulated in a master construct with a barrier for synchronization.
926     auto CreateSequentialRegion = [&](Function *OuterFn,
927                                       BasicBlock *OuterPredBB,
928                                       Instruction *SeqStartI,
929                                       Instruction *SeqEndI) {
930       // Isolate the instructions of the sequential region to a separate
931       // block.
932       BasicBlock *ParentBB = SeqStartI->getParent();
933       BasicBlock *SeqEndBB =
934           SplitBlock(ParentBB, SeqEndI->getNextNode(), DT, LI);
935       BasicBlock *SeqAfterBB =
936           SplitBlock(SeqEndBB, &*SeqEndBB->getFirstInsertionPt(), DT, LI);
937       BasicBlock *SeqStartBB =
938           SplitBlock(ParentBB, SeqStartI, DT, LI, nullptr, "seq.par.merged");
939 
940       assert(ParentBB->getUniqueSuccessor() == SeqStartBB &&
941              "Expected a different CFG");
942       const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc();
943       ParentBB->getTerminator()->eraseFromParent();
944 
945       auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
946                            BasicBlock &ContinuationIP) {
947         BasicBlock *CGStartBB = CodeGenIP.getBlock();
948         BasicBlock *CGEndBB =
949             SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
950         assert(SeqStartBB != nullptr && "SeqStartBB should not be null");
951         CGStartBB->getTerminator()->setSuccessor(0, SeqStartBB);
952         assert(SeqEndBB != nullptr && "SeqEndBB should not be null");
953         SeqEndBB->getTerminator()->setSuccessor(0, CGEndBB);
954       };
955       auto FiniCB = [&](InsertPointTy CodeGenIP) {};
956 
957       // Find outputs from the sequential region to outside users and
958       // broadcast their values to them.
959       for (Instruction &I : *SeqStartBB) {
960         SmallPtrSet<Instruction *, 4> OutsideUsers;
961         for (User *Usr : I.users()) {
962           Instruction &UsrI = *cast<Instruction>(Usr);
963           // Ignore outputs to LT intrinsics, code extraction for the merged
964           // parallel region will fix them.
965           if (UsrI.isLifetimeStartOrEnd())
966             continue;
967 
968           if (UsrI.getParent() != SeqStartBB)
969             OutsideUsers.insert(&UsrI);
970         }
971 
972         if (OutsideUsers.empty())
973           continue;
974 
975         // Emit an alloca in the outer region to store the broadcasted
976         // value.
977         const DataLayout &DL = M.getDataLayout();
978         AllocaInst *AllocaI = new AllocaInst(
979             I.getType(), DL.getAllocaAddrSpace(), nullptr,
980             I.getName() + ".seq.output.alloc", &OuterFn->front().front());
981 
982         // Emit a store instruction in the sequential BB to update the
983         // value.
984         new StoreInst(&I, AllocaI, SeqStartBB->getTerminator());
985 
986         // Emit a load instruction and replace the use of the output value
987         // with it.
988         for (Instruction *UsrI : OutsideUsers) {
989           LoadInst *LoadI = new LoadInst(
990               I.getType(), AllocaI, I.getName() + ".seq.output.load", UsrI);
991           UsrI->replaceUsesOfWith(&I, LoadI);
992         }
993       }
994 
995       OpenMPIRBuilder::LocationDescription Loc(
996           InsertPointTy(ParentBB, ParentBB->end()), DL);
997       InsertPointTy SeqAfterIP =
998           OMPInfoCache.OMPBuilder.createMaster(Loc, BodyGenCB, FiniCB);
999 
1000       OMPInfoCache.OMPBuilder.createBarrier(SeqAfterIP, OMPD_parallel);
1001 
1002       BranchInst::Create(SeqAfterBB, SeqAfterIP.getBlock());
1003 
1004       LLVM_DEBUG(dbgs() << TAG << "After sequential inlining " << *OuterFn
1005                         << "\n");
1006     };
1007 
1008     // Helper to merge the __kmpc_fork_call calls in MergableCIs. They are all
1009     // contained in BB and only separated by instructions that can be
1010     // redundantly executed in parallel. The block BB is split before the first
1011     // call (in MergableCIs) and after the last so the entire region we merge
1012     // into a single parallel region is contained in a single basic block
1013     // without any other instructions. We use the OpenMPIRBuilder to outline
1014     // that block and call the resulting function via __kmpc_fork_call.
1015     auto Merge = [&](const SmallVectorImpl<CallInst *> &MergableCIs,
1016                      BasicBlock *BB) {
1017       // TODO: Change the interface to allow single CIs expanded, e.g, to
1018       // include an outer loop.
1019       assert(MergableCIs.size() > 1 && "Assumed multiple mergable CIs");
1020 
1021       auto Remark = [&](OptimizationRemark OR) {
1022         OR << "Parallel region merged with parallel region"
1023            << (MergableCIs.size() > 2 ? "s" : "") << " at ";
1024         for (auto *CI : llvm::drop_begin(MergableCIs)) {
1025           OR << ore::NV("OpenMPParallelMerge", CI->getDebugLoc());
1026           if (CI != MergableCIs.back())
1027             OR << ", ";
1028         }
1029         return OR << ".";
1030       };
1031 
1032       emitRemark<OptimizationRemark>(MergableCIs.front(), "OMP150", Remark);
1033 
1034       Function *OriginalFn = BB->getParent();
1035       LLVM_DEBUG(dbgs() << TAG << "Merge " << MergableCIs.size()
1036                         << " parallel regions in " << OriginalFn->getName()
1037                         << "\n");
1038 
1039       // Isolate the calls to merge in a separate block.
1040       EndBB = SplitBlock(BB, MergableCIs.back()->getNextNode(), DT, LI);
1041       BasicBlock *AfterBB =
1042           SplitBlock(EndBB, &*EndBB->getFirstInsertionPt(), DT, LI);
1043       StartBB = SplitBlock(BB, MergableCIs.front(), DT, LI, nullptr,
1044                            "omp.par.merged");
1045 
1046       assert(BB->getUniqueSuccessor() == StartBB && "Expected a different CFG");
1047       const DebugLoc DL = BB->getTerminator()->getDebugLoc();
1048       BB->getTerminator()->eraseFromParent();
1049 
1050       // Create sequential regions for sequential instructions that are
1051       // in-between mergable parallel regions.
1052       for (auto *It = MergableCIs.begin(), *End = MergableCIs.end() - 1;
1053            It != End; ++It) {
1054         Instruction *ForkCI = *It;
1055         Instruction *NextForkCI = *(It + 1);
1056 
1057         // Continue if there are not in-between instructions.
1058         if (ForkCI->getNextNode() == NextForkCI)
1059           continue;
1060 
1061         CreateSequentialRegion(OriginalFn, BB, ForkCI->getNextNode(),
1062                                NextForkCI->getPrevNode());
1063       }
1064 
1065       OpenMPIRBuilder::LocationDescription Loc(InsertPointTy(BB, BB->end()),
1066                                                DL);
1067       IRBuilder<>::InsertPoint AllocaIP(
1068           &OriginalFn->getEntryBlock(),
1069           OriginalFn->getEntryBlock().getFirstInsertionPt());
1070       // Create the merged parallel region with default proc binding, to
1071       // avoid overriding binding settings, and without explicit cancellation.
1072       InsertPointTy AfterIP = OMPInfoCache.OMPBuilder.createParallel(
1073           Loc, AllocaIP, BodyGenCB, PrivCB, FiniCB, nullptr, nullptr,
1074           OMP_PROC_BIND_default, /* IsCancellable */ false);
1075       BranchInst::Create(AfterBB, AfterIP.getBlock());
1076 
1077       // Perform the actual outlining.
1078       OMPInfoCache.OMPBuilder.finalize(OriginalFn);
1079 
1080       Function *OutlinedFn = MergableCIs.front()->getCaller();
1081 
1082       // Replace the __kmpc_fork_call calls with direct calls to the outlined
1083       // callbacks.
1084       SmallVector<Value *, 8> Args;
1085       for (auto *CI : MergableCIs) {
1086         Value *Callee =
1087             CI->getArgOperand(CallbackCalleeOperand)->stripPointerCasts();
1088         FunctionType *FT =
1089             cast<FunctionType>(Callee->getType()->getPointerElementType());
1090         Args.clear();
1091         Args.push_back(OutlinedFn->getArg(0));
1092         Args.push_back(OutlinedFn->getArg(1));
1093         for (unsigned U = CallbackFirstArgOperand, E = CI->arg_size(); U < E;
1094              ++U)
1095           Args.push_back(CI->getArgOperand(U));
1096 
1097         CallInst *NewCI = CallInst::Create(FT, Callee, Args, "", CI);
1098         if (CI->getDebugLoc())
1099           NewCI->setDebugLoc(CI->getDebugLoc());
1100 
1101         // Forward parameter attributes from the callback to the callee.
1102         for (unsigned U = CallbackFirstArgOperand, E = CI->arg_size(); U < E;
1103              ++U)
1104           for (const Attribute &A : CI->getAttributes().getParamAttrs(U))
1105             NewCI->addParamAttr(
1106                 U - (CallbackFirstArgOperand - CallbackCalleeOperand), A);
1107 
1108         // Emit an explicit barrier to replace the implicit fork-join barrier.
1109         if (CI != MergableCIs.back()) {
1110           // TODO: Remove barrier if the merged parallel region includes the
1111           // 'nowait' clause.
1112           OMPInfoCache.OMPBuilder.createBarrier(
1113               InsertPointTy(NewCI->getParent(),
1114                             NewCI->getNextNode()->getIterator()),
1115               OMPD_parallel);
1116         }
1117 
1118         CI->eraseFromParent();
1119       }
1120 
1121       assert(OutlinedFn != OriginalFn && "Outlining failed");
1122       CGUpdater.registerOutlinedFunction(*OriginalFn, *OutlinedFn);
1123       CGUpdater.reanalyzeFunction(*OriginalFn);
1124 
1125       NumOpenMPParallelRegionsMerged += MergableCIs.size();
1126 
1127       return true;
1128     };
1129 
1130     // Helper function that identifes sequences of
1131     // __kmpc_fork_call uses in a basic block.
1132     auto DetectPRsCB = [&](Use &U, Function &F) {
1133       CallInst *CI = getCallIfRegularCall(U, &RFI);
1134       BB2PRMap[CI->getParent()].insert(CI);
1135 
1136       return false;
1137     };
1138 
1139     BB2PRMap.clear();
1140     RFI.foreachUse(SCC, DetectPRsCB);
1141     SmallVector<SmallVector<CallInst *, 4>, 4> MergableCIsVector;
1142     // Find mergable parallel regions within a basic block that are
1143     // safe to merge, that is any in-between instructions can safely
1144     // execute in parallel after merging.
1145     // TODO: support merging across basic-blocks.
1146     for (auto &It : BB2PRMap) {
1147       auto &CIs = It.getSecond();
1148       if (CIs.size() < 2)
1149         continue;
1150 
1151       BasicBlock *BB = It.getFirst();
1152       SmallVector<CallInst *, 4> MergableCIs;
1153 
1154       /// Returns true if the instruction is mergable, false otherwise.
1155       /// A terminator instruction is unmergable by definition since merging
1156       /// works within a BB. Instructions before the mergable region are
1157       /// mergable if they are not calls to OpenMP runtime functions that may
1158       /// set different execution parameters for subsequent parallel regions.
1159       /// Instructions in-between parallel regions are mergable if they are not
1160       /// calls to any non-intrinsic function since that may call a non-mergable
1161       /// OpenMP runtime function.
1162       auto IsMergable = [&](Instruction &I, bool IsBeforeMergableRegion) {
1163         // We do not merge across BBs, hence return false (unmergable) if the
1164         // instruction is a terminator.
1165         if (I.isTerminator())
1166           return false;
1167 
1168         if (!isa<CallInst>(&I))
1169           return true;
1170 
1171         CallInst *CI = cast<CallInst>(&I);
1172         if (IsBeforeMergableRegion) {
1173           Function *CalledFunction = CI->getCalledFunction();
1174           if (!CalledFunction)
1175             return false;
1176           // Return false (unmergable) if the call before the parallel
1177           // region calls an explicit affinity (proc_bind) or number of
1178           // threads (num_threads) compiler-generated function. Those settings
1179           // may be incompatible with following parallel regions.
1180           // TODO: ICV tracking to detect compatibility.
1181           for (const auto &RFI : UnmergableCallsInfo) {
1182             if (CalledFunction == RFI.Declaration)
1183               return false;
1184           }
1185         } else {
1186           // Return false (unmergable) if there is a call instruction
1187           // in-between parallel regions when it is not an intrinsic. It
1188           // may call an unmergable OpenMP runtime function in its callpath.
1189           // TODO: Keep track of possible OpenMP calls in the callpath.
1190           if (!isa<IntrinsicInst>(CI))
1191             return false;
1192         }
1193 
1194         return true;
1195       };
1196       // Find maximal number of parallel region CIs that are safe to merge.
1197       for (auto It = BB->begin(), End = BB->end(); It != End;) {
1198         Instruction &I = *It;
1199         ++It;
1200 
1201         if (CIs.count(&I)) {
1202           MergableCIs.push_back(cast<CallInst>(&I));
1203           continue;
1204         }
1205 
1206         // Continue expanding if the instruction is mergable.
1207         if (IsMergable(I, MergableCIs.empty()))
1208           continue;
1209 
1210         // Forward the instruction iterator to skip the next parallel region
1211         // since there is an unmergable instruction which can affect it.
1212         for (; It != End; ++It) {
1213           Instruction &SkipI = *It;
1214           if (CIs.count(&SkipI)) {
1215             LLVM_DEBUG(dbgs() << TAG << "Skip parallel region " << SkipI
1216                               << " due to " << I << "\n");
1217             ++It;
1218             break;
1219           }
1220         }
1221 
1222         // Store mergable regions found.
1223         if (MergableCIs.size() > 1) {
1224           MergableCIsVector.push_back(MergableCIs);
1225           LLVM_DEBUG(dbgs() << TAG << "Found " << MergableCIs.size()
1226                             << " parallel regions in block " << BB->getName()
1227                             << " of function " << BB->getParent()->getName()
1228                             << "\n";);
1229         }
1230 
1231         MergableCIs.clear();
1232       }
1233 
1234       if (!MergableCIsVector.empty()) {
1235         Changed = true;
1236 
1237         for (auto &MergableCIs : MergableCIsVector)
1238           Merge(MergableCIs, BB);
1239         MergableCIsVector.clear();
1240       }
1241     }
1242 
1243     if (Changed) {
1244       /// Re-collect use for fork calls, emitted barrier calls, and
1245       /// any emitted master/end_master calls.
1246       OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_fork_call);
1247       OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_barrier);
1248       OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_master);
1249       OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_end_master);
1250     }
1251 
1252     return Changed;
1253   }
1254 
1255   /// Try to delete parallel regions if possible.
1256   bool deleteParallelRegions() {
1257     const unsigned CallbackCalleeOperand = 2;
1258 
1259     OMPInformationCache::RuntimeFunctionInfo &RFI =
1260         OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
1261 
1262     if (!RFI.Declaration)
1263       return false;
1264 
1265     bool Changed = false;
1266     auto DeleteCallCB = [&](Use &U, Function &) {
1267       CallInst *CI = getCallIfRegularCall(U);
1268       if (!CI)
1269         return false;
1270       auto *Fn = dyn_cast<Function>(
1271           CI->getArgOperand(CallbackCalleeOperand)->stripPointerCasts());
1272       if (!Fn)
1273         return false;
1274       if (!Fn->onlyReadsMemory())
1275         return false;
1276       if (!Fn->hasFnAttribute(Attribute::WillReturn))
1277         return false;
1278 
1279       LLVM_DEBUG(dbgs() << TAG << "Delete read-only parallel region in "
1280                         << CI->getCaller()->getName() << "\n");
1281 
1282       auto Remark = [&](OptimizationRemark OR) {
1283         return OR << "Removing parallel region with no side-effects.";
1284       };
1285       emitRemark<OptimizationRemark>(CI, "OMP160", Remark);
1286 
1287       CGUpdater.removeCallSite(*CI);
1288       CI->eraseFromParent();
1289       Changed = true;
1290       ++NumOpenMPParallelRegionsDeleted;
1291       return true;
1292     };
1293 
1294     RFI.foreachUse(SCC, DeleteCallCB);
1295 
1296     return Changed;
1297   }
1298 
1299   /// Try to eliminate runtime calls by reusing existing ones.
1300   bool deduplicateRuntimeCalls() {
1301     bool Changed = false;
1302 
1303     RuntimeFunction DeduplicableRuntimeCallIDs[] = {
1304         OMPRTL_omp_get_num_threads,
1305         OMPRTL_omp_in_parallel,
1306         OMPRTL_omp_get_cancellation,
1307         OMPRTL_omp_get_thread_limit,
1308         OMPRTL_omp_get_supported_active_levels,
1309         OMPRTL_omp_get_level,
1310         OMPRTL_omp_get_ancestor_thread_num,
1311         OMPRTL_omp_get_team_size,
1312         OMPRTL_omp_get_active_level,
1313         OMPRTL_omp_in_final,
1314         OMPRTL_omp_get_proc_bind,
1315         OMPRTL_omp_get_num_places,
1316         OMPRTL_omp_get_num_procs,
1317         OMPRTL_omp_get_place_num,
1318         OMPRTL_omp_get_partition_num_places,
1319         OMPRTL_omp_get_partition_place_nums};
1320 
1321     // Global-tid is handled separately.
1322     SmallSetVector<Value *, 16> GTIdArgs;
1323     collectGlobalThreadIdArguments(GTIdArgs);
1324     LLVM_DEBUG(dbgs() << TAG << "Found " << GTIdArgs.size()
1325                       << " global thread ID arguments\n");
1326 
1327     for (Function *F : SCC) {
1328       for (auto DeduplicableRuntimeCallID : DeduplicableRuntimeCallIDs)
1329         Changed |= deduplicateRuntimeCalls(
1330             *F, OMPInfoCache.RFIs[DeduplicableRuntimeCallID]);
1331 
1332       // __kmpc_global_thread_num is special as we can replace it with an
1333       // argument in enough cases to make it worth trying.
1334       Value *GTIdArg = nullptr;
1335       for (Argument &Arg : F->args())
1336         if (GTIdArgs.count(&Arg)) {
1337           GTIdArg = &Arg;
1338           break;
1339         }
1340       Changed |= deduplicateRuntimeCalls(
1341           *F, OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num], GTIdArg);
1342     }
1343 
1344     return Changed;
1345   }
1346 
1347   /// Tries to hide the latency of runtime calls that involve host to
1348   /// device memory transfers by splitting them into their "issue" and "wait"
1349   /// versions. The "issue" is moved upwards as much as possible. The "wait" is
1350   /// moved downards as much as possible. The "issue" issues the memory transfer
1351   /// asynchronously, returning a handle. The "wait" waits in the returned
1352   /// handle for the memory transfer to finish.
1353   bool hideMemTransfersLatency() {
1354     auto &RFI = OMPInfoCache.RFIs[OMPRTL___tgt_target_data_begin_mapper];
1355     bool Changed = false;
1356     auto SplitMemTransfers = [&](Use &U, Function &Decl) {
1357       auto *RTCall = getCallIfRegularCall(U, &RFI);
1358       if (!RTCall)
1359         return false;
1360 
1361       OffloadArray OffloadArrays[3];
1362       if (!getValuesInOffloadArrays(*RTCall, OffloadArrays))
1363         return false;
1364 
1365       LLVM_DEBUG(dumpValuesInOffloadArrays(OffloadArrays));
1366 
1367       // TODO: Check if can be moved upwards.
1368       bool WasSplit = false;
1369       Instruction *WaitMovementPoint = canBeMovedDownwards(*RTCall);
1370       if (WaitMovementPoint)
1371         WasSplit = splitTargetDataBeginRTC(*RTCall, *WaitMovementPoint);
1372 
1373       Changed |= WasSplit;
1374       return WasSplit;
1375     };
1376     RFI.foreachUse(SCC, SplitMemTransfers);
1377 
1378     return Changed;
1379   }
1380 
1381   void analysisGlobalization() {
1382     auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
1383 
1384     auto CheckGlobalization = [&](Use &U, Function &Decl) {
1385       if (CallInst *CI = getCallIfRegularCall(U, &RFI)) {
1386         auto Remark = [&](OptimizationRemarkMissed ORM) {
1387           return ORM
1388                  << "Found thread data sharing on the GPU. "
1389                  << "Expect degraded performance due to data globalization.";
1390         };
1391         emitRemark<OptimizationRemarkMissed>(CI, "OMP112", Remark);
1392       }
1393 
1394       return false;
1395     };
1396 
1397     RFI.foreachUse(SCC, CheckGlobalization);
1398   }
1399 
1400   /// Maps the values stored in the offload arrays passed as arguments to
1401   /// \p RuntimeCall into the offload arrays in \p OAs.
1402   bool getValuesInOffloadArrays(CallInst &RuntimeCall,
1403                                 MutableArrayRef<OffloadArray> OAs) {
1404     assert(OAs.size() == 3 && "Need space for three offload arrays!");
1405 
1406     // A runtime call that involves memory offloading looks something like:
1407     // call void @__tgt_target_data_begin_mapper(arg0, arg1,
1408     //   i8** %offload_baseptrs, i8** %offload_ptrs, i64* %offload_sizes,
1409     // ...)
1410     // So, the idea is to access the allocas that allocate space for these
1411     // offload arrays, offload_baseptrs, offload_ptrs, offload_sizes.
1412     // Therefore:
1413     // i8** %offload_baseptrs.
1414     Value *BasePtrsArg =
1415         RuntimeCall.getArgOperand(OffloadArray::BasePtrsArgNum);
1416     // i8** %offload_ptrs.
1417     Value *PtrsArg = RuntimeCall.getArgOperand(OffloadArray::PtrsArgNum);
1418     // i8** %offload_sizes.
1419     Value *SizesArg = RuntimeCall.getArgOperand(OffloadArray::SizesArgNum);
1420 
1421     // Get values stored in **offload_baseptrs.
1422     auto *V = getUnderlyingObject(BasePtrsArg);
1423     if (!isa<AllocaInst>(V))
1424       return false;
1425     auto *BasePtrsArray = cast<AllocaInst>(V);
1426     if (!OAs[0].initialize(*BasePtrsArray, RuntimeCall))
1427       return false;
1428 
1429     // Get values stored in **offload_baseptrs.
1430     V = getUnderlyingObject(PtrsArg);
1431     if (!isa<AllocaInst>(V))
1432       return false;
1433     auto *PtrsArray = cast<AllocaInst>(V);
1434     if (!OAs[1].initialize(*PtrsArray, RuntimeCall))
1435       return false;
1436 
1437     // Get values stored in **offload_sizes.
1438     V = getUnderlyingObject(SizesArg);
1439     // If it's a [constant] global array don't analyze it.
1440     if (isa<GlobalValue>(V))
1441       return isa<Constant>(V);
1442     if (!isa<AllocaInst>(V))
1443       return false;
1444 
1445     auto *SizesArray = cast<AllocaInst>(V);
1446     if (!OAs[2].initialize(*SizesArray, RuntimeCall))
1447       return false;
1448 
1449     return true;
1450   }
1451 
1452   /// Prints the values in the OffloadArrays \p OAs using LLVM_DEBUG.
1453   /// For now this is a way to test that the function getValuesInOffloadArrays
1454   /// is working properly.
1455   /// TODO: Move this to a unittest when unittests are available for OpenMPOpt.
1456   void dumpValuesInOffloadArrays(ArrayRef<OffloadArray> OAs) {
1457     assert(OAs.size() == 3 && "There are three offload arrays to debug!");
1458 
1459     LLVM_DEBUG(dbgs() << TAG << " Successfully got offload values:\n");
1460     std::string ValuesStr;
1461     raw_string_ostream Printer(ValuesStr);
1462     std::string Separator = " --- ";
1463 
1464     for (auto *BP : OAs[0].StoredValues) {
1465       BP->print(Printer);
1466       Printer << Separator;
1467     }
1468     LLVM_DEBUG(dbgs() << "\t\toffload_baseptrs: " << Printer.str() << "\n");
1469     ValuesStr.clear();
1470 
1471     for (auto *P : OAs[1].StoredValues) {
1472       P->print(Printer);
1473       Printer << Separator;
1474     }
1475     LLVM_DEBUG(dbgs() << "\t\toffload_ptrs: " << Printer.str() << "\n");
1476     ValuesStr.clear();
1477 
1478     for (auto *S : OAs[2].StoredValues) {
1479       S->print(Printer);
1480       Printer << Separator;
1481     }
1482     LLVM_DEBUG(dbgs() << "\t\toffload_sizes: " << Printer.str() << "\n");
1483   }
1484 
1485   /// Returns the instruction where the "wait" counterpart \p RuntimeCall can be
1486   /// moved. Returns nullptr if the movement is not possible, or not worth it.
1487   Instruction *canBeMovedDownwards(CallInst &RuntimeCall) {
1488     // FIXME: This traverses only the BasicBlock where RuntimeCall is.
1489     //  Make it traverse the CFG.
1490 
1491     Instruction *CurrentI = &RuntimeCall;
1492     bool IsWorthIt = false;
1493     while ((CurrentI = CurrentI->getNextNode())) {
1494 
1495       // TODO: Once we detect the regions to be offloaded we should use the
1496       //  alias analysis manager to check if CurrentI may modify one of
1497       //  the offloaded regions.
1498       if (CurrentI->mayHaveSideEffects() || CurrentI->mayReadFromMemory()) {
1499         if (IsWorthIt)
1500           return CurrentI;
1501 
1502         return nullptr;
1503       }
1504 
1505       // FIXME: For now if we move it over anything without side effect
1506       //  is worth it.
1507       IsWorthIt = true;
1508     }
1509 
1510     // Return end of BasicBlock.
1511     return RuntimeCall.getParent()->getTerminator();
1512   }
1513 
1514   /// Splits \p RuntimeCall into its "issue" and "wait" counterparts.
1515   bool splitTargetDataBeginRTC(CallInst &RuntimeCall,
1516                                Instruction &WaitMovementPoint) {
1517     // Create stack allocated handle (__tgt_async_info) at the beginning of the
1518     // function. Used for storing information of the async transfer, allowing to
1519     // wait on it later.
1520     auto &IRBuilder = OMPInfoCache.OMPBuilder;
1521     auto *F = RuntimeCall.getCaller();
1522     Instruction *FirstInst = &(F->getEntryBlock().front());
1523     AllocaInst *Handle = new AllocaInst(
1524         IRBuilder.AsyncInfo, F->getAddressSpace(), "handle", FirstInst);
1525 
1526     // Add "issue" runtime call declaration:
1527     // declare %struct.tgt_async_info @__tgt_target_data_begin_issue(i64, i32,
1528     //   i8**, i8**, i64*, i64*)
1529     FunctionCallee IssueDecl = IRBuilder.getOrCreateRuntimeFunction(
1530         M, OMPRTL___tgt_target_data_begin_mapper_issue);
1531 
1532     // Change RuntimeCall call site for its asynchronous version.
1533     SmallVector<Value *, 16> Args;
1534     for (auto &Arg : RuntimeCall.args())
1535       Args.push_back(Arg.get());
1536     Args.push_back(Handle);
1537 
1538     CallInst *IssueCallsite =
1539         CallInst::Create(IssueDecl, Args, /*NameStr=*/"", &RuntimeCall);
1540     OMPInfoCache.setCallingConvention(IssueDecl, IssueCallsite);
1541     RuntimeCall.eraseFromParent();
1542 
1543     // Add "wait" runtime call declaration:
1544     // declare void @__tgt_target_data_begin_wait(i64, %struct.__tgt_async_info)
1545     FunctionCallee WaitDecl = IRBuilder.getOrCreateRuntimeFunction(
1546         M, OMPRTL___tgt_target_data_begin_mapper_wait);
1547 
1548     Value *WaitParams[2] = {
1549         IssueCallsite->getArgOperand(
1550             OffloadArray::DeviceIDArgNum), // device_id.
1551         Handle                             // handle to wait on.
1552     };
1553     CallInst *WaitCallsite = CallInst::Create(
1554         WaitDecl, WaitParams, /*NameStr=*/"", &WaitMovementPoint);
1555     OMPInfoCache.setCallingConvention(WaitDecl, WaitCallsite);
1556 
1557     return true;
1558   }
1559 
1560   static Value *combinedIdentStruct(Value *CurrentIdent, Value *NextIdent,
1561                                     bool GlobalOnly, bool &SingleChoice) {
1562     if (CurrentIdent == NextIdent)
1563       return CurrentIdent;
1564 
1565     // TODO: Figure out how to actually combine multiple debug locations. For
1566     //       now we just keep an existing one if there is a single choice.
1567     if (!GlobalOnly || isa<GlobalValue>(NextIdent)) {
1568       SingleChoice = !CurrentIdent;
1569       return NextIdent;
1570     }
1571     return nullptr;
1572   }
1573 
1574   /// Return an `struct ident_t*` value that represents the ones used in the
1575   /// calls of \p RFI inside of \p F. If \p GlobalOnly is true, we will not
1576   /// return a local `struct ident_t*`. For now, if we cannot find a suitable
1577   /// return value we create one from scratch. We also do not yet combine
1578   /// information, e.g., the source locations, see combinedIdentStruct.
1579   Value *
1580   getCombinedIdentFromCallUsesIn(OMPInformationCache::RuntimeFunctionInfo &RFI,
1581                                  Function &F, bool GlobalOnly) {
1582     bool SingleChoice = true;
1583     Value *Ident = nullptr;
1584     auto CombineIdentStruct = [&](Use &U, Function &Caller) {
1585       CallInst *CI = getCallIfRegularCall(U, &RFI);
1586       if (!CI || &F != &Caller)
1587         return false;
1588       Ident = combinedIdentStruct(Ident, CI->getArgOperand(0),
1589                                   /* GlobalOnly */ true, SingleChoice);
1590       return false;
1591     };
1592     RFI.foreachUse(SCC, CombineIdentStruct);
1593 
1594     if (!Ident || !SingleChoice) {
1595       // The IRBuilder uses the insertion block to get to the module, this is
1596       // unfortunate but we work around it for now.
1597       if (!OMPInfoCache.OMPBuilder.getInsertionPoint().getBlock())
1598         OMPInfoCache.OMPBuilder.updateToLocation(OpenMPIRBuilder::InsertPointTy(
1599             &F.getEntryBlock(), F.getEntryBlock().begin()));
1600       // Create a fallback location if non was found.
1601       // TODO: Use the debug locations of the calls instead.
1602       uint32_t SrcLocStrSize;
1603       Constant *Loc =
1604           OMPInfoCache.OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
1605       Ident = OMPInfoCache.OMPBuilder.getOrCreateIdent(Loc, SrcLocStrSize);
1606     }
1607     return Ident;
1608   }
1609 
1610   /// Try to eliminate calls of \p RFI in \p F by reusing an existing one or
1611   /// \p ReplVal if given.
1612   bool deduplicateRuntimeCalls(Function &F,
1613                                OMPInformationCache::RuntimeFunctionInfo &RFI,
1614                                Value *ReplVal = nullptr) {
1615     auto *UV = RFI.getUseVector(F);
1616     if (!UV || UV->size() + (ReplVal != nullptr) < 2)
1617       return false;
1618 
1619     LLVM_DEBUG(
1620         dbgs() << TAG << "Deduplicate " << UV->size() << " uses of " << RFI.Name
1621                << (ReplVal ? " with an existing value\n" : "\n") << "\n");
1622 
1623     assert((!ReplVal || (isa<Argument>(ReplVal) &&
1624                          cast<Argument>(ReplVal)->getParent() == &F)) &&
1625            "Unexpected replacement value!");
1626 
1627     // TODO: Use dominance to find a good position instead.
1628     auto CanBeMoved = [this](CallBase &CB) {
1629       unsigned NumArgs = CB.arg_size();
1630       if (NumArgs == 0)
1631         return true;
1632       if (CB.getArgOperand(0)->getType() != OMPInfoCache.OMPBuilder.IdentPtr)
1633         return false;
1634       for (unsigned U = 1; U < NumArgs; ++U)
1635         if (isa<Instruction>(CB.getArgOperand(U)))
1636           return false;
1637       return true;
1638     };
1639 
1640     if (!ReplVal) {
1641       for (Use *U : *UV)
1642         if (CallInst *CI = getCallIfRegularCall(*U, &RFI)) {
1643           if (!CanBeMoved(*CI))
1644             continue;
1645 
1646           // If the function is a kernel, dedup will move
1647           // the runtime call right after the kernel init callsite. Otherwise,
1648           // it will move it to the beginning of the caller function.
1649           if (isKernel(F)) {
1650             auto &KernelInitRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
1651             auto *KernelInitUV = KernelInitRFI.getUseVector(F);
1652 
1653             if (KernelInitUV->empty())
1654               continue;
1655 
1656             assert(KernelInitUV->size() == 1 &&
1657                    "Expected a single __kmpc_target_init in kernel\n");
1658 
1659             CallInst *KernelInitCI =
1660                 getCallIfRegularCall(*KernelInitUV->front(), &KernelInitRFI);
1661             assert(KernelInitCI &&
1662                    "Expected a call to __kmpc_target_init in kernel\n");
1663 
1664             CI->moveAfter(KernelInitCI);
1665           } else
1666             CI->moveBefore(&*F.getEntryBlock().getFirstInsertionPt());
1667           ReplVal = CI;
1668           break;
1669         }
1670       if (!ReplVal)
1671         return false;
1672     }
1673 
1674     // If we use a call as a replacement value we need to make sure the ident is
1675     // valid at the new location. For now we just pick a global one, either
1676     // existing and used by one of the calls, or created from scratch.
1677     if (CallBase *CI = dyn_cast<CallBase>(ReplVal)) {
1678       if (!CI->arg_empty() &&
1679           CI->getArgOperand(0)->getType() == OMPInfoCache.OMPBuilder.IdentPtr) {
1680         Value *Ident = getCombinedIdentFromCallUsesIn(RFI, F,
1681                                                       /* GlobalOnly */ true);
1682         CI->setArgOperand(0, Ident);
1683       }
1684     }
1685 
1686     bool Changed = false;
1687     auto ReplaceAndDeleteCB = [&](Use &U, Function &Caller) {
1688       CallInst *CI = getCallIfRegularCall(U, &RFI);
1689       if (!CI || CI == ReplVal || &F != &Caller)
1690         return false;
1691       assert(CI->getCaller() == &F && "Unexpected call!");
1692 
1693       auto Remark = [&](OptimizationRemark OR) {
1694         return OR << "OpenMP runtime call "
1695                   << ore::NV("OpenMPOptRuntime", RFI.Name) << " deduplicated.";
1696       };
1697       if (CI->getDebugLoc())
1698         emitRemark<OptimizationRemark>(CI, "OMP170", Remark);
1699       else
1700         emitRemark<OptimizationRemark>(&F, "OMP170", Remark);
1701 
1702       CGUpdater.removeCallSite(*CI);
1703       CI->replaceAllUsesWith(ReplVal);
1704       CI->eraseFromParent();
1705       ++NumOpenMPRuntimeCallsDeduplicated;
1706       Changed = true;
1707       return true;
1708     };
1709     RFI.foreachUse(SCC, ReplaceAndDeleteCB);
1710 
1711     return Changed;
1712   }
1713 
1714   /// Collect arguments that represent the global thread id in \p GTIdArgs.
1715   void collectGlobalThreadIdArguments(SmallSetVector<Value *, 16> &GTIdArgs) {
1716     // TODO: Below we basically perform a fixpoint iteration with a pessimistic
1717     //       initialization. We could define an AbstractAttribute instead and
1718     //       run the Attributor here once it can be run as an SCC pass.
1719 
1720     // Helper to check the argument \p ArgNo at all call sites of \p F for
1721     // a GTId.
1722     auto CallArgOpIsGTId = [&](Function &F, unsigned ArgNo, CallInst &RefCI) {
1723       if (!F.hasLocalLinkage())
1724         return false;
1725       for (Use &U : F.uses()) {
1726         if (CallInst *CI = getCallIfRegularCall(U)) {
1727           Value *ArgOp = CI->getArgOperand(ArgNo);
1728           if (CI == &RefCI || GTIdArgs.count(ArgOp) ||
1729               getCallIfRegularCall(
1730                   *ArgOp, &OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num]))
1731             continue;
1732         }
1733         return false;
1734       }
1735       return true;
1736     };
1737 
1738     // Helper to identify uses of a GTId as GTId arguments.
1739     auto AddUserArgs = [&](Value &GTId) {
1740       for (Use &U : GTId.uses())
1741         if (CallInst *CI = dyn_cast<CallInst>(U.getUser()))
1742           if (CI->isArgOperand(&U))
1743             if (Function *Callee = CI->getCalledFunction())
1744               if (CallArgOpIsGTId(*Callee, U.getOperandNo(), *CI))
1745                 GTIdArgs.insert(Callee->getArg(U.getOperandNo()));
1746     };
1747 
1748     // The argument users of __kmpc_global_thread_num calls are GTIds.
1749     OMPInformationCache::RuntimeFunctionInfo &GlobThreadNumRFI =
1750         OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num];
1751 
1752     GlobThreadNumRFI.foreachUse(SCC, [&](Use &U, Function &F) {
1753       if (CallInst *CI = getCallIfRegularCall(U, &GlobThreadNumRFI))
1754         AddUserArgs(*CI);
1755       return false;
1756     });
1757 
1758     // Transitively search for more arguments by looking at the users of the
1759     // ones we know already. During the search the GTIdArgs vector is extended
1760     // so we cannot cache the size nor can we use a range based for.
1761     for (unsigned U = 0; U < GTIdArgs.size(); ++U)
1762       AddUserArgs(*GTIdArgs[U]);
1763   }
1764 
1765   /// Kernel (=GPU) optimizations and utility functions
1766   ///
1767   ///{{
1768 
1769   /// Check if \p F is a kernel, hence entry point for target offloading.
1770   bool isKernel(Function &F) { return OMPInfoCache.Kernels.count(&F); }
1771 
1772   /// Cache to remember the unique kernel for a function.
1773   DenseMap<Function *, Optional<Kernel>> UniqueKernelMap;
1774 
1775   /// Find the unique kernel that will execute \p F, if any.
1776   Kernel getUniqueKernelFor(Function &F);
1777 
1778   /// Find the unique kernel that will execute \p I, if any.
1779   Kernel getUniqueKernelFor(Instruction &I) {
1780     return getUniqueKernelFor(*I.getFunction());
1781   }
1782 
1783   /// Rewrite the device (=GPU) code state machine create in non-SPMD mode in
1784   /// the cases we can avoid taking the address of a function.
1785   bool rewriteDeviceCodeStateMachine();
1786 
1787   ///
1788   ///}}
1789 
1790   /// Emit a remark generically
1791   ///
1792   /// This template function can be used to generically emit a remark. The
1793   /// RemarkKind should be one of the following:
1794   ///   - OptimizationRemark to indicate a successful optimization attempt
1795   ///   - OptimizationRemarkMissed to report a failed optimization attempt
1796   ///   - OptimizationRemarkAnalysis to provide additional information about an
1797   ///     optimization attempt
1798   ///
1799   /// The remark is built using a callback function provided by the caller that
1800   /// takes a RemarkKind as input and returns a RemarkKind.
1801   template <typename RemarkKind, typename RemarkCallBack>
1802   void emitRemark(Instruction *I, StringRef RemarkName,
1803                   RemarkCallBack &&RemarkCB) const {
1804     Function *F = I->getParent()->getParent();
1805     auto &ORE = OREGetter(F);
1806 
1807     if (RemarkName.startswith("OMP"))
1808       ORE.emit([&]() {
1809         return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I))
1810                << " [" << RemarkName << "]";
1811       });
1812     else
1813       ORE.emit(
1814           [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I)); });
1815   }
1816 
1817   /// Emit a remark on a function.
1818   template <typename RemarkKind, typename RemarkCallBack>
1819   void emitRemark(Function *F, StringRef RemarkName,
1820                   RemarkCallBack &&RemarkCB) const {
1821     auto &ORE = OREGetter(F);
1822 
1823     if (RemarkName.startswith("OMP"))
1824       ORE.emit([&]() {
1825         return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F))
1826                << " [" << RemarkName << "]";
1827       });
1828     else
1829       ORE.emit(
1830           [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F)); });
1831   }
1832 
1833   /// RAII struct to temporarily change an RTL function's linkage to external.
1834   /// This prevents it from being mistakenly removed by other optimizations.
1835   struct ExternalizationRAII {
1836     ExternalizationRAII(OMPInformationCache &OMPInfoCache,
1837                         RuntimeFunction RFKind)
1838         : Declaration(OMPInfoCache.RFIs[RFKind].Declaration) {
1839       if (!Declaration)
1840         return;
1841 
1842       LinkageType = Declaration->getLinkage();
1843       Declaration->setLinkage(GlobalValue::ExternalLinkage);
1844     }
1845 
1846     ~ExternalizationRAII() {
1847       if (!Declaration)
1848         return;
1849 
1850       Declaration->setLinkage(LinkageType);
1851     }
1852 
1853     Function *Declaration;
1854     GlobalValue::LinkageTypes LinkageType;
1855   };
1856 
1857   /// The underlying module.
1858   Module &M;
1859 
1860   /// The SCC we are operating on.
1861   SmallVectorImpl<Function *> &SCC;
1862 
1863   /// Callback to update the call graph, the first argument is a removed call,
1864   /// the second an optional replacement call.
1865   CallGraphUpdater &CGUpdater;
1866 
1867   /// Callback to get an OptimizationRemarkEmitter from a Function *
1868   OptimizationRemarkGetter OREGetter;
1869 
1870   /// OpenMP-specific information cache. Also Used for Attributor runs.
1871   OMPInformationCache &OMPInfoCache;
1872 
1873   /// Attributor instance.
1874   Attributor &A;
1875 
1876   /// Helper function to run Attributor on SCC.
1877   bool runAttributor(bool IsModulePass) {
1878     if (SCC.empty())
1879       return false;
1880 
1881     // Temporarily make these function have external linkage so the Attributor
1882     // doesn't remove them when we try to look them up later.
1883     ExternalizationRAII Parallel(OMPInfoCache, OMPRTL___kmpc_kernel_parallel);
1884     ExternalizationRAII EndParallel(OMPInfoCache,
1885                                     OMPRTL___kmpc_kernel_end_parallel);
1886     ExternalizationRAII BarrierSPMD(OMPInfoCache,
1887                                     OMPRTL___kmpc_barrier_simple_spmd);
1888     ExternalizationRAII BarrierGeneric(OMPInfoCache,
1889                                        OMPRTL___kmpc_barrier_simple_generic);
1890     ExternalizationRAII ThreadId(OMPInfoCache,
1891                                  OMPRTL___kmpc_get_hardware_thread_id_in_block);
1892     ExternalizationRAII WarpSize(OMPInfoCache, OMPRTL___kmpc_get_warp_size);
1893 
1894     registerAAs(IsModulePass);
1895 
1896     ChangeStatus Changed = A.run();
1897 
1898     LLVM_DEBUG(dbgs() << "[Attributor] Done with " << SCC.size()
1899                       << " functions, result: " << Changed << ".\n");
1900 
1901     return Changed == ChangeStatus::CHANGED;
1902   }
1903 
1904   void registerFoldRuntimeCall(RuntimeFunction RF);
1905 
1906   /// Populate the Attributor with abstract attribute opportunities in the
1907   /// function.
1908   void registerAAs(bool IsModulePass);
1909 };
1910 
1911 Kernel OpenMPOpt::getUniqueKernelFor(Function &F) {
1912   if (!OMPInfoCache.ModuleSlice.count(&F))
1913     return nullptr;
1914 
1915   // Use a scope to keep the lifetime of the CachedKernel short.
1916   {
1917     Optional<Kernel> &CachedKernel = UniqueKernelMap[&F];
1918     if (CachedKernel)
1919       return *CachedKernel;
1920 
1921     // TODO: We should use an AA to create an (optimistic and callback
1922     //       call-aware) call graph. For now we stick to simple patterns that
1923     //       are less powerful, basically the worst fixpoint.
1924     if (isKernel(F)) {
1925       CachedKernel = Kernel(&F);
1926       return *CachedKernel;
1927     }
1928 
1929     CachedKernel = nullptr;
1930     if (!F.hasLocalLinkage()) {
1931 
1932       // See https://openmp.llvm.org/remarks/OptimizationRemarks.html
1933       auto Remark = [&](OptimizationRemarkAnalysis ORA) {
1934         return ORA << "Potentially unknown OpenMP target region caller.";
1935       };
1936       emitRemark<OptimizationRemarkAnalysis>(&F, "OMP100", Remark);
1937 
1938       return nullptr;
1939     }
1940   }
1941 
1942   auto GetUniqueKernelForUse = [&](const Use &U) -> Kernel {
1943     if (auto *Cmp = dyn_cast<ICmpInst>(U.getUser())) {
1944       // Allow use in equality comparisons.
1945       if (Cmp->isEquality())
1946         return getUniqueKernelFor(*Cmp);
1947       return nullptr;
1948     }
1949     if (auto *CB = dyn_cast<CallBase>(U.getUser())) {
1950       // Allow direct calls.
1951       if (CB->isCallee(&U))
1952         return getUniqueKernelFor(*CB);
1953 
1954       OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
1955           OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
1956       // Allow the use in __kmpc_parallel_51 calls.
1957       if (OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI))
1958         return getUniqueKernelFor(*CB);
1959       return nullptr;
1960     }
1961     // Disallow every other use.
1962     return nullptr;
1963   };
1964 
1965   // TODO: In the future we want to track more than just a unique kernel.
1966   SmallPtrSet<Kernel, 2> PotentialKernels;
1967   OMPInformationCache::foreachUse(F, [&](const Use &U) {
1968     PotentialKernels.insert(GetUniqueKernelForUse(U));
1969   });
1970 
1971   Kernel K = nullptr;
1972   if (PotentialKernels.size() == 1)
1973     K = *PotentialKernels.begin();
1974 
1975   // Cache the result.
1976   UniqueKernelMap[&F] = K;
1977 
1978   return K;
1979 }
1980 
1981 bool OpenMPOpt::rewriteDeviceCodeStateMachine() {
1982   OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
1983       OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
1984 
1985   bool Changed = false;
1986   if (!KernelParallelRFI)
1987     return Changed;
1988 
1989   // If we have disabled state machine changes, exit
1990   if (DisableOpenMPOptStateMachineRewrite)
1991     return Changed;
1992 
1993   for (Function *F : SCC) {
1994 
1995     // Check if the function is a use in a __kmpc_parallel_51 call at
1996     // all.
1997     bool UnknownUse = false;
1998     bool KernelParallelUse = false;
1999     unsigned NumDirectCalls = 0;
2000 
2001     SmallVector<Use *, 2> ToBeReplacedStateMachineUses;
2002     OMPInformationCache::foreachUse(*F, [&](Use &U) {
2003       if (auto *CB = dyn_cast<CallBase>(U.getUser()))
2004         if (CB->isCallee(&U)) {
2005           ++NumDirectCalls;
2006           return;
2007         }
2008 
2009       if (isa<ICmpInst>(U.getUser())) {
2010         ToBeReplacedStateMachineUses.push_back(&U);
2011         return;
2012       }
2013 
2014       // Find wrapper functions that represent parallel kernels.
2015       CallInst *CI =
2016           OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI);
2017       const unsigned int WrapperFunctionArgNo = 6;
2018       if (!KernelParallelUse && CI &&
2019           CI->getArgOperandNo(&U) == WrapperFunctionArgNo) {
2020         KernelParallelUse = true;
2021         ToBeReplacedStateMachineUses.push_back(&U);
2022         return;
2023       }
2024       UnknownUse = true;
2025     });
2026 
2027     // Do not emit a remark if we haven't seen a __kmpc_parallel_51
2028     // use.
2029     if (!KernelParallelUse)
2030       continue;
2031 
2032     // If this ever hits, we should investigate.
2033     // TODO: Checking the number of uses is not a necessary restriction and
2034     // should be lifted.
2035     if (UnknownUse || NumDirectCalls != 1 ||
2036         ToBeReplacedStateMachineUses.size() > 2) {
2037       auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2038         return ORA << "Parallel region is used in "
2039                    << (UnknownUse ? "unknown" : "unexpected")
2040                    << " ways. Will not attempt to rewrite the state machine.";
2041       };
2042       emitRemark<OptimizationRemarkAnalysis>(F, "OMP101", Remark);
2043       continue;
2044     }
2045 
2046     // Even if we have __kmpc_parallel_51 calls, we (for now) give
2047     // up if the function is not called from a unique kernel.
2048     Kernel K = getUniqueKernelFor(*F);
2049     if (!K) {
2050       auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2051         return ORA << "Parallel region is not called from a unique kernel. "
2052                       "Will not attempt to rewrite the state machine.";
2053       };
2054       emitRemark<OptimizationRemarkAnalysis>(F, "OMP102", Remark);
2055       continue;
2056     }
2057 
2058     // We now know F is a parallel body function called only from the kernel K.
2059     // We also identified the state machine uses in which we replace the
2060     // function pointer by a new global symbol for identification purposes. This
2061     // ensures only direct calls to the function are left.
2062 
2063     Module &M = *F->getParent();
2064     Type *Int8Ty = Type::getInt8Ty(M.getContext());
2065 
2066     auto *ID = new GlobalVariable(
2067         M, Int8Ty, /* isConstant */ true, GlobalValue::PrivateLinkage,
2068         UndefValue::get(Int8Ty), F->getName() + ".ID");
2069 
2070     for (Use *U : ToBeReplacedStateMachineUses)
2071       U->set(ConstantExpr::getPointerBitCastOrAddrSpaceCast(
2072           ID, U->get()->getType()));
2073 
2074     ++NumOpenMPParallelRegionsReplacedInGPUStateMachine;
2075 
2076     Changed = true;
2077   }
2078 
2079   return Changed;
2080 }
2081 
2082 /// Abstract Attribute for tracking ICV values.
2083 struct AAICVTracker : public StateWrapper<BooleanState, AbstractAttribute> {
2084   using Base = StateWrapper<BooleanState, AbstractAttribute>;
2085   AAICVTracker(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
2086 
2087   void initialize(Attributor &A) override {
2088     Function *F = getAnchorScope();
2089     if (!F || !A.isFunctionIPOAmendable(*F))
2090       indicatePessimisticFixpoint();
2091   }
2092 
2093   /// Returns true if value is assumed to be tracked.
2094   bool isAssumedTracked() const { return getAssumed(); }
2095 
2096   /// Returns true if value is known to be tracked.
2097   bool isKnownTracked() const { return getAssumed(); }
2098 
2099   /// Create an abstract attribute biew for the position \p IRP.
2100   static AAICVTracker &createForPosition(const IRPosition &IRP, Attributor &A);
2101 
2102   /// Return the value with which \p I can be replaced for specific \p ICV.
2103   virtual Optional<Value *> getReplacementValue(InternalControlVar ICV,
2104                                                 const Instruction *I,
2105                                                 Attributor &A) const {
2106     return None;
2107   }
2108 
2109   /// Return an assumed unique ICV value if a single candidate is found. If
2110   /// there cannot be one, return a nullptr. If it is not clear yet, return the
2111   /// Optional::NoneType.
2112   virtual Optional<Value *>
2113   getUniqueReplacementValue(InternalControlVar ICV) const = 0;
2114 
2115   // Currently only nthreads is being tracked.
2116   // this array will only grow with time.
2117   InternalControlVar TrackableICVs[1] = {ICV_nthreads};
2118 
2119   /// See AbstractAttribute::getName()
2120   const std::string getName() const override { return "AAICVTracker"; }
2121 
2122   /// See AbstractAttribute::getIdAddr()
2123   const char *getIdAddr() const override { return &ID; }
2124 
2125   /// This function should return true if the type of the \p AA is AAICVTracker
2126   static bool classof(const AbstractAttribute *AA) {
2127     return (AA->getIdAddr() == &ID);
2128   }
2129 
2130   static const char ID;
2131 };
2132 
2133 struct AAICVTrackerFunction : public AAICVTracker {
2134   AAICVTrackerFunction(const IRPosition &IRP, Attributor &A)
2135       : AAICVTracker(IRP, A) {}
2136 
2137   // FIXME: come up with better string.
2138   const std::string getAsStr() const override { return "ICVTrackerFunction"; }
2139 
2140   // FIXME: come up with some stats.
2141   void trackStatistics() const override {}
2142 
2143   /// We don't manifest anything for this AA.
2144   ChangeStatus manifest(Attributor &A) override {
2145     return ChangeStatus::UNCHANGED;
2146   }
2147 
2148   // Map of ICV to their values at specific program point.
2149   EnumeratedArray<DenseMap<Instruction *, Value *>, InternalControlVar,
2150                   InternalControlVar::ICV___last>
2151       ICVReplacementValuesMap;
2152 
2153   ChangeStatus updateImpl(Attributor &A) override {
2154     ChangeStatus HasChanged = ChangeStatus::UNCHANGED;
2155 
2156     Function *F = getAnchorScope();
2157 
2158     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2159 
2160     for (InternalControlVar ICV : TrackableICVs) {
2161       auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
2162 
2163       auto &ValuesMap = ICVReplacementValuesMap[ICV];
2164       auto TrackValues = [&](Use &U, Function &) {
2165         CallInst *CI = OpenMPOpt::getCallIfRegularCall(U);
2166         if (!CI)
2167           return false;
2168 
2169         // FIXME: handle setters with more that 1 arguments.
2170         /// Track new value.
2171         if (ValuesMap.insert(std::make_pair(CI, CI->getArgOperand(0))).second)
2172           HasChanged = ChangeStatus::CHANGED;
2173 
2174         return false;
2175       };
2176 
2177       auto CallCheck = [&](Instruction &I) {
2178         Optional<Value *> ReplVal = getValueForCall(A, I, ICV);
2179         if (ReplVal.hasValue() &&
2180             ValuesMap.insert(std::make_pair(&I, *ReplVal)).second)
2181           HasChanged = ChangeStatus::CHANGED;
2182 
2183         return true;
2184       };
2185 
2186       // Track all changes of an ICV.
2187       SetterRFI.foreachUse(TrackValues, F);
2188 
2189       bool UsedAssumedInformation = false;
2190       A.checkForAllInstructions(CallCheck, *this, {Instruction::Call},
2191                                 UsedAssumedInformation,
2192                                 /* CheckBBLivenessOnly */ true);
2193 
2194       /// TODO: Figure out a way to avoid adding entry in
2195       /// ICVReplacementValuesMap
2196       Instruction *Entry = &F->getEntryBlock().front();
2197       if (HasChanged == ChangeStatus::CHANGED && !ValuesMap.count(Entry))
2198         ValuesMap.insert(std::make_pair(Entry, nullptr));
2199     }
2200 
2201     return HasChanged;
2202   }
2203 
2204   /// Helper to check if \p I is a call and get the value for it if it is
2205   /// unique.
2206   Optional<Value *> getValueForCall(Attributor &A, const Instruction &I,
2207                                     InternalControlVar &ICV) const {
2208 
2209     const auto *CB = dyn_cast<CallBase>(&I);
2210     if (!CB || CB->hasFnAttr("no_openmp") ||
2211         CB->hasFnAttr("no_openmp_routines"))
2212       return None;
2213 
2214     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2215     auto &GetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Getter];
2216     auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
2217     Function *CalledFunction = CB->getCalledFunction();
2218 
2219     // Indirect call, assume ICV changes.
2220     if (CalledFunction == nullptr)
2221       return nullptr;
2222     if (CalledFunction == GetterRFI.Declaration)
2223       return None;
2224     if (CalledFunction == SetterRFI.Declaration) {
2225       if (ICVReplacementValuesMap[ICV].count(&I))
2226         return ICVReplacementValuesMap[ICV].lookup(&I);
2227 
2228       return nullptr;
2229     }
2230 
2231     // Since we don't know, assume it changes the ICV.
2232     if (CalledFunction->isDeclaration())
2233       return nullptr;
2234 
2235     const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
2236         *this, IRPosition::callsite_returned(*CB), DepClassTy::REQUIRED);
2237 
2238     if (ICVTrackingAA.isAssumedTracked()) {
2239       Optional<Value *> URV = ICVTrackingAA.getUniqueReplacementValue(ICV);
2240       if (!URV || (*URV && AA::isValidAtPosition(**URV, I, OMPInfoCache)))
2241         return URV;
2242     }
2243 
2244     // If we don't know, assume it changes.
2245     return nullptr;
2246   }
2247 
2248   // We don't check unique value for a function, so return None.
2249   Optional<Value *>
2250   getUniqueReplacementValue(InternalControlVar ICV) const override {
2251     return None;
2252   }
2253 
2254   /// Return the value with which \p I can be replaced for specific \p ICV.
2255   Optional<Value *> getReplacementValue(InternalControlVar ICV,
2256                                         const Instruction *I,
2257                                         Attributor &A) const override {
2258     const auto &ValuesMap = ICVReplacementValuesMap[ICV];
2259     if (ValuesMap.count(I))
2260       return ValuesMap.lookup(I);
2261 
2262     SmallVector<const Instruction *, 16> Worklist;
2263     SmallPtrSet<const Instruction *, 16> Visited;
2264     Worklist.push_back(I);
2265 
2266     Optional<Value *> ReplVal;
2267 
2268     while (!Worklist.empty()) {
2269       const Instruction *CurrInst = Worklist.pop_back_val();
2270       if (!Visited.insert(CurrInst).second)
2271         continue;
2272 
2273       const BasicBlock *CurrBB = CurrInst->getParent();
2274 
2275       // Go up and look for all potential setters/calls that might change the
2276       // ICV.
2277       while ((CurrInst = CurrInst->getPrevNode())) {
2278         if (ValuesMap.count(CurrInst)) {
2279           Optional<Value *> NewReplVal = ValuesMap.lookup(CurrInst);
2280           // Unknown value, track new.
2281           if (!ReplVal.hasValue()) {
2282             ReplVal = NewReplVal;
2283             break;
2284           }
2285 
2286           // If we found a new value, we can't know the icv value anymore.
2287           if (NewReplVal.hasValue())
2288             if (ReplVal != NewReplVal)
2289               return nullptr;
2290 
2291           break;
2292         }
2293 
2294         Optional<Value *> NewReplVal = getValueForCall(A, *CurrInst, ICV);
2295         if (!NewReplVal.hasValue())
2296           continue;
2297 
2298         // Unknown value, track new.
2299         if (!ReplVal.hasValue()) {
2300           ReplVal = NewReplVal;
2301           break;
2302         }
2303 
2304         // if (NewReplVal.hasValue())
2305         // We found a new value, we can't know the icv value anymore.
2306         if (ReplVal != NewReplVal)
2307           return nullptr;
2308       }
2309 
2310       // If we are in the same BB and we have a value, we are done.
2311       if (CurrBB == I->getParent() && ReplVal.hasValue())
2312         return ReplVal;
2313 
2314       // Go through all predecessors and add terminators for analysis.
2315       for (const BasicBlock *Pred : predecessors(CurrBB))
2316         if (const Instruction *Terminator = Pred->getTerminator())
2317           Worklist.push_back(Terminator);
2318     }
2319 
2320     return ReplVal;
2321   }
2322 };
2323 
2324 struct AAICVTrackerFunctionReturned : AAICVTracker {
2325   AAICVTrackerFunctionReturned(const IRPosition &IRP, Attributor &A)
2326       : AAICVTracker(IRP, A) {}
2327 
2328   // FIXME: come up with better string.
2329   const std::string getAsStr() const override {
2330     return "ICVTrackerFunctionReturned";
2331   }
2332 
2333   // FIXME: come up with some stats.
2334   void trackStatistics() const override {}
2335 
2336   /// We don't manifest anything for this AA.
2337   ChangeStatus manifest(Attributor &A) override {
2338     return ChangeStatus::UNCHANGED;
2339   }
2340 
2341   // Map of ICV to their values at specific program point.
2342   EnumeratedArray<Optional<Value *>, InternalControlVar,
2343                   InternalControlVar::ICV___last>
2344       ICVReplacementValuesMap;
2345 
2346   /// Return the value with which \p I can be replaced for specific \p ICV.
2347   Optional<Value *>
2348   getUniqueReplacementValue(InternalControlVar ICV) const override {
2349     return ICVReplacementValuesMap[ICV];
2350   }
2351 
2352   ChangeStatus updateImpl(Attributor &A) override {
2353     ChangeStatus Changed = ChangeStatus::UNCHANGED;
2354     const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
2355         *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
2356 
2357     if (!ICVTrackingAA.isAssumedTracked())
2358       return indicatePessimisticFixpoint();
2359 
2360     for (InternalControlVar ICV : TrackableICVs) {
2361       Optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
2362       Optional<Value *> UniqueICVValue;
2363 
2364       auto CheckReturnInst = [&](Instruction &I) {
2365         Optional<Value *> NewReplVal =
2366             ICVTrackingAA.getReplacementValue(ICV, &I, A);
2367 
2368         // If we found a second ICV value there is no unique returned value.
2369         if (UniqueICVValue.hasValue() && UniqueICVValue != NewReplVal)
2370           return false;
2371 
2372         UniqueICVValue = NewReplVal;
2373 
2374         return true;
2375       };
2376 
2377       bool UsedAssumedInformation = false;
2378       if (!A.checkForAllInstructions(CheckReturnInst, *this, {Instruction::Ret},
2379                                      UsedAssumedInformation,
2380                                      /* CheckBBLivenessOnly */ true))
2381         UniqueICVValue = nullptr;
2382 
2383       if (UniqueICVValue == ReplVal)
2384         continue;
2385 
2386       ReplVal = UniqueICVValue;
2387       Changed = ChangeStatus::CHANGED;
2388     }
2389 
2390     return Changed;
2391   }
2392 };
2393 
2394 struct AAICVTrackerCallSite : AAICVTracker {
2395   AAICVTrackerCallSite(const IRPosition &IRP, Attributor &A)
2396       : AAICVTracker(IRP, A) {}
2397 
2398   void initialize(Attributor &A) override {
2399     Function *F = getAnchorScope();
2400     if (!F || !A.isFunctionIPOAmendable(*F))
2401       indicatePessimisticFixpoint();
2402 
2403     // We only initialize this AA for getters, so we need to know which ICV it
2404     // gets.
2405     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2406     for (InternalControlVar ICV : TrackableICVs) {
2407       auto ICVInfo = OMPInfoCache.ICVs[ICV];
2408       auto &Getter = OMPInfoCache.RFIs[ICVInfo.Getter];
2409       if (Getter.Declaration == getAssociatedFunction()) {
2410         AssociatedICV = ICVInfo.Kind;
2411         return;
2412       }
2413     }
2414 
2415     /// Unknown ICV.
2416     indicatePessimisticFixpoint();
2417   }
2418 
2419   ChangeStatus manifest(Attributor &A) override {
2420     if (!ReplVal.hasValue() || !ReplVal.getValue())
2421       return ChangeStatus::UNCHANGED;
2422 
2423     A.changeValueAfterManifest(*getCtxI(), **ReplVal);
2424     A.deleteAfterManifest(*getCtxI());
2425 
2426     return ChangeStatus::CHANGED;
2427   }
2428 
2429   // FIXME: come up with better string.
2430   const std::string getAsStr() const override { return "ICVTrackerCallSite"; }
2431 
2432   // FIXME: come up with some stats.
2433   void trackStatistics() const override {}
2434 
2435   InternalControlVar AssociatedICV;
2436   Optional<Value *> ReplVal;
2437 
2438   ChangeStatus updateImpl(Attributor &A) override {
2439     const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
2440         *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
2441 
2442     // We don't have any information, so we assume it changes the ICV.
2443     if (!ICVTrackingAA.isAssumedTracked())
2444       return indicatePessimisticFixpoint();
2445 
2446     Optional<Value *> NewReplVal =
2447         ICVTrackingAA.getReplacementValue(AssociatedICV, getCtxI(), A);
2448 
2449     if (ReplVal == NewReplVal)
2450       return ChangeStatus::UNCHANGED;
2451 
2452     ReplVal = NewReplVal;
2453     return ChangeStatus::CHANGED;
2454   }
2455 
2456   // Return the value with which associated value can be replaced for specific
2457   // \p ICV.
2458   Optional<Value *>
2459   getUniqueReplacementValue(InternalControlVar ICV) const override {
2460     return ReplVal;
2461   }
2462 };
2463 
2464 struct AAICVTrackerCallSiteReturned : AAICVTracker {
2465   AAICVTrackerCallSiteReturned(const IRPosition &IRP, Attributor &A)
2466       : AAICVTracker(IRP, A) {}
2467 
2468   // FIXME: come up with better string.
2469   const std::string getAsStr() const override {
2470     return "ICVTrackerCallSiteReturned";
2471   }
2472 
2473   // FIXME: come up with some stats.
2474   void trackStatistics() const override {}
2475 
2476   /// We don't manifest anything for this AA.
2477   ChangeStatus manifest(Attributor &A) override {
2478     return ChangeStatus::UNCHANGED;
2479   }
2480 
2481   // Map of ICV to their values at specific program point.
2482   EnumeratedArray<Optional<Value *>, InternalControlVar,
2483                   InternalControlVar::ICV___last>
2484       ICVReplacementValuesMap;
2485 
2486   /// Return the value with which associated value can be replaced for specific
2487   /// \p ICV.
2488   Optional<Value *>
2489   getUniqueReplacementValue(InternalControlVar ICV) const override {
2490     return ICVReplacementValuesMap[ICV];
2491   }
2492 
2493   ChangeStatus updateImpl(Attributor &A) override {
2494     ChangeStatus Changed = ChangeStatus::UNCHANGED;
2495     const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
2496         *this, IRPosition::returned(*getAssociatedFunction()),
2497         DepClassTy::REQUIRED);
2498 
2499     // We don't have any information, so we assume it changes the ICV.
2500     if (!ICVTrackingAA.isAssumedTracked())
2501       return indicatePessimisticFixpoint();
2502 
2503     for (InternalControlVar ICV : TrackableICVs) {
2504       Optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
2505       Optional<Value *> NewReplVal =
2506           ICVTrackingAA.getUniqueReplacementValue(ICV);
2507 
2508       if (ReplVal == NewReplVal)
2509         continue;
2510 
2511       ReplVal = NewReplVal;
2512       Changed = ChangeStatus::CHANGED;
2513     }
2514     return Changed;
2515   }
2516 };
2517 
2518 struct AAExecutionDomainFunction : public AAExecutionDomain {
2519   AAExecutionDomainFunction(const IRPosition &IRP, Attributor &A)
2520       : AAExecutionDomain(IRP, A) {}
2521 
2522   const std::string getAsStr() const override {
2523     return "[AAExecutionDomain] " + std::to_string(SingleThreadedBBs.size()) +
2524            "/" + std::to_string(NumBBs) + " BBs thread 0 only.";
2525   }
2526 
2527   /// See AbstractAttribute::trackStatistics().
2528   void trackStatistics() const override {}
2529 
2530   void initialize(Attributor &A) override {
2531     Function *F = getAnchorScope();
2532     for (const auto &BB : *F)
2533       SingleThreadedBBs.insert(&BB);
2534     NumBBs = SingleThreadedBBs.size();
2535   }
2536 
2537   ChangeStatus manifest(Attributor &A) override {
2538     LLVM_DEBUG({
2539       for (const BasicBlock *BB : SingleThreadedBBs)
2540         dbgs() << TAG << " Basic block @" << getAnchorScope()->getName() << " "
2541                << BB->getName() << " is executed by a single thread.\n";
2542     });
2543     return ChangeStatus::UNCHANGED;
2544   }
2545 
2546   ChangeStatus updateImpl(Attributor &A) override;
2547 
2548   /// Check if an instruction is executed by a single thread.
2549   bool isExecutedByInitialThreadOnly(const Instruction &I) const override {
2550     return isExecutedByInitialThreadOnly(*I.getParent());
2551   }
2552 
2553   bool isExecutedByInitialThreadOnly(const BasicBlock &BB) const override {
2554     return isValidState() && SingleThreadedBBs.contains(&BB);
2555   }
2556 
2557   /// Set of basic blocks that are executed by a single thread.
2558   SmallSetVector<const BasicBlock *, 16> SingleThreadedBBs;
2559 
2560   /// Total number of basic blocks in this function.
2561   long unsigned NumBBs;
2562 };
2563 
2564 ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) {
2565   Function *F = getAnchorScope();
2566   ReversePostOrderTraversal<Function *> RPOT(F);
2567   auto NumSingleThreadedBBs = SingleThreadedBBs.size();
2568 
2569   bool AllCallSitesKnown;
2570   auto PredForCallSite = [&](AbstractCallSite ACS) {
2571     const auto &ExecutionDomainAA = A.getAAFor<AAExecutionDomain>(
2572         *this, IRPosition::function(*ACS.getInstruction()->getFunction()),
2573         DepClassTy::REQUIRED);
2574     return ACS.isDirectCall() &&
2575            ExecutionDomainAA.isExecutedByInitialThreadOnly(
2576                *ACS.getInstruction());
2577   };
2578 
2579   if (!A.checkForAllCallSites(PredForCallSite, *this,
2580                               /* RequiresAllCallSites */ true,
2581                               AllCallSitesKnown))
2582     SingleThreadedBBs.remove(&F->getEntryBlock());
2583 
2584   auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2585   auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
2586 
2587   // Check if the edge into the successor block contains a condition that only
2588   // lets the main thread execute it.
2589   auto IsInitialThreadOnly = [&](BranchInst *Edge, BasicBlock *SuccessorBB) {
2590     if (!Edge || !Edge->isConditional())
2591       return false;
2592     if (Edge->getSuccessor(0) != SuccessorBB)
2593       return false;
2594 
2595     auto *Cmp = dyn_cast<CmpInst>(Edge->getCondition());
2596     if (!Cmp || !Cmp->isTrueWhenEqual() || !Cmp->isEquality())
2597       return false;
2598 
2599     ConstantInt *C = dyn_cast<ConstantInt>(Cmp->getOperand(1));
2600     if (!C)
2601       return false;
2602 
2603     // Match: -1 == __kmpc_target_init (for non-SPMD kernels only!)
2604     if (C->isAllOnesValue()) {
2605       auto *CB = dyn_cast<CallBase>(Cmp->getOperand(0));
2606       CB = CB ? OpenMPOpt::getCallIfRegularCall(*CB, &RFI) : nullptr;
2607       if (!CB)
2608         return false;
2609       const int InitModeArgNo = 1;
2610       auto *ModeCI = dyn_cast<ConstantInt>(CB->getOperand(InitModeArgNo));
2611       return ModeCI && (ModeCI->getSExtValue() & OMP_TGT_EXEC_MODE_GENERIC);
2612     }
2613 
2614     if (C->isZero()) {
2615       // Match: 0 == llvm.nvvm.read.ptx.sreg.tid.x()
2616       if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0)))
2617         if (II->getIntrinsicID() == Intrinsic::nvvm_read_ptx_sreg_tid_x)
2618           return true;
2619 
2620       // Match: 0 == llvm.amdgcn.workitem.id.x()
2621       if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0)))
2622         if (II->getIntrinsicID() == Intrinsic::amdgcn_workitem_id_x)
2623           return true;
2624     }
2625 
2626     return false;
2627   };
2628 
2629   // Merge all the predecessor states into the current basic block. A basic
2630   // block is executed by a single thread if all of its predecessors are.
2631   auto MergePredecessorStates = [&](BasicBlock *BB) {
2632     if (pred_empty(BB))
2633       return SingleThreadedBBs.contains(BB);
2634 
2635     bool IsInitialThread = true;
2636     for (BasicBlock *PredBB : predecessors(BB)) {
2637       if (!IsInitialThreadOnly(dyn_cast<BranchInst>(PredBB->getTerminator()),
2638                                BB))
2639         IsInitialThread &= SingleThreadedBBs.contains(PredBB);
2640     }
2641 
2642     return IsInitialThread;
2643   };
2644 
2645   for (auto *BB : RPOT) {
2646     if (!MergePredecessorStates(BB))
2647       SingleThreadedBBs.remove(BB);
2648   }
2649 
2650   return (NumSingleThreadedBBs == SingleThreadedBBs.size())
2651              ? ChangeStatus::UNCHANGED
2652              : ChangeStatus::CHANGED;
2653 }
2654 
2655 /// Try to replace memory allocation calls called by a single thread with a
2656 /// static buffer of shared memory.
2657 struct AAHeapToShared : public StateWrapper<BooleanState, AbstractAttribute> {
2658   using Base = StateWrapper<BooleanState, AbstractAttribute>;
2659   AAHeapToShared(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
2660 
2661   /// Create an abstract attribute view for the position \p IRP.
2662   static AAHeapToShared &createForPosition(const IRPosition &IRP,
2663                                            Attributor &A);
2664 
2665   /// Returns true if HeapToShared conversion is assumed to be possible.
2666   virtual bool isAssumedHeapToShared(CallBase &CB) const = 0;
2667 
2668   /// Returns true if HeapToShared conversion is assumed and the CB is a
2669   /// callsite to a free operation to be removed.
2670   virtual bool isAssumedHeapToSharedRemovedFree(CallBase &CB) const = 0;
2671 
2672   /// See AbstractAttribute::getName().
2673   const std::string getName() const override { return "AAHeapToShared"; }
2674 
2675   /// See AbstractAttribute::getIdAddr().
2676   const char *getIdAddr() const override { return &ID; }
2677 
2678   /// This function should return true if the type of the \p AA is
2679   /// AAHeapToShared.
2680   static bool classof(const AbstractAttribute *AA) {
2681     return (AA->getIdAddr() == &ID);
2682   }
2683 
2684   /// Unique ID (due to the unique address)
2685   static const char ID;
2686 };
2687 
2688 struct AAHeapToSharedFunction : public AAHeapToShared {
2689   AAHeapToSharedFunction(const IRPosition &IRP, Attributor &A)
2690       : AAHeapToShared(IRP, A) {}
2691 
2692   const std::string getAsStr() const override {
2693     return "[AAHeapToShared] " + std::to_string(MallocCalls.size()) +
2694            " malloc calls eligible.";
2695   }
2696 
2697   /// See AbstractAttribute::trackStatistics().
2698   void trackStatistics() const override {}
2699 
2700   /// This functions finds free calls that will be removed by the
2701   /// HeapToShared transformation.
2702   void findPotentialRemovedFreeCalls(Attributor &A) {
2703     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2704     auto &FreeRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];
2705 
2706     PotentialRemovedFreeCalls.clear();
2707     // Update free call users of found malloc calls.
2708     for (CallBase *CB : MallocCalls) {
2709       SmallVector<CallBase *, 4> FreeCalls;
2710       for (auto *U : CB->users()) {
2711         CallBase *C = dyn_cast<CallBase>(U);
2712         if (C && C->getCalledFunction() == FreeRFI.Declaration)
2713           FreeCalls.push_back(C);
2714       }
2715 
2716       if (FreeCalls.size() != 1)
2717         continue;
2718 
2719       PotentialRemovedFreeCalls.insert(FreeCalls.front());
2720     }
2721   }
2722 
2723   void initialize(Attributor &A) override {
2724     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2725     auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
2726 
2727     for (User *U : RFI.Declaration->users())
2728       if (CallBase *CB = dyn_cast<CallBase>(U))
2729         MallocCalls.insert(CB);
2730 
2731     findPotentialRemovedFreeCalls(A);
2732   }
2733 
2734   bool isAssumedHeapToShared(CallBase &CB) const override {
2735     return isValidState() && MallocCalls.count(&CB);
2736   }
2737 
2738   bool isAssumedHeapToSharedRemovedFree(CallBase &CB) const override {
2739     return isValidState() && PotentialRemovedFreeCalls.count(&CB);
2740   }
2741 
2742   ChangeStatus manifest(Attributor &A) override {
2743     if (MallocCalls.empty())
2744       return ChangeStatus::UNCHANGED;
2745 
2746     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2747     auto &FreeCall = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];
2748 
2749     Function *F = getAnchorScope();
2750     auto *HS = A.lookupAAFor<AAHeapToStack>(IRPosition::function(*F), this,
2751                                             DepClassTy::OPTIONAL);
2752 
2753     ChangeStatus Changed = ChangeStatus::UNCHANGED;
2754     for (CallBase *CB : MallocCalls) {
2755       // Skip replacing this if HeapToStack has already claimed it.
2756       if (HS && HS->isAssumedHeapToStack(*CB))
2757         continue;
2758 
2759       // Find the unique free call to remove it.
2760       SmallVector<CallBase *, 4> FreeCalls;
2761       for (auto *U : CB->users()) {
2762         CallBase *C = dyn_cast<CallBase>(U);
2763         if (C && C->getCalledFunction() == FreeCall.Declaration)
2764           FreeCalls.push_back(C);
2765       }
2766       if (FreeCalls.size() != 1)
2767         continue;
2768 
2769       auto *AllocSize = cast<ConstantInt>(CB->getArgOperand(0));
2770 
2771       LLVM_DEBUG(dbgs() << TAG << "Replace globalization call " << *CB
2772                         << " with " << AllocSize->getZExtValue()
2773                         << " bytes of shared memory\n");
2774 
2775       // Create a new shared memory buffer of the same size as the allocation
2776       // and replace all the uses of the original allocation with it.
2777       Module *M = CB->getModule();
2778       Type *Int8Ty = Type::getInt8Ty(M->getContext());
2779       Type *Int8ArrTy = ArrayType::get(Int8Ty, AllocSize->getZExtValue());
2780       auto *SharedMem = new GlobalVariable(
2781           *M, Int8ArrTy, /* IsConstant */ false, GlobalValue::InternalLinkage,
2782           UndefValue::get(Int8ArrTy), CB->getName() + "_shared", nullptr,
2783           GlobalValue::NotThreadLocal,
2784           static_cast<unsigned>(AddressSpace::Shared));
2785       auto *NewBuffer =
2786           ConstantExpr::getPointerCast(SharedMem, Int8Ty->getPointerTo());
2787 
2788       auto Remark = [&](OptimizationRemark OR) {
2789         return OR << "Replaced globalized variable with "
2790                   << ore::NV("SharedMemory", AllocSize->getZExtValue())
2791                   << ((AllocSize->getZExtValue() != 1) ? " bytes " : " byte ")
2792                   << "of shared memory.";
2793       };
2794       A.emitRemark<OptimizationRemark>(CB, "OMP111", Remark);
2795 
2796       MaybeAlign Alignment = CB->getRetAlign();
2797       assert(Alignment &&
2798              "HeapToShared on allocation without alignment attribute");
2799       SharedMem->setAlignment(MaybeAlign(Alignment));
2800 
2801       A.changeValueAfterManifest(*CB, *NewBuffer);
2802       A.deleteAfterManifest(*CB);
2803       A.deleteAfterManifest(*FreeCalls.front());
2804 
2805       NumBytesMovedToSharedMemory += AllocSize->getZExtValue();
2806       Changed = ChangeStatus::CHANGED;
2807     }
2808 
2809     return Changed;
2810   }
2811 
2812   ChangeStatus updateImpl(Attributor &A) override {
2813     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2814     auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
2815     Function *F = getAnchorScope();
2816 
2817     auto NumMallocCalls = MallocCalls.size();
2818 
2819     // Only consider malloc calls executed by a single thread with a constant.
2820     for (User *U : RFI.Declaration->users()) {
2821       const auto &ED = A.getAAFor<AAExecutionDomain>(
2822           *this, IRPosition::function(*F), DepClassTy::REQUIRED);
2823       if (CallBase *CB = dyn_cast<CallBase>(U))
2824         if (!isa<ConstantInt>(CB->getArgOperand(0)) ||
2825             !ED.isExecutedByInitialThreadOnly(*CB))
2826           MallocCalls.remove(CB);
2827     }
2828 
2829     findPotentialRemovedFreeCalls(A);
2830 
2831     if (NumMallocCalls != MallocCalls.size())
2832       return ChangeStatus::CHANGED;
2833 
2834     return ChangeStatus::UNCHANGED;
2835   }
2836 
2837   /// Collection of all malloc calls in a function.
2838   SmallSetVector<CallBase *, 4> MallocCalls;
2839   /// Collection of potentially removed free calls in a function.
2840   SmallPtrSet<CallBase *, 4> PotentialRemovedFreeCalls;
2841 };
2842 
2843 struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> {
2844   using Base = StateWrapper<KernelInfoState, AbstractAttribute>;
2845   AAKernelInfo(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
2846 
2847   /// Statistics are tracked as part of manifest for now.
2848   void trackStatistics() const override {}
2849 
2850   /// See AbstractAttribute::getAsStr()
2851   const std::string getAsStr() const override {
2852     if (!isValidState())
2853       return "<invalid>";
2854     return std::string(SPMDCompatibilityTracker.isAssumed() ? "SPMD"
2855                                                             : "generic") +
2856            std::string(SPMDCompatibilityTracker.isAtFixpoint() ? " [FIX]"
2857                                                                : "") +
2858            std::string(" #PRs: ") +
2859            (ReachedKnownParallelRegions.isValidState()
2860                 ? std::to_string(ReachedKnownParallelRegions.size())
2861                 : "<invalid>") +
2862            ", #Unknown PRs: " +
2863            (ReachedUnknownParallelRegions.isValidState()
2864                 ? std::to_string(ReachedUnknownParallelRegions.size())
2865                 : "<invalid>") +
2866            ", #Reaching Kernels: " +
2867            (ReachingKernelEntries.isValidState()
2868                 ? std::to_string(ReachingKernelEntries.size())
2869                 : "<invalid>");
2870   }
2871 
2872   /// Create an abstract attribute biew for the position \p IRP.
2873   static AAKernelInfo &createForPosition(const IRPosition &IRP, Attributor &A);
2874 
2875   /// See AbstractAttribute::getName()
2876   const std::string getName() const override { return "AAKernelInfo"; }
2877 
2878   /// See AbstractAttribute::getIdAddr()
2879   const char *getIdAddr() const override { return &ID; }
2880 
2881   /// This function should return true if the type of the \p AA is AAKernelInfo
2882   static bool classof(const AbstractAttribute *AA) {
2883     return (AA->getIdAddr() == &ID);
2884   }
2885 
2886   static const char ID;
2887 };
2888 
2889 /// The function kernel info abstract attribute, basically, what can we say
2890 /// about a function with regards to the KernelInfoState.
2891 struct AAKernelInfoFunction : AAKernelInfo {
2892   AAKernelInfoFunction(const IRPosition &IRP, Attributor &A)
2893       : AAKernelInfo(IRP, A) {}
2894 
2895   SmallPtrSet<Instruction *, 4> GuardedInstructions;
2896 
2897   SmallPtrSetImpl<Instruction *> &getGuardedInstructions() {
2898     return GuardedInstructions;
2899   }
2900 
2901   /// See AbstractAttribute::initialize(...).
2902   void initialize(Attributor &A) override {
2903     // This is a high-level transform that might change the constant arguments
2904     // of the init and dinit calls. We need to tell the Attributor about this
2905     // to avoid other parts using the current constant value for simpliication.
2906     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2907 
2908     Function *Fn = getAnchorScope();
2909     if (!OMPInfoCache.Kernels.count(Fn))
2910       return;
2911 
2912     // Add itself to the reaching kernel and set IsKernelEntry.
2913     ReachingKernelEntries.insert(Fn);
2914     IsKernelEntry = true;
2915 
2916     OMPInformationCache::RuntimeFunctionInfo &InitRFI =
2917         OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
2918     OMPInformationCache::RuntimeFunctionInfo &DeinitRFI =
2919         OMPInfoCache.RFIs[OMPRTL___kmpc_target_deinit];
2920 
2921     // For kernels we perform more initialization work, first we find the init
2922     // and deinit calls.
2923     auto StoreCallBase = [](Use &U,
2924                             OMPInformationCache::RuntimeFunctionInfo &RFI,
2925                             CallBase *&Storage) {
2926       CallBase *CB = OpenMPOpt::getCallIfRegularCall(U, &RFI);
2927       assert(CB &&
2928              "Unexpected use of __kmpc_target_init or __kmpc_target_deinit!");
2929       assert(!Storage &&
2930              "Multiple uses of __kmpc_target_init or __kmpc_target_deinit!");
2931       Storage = CB;
2932       return false;
2933     };
2934     InitRFI.foreachUse(
2935         [&](Use &U, Function &) {
2936           StoreCallBase(U, InitRFI, KernelInitCB);
2937           return false;
2938         },
2939         Fn);
2940     DeinitRFI.foreachUse(
2941         [&](Use &U, Function &) {
2942           StoreCallBase(U, DeinitRFI, KernelDeinitCB);
2943           return false;
2944         },
2945         Fn);
2946 
2947     // Ignore kernels without initializers such as global constructors.
2948     if (!KernelInitCB || !KernelDeinitCB) {
2949       indicateOptimisticFixpoint();
2950       return;
2951     }
2952 
2953     // For kernels we might need to initialize/finalize the IsSPMD state and
2954     // we need to register a simplification callback so that the Attributor
2955     // knows the constant arguments to __kmpc_target_init and
2956     // __kmpc_target_deinit might actually change.
2957 
2958     Attributor::SimplifictionCallbackTy StateMachineSimplifyCB =
2959         [&](const IRPosition &IRP, const AbstractAttribute *AA,
2960             bool &UsedAssumedInformation) -> Optional<Value *> {
2961       // IRP represents the "use generic state machine" argument of an
2962       // __kmpc_target_init call. We will answer this one with the internal
2963       // state. As long as we are not in an invalid state, we will create a
2964       // custom state machine so the value should be a `i1 false`. If we are
2965       // in an invalid state, we won't change the value that is in the IR.
2966       if (!ReachedKnownParallelRegions.isValidState())
2967         return nullptr;
2968       // If we have disabled state machine rewrites, don't make a custom one.
2969       if (DisableOpenMPOptStateMachineRewrite)
2970         return nullptr;
2971       if (AA)
2972         A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
2973       UsedAssumedInformation = !isAtFixpoint();
2974       auto *FalseVal =
2975           ConstantInt::getBool(IRP.getAnchorValue().getContext(), false);
2976       return FalseVal;
2977     };
2978 
2979     Attributor::SimplifictionCallbackTy ModeSimplifyCB =
2980         [&](const IRPosition &IRP, const AbstractAttribute *AA,
2981             bool &UsedAssumedInformation) -> Optional<Value *> {
2982       // IRP represents the "SPMDCompatibilityTracker" argument of an
2983       // __kmpc_target_init or
2984       // __kmpc_target_deinit call. We will answer this one with the internal
2985       // state.
2986       if (!SPMDCompatibilityTracker.isValidState())
2987         return nullptr;
2988       if (!SPMDCompatibilityTracker.isAtFixpoint()) {
2989         if (AA)
2990           A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
2991         UsedAssumedInformation = true;
2992       } else {
2993         UsedAssumedInformation = false;
2994       }
2995       auto *Val = ConstantInt::getSigned(
2996           IntegerType::getInt8Ty(IRP.getAnchorValue().getContext()),
2997           SPMDCompatibilityTracker.isAssumed() ? OMP_TGT_EXEC_MODE_SPMD
2998                                                : OMP_TGT_EXEC_MODE_GENERIC);
2999       return Val;
3000     };
3001 
3002     Attributor::SimplifictionCallbackTy IsGenericModeSimplifyCB =
3003         [&](const IRPosition &IRP, const AbstractAttribute *AA,
3004             bool &UsedAssumedInformation) -> Optional<Value *> {
3005       // IRP represents the "RequiresFullRuntime" argument of an
3006       // __kmpc_target_init or __kmpc_target_deinit call. We will answer this
3007       // one with the internal state of the SPMDCompatibilityTracker, so if
3008       // generic then true, if SPMD then false.
3009       if (!SPMDCompatibilityTracker.isValidState())
3010         return nullptr;
3011       if (!SPMDCompatibilityTracker.isAtFixpoint()) {
3012         if (AA)
3013           A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
3014         UsedAssumedInformation = true;
3015       } else {
3016         UsedAssumedInformation = false;
3017       }
3018       auto *Val = ConstantInt::getBool(IRP.getAnchorValue().getContext(),
3019                                        !SPMDCompatibilityTracker.isAssumed());
3020       return Val;
3021     };
3022 
3023     constexpr const int InitModeArgNo = 1;
3024     constexpr const int DeinitModeArgNo = 1;
3025     constexpr const int InitUseStateMachineArgNo = 2;
3026     constexpr const int InitRequiresFullRuntimeArgNo = 3;
3027     constexpr const int DeinitRequiresFullRuntimeArgNo = 2;
3028     A.registerSimplificationCallback(
3029         IRPosition::callsite_argument(*KernelInitCB, InitUseStateMachineArgNo),
3030         StateMachineSimplifyCB);
3031     A.registerSimplificationCallback(
3032         IRPosition::callsite_argument(*KernelInitCB, InitModeArgNo),
3033         ModeSimplifyCB);
3034     A.registerSimplificationCallback(
3035         IRPosition::callsite_argument(*KernelDeinitCB, DeinitModeArgNo),
3036         ModeSimplifyCB);
3037     A.registerSimplificationCallback(
3038         IRPosition::callsite_argument(*KernelInitCB,
3039                                       InitRequiresFullRuntimeArgNo),
3040         IsGenericModeSimplifyCB);
3041     A.registerSimplificationCallback(
3042         IRPosition::callsite_argument(*KernelDeinitCB,
3043                                       DeinitRequiresFullRuntimeArgNo),
3044         IsGenericModeSimplifyCB);
3045 
3046     // Check if we know we are in SPMD-mode already.
3047     ConstantInt *ModeArg =
3048         dyn_cast<ConstantInt>(KernelInitCB->getArgOperand(InitModeArgNo));
3049     if (ModeArg && (ModeArg->getSExtValue() & OMP_TGT_EXEC_MODE_SPMD))
3050       SPMDCompatibilityTracker.indicateOptimisticFixpoint();
3051     // This is a generic region but SPMDization is disabled so stop tracking.
3052     else if (DisableOpenMPOptSPMDization)
3053       SPMDCompatibilityTracker.indicatePessimisticFixpoint();
3054   }
3055 
3056   /// Sanitize the string \p S such that it is a suitable global symbol name.
3057   static std::string sanitizeForGlobalName(std::string S) {
3058     std::replace_if(
3059         S.begin(), S.end(),
3060         [](const char C) {
3061           return !((C >= 'a' && C <= 'z') || (C >= 'A' && C <= 'Z') ||
3062                    (C >= '0' && C <= '9') || C == '_');
3063         },
3064         '.');
3065     return S;
3066   }
3067 
3068   /// Modify the IR based on the KernelInfoState as the fixpoint iteration is
3069   /// finished now.
3070   ChangeStatus manifest(Attributor &A) override {
3071     // If we are not looking at a kernel with __kmpc_target_init and
3072     // __kmpc_target_deinit call we cannot actually manifest the information.
3073     if (!KernelInitCB || !KernelDeinitCB)
3074       return ChangeStatus::UNCHANGED;
3075 
3076     // If we can we change the execution mode to SPMD-mode otherwise we build a
3077     // custom state machine.
3078     ChangeStatus Changed = ChangeStatus::UNCHANGED;
3079     if (!changeToSPMDMode(A, Changed))
3080       return buildCustomStateMachine(A);
3081 
3082     return Changed;
3083   }
3084 
3085   bool changeToSPMDMode(Attributor &A, ChangeStatus &Changed) {
3086     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3087 
3088     if (!SPMDCompatibilityTracker.isAssumed()) {
3089       for (Instruction *NonCompatibleI : SPMDCompatibilityTracker) {
3090         if (!NonCompatibleI)
3091           continue;
3092 
3093         // Skip diagnostics on calls to known OpenMP runtime functions for now.
3094         if (auto *CB = dyn_cast<CallBase>(NonCompatibleI))
3095           if (OMPInfoCache.RTLFunctions.contains(CB->getCalledFunction()))
3096             continue;
3097 
3098         auto Remark = [&](OptimizationRemarkAnalysis ORA) {
3099           ORA << "Value has potential side effects preventing SPMD-mode "
3100                  "execution";
3101           if (isa<CallBase>(NonCompatibleI)) {
3102             ORA << ". Add `__attribute__((assume(\"ompx_spmd_amenable\")))` to "
3103                    "the called function to override";
3104           }
3105           return ORA << ".";
3106         };
3107         A.emitRemark<OptimizationRemarkAnalysis>(NonCompatibleI, "OMP121",
3108                                                  Remark);
3109 
3110         LLVM_DEBUG(dbgs() << TAG << "SPMD-incompatible side-effect: "
3111                           << *NonCompatibleI << "\n");
3112       }
3113 
3114       return false;
3115     }
3116 
3117     // Check if the kernel is already in SPMD mode, if so, return success.
3118     Function *Kernel = getAnchorScope();
3119     GlobalVariable *ExecMode = Kernel->getParent()->getGlobalVariable(
3120         (Kernel->getName() + "_exec_mode").str());
3121     assert(ExecMode && "Kernel without exec mode?");
3122     assert(ExecMode->getInitializer() && "ExecMode doesn't have initializer!");
3123 
3124     // Set the global exec mode flag to indicate SPMD-Generic mode.
3125     assert(isa<ConstantInt>(ExecMode->getInitializer()) &&
3126            "ExecMode is not an integer!");
3127     const int8_t ExecModeVal =
3128         cast<ConstantInt>(ExecMode->getInitializer())->getSExtValue();
3129     if (ExecModeVal != OMP_TGT_EXEC_MODE_GENERIC)
3130       return true;
3131 
3132     // We will now unconditionally modify the IR, indicate a change.
3133     Changed = ChangeStatus::CHANGED;
3134 
3135     auto CreateGuardedRegion = [&](Instruction *RegionStartI,
3136                                    Instruction *RegionEndI) {
3137       LoopInfo *LI = nullptr;
3138       DominatorTree *DT = nullptr;
3139       MemorySSAUpdater *MSU = nullptr;
3140       using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
3141 
3142       BasicBlock *ParentBB = RegionStartI->getParent();
3143       Function *Fn = ParentBB->getParent();
3144       Module &M = *Fn->getParent();
3145 
3146       // Create all the blocks and logic.
3147       // ParentBB:
3148       //    goto RegionCheckTidBB
3149       // RegionCheckTidBB:
3150       //    Tid = __kmpc_hardware_thread_id()
3151       //    if (Tid != 0)
3152       //        goto RegionBarrierBB
3153       // RegionStartBB:
3154       //    <execute instructions guarded>
3155       //    goto RegionEndBB
3156       // RegionEndBB:
3157       //    <store escaping values to shared mem>
3158       //    goto RegionBarrierBB
3159       //  RegionBarrierBB:
3160       //    __kmpc_simple_barrier_spmd()
3161       //    // second barrier is omitted if lacking escaping values.
3162       //    <load escaping values from shared mem>
3163       //    __kmpc_simple_barrier_spmd()
3164       //    goto RegionExitBB
3165       // RegionExitBB:
3166       //    <execute rest of instructions>
3167 
3168       BasicBlock *RegionEndBB = SplitBlock(ParentBB, RegionEndI->getNextNode(),
3169                                            DT, LI, MSU, "region.guarded.end");
3170       BasicBlock *RegionBarrierBB =
3171           SplitBlock(RegionEndBB, &*RegionEndBB->getFirstInsertionPt(), DT, LI,
3172                      MSU, "region.barrier");
3173       BasicBlock *RegionExitBB =
3174           SplitBlock(RegionBarrierBB, &*RegionBarrierBB->getFirstInsertionPt(),
3175                      DT, LI, MSU, "region.exit");
3176       BasicBlock *RegionStartBB =
3177           SplitBlock(ParentBB, RegionStartI, DT, LI, MSU, "region.guarded");
3178 
3179       assert(ParentBB->getUniqueSuccessor() == RegionStartBB &&
3180              "Expected a different CFG");
3181 
3182       BasicBlock *RegionCheckTidBB = SplitBlock(
3183           ParentBB, ParentBB->getTerminator(), DT, LI, MSU, "region.check.tid");
3184 
3185       // Register basic blocks with the Attributor.
3186       A.registerManifestAddedBasicBlock(*RegionEndBB);
3187       A.registerManifestAddedBasicBlock(*RegionBarrierBB);
3188       A.registerManifestAddedBasicBlock(*RegionExitBB);
3189       A.registerManifestAddedBasicBlock(*RegionStartBB);
3190       A.registerManifestAddedBasicBlock(*RegionCheckTidBB);
3191 
3192       bool HasBroadcastValues = false;
3193       // Find escaping outputs from the guarded region to outside users and
3194       // broadcast their values to them.
3195       for (Instruction &I : *RegionStartBB) {
3196         SmallPtrSet<Instruction *, 4> OutsideUsers;
3197         for (User *Usr : I.users()) {
3198           Instruction &UsrI = *cast<Instruction>(Usr);
3199           if (UsrI.getParent() != RegionStartBB)
3200             OutsideUsers.insert(&UsrI);
3201         }
3202 
3203         if (OutsideUsers.empty())
3204           continue;
3205 
3206         HasBroadcastValues = true;
3207 
3208         // Emit a global variable in shared memory to store the broadcasted
3209         // value.
3210         auto *SharedMem = new GlobalVariable(
3211             M, I.getType(), /* IsConstant */ false,
3212             GlobalValue::InternalLinkage, UndefValue::get(I.getType()),
3213             sanitizeForGlobalName(
3214                 (I.getName() + ".guarded.output.alloc").str()),
3215             nullptr, GlobalValue::NotThreadLocal,
3216             static_cast<unsigned>(AddressSpace::Shared));
3217 
3218         // Emit a store instruction to update the value.
3219         new StoreInst(&I, SharedMem, RegionEndBB->getTerminator());
3220 
3221         LoadInst *LoadI = new LoadInst(I.getType(), SharedMem,
3222                                        I.getName() + ".guarded.output.load",
3223                                        RegionBarrierBB->getTerminator());
3224 
3225         // Emit a load instruction and replace uses of the output value.
3226         for (Instruction *UsrI : OutsideUsers)
3227           UsrI->replaceUsesOfWith(&I, LoadI);
3228       }
3229 
3230       auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3231 
3232       // Go to tid check BB in ParentBB.
3233       const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc();
3234       ParentBB->getTerminator()->eraseFromParent();
3235       OpenMPIRBuilder::LocationDescription Loc(
3236           InsertPointTy(ParentBB, ParentBB->end()), DL);
3237       OMPInfoCache.OMPBuilder.updateToLocation(Loc);
3238       uint32_t SrcLocStrSize;
3239       auto *SrcLocStr =
3240           OMPInfoCache.OMPBuilder.getOrCreateSrcLocStr(Loc, SrcLocStrSize);
3241       Value *Ident =
3242           OMPInfoCache.OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize);
3243       BranchInst::Create(RegionCheckTidBB, ParentBB)->setDebugLoc(DL);
3244 
3245       // Add check for Tid in RegionCheckTidBB
3246       RegionCheckTidBB->getTerminator()->eraseFromParent();
3247       OpenMPIRBuilder::LocationDescription LocRegionCheckTid(
3248           InsertPointTy(RegionCheckTidBB, RegionCheckTidBB->end()), DL);
3249       OMPInfoCache.OMPBuilder.updateToLocation(LocRegionCheckTid);
3250       FunctionCallee HardwareTidFn =
3251           OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3252               M, OMPRTL___kmpc_get_hardware_thread_id_in_block);
3253       CallInst *Tid =
3254           OMPInfoCache.OMPBuilder.Builder.CreateCall(HardwareTidFn, {});
3255       Tid->setDebugLoc(DL);
3256       OMPInfoCache.setCallingConvention(HardwareTidFn, Tid);
3257       Value *TidCheck = OMPInfoCache.OMPBuilder.Builder.CreateIsNull(Tid);
3258       OMPInfoCache.OMPBuilder.Builder
3259           .CreateCondBr(TidCheck, RegionStartBB, RegionBarrierBB)
3260           ->setDebugLoc(DL);
3261 
3262       // First barrier for synchronization, ensures main thread has updated
3263       // values.
3264       FunctionCallee BarrierFn =
3265           OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3266               M, OMPRTL___kmpc_barrier_simple_spmd);
3267       OMPInfoCache.OMPBuilder.updateToLocation(InsertPointTy(
3268           RegionBarrierBB, RegionBarrierBB->getFirstInsertionPt()));
3269       CallInst *Barrier =
3270           OMPInfoCache.OMPBuilder.Builder.CreateCall(BarrierFn, {Ident, Tid});
3271       Barrier->setDebugLoc(DL);
3272       OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
3273 
3274       // Second barrier ensures workers have read broadcast values.
3275       if (HasBroadcastValues) {
3276         CallInst *Barrier = CallInst::Create(BarrierFn, {Ident, Tid}, "",
3277                                              RegionBarrierBB->getTerminator());
3278         Barrier->setDebugLoc(DL);
3279         OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
3280       }
3281     };
3282 
3283     auto &AllocSharedRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
3284     SmallPtrSet<BasicBlock *, 8> Visited;
3285     for (Instruction *GuardedI : SPMDCompatibilityTracker) {
3286       BasicBlock *BB = GuardedI->getParent();
3287       if (!Visited.insert(BB).second)
3288         continue;
3289 
3290       SmallVector<std::pair<Instruction *, Instruction *>> Reorders;
3291       Instruction *LastEffect = nullptr;
3292       BasicBlock::reverse_iterator IP = BB->rbegin(), IPEnd = BB->rend();
3293       while (++IP != IPEnd) {
3294         if (!IP->mayHaveSideEffects() && !IP->mayReadFromMemory())
3295           continue;
3296         Instruction *I = &*IP;
3297         if (OpenMPOpt::getCallIfRegularCall(*I, &AllocSharedRFI))
3298           continue;
3299         if (!I->user_empty() || !SPMDCompatibilityTracker.contains(I)) {
3300           LastEffect = nullptr;
3301           continue;
3302         }
3303         if (LastEffect)
3304           Reorders.push_back({I, LastEffect});
3305         LastEffect = &*IP;
3306       }
3307       for (auto &Reorder : Reorders)
3308         Reorder.first->moveBefore(Reorder.second);
3309     }
3310 
3311     SmallVector<std::pair<Instruction *, Instruction *>, 4> GuardedRegions;
3312 
3313     for (Instruction *GuardedI : SPMDCompatibilityTracker) {
3314       BasicBlock *BB = GuardedI->getParent();
3315       auto *CalleeAA = A.lookupAAFor<AAKernelInfo>(
3316           IRPosition::function(*GuardedI->getFunction()), nullptr,
3317           DepClassTy::NONE);
3318       assert(CalleeAA != nullptr && "Expected Callee AAKernelInfo");
3319       auto &CalleeAAFunction = *cast<AAKernelInfoFunction>(CalleeAA);
3320       // Continue if instruction is already guarded.
3321       if (CalleeAAFunction.getGuardedInstructions().contains(GuardedI))
3322         continue;
3323 
3324       Instruction *GuardedRegionStart = nullptr, *GuardedRegionEnd = nullptr;
3325       for (Instruction &I : *BB) {
3326         // If instruction I needs to be guarded update the guarded region
3327         // bounds.
3328         if (SPMDCompatibilityTracker.contains(&I)) {
3329           CalleeAAFunction.getGuardedInstructions().insert(&I);
3330           if (GuardedRegionStart)
3331             GuardedRegionEnd = &I;
3332           else
3333             GuardedRegionStart = GuardedRegionEnd = &I;
3334 
3335           continue;
3336         }
3337 
3338         // Instruction I does not need guarding, store
3339         // any region found and reset bounds.
3340         if (GuardedRegionStart) {
3341           GuardedRegions.push_back(
3342               std::make_pair(GuardedRegionStart, GuardedRegionEnd));
3343           GuardedRegionStart = nullptr;
3344           GuardedRegionEnd = nullptr;
3345         }
3346       }
3347     }
3348 
3349     for (auto &GR : GuardedRegions)
3350       CreateGuardedRegion(GR.first, GR.second);
3351 
3352     // Adjust the global exec mode flag that tells the runtime what mode this
3353     // kernel is executed in.
3354     assert(ExecModeVal == OMP_TGT_EXEC_MODE_GENERIC &&
3355            "Initially non-SPMD kernel has SPMD exec mode!");
3356     ExecMode->setInitializer(
3357         ConstantInt::get(ExecMode->getInitializer()->getType(),
3358                          ExecModeVal | OMP_TGT_EXEC_MODE_GENERIC_SPMD));
3359 
3360     // Next rewrite the init and deinit calls to indicate we use SPMD-mode now.
3361     const int InitModeArgNo = 1;
3362     const int DeinitModeArgNo = 1;
3363     const int InitUseStateMachineArgNo = 2;
3364     const int InitRequiresFullRuntimeArgNo = 3;
3365     const int DeinitRequiresFullRuntimeArgNo = 2;
3366 
3367     auto &Ctx = getAnchorValue().getContext();
3368     A.changeUseAfterManifest(
3369         KernelInitCB->getArgOperandUse(InitModeArgNo),
3370         *ConstantInt::getSigned(IntegerType::getInt8Ty(Ctx),
3371                                 OMP_TGT_EXEC_MODE_SPMD));
3372     A.changeUseAfterManifest(
3373         KernelInitCB->getArgOperandUse(InitUseStateMachineArgNo),
3374         *ConstantInt::getBool(Ctx, false));
3375     A.changeUseAfterManifest(
3376         KernelDeinitCB->getArgOperandUse(DeinitModeArgNo),
3377         *ConstantInt::getSigned(IntegerType::getInt8Ty(Ctx),
3378                                 OMP_TGT_EXEC_MODE_SPMD));
3379     A.changeUseAfterManifest(
3380         KernelInitCB->getArgOperandUse(InitRequiresFullRuntimeArgNo),
3381         *ConstantInt::getBool(Ctx, false));
3382     A.changeUseAfterManifest(
3383         KernelDeinitCB->getArgOperandUse(DeinitRequiresFullRuntimeArgNo),
3384         *ConstantInt::getBool(Ctx, false));
3385 
3386     ++NumOpenMPTargetRegionKernelsSPMD;
3387 
3388     auto Remark = [&](OptimizationRemark OR) {
3389       return OR << "Transformed generic-mode kernel to SPMD-mode.";
3390     };
3391     A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP120", Remark);
3392     return true;
3393   };
3394 
3395   ChangeStatus buildCustomStateMachine(Attributor &A) {
3396     // If we have disabled state machine rewrites, don't make a custom one
3397     if (DisableOpenMPOptStateMachineRewrite)
3398       return ChangeStatus::UNCHANGED;
3399 
3400     // Don't rewrite the state machine if we are not in a valid state.
3401     if (!ReachedKnownParallelRegions.isValidState())
3402       return ChangeStatus::UNCHANGED;
3403 
3404     const int InitModeArgNo = 1;
3405     const int InitUseStateMachineArgNo = 2;
3406 
3407     // Check if the current configuration is non-SPMD and generic state machine.
3408     // If we already have SPMD mode or a custom state machine we do not need to
3409     // go any further. If it is anything but a constant something is weird and
3410     // we give up.
3411     ConstantInt *UseStateMachine = dyn_cast<ConstantInt>(
3412         KernelInitCB->getArgOperand(InitUseStateMachineArgNo));
3413     ConstantInt *Mode =
3414         dyn_cast<ConstantInt>(KernelInitCB->getArgOperand(InitModeArgNo));
3415 
3416     // If we are stuck with generic mode, try to create a custom device (=GPU)
3417     // state machine which is specialized for the parallel regions that are
3418     // reachable by the kernel.
3419     if (!UseStateMachine || UseStateMachine->isZero() || !Mode ||
3420         (Mode->getSExtValue() & OMP_TGT_EXEC_MODE_SPMD))
3421       return ChangeStatus::UNCHANGED;
3422 
3423     // If not SPMD mode, indicate we use a custom state machine now.
3424     auto &Ctx = getAnchorValue().getContext();
3425     auto *FalseVal = ConstantInt::getBool(Ctx, false);
3426     A.changeUseAfterManifest(
3427         KernelInitCB->getArgOperandUse(InitUseStateMachineArgNo), *FalseVal);
3428 
3429     // If we don't actually need a state machine we are done here. This can
3430     // happen if there simply are no parallel regions. In the resulting kernel
3431     // all worker threads will simply exit right away, leaving the main thread
3432     // to do the work alone.
3433     if (!mayContainParallelRegion()) {
3434       ++NumOpenMPTargetRegionKernelsWithoutStateMachine;
3435 
3436       auto Remark = [&](OptimizationRemark OR) {
3437         return OR << "Removing unused state machine from generic-mode kernel.";
3438       };
3439       A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP130", Remark);
3440 
3441       return ChangeStatus::CHANGED;
3442     }
3443 
3444     // Keep track in the statistics of our new shiny custom state machine.
3445     if (ReachedUnknownParallelRegions.empty()) {
3446       ++NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback;
3447 
3448       auto Remark = [&](OptimizationRemark OR) {
3449         return OR << "Rewriting generic-mode kernel with a customized state "
3450                      "machine.";
3451       };
3452       A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP131", Remark);
3453     } else {
3454       ++NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback;
3455 
3456       auto Remark = [&](OptimizationRemarkAnalysis OR) {
3457         return OR << "Generic-mode kernel is executed with a customized state "
3458                      "machine that requires a fallback.";
3459       };
3460       A.emitRemark<OptimizationRemarkAnalysis>(KernelInitCB, "OMP132", Remark);
3461 
3462       // Tell the user why we ended up with a fallback.
3463       for (CallBase *UnknownParallelRegionCB : ReachedUnknownParallelRegions) {
3464         if (!UnknownParallelRegionCB)
3465           continue;
3466         auto Remark = [&](OptimizationRemarkAnalysis ORA) {
3467           return ORA << "Call may contain unknown parallel regions. Use "
3468                      << "`__attribute__((assume(\"omp_no_parallelism\")))` to "
3469                         "override.";
3470         };
3471         A.emitRemark<OptimizationRemarkAnalysis>(UnknownParallelRegionCB,
3472                                                  "OMP133", Remark);
3473       }
3474     }
3475 
3476     // Create all the blocks:
3477     //
3478     //                       InitCB = __kmpc_target_init(...)
3479     //                       BlockHwSize =
3480     //                         __kmpc_get_hardware_num_threads_in_block();
3481     //                       WarpSize = __kmpc_get_warp_size();
3482     //                       BlockSize = BlockHwSize - WarpSize;
3483     //                       if (InitCB >= BlockSize) return;
3484     // IsWorkerCheckBB:      bool IsWorker = InitCB >= 0;
3485     //                       if (IsWorker) {
3486     // SMBeginBB:               __kmpc_barrier_simple_generic(...);
3487     //                         void *WorkFn;
3488     //                         bool Active = __kmpc_kernel_parallel(&WorkFn);
3489     //                         if (!WorkFn) return;
3490     // SMIsActiveCheckBB:       if (Active) {
3491     // SMIfCascadeCurrentBB:      if      (WorkFn == <ParFn0>)
3492     //                              ParFn0(...);
3493     // SMIfCascadeCurrentBB:      else if (WorkFn == <ParFn1>)
3494     //                              ParFn1(...);
3495     //                            ...
3496     // SMIfCascadeCurrentBB:      else
3497     //                              ((WorkFnTy*)WorkFn)(...);
3498     // SMEndParallelBB:           __kmpc_kernel_end_parallel(...);
3499     //                          }
3500     // SMDoneBB:                __kmpc_barrier_simple_generic(...);
3501     //                          goto SMBeginBB;
3502     //                       }
3503     // UserCodeEntryBB:      // user code
3504     //                       __kmpc_target_deinit(...)
3505     //
3506     Function *Kernel = getAssociatedFunction();
3507     assert(Kernel && "Expected an associated function!");
3508 
3509     BasicBlock *InitBB = KernelInitCB->getParent();
3510     BasicBlock *UserCodeEntryBB = InitBB->splitBasicBlock(
3511         KernelInitCB->getNextNode(), "thread.user_code.check");
3512     BasicBlock *IsWorkerCheckBB =
3513         BasicBlock::Create(Ctx, "is_worker_check", Kernel, UserCodeEntryBB);
3514     BasicBlock *StateMachineBeginBB = BasicBlock::Create(
3515         Ctx, "worker_state_machine.begin", Kernel, UserCodeEntryBB);
3516     BasicBlock *StateMachineFinishedBB = BasicBlock::Create(
3517         Ctx, "worker_state_machine.finished", Kernel, UserCodeEntryBB);
3518     BasicBlock *StateMachineIsActiveCheckBB = BasicBlock::Create(
3519         Ctx, "worker_state_machine.is_active.check", Kernel, UserCodeEntryBB);
3520     BasicBlock *StateMachineIfCascadeCurrentBB =
3521         BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.check",
3522                            Kernel, UserCodeEntryBB);
3523     BasicBlock *StateMachineEndParallelBB =
3524         BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.end",
3525                            Kernel, UserCodeEntryBB);
3526     BasicBlock *StateMachineDoneBarrierBB = BasicBlock::Create(
3527         Ctx, "worker_state_machine.done.barrier", Kernel, UserCodeEntryBB);
3528     A.registerManifestAddedBasicBlock(*InitBB);
3529     A.registerManifestAddedBasicBlock(*UserCodeEntryBB);
3530     A.registerManifestAddedBasicBlock(*IsWorkerCheckBB);
3531     A.registerManifestAddedBasicBlock(*StateMachineBeginBB);
3532     A.registerManifestAddedBasicBlock(*StateMachineFinishedBB);
3533     A.registerManifestAddedBasicBlock(*StateMachineIsActiveCheckBB);
3534     A.registerManifestAddedBasicBlock(*StateMachineIfCascadeCurrentBB);
3535     A.registerManifestAddedBasicBlock(*StateMachineEndParallelBB);
3536     A.registerManifestAddedBasicBlock(*StateMachineDoneBarrierBB);
3537 
3538     const DebugLoc &DLoc = KernelInitCB->getDebugLoc();
3539     ReturnInst::Create(Ctx, StateMachineFinishedBB)->setDebugLoc(DLoc);
3540     InitBB->getTerminator()->eraseFromParent();
3541 
3542     Module &M = *Kernel->getParent();
3543     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3544     FunctionCallee BlockHwSizeFn =
3545         OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3546             M, OMPRTL___kmpc_get_hardware_num_threads_in_block);
3547     FunctionCallee WarpSizeFn =
3548         OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3549             M, OMPRTL___kmpc_get_warp_size);
3550     CallInst *BlockHwSize =
3551         CallInst::Create(BlockHwSizeFn, "block.hw_size", InitBB);
3552     OMPInfoCache.setCallingConvention(BlockHwSizeFn, BlockHwSize);
3553     BlockHwSize->setDebugLoc(DLoc);
3554     CallInst *WarpSize = CallInst::Create(WarpSizeFn, "warp.size", InitBB);
3555     OMPInfoCache.setCallingConvention(WarpSizeFn, WarpSize);
3556     WarpSize->setDebugLoc(DLoc);
3557     Instruction *BlockSize =
3558         BinaryOperator::CreateSub(BlockHwSize, WarpSize, "block.size", InitBB);
3559     BlockSize->setDebugLoc(DLoc);
3560     Instruction *IsMainOrWorker =
3561         ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_SLT, KernelInitCB,
3562                          BlockSize, "thread.is_main_or_worker", InitBB);
3563     IsMainOrWorker->setDebugLoc(DLoc);
3564     BranchInst::Create(IsWorkerCheckBB, StateMachineFinishedBB, IsMainOrWorker,
3565                        InitBB);
3566 
3567     Instruction *IsWorker =
3568         ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_NE, KernelInitCB,
3569                          ConstantInt::get(KernelInitCB->getType(), -1),
3570                          "thread.is_worker", IsWorkerCheckBB);
3571     IsWorker->setDebugLoc(DLoc);
3572     BranchInst::Create(StateMachineBeginBB, UserCodeEntryBB, IsWorker,
3573                        IsWorkerCheckBB);
3574 
3575     // Create local storage for the work function pointer.
3576     const DataLayout &DL = M.getDataLayout();
3577     Type *VoidPtrTy = Type::getInt8PtrTy(Ctx);
3578     Instruction *WorkFnAI =
3579         new AllocaInst(VoidPtrTy, DL.getAllocaAddrSpace(), nullptr,
3580                        "worker.work_fn.addr", &Kernel->getEntryBlock().front());
3581     WorkFnAI->setDebugLoc(DLoc);
3582 
3583     OMPInfoCache.OMPBuilder.updateToLocation(
3584         OpenMPIRBuilder::LocationDescription(
3585             IRBuilder<>::InsertPoint(StateMachineBeginBB,
3586                                      StateMachineBeginBB->end()),
3587             DLoc));
3588 
3589     Value *Ident = KernelInitCB->getArgOperand(0);
3590     Value *GTid = KernelInitCB;
3591 
3592     FunctionCallee BarrierFn =
3593         OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3594             M, OMPRTL___kmpc_barrier_simple_generic);
3595     CallInst *Barrier =
3596         CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineBeginBB);
3597     OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
3598     Barrier->setDebugLoc(DLoc);
3599 
3600     if (WorkFnAI->getType()->getPointerAddressSpace() !=
3601         (unsigned int)AddressSpace::Generic) {
3602       WorkFnAI = new AddrSpaceCastInst(
3603           WorkFnAI,
3604           PointerType::getWithSamePointeeType(
3605               cast<PointerType>(WorkFnAI->getType()),
3606               (unsigned int)AddressSpace::Generic),
3607           WorkFnAI->getName() + ".generic", StateMachineBeginBB);
3608       WorkFnAI->setDebugLoc(DLoc);
3609     }
3610 
3611     FunctionCallee KernelParallelFn =
3612         OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3613             M, OMPRTL___kmpc_kernel_parallel);
3614     CallInst *IsActiveWorker = CallInst::Create(
3615         KernelParallelFn, {WorkFnAI}, "worker.is_active", StateMachineBeginBB);
3616     OMPInfoCache.setCallingConvention(KernelParallelFn, IsActiveWorker);
3617     IsActiveWorker->setDebugLoc(DLoc);
3618     Instruction *WorkFn = new LoadInst(VoidPtrTy, WorkFnAI, "worker.work_fn",
3619                                        StateMachineBeginBB);
3620     WorkFn->setDebugLoc(DLoc);
3621 
3622     FunctionType *ParallelRegionFnTy = FunctionType::get(
3623         Type::getVoidTy(Ctx), {Type::getInt16Ty(Ctx), Type::getInt32Ty(Ctx)},
3624         false);
3625     Value *WorkFnCast = BitCastInst::CreatePointerBitCastOrAddrSpaceCast(
3626         WorkFn, ParallelRegionFnTy->getPointerTo(), "worker.work_fn.addr_cast",
3627         StateMachineBeginBB);
3628 
3629     Instruction *IsDone =
3630         ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFn,
3631                          Constant::getNullValue(VoidPtrTy), "worker.is_done",
3632                          StateMachineBeginBB);
3633     IsDone->setDebugLoc(DLoc);
3634     BranchInst::Create(StateMachineFinishedBB, StateMachineIsActiveCheckBB,
3635                        IsDone, StateMachineBeginBB)
3636         ->setDebugLoc(DLoc);
3637 
3638     BranchInst::Create(StateMachineIfCascadeCurrentBB,
3639                        StateMachineDoneBarrierBB, IsActiveWorker,
3640                        StateMachineIsActiveCheckBB)
3641         ->setDebugLoc(DLoc);
3642 
3643     Value *ZeroArg =
3644         Constant::getNullValue(ParallelRegionFnTy->getParamType(0));
3645 
3646     // Now that we have most of the CFG skeleton it is time for the if-cascade
3647     // that checks the function pointer we got from the runtime against the
3648     // parallel regions we expect, if there are any.
3649     for (int I = 0, E = ReachedKnownParallelRegions.size(); I < E; ++I) {
3650       auto *ParallelRegion = ReachedKnownParallelRegions[I];
3651       BasicBlock *PRExecuteBB = BasicBlock::Create(
3652           Ctx, "worker_state_machine.parallel_region.execute", Kernel,
3653           StateMachineEndParallelBB);
3654       CallInst::Create(ParallelRegion, {ZeroArg, GTid}, "", PRExecuteBB)
3655           ->setDebugLoc(DLoc);
3656       BranchInst::Create(StateMachineEndParallelBB, PRExecuteBB)
3657           ->setDebugLoc(DLoc);
3658 
3659       BasicBlock *PRNextBB =
3660           BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.check",
3661                              Kernel, StateMachineEndParallelBB);
3662 
3663       // Check if we need to compare the pointer at all or if we can just
3664       // call the parallel region function.
3665       Value *IsPR;
3666       if (I + 1 < E || !ReachedUnknownParallelRegions.empty()) {
3667         Instruction *CmpI = ICmpInst::Create(
3668             ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFnCast, ParallelRegion,
3669             "worker.check_parallel_region", StateMachineIfCascadeCurrentBB);
3670         CmpI->setDebugLoc(DLoc);
3671         IsPR = CmpI;
3672       } else {
3673         IsPR = ConstantInt::getTrue(Ctx);
3674       }
3675 
3676       BranchInst::Create(PRExecuteBB, PRNextBB, IsPR,
3677                          StateMachineIfCascadeCurrentBB)
3678           ->setDebugLoc(DLoc);
3679       StateMachineIfCascadeCurrentBB = PRNextBB;
3680     }
3681 
3682     // At the end of the if-cascade we place the indirect function pointer call
3683     // in case we might need it, that is if there can be parallel regions we
3684     // have not handled in the if-cascade above.
3685     if (!ReachedUnknownParallelRegions.empty()) {
3686       StateMachineIfCascadeCurrentBB->setName(
3687           "worker_state_machine.parallel_region.fallback.execute");
3688       CallInst::Create(ParallelRegionFnTy, WorkFnCast, {ZeroArg, GTid}, "",
3689                        StateMachineIfCascadeCurrentBB)
3690           ->setDebugLoc(DLoc);
3691     }
3692     BranchInst::Create(StateMachineEndParallelBB,
3693                        StateMachineIfCascadeCurrentBB)
3694         ->setDebugLoc(DLoc);
3695 
3696     FunctionCallee EndParallelFn =
3697         OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3698             M, OMPRTL___kmpc_kernel_end_parallel);
3699     CallInst *EndParallel =
3700         CallInst::Create(EndParallelFn, {}, "", StateMachineEndParallelBB);
3701     OMPInfoCache.setCallingConvention(EndParallelFn, EndParallel);
3702     EndParallel->setDebugLoc(DLoc);
3703     BranchInst::Create(StateMachineDoneBarrierBB, StateMachineEndParallelBB)
3704         ->setDebugLoc(DLoc);
3705 
3706     CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineDoneBarrierBB)
3707         ->setDebugLoc(DLoc);
3708     BranchInst::Create(StateMachineBeginBB, StateMachineDoneBarrierBB)
3709         ->setDebugLoc(DLoc);
3710 
3711     return ChangeStatus::CHANGED;
3712   }
3713 
3714   /// Fixpoint iteration update function. Will be called every time a dependence
3715   /// changed its state (and in the beginning).
3716   ChangeStatus updateImpl(Attributor &A) override {
3717     KernelInfoState StateBefore = getState();
3718 
3719     // Callback to check a read/write instruction.
3720     auto CheckRWInst = [&](Instruction &I) {
3721       // We handle calls later.
3722       if (isa<CallBase>(I))
3723         return true;
3724       // We only care about write effects.
3725       if (!I.mayWriteToMemory())
3726         return true;
3727       if (auto *SI = dyn_cast<StoreInst>(&I)) {
3728         SmallVector<const Value *> Objects;
3729         getUnderlyingObjects(SI->getPointerOperand(), Objects);
3730         if (llvm::all_of(Objects,
3731                          [](const Value *Obj) { return isa<AllocaInst>(Obj); }))
3732           return true;
3733         // Check for AAHeapToStack moved objects which must not be guarded.
3734         auto &HS = A.getAAFor<AAHeapToStack>(
3735             *this, IRPosition::function(*I.getFunction()),
3736             DepClassTy::OPTIONAL);
3737         if (llvm::all_of(Objects, [&HS](const Value *Obj) {
3738               auto *CB = dyn_cast<CallBase>(Obj);
3739               if (!CB)
3740                 return false;
3741               return HS.isAssumedHeapToStack(*CB);
3742             })) {
3743           return true;
3744         }
3745       }
3746 
3747       // Insert instruction that needs guarding.
3748       SPMDCompatibilityTracker.insert(&I);
3749       return true;
3750     };
3751 
3752     bool UsedAssumedInformationInCheckRWInst = false;
3753     if (!SPMDCompatibilityTracker.isAtFixpoint())
3754       if (!A.checkForAllReadWriteInstructions(
3755               CheckRWInst, *this, UsedAssumedInformationInCheckRWInst))
3756         SPMDCompatibilityTracker.indicatePessimisticFixpoint();
3757 
3758     bool UsedAssumedInformationFromReachingKernels = false;
3759     if (!IsKernelEntry) {
3760       updateParallelLevels(A);
3761 
3762       bool AllReachingKernelsKnown = true;
3763       updateReachingKernelEntries(A, AllReachingKernelsKnown);
3764       UsedAssumedInformationFromReachingKernels = !AllReachingKernelsKnown;
3765 
3766       if (!ParallelLevels.isValidState())
3767         SPMDCompatibilityTracker.indicatePessimisticFixpoint();
3768       else if (!ReachingKernelEntries.isValidState())
3769         SPMDCompatibilityTracker.indicatePessimisticFixpoint();
3770       else if (!SPMDCompatibilityTracker.empty()) {
3771         // Check if all reaching kernels agree on the mode as we can otherwise
3772         // not guard instructions. We might not be sure about the mode so we
3773         // we cannot fix the internal spmd-zation state either.
3774         int SPMD = 0, Generic = 0;
3775         for (auto *Kernel : ReachingKernelEntries) {
3776           auto &CBAA = A.getAAFor<AAKernelInfo>(
3777               *this, IRPosition::function(*Kernel), DepClassTy::OPTIONAL);
3778           if (CBAA.SPMDCompatibilityTracker.isValidState() &&
3779               CBAA.SPMDCompatibilityTracker.isAssumed())
3780             ++SPMD;
3781           else
3782             ++Generic;
3783           if (!CBAA.SPMDCompatibilityTracker.isAtFixpoint())
3784             UsedAssumedInformationFromReachingKernels = true;
3785         }
3786         if (SPMD != 0 && Generic != 0)
3787           SPMDCompatibilityTracker.indicatePessimisticFixpoint();
3788       }
3789     }
3790 
3791     // Callback to check a call instruction.
3792     bool AllParallelRegionStatesWereFixed = true;
3793     bool AllSPMDStatesWereFixed = true;
3794     auto CheckCallInst = [&](Instruction &I) {
3795       auto &CB = cast<CallBase>(I);
3796       auto &CBAA = A.getAAFor<AAKernelInfo>(
3797           *this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL);
3798       getState() ^= CBAA.getState();
3799       AllSPMDStatesWereFixed &= CBAA.SPMDCompatibilityTracker.isAtFixpoint();
3800       AllParallelRegionStatesWereFixed &=
3801           CBAA.ReachedKnownParallelRegions.isAtFixpoint();
3802       AllParallelRegionStatesWereFixed &=
3803           CBAA.ReachedUnknownParallelRegions.isAtFixpoint();
3804       return true;
3805     };
3806 
3807     bool UsedAssumedInformationInCheckCallInst = false;
3808     if (!A.checkForAllCallLikeInstructions(
3809             CheckCallInst, *this, UsedAssumedInformationInCheckCallInst)) {
3810       LLVM_DEBUG(dbgs() << TAG
3811                         << "Failed to visit all call-like instructions!\n";);
3812       return indicatePessimisticFixpoint();
3813     }
3814 
3815     // If we haven't used any assumed information for the reached parallel
3816     // region states we can fix it.
3817     if (!UsedAssumedInformationInCheckCallInst &&
3818         AllParallelRegionStatesWereFixed) {
3819       ReachedKnownParallelRegions.indicateOptimisticFixpoint();
3820       ReachedUnknownParallelRegions.indicateOptimisticFixpoint();
3821     }
3822 
3823     // If we are sure there are no parallel regions in the kernel we do not
3824     // want SPMD mode.
3825     if (IsKernelEntry && ReachedUnknownParallelRegions.isAtFixpoint() &&
3826         ReachedKnownParallelRegions.isAtFixpoint() &&
3827         ReachedUnknownParallelRegions.isValidState() &&
3828         ReachedKnownParallelRegions.isValidState() &&
3829         !mayContainParallelRegion())
3830       SPMDCompatibilityTracker.indicatePessimisticFixpoint();
3831 
3832     // If we haven't used any assumed information for the SPMD state we can fix
3833     // it.
3834     if (!UsedAssumedInformationInCheckRWInst &&
3835         !UsedAssumedInformationInCheckCallInst &&
3836         !UsedAssumedInformationFromReachingKernels && AllSPMDStatesWereFixed)
3837       SPMDCompatibilityTracker.indicateOptimisticFixpoint();
3838 
3839     return StateBefore == getState() ? ChangeStatus::UNCHANGED
3840                                      : ChangeStatus::CHANGED;
3841   }
3842 
3843 private:
3844   /// Update info regarding reaching kernels.
3845   void updateReachingKernelEntries(Attributor &A,
3846                                    bool &AllReachingKernelsKnown) {
3847     auto PredCallSite = [&](AbstractCallSite ACS) {
3848       Function *Caller = ACS.getInstruction()->getFunction();
3849 
3850       assert(Caller && "Caller is nullptr");
3851 
3852       auto &CAA = A.getOrCreateAAFor<AAKernelInfo>(
3853           IRPosition::function(*Caller), this, DepClassTy::REQUIRED);
3854       if (CAA.ReachingKernelEntries.isValidState()) {
3855         ReachingKernelEntries ^= CAA.ReachingKernelEntries;
3856         return true;
3857       }
3858 
3859       // We lost track of the caller of the associated function, any kernel
3860       // could reach now.
3861       ReachingKernelEntries.indicatePessimisticFixpoint();
3862 
3863       return true;
3864     };
3865 
3866     if (!A.checkForAllCallSites(PredCallSite, *this,
3867                                 true /* RequireAllCallSites */,
3868                                 AllReachingKernelsKnown))
3869       ReachingKernelEntries.indicatePessimisticFixpoint();
3870   }
3871 
3872   /// Update info regarding parallel levels.
3873   void updateParallelLevels(Attributor &A) {
3874     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3875     OMPInformationCache::RuntimeFunctionInfo &Parallel51RFI =
3876         OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
3877 
3878     auto PredCallSite = [&](AbstractCallSite ACS) {
3879       Function *Caller = ACS.getInstruction()->getFunction();
3880 
3881       assert(Caller && "Caller is nullptr");
3882 
3883       auto &CAA =
3884           A.getOrCreateAAFor<AAKernelInfo>(IRPosition::function(*Caller));
3885       if (CAA.ParallelLevels.isValidState()) {
3886         // Any function that is called by `__kmpc_parallel_51` will not be
3887         // folded as the parallel level in the function is updated. In order to
3888         // get it right, all the analysis would depend on the implentation. That
3889         // said, if in the future any change to the implementation, the analysis
3890         // could be wrong. As a consequence, we are just conservative here.
3891         if (Caller == Parallel51RFI.Declaration) {
3892           ParallelLevels.indicatePessimisticFixpoint();
3893           return true;
3894         }
3895 
3896         ParallelLevels ^= CAA.ParallelLevels;
3897 
3898         return true;
3899       }
3900 
3901       // We lost track of the caller of the associated function, any kernel
3902       // could reach now.
3903       ParallelLevels.indicatePessimisticFixpoint();
3904 
3905       return true;
3906     };
3907 
3908     bool AllCallSitesKnown = true;
3909     if (!A.checkForAllCallSites(PredCallSite, *this,
3910                                 true /* RequireAllCallSites */,
3911                                 AllCallSitesKnown))
3912       ParallelLevels.indicatePessimisticFixpoint();
3913   }
3914 };
3915 
3916 /// The call site kernel info abstract attribute, basically, what can we say
3917 /// about a call site with regards to the KernelInfoState. For now this simply
3918 /// forwards the information from the callee.
3919 struct AAKernelInfoCallSite : AAKernelInfo {
3920   AAKernelInfoCallSite(const IRPosition &IRP, Attributor &A)
3921       : AAKernelInfo(IRP, A) {}
3922 
3923   /// See AbstractAttribute::initialize(...).
3924   void initialize(Attributor &A) override {
3925     AAKernelInfo::initialize(A);
3926 
3927     CallBase &CB = cast<CallBase>(getAssociatedValue());
3928     Function *Callee = getAssociatedFunction();
3929 
3930     auto &AssumptionAA = A.getAAFor<AAAssumptionInfo>(
3931         *this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL);
3932 
3933     // Check for SPMD-mode assumptions.
3934     if (AssumptionAA.hasAssumption("ompx_spmd_amenable")) {
3935       SPMDCompatibilityTracker.indicateOptimisticFixpoint();
3936       indicateOptimisticFixpoint();
3937     }
3938 
3939     // First weed out calls we do not care about, that is readonly/readnone
3940     // calls, intrinsics, and "no_openmp" calls. Neither of these can reach a
3941     // parallel region or anything else we are looking for.
3942     if (!CB.mayWriteToMemory() || isa<IntrinsicInst>(CB)) {
3943       indicateOptimisticFixpoint();
3944       return;
3945     }
3946 
3947     // Next we check if we know the callee. If it is a known OpenMP function
3948     // we will handle them explicitly in the switch below. If it is not, we
3949     // will use an AAKernelInfo object on the callee to gather information and
3950     // merge that into the current state. The latter happens in the updateImpl.
3951     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3952     const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
3953     if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
3954       // Unknown caller or declarations are not analyzable, we give up.
3955       if (!Callee || !A.isFunctionIPOAmendable(*Callee)) {
3956 
3957         // Unknown callees might contain parallel regions, except if they have
3958         // an appropriate assumption attached.
3959         if (!(AssumptionAA.hasAssumption("omp_no_openmp") ||
3960               AssumptionAA.hasAssumption("omp_no_parallelism")))
3961           ReachedUnknownParallelRegions.insert(&CB);
3962 
3963         // If SPMDCompatibilityTracker is not fixed, we need to give up on the
3964         // idea we can run something unknown in SPMD-mode.
3965         if (!SPMDCompatibilityTracker.isAtFixpoint()) {
3966           SPMDCompatibilityTracker.indicatePessimisticFixpoint();
3967           SPMDCompatibilityTracker.insert(&CB);
3968         }
3969 
3970         // We have updated the state for this unknown call properly, there won't
3971         // be any change so we indicate a fixpoint.
3972         indicateOptimisticFixpoint();
3973       }
3974       // If the callee is known and can be used in IPO, we will update the state
3975       // based on the callee state in updateImpl.
3976       return;
3977     }
3978 
3979     const unsigned int WrapperFunctionArgNo = 6;
3980     RuntimeFunction RF = It->getSecond();
3981     switch (RF) {
3982     // All the functions we know are compatible with SPMD mode.
3983     case OMPRTL___kmpc_is_spmd_exec_mode:
3984     case OMPRTL___kmpc_distribute_static_fini:
3985     case OMPRTL___kmpc_for_static_fini:
3986     case OMPRTL___kmpc_global_thread_num:
3987     case OMPRTL___kmpc_get_hardware_num_threads_in_block:
3988     case OMPRTL___kmpc_get_hardware_num_blocks:
3989     case OMPRTL___kmpc_single:
3990     case OMPRTL___kmpc_end_single:
3991     case OMPRTL___kmpc_master:
3992     case OMPRTL___kmpc_end_master:
3993     case OMPRTL___kmpc_barrier:
3994     case OMPRTL___kmpc_nvptx_parallel_reduce_nowait_v2:
3995     case OMPRTL___kmpc_nvptx_teams_reduce_nowait_v2:
3996     case OMPRTL___kmpc_nvptx_end_reduce_nowait:
3997       break;
3998     case OMPRTL___kmpc_distribute_static_init_4:
3999     case OMPRTL___kmpc_distribute_static_init_4u:
4000     case OMPRTL___kmpc_distribute_static_init_8:
4001     case OMPRTL___kmpc_distribute_static_init_8u:
4002     case OMPRTL___kmpc_for_static_init_4:
4003     case OMPRTL___kmpc_for_static_init_4u:
4004     case OMPRTL___kmpc_for_static_init_8:
4005     case OMPRTL___kmpc_for_static_init_8u: {
4006       // Check the schedule and allow static schedule in SPMD mode.
4007       unsigned ScheduleArgOpNo = 2;
4008       auto *ScheduleTypeCI =
4009           dyn_cast<ConstantInt>(CB.getArgOperand(ScheduleArgOpNo));
4010       unsigned ScheduleTypeVal =
4011           ScheduleTypeCI ? ScheduleTypeCI->getZExtValue() : 0;
4012       switch (OMPScheduleType(ScheduleTypeVal)) {
4013       case OMPScheduleType::Static:
4014       case OMPScheduleType::StaticChunked:
4015       case OMPScheduleType::Distribute:
4016       case OMPScheduleType::DistributeChunked:
4017         break;
4018       default:
4019         SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4020         SPMDCompatibilityTracker.insert(&CB);
4021         break;
4022       };
4023     } break;
4024     case OMPRTL___kmpc_target_init:
4025       KernelInitCB = &CB;
4026       break;
4027     case OMPRTL___kmpc_target_deinit:
4028       KernelDeinitCB = &CB;
4029       break;
4030     case OMPRTL___kmpc_parallel_51:
4031       if (auto *ParallelRegion = dyn_cast<Function>(
4032               CB.getArgOperand(WrapperFunctionArgNo)->stripPointerCasts())) {
4033         ReachedKnownParallelRegions.insert(ParallelRegion);
4034         break;
4035       }
4036       // The condition above should usually get the parallel region function
4037       // pointer and record it. In the off chance it doesn't we assume the
4038       // worst.
4039       ReachedUnknownParallelRegions.insert(&CB);
4040       break;
4041     case OMPRTL___kmpc_omp_task:
4042       // We do not look into tasks right now, just give up.
4043       SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4044       SPMDCompatibilityTracker.insert(&CB);
4045       ReachedUnknownParallelRegions.insert(&CB);
4046       break;
4047     case OMPRTL___kmpc_alloc_shared:
4048     case OMPRTL___kmpc_free_shared:
4049       // Return without setting a fixpoint, to be resolved in updateImpl.
4050       return;
4051     default:
4052       // Unknown OpenMP runtime calls cannot be executed in SPMD-mode,
4053       // generally. However, they do not hide parallel regions.
4054       SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4055       SPMDCompatibilityTracker.insert(&CB);
4056       break;
4057     }
4058     // All other OpenMP runtime calls will not reach parallel regions so they
4059     // can be safely ignored for now. Since it is a known OpenMP runtime call we
4060     // have now modeled all effects and there is no need for any update.
4061     indicateOptimisticFixpoint();
4062   }
4063 
4064   ChangeStatus updateImpl(Attributor &A) override {
4065     // TODO: Once we have call site specific value information we can provide
4066     //       call site specific liveness information and then it makes
4067     //       sense to specialize attributes for call sites arguments instead of
4068     //       redirecting requests to the callee argument.
4069     Function *F = getAssociatedFunction();
4070 
4071     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4072     const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(F);
4073 
4074     // If F is not a runtime function, propagate the AAKernelInfo of the callee.
4075     if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
4076       const IRPosition &FnPos = IRPosition::function(*F);
4077       auto &FnAA = A.getAAFor<AAKernelInfo>(*this, FnPos, DepClassTy::REQUIRED);
4078       if (getState() == FnAA.getState())
4079         return ChangeStatus::UNCHANGED;
4080       getState() = FnAA.getState();
4081       return ChangeStatus::CHANGED;
4082     }
4083 
4084     // F is a runtime function that allocates or frees memory, check
4085     // AAHeapToStack and AAHeapToShared.
4086     KernelInfoState StateBefore = getState();
4087     assert((It->getSecond() == OMPRTL___kmpc_alloc_shared ||
4088             It->getSecond() == OMPRTL___kmpc_free_shared) &&
4089            "Expected a __kmpc_alloc_shared or __kmpc_free_shared runtime call");
4090 
4091     CallBase &CB = cast<CallBase>(getAssociatedValue());
4092 
4093     auto &HeapToStackAA = A.getAAFor<AAHeapToStack>(
4094         *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL);
4095     auto &HeapToSharedAA = A.getAAFor<AAHeapToShared>(
4096         *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL);
4097 
4098     RuntimeFunction RF = It->getSecond();
4099 
4100     switch (RF) {
4101     // If neither HeapToStack nor HeapToShared assume the call is removed,
4102     // assume SPMD incompatibility.
4103     case OMPRTL___kmpc_alloc_shared:
4104       if (!HeapToStackAA.isAssumedHeapToStack(CB) &&
4105           !HeapToSharedAA.isAssumedHeapToShared(CB))
4106         SPMDCompatibilityTracker.insert(&CB);
4107       break;
4108     case OMPRTL___kmpc_free_shared:
4109       if (!HeapToStackAA.isAssumedHeapToStackRemovedFree(CB) &&
4110           !HeapToSharedAA.isAssumedHeapToSharedRemovedFree(CB))
4111         SPMDCompatibilityTracker.insert(&CB);
4112       break;
4113     default:
4114       SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4115       SPMDCompatibilityTracker.insert(&CB);
4116     }
4117 
4118     return StateBefore == getState() ? ChangeStatus::UNCHANGED
4119                                      : ChangeStatus::CHANGED;
4120   }
4121 };
4122 
4123 struct AAFoldRuntimeCall
4124     : public StateWrapper<BooleanState, AbstractAttribute> {
4125   using Base = StateWrapper<BooleanState, AbstractAttribute>;
4126 
4127   AAFoldRuntimeCall(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
4128 
4129   /// Statistics are tracked as part of manifest for now.
4130   void trackStatistics() const override {}
4131 
4132   /// Create an abstract attribute biew for the position \p IRP.
4133   static AAFoldRuntimeCall &createForPosition(const IRPosition &IRP,
4134                                               Attributor &A);
4135 
4136   /// See AbstractAttribute::getName()
4137   const std::string getName() const override { return "AAFoldRuntimeCall"; }
4138 
4139   /// See AbstractAttribute::getIdAddr()
4140   const char *getIdAddr() const override { return &ID; }
4141 
4142   /// This function should return true if the type of the \p AA is
4143   /// AAFoldRuntimeCall
4144   static bool classof(const AbstractAttribute *AA) {
4145     return (AA->getIdAddr() == &ID);
4146   }
4147 
4148   static const char ID;
4149 };
4150 
4151 struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
4152   AAFoldRuntimeCallCallSiteReturned(const IRPosition &IRP, Attributor &A)
4153       : AAFoldRuntimeCall(IRP, A) {}
4154 
4155   /// See AbstractAttribute::getAsStr()
4156   const std::string getAsStr() const override {
4157     if (!isValidState())
4158       return "<invalid>";
4159 
4160     std::string Str("simplified value: ");
4161 
4162     if (!SimplifiedValue.hasValue())
4163       return Str + std::string("none");
4164 
4165     if (!SimplifiedValue.getValue())
4166       return Str + std::string("nullptr");
4167 
4168     if (ConstantInt *CI = dyn_cast<ConstantInt>(SimplifiedValue.getValue()))
4169       return Str + std::to_string(CI->getSExtValue());
4170 
4171     return Str + std::string("unknown");
4172   }
4173 
4174   void initialize(Attributor &A) override {
4175     if (DisableOpenMPOptFolding)
4176       indicatePessimisticFixpoint();
4177 
4178     Function *Callee = getAssociatedFunction();
4179 
4180     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4181     const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
4182     assert(It != OMPInfoCache.RuntimeFunctionIDMap.end() &&
4183            "Expected a known OpenMP runtime function");
4184 
4185     RFKind = It->getSecond();
4186 
4187     CallBase &CB = cast<CallBase>(getAssociatedValue());
4188     A.registerSimplificationCallback(
4189         IRPosition::callsite_returned(CB),
4190         [&](const IRPosition &IRP, const AbstractAttribute *AA,
4191             bool &UsedAssumedInformation) -> Optional<Value *> {
4192           assert((isValidState() || (SimplifiedValue.hasValue() &&
4193                                      SimplifiedValue.getValue() == nullptr)) &&
4194                  "Unexpected invalid state!");
4195 
4196           if (!isAtFixpoint()) {
4197             UsedAssumedInformation = true;
4198             if (AA)
4199               A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
4200           }
4201           return SimplifiedValue;
4202         });
4203   }
4204 
4205   ChangeStatus updateImpl(Attributor &A) override {
4206     ChangeStatus Changed = ChangeStatus::UNCHANGED;
4207     switch (RFKind) {
4208     case OMPRTL___kmpc_is_spmd_exec_mode:
4209       Changed |= foldIsSPMDExecMode(A);
4210       break;
4211     case OMPRTL___kmpc_is_generic_main_thread_id:
4212       Changed |= foldIsGenericMainThread(A);
4213       break;
4214     case OMPRTL___kmpc_parallel_level:
4215       Changed |= foldParallelLevel(A);
4216       break;
4217     case OMPRTL___kmpc_get_hardware_num_threads_in_block:
4218       Changed = Changed | foldKernelFnAttribute(A, "omp_target_thread_limit");
4219       break;
4220     case OMPRTL___kmpc_get_hardware_num_blocks:
4221       Changed = Changed | foldKernelFnAttribute(A, "omp_target_num_teams");
4222       break;
4223     default:
4224       llvm_unreachable("Unhandled OpenMP runtime function!");
4225     }
4226 
4227     return Changed;
4228   }
4229 
4230   ChangeStatus manifest(Attributor &A) override {
4231     ChangeStatus Changed = ChangeStatus::UNCHANGED;
4232 
4233     if (SimplifiedValue.hasValue() && SimplifiedValue.getValue()) {
4234       Instruction &I = *getCtxI();
4235       A.changeValueAfterManifest(I, **SimplifiedValue);
4236       A.deleteAfterManifest(I);
4237 
4238       CallBase *CB = dyn_cast<CallBase>(&I);
4239       auto Remark = [&](OptimizationRemark OR) {
4240         if (auto *C = dyn_cast<ConstantInt>(*SimplifiedValue))
4241           return OR << "Replacing OpenMP runtime call "
4242                     << CB->getCalledFunction()->getName() << " with "
4243                     << ore::NV("FoldedValue", C->getZExtValue()) << ".";
4244         return OR << "Replacing OpenMP runtime call "
4245                   << CB->getCalledFunction()->getName() << ".";
4246       };
4247 
4248       if (CB && EnableVerboseRemarks)
4249         A.emitRemark<OptimizationRemark>(CB, "OMP180", Remark);
4250 
4251       LLVM_DEBUG(dbgs() << TAG << "Replacing runtime call: " << I << " with "
4252                         << **SimplifiedValue << "\n");
4253 
4254       Changed = ChangeStatus::CHANGED;
4255     }
4256 
4257     return Changed;
4258   }
4259 
4260   ChangeStatus indicatePessimisticFixpoint() override {
4261     SimplifiedValue = nullptr;
4262     return AAFoldRuntimeCall::indicatePessimisticFixpoint();
4263   }
4264 
4265 private:
4266   /// Fold __kmpc_is_spmd_exec_mode into a constant if possible.
4267   ChangeStatus foldIsSPMDExecMode(Attributor &A) {
4268     Optional<Value *> SimplifiedValueBefore = SimplifiedValue;
4269 
4270     unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
4271     unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
4272     auto &CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
4273         *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
4274 
4275     if (!CallerKernelInfoAA.ReachingKernelEntries.isValidState())
4276       return indicatePessimisticFixpoint();
4277 
4278     for (Kernel K : CallerKernelInfoAA.ReachingKernelEntries) {
4279       auto &AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
4280                                           DepClassTy::REQUIRED);
4281 
4282       if (!AA.isValidState()) {
4283         SimplifiedValue = nullptr;
4284         return indicatePessimisticFixpoint();
4285       }
4286 
4287       if (AA.SPMDCompatibilityTracker.isAssumed()) {
4288         if (AA.SPMDCompatibilityTracker.isAtFixpoint())
4289           ++KnownSPMDCount;
4290         else
4291           ++AssumedSPMDCount;
4292       } else {
4293         if (AA.SPMDCompatibilityTracker.isAtFixpoint())
4294           ++KnownNonSPMDCount;
4295         else
4296           ++AssumedNonSPMDCount;
4297       }
4298     }
4299 
4300     if ((AssumedSPMDCount + KnownSPMDCount) &&
4301         (AssumedNonSPMDCount + KnownNonSPMDCount))
4302       return indicatePessimisticFixpoint();
4303 
4304     auto &Ctx = getAnchorValue().getContext();
4305     if (KnownSPMDCount || AssumedSPMDCount) {
4306       assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 &&
4307              "Expected only SPMD kernels!");
4308       // All reaching kernels are in SPMD mode. Update all function calls to
4309       // __kmpc_is_spmd_exec_mode to 1.
4310       SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), true);
4311     } else if (KnownNonSPMDCount || AssumedNonSPMDCount) {
4312       assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 &&
4313              "Expected only non-SPMD kernels!");
4314       // All reaching kernels are in non-SPMD mode. Update all function
4315       // calls to __kmpc_is_spmd_exec_mode to 0.
4316       SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), false);
4317     } else {
4318       // We have empty reaching kernels, therefore we cannot tell if the
4319       // associated call site can be folded. At this moment, SimplifiedValue
4320       // must be none.
4321       assert(!SimplifiedValue.hasValue() && "SimplifiedValue should be none");
4322     }
4323 
4324     return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
4325                                                     : ChangeStatus::CHANGED;
4326   }
4327 
4328   /// Fold __kmpc_is_generic_main_thread_id into a constant if possible.
4329   ChangeStatus foldIsGenericMainThread(Attributor &A) {
4330     Optional<Value *> SimplifiedValueBefore = SimplifiedValue;
4331 
4332     CallBase &CB = cast<CallBase>(getAssociatedValue());
4333     Function *F = CB.getFunction();
4334     const auto &ExecutionDomainAA = A.getAAFor<AAExecutionDomain>(
4335         *this, IRPosition::function(*F), DepClassTy::REQUIRED);
4336 
4337     if (!ExecutionDomainAA.isValidState())
4338       return indicatePessimisticFixpoint();
4339 
4340     auto &Ctx = getAnchorValue().getContext();
4341     if (ExecutionDomainAA.isExecutedByInitialThreadOnly(CB))
4342       SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), true);
4343     else
4344       return indicatePessimisticFixpoint();
4345 
4346     return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
4347                                                     : ChangeStatus::CHANGED;
4348   }
4349 
4350   /// Fold __kmpc_parallel_level into a constant if possible.
4351   ChangeStatus foldParallelLevel(Attributor &A) {
4352     Optional<Value *> SimplifiedValueBefore = SimplifiedValue;
4353 
4354     auto &CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
4355         *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
4356 
4357     if (!CallerKernelInfoAA.ParallelLevels.isValidState())
4358       return indicatePessimisticFixpoint();
4359 
4360     if (!CallerKernelInfoAA.ReachingKernelEntries.isValidState())
4361       return indicatePessimisticFixpoint();
4362 
4363     if (CallerKernelInfoAA.ReachingKernelEntries.empty()) {
4364       assert(!SimplifiedValue.hasValue() &&
4365              "SimplifiedValue should keep none at this point");
4366       return ChangeStatus::UNCHANGED;
4367     }
4368 
4369     unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
4370     unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
4371     for (Kernel K : CallerKernelInfoAA.ReachingKernelEntries) {
4372       auto &AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
4373                                           DepClassTy::REQUIRED);
4374       if (!AA.SPMDCompatibilityTracker.isValidState())
4375         return indicatePessimisticFixpoint();
4376 
4377       if (AA.SPMDCompatibilityTracker.isAssumed()) {
4378         if (AA.SPMDCompatibilityTracker.isAtFixpoint())
4379           ++KnownSPMDCount;
4380         else
4381           ++AssumedSPMDCount;
4382       } else {
4383         if (AA.SPMDCompatibilityTracker.isAtFixpoint())
4384           ++KnownNonSPMDCount;
4385         else
4386           ++AssumedNonSPMDCount;
4387       }
4388     }
4389 
4390     if ((AssumedSPMDCount + KnownSPMDCount) &&
4391         (AssumedNonSPMDCount + KnownNonSPMDCount))
4392       return indicatePessimisticFixpoint();
4393 
4394     auto &Ctx = getAnchorValue().getContext();
4395     // If the caller can only be reached by SPMD kernel entries, the parallel
4396     // level is 1. Similarly, if the caller can only be reached by non-SPMD
4397     // kernel entries, it is 0.
4398     if (AssumedSPMDCount || KnownSPMDCount) {
4399       assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 &&
4400              "Expected only SPMD kernels!");
4401       SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 1);
4402     } else {
4403       assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 &&
4404              "Expected only non-SPMD kernels!");
4405       SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 0);
4406     }
4407     return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
4408                                                     : ChangeStatus::CHANGED;
4409   }
4410 
4411   ChangeStatus foldKernelFnAttribute(Attributor &A, llvm::StringRef Attr) {
4412     // Specialize only if all the calls agree with the attribute constant value
4413     int32_t CurrentAttrValue = -1;
4414     Optional<Value *> SimplifiedValueBefore = SimplifiedValue;
4415 
4416     auto &CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
4417         *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
4418 
4419     if (!CallerKernelInfoAA.ReachingKernelEntries.isValidState())
4420       return indicatePessimisticFixpoint();
4421 
4422     // Iterate over the kernels that reach this function
4423     for (Kernel K : CallerKernelInfoAA.ReachingKernelEntries) {
4424       int32_t NextAttrVal = -1;
4425       if (K->hasFnAttribute(Attr))
4426         NextAttrVal =
4427             std::stoi(K->getFnAttribute(Attr).getValueAsString().str());
4428 
4429       if (NextAttrVal == -1 ||
4430           (CurrentAttrValue != -1 && CurrentAttrValue != NextAttrVal))
4431         return indicatePessimisticFixpoint();
4432       CurrentAttrValue = NextAttrVal;
4433     }
4434 
4435     if (CurrentAttrValue != -1) {
4436       auto &Ctx = getAnchorValue().getContext();
4437       SimplifiedValue =
4438           ConstantInt::get(Type::getInt32Ty(Ctx), CurrentAttrValue);
4439     }
4440     return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
4441                                                     : ChangeStatus::CHANGED;
4442   }
4443 
4444   /// An optional value the associated value is assumed to fold to. That is, we
4445   /// assume the associated value (which is a call) can be replaced by this
4446   /// simplified value.
4447   Optional<Value *> SimplifiedValue;
4448 
4449   /// The runtime function kind of the callee of the associated call site.
4450   RuntimeFunction RFKind;
4451 };
4452 
4453 } // namespace
4454 
4455 /// Register folding callsite
4456 void OpenMPOpt::registerFoldRuntimeCall(RuntimeFunction RF) {
4457   auto &RFI = OMPInfoCache.RFIs[RF];
4458   RFI.foreachUse(SCC, [&](Use &U, Function &F) {
4459     CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &RFI);
4460     if (!CI)
4461       return false;
4462     A.getOrCreateAAFor<AAFoldRuntimeCall>(
4463         IRPosition::callsite_returned(*CI), /* QueryingAA */ nullptr,
4464         DepClassTy::NONE, /* ForceUpdate */ false,
4465         /* UpdateAfterInit */ false);
4466     return false;
4467   });
4468 }
4469 
4470 void OpenMPOpt::registerAAs(bool IsModulePass) {
4471   if (SCC.empty())
4472 
4473     return;
4474   if (IsModulePass) {
4475     // Ensure we create the AAKernelInfo AAs first and without triggering an
4476     // update. This will make sure we register all value simplification
4477     // callbacks before any other AA has the chance to create an AAValueSimplify
4478     // or similar.
4479     for (Function *Kernel : OMPInfoCache.Kernels)
4480       A.getOrCreateAAFor<AAKernelInfo>(
4481           IRPosition::function(*Kernel), /* QueryingAA */ nullptr,
4482           DepClassTy::NONE, /* ForceUpdate */ false,
4483           /* UpdateAfterInit */ false);
4484 
4485     registerFoldRuntimeCall(OMPRTL___kmpc_is_generic_main_thread_id);
4486     registerFoldRuntimeCall(OMPRTL___kmpc_is_spmd_exec_mode);
4487     registerFoldRuntimeCall(OMPRTL___kmpc_parallel_level);
4488     registerFoldRuntimeCall(OMPRTL___kmpc_get_hardware_num_threads_in_block);
4489     registerFoldRuntimeCall(OMPRTL___kmpc_get_hardware_num_blocks);
4490   }
4491 
4492   // Create CallSite AA for all Getters.
4493   for (int Idx = 0; Idx < OMPInfoCache.ICVs.size() - 1; ++Idx) {
4494     auto ICVInfo = OMPInfoCache.ICVs[static_cast<InternalControlVar>(Idx)];
4495 
4496     auto &GetterRFI = OMPInfoCache.RFIs[ICVInfo.Getter];
4497 
4498     auto CreateAA = [&](Use &U, Function &Caller) {
4499       CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &GetterRFI);
4500       if (!CI)
4501         return false;
4502 
4503       auto &CB = cast<CallBase>(*CI);
4504 
4505       IRPosition CBPos = IRPosition::callsite_function(CB);
4506       A.getOrCreateAAFor<AAICVTracker>(CBPos);
4507       return false;
4508     };
4509 
4510     GetterRFI.foreachUse(SCC, CreateAA);
4511   }
4512   auto &GlobalizationRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
4513   auto CreateAA = [&](Use &U, Function &F) {
4514     A.getOrCreateAAFor<AAHeapToShared>(IRPosition::function(F));
4515     return false;
4516   };
4517   if (!DisableOpenMPOptDeglobalization)
4518     GlobalizationRFI.foreachUse(SCC, CreateAA);
4519 
4520   // Create an ExecutionDomain AA for every function and a HeapToStack AA for
4521   // every function if there is a device kernel.
4522   if (!isOpenMPDevice(M))
4523     return;
4524 
4525   for (auto *F : SCC) {
4526     if (F->isDeclaration())
4527       continue;
4528 
4529     A.getOrCreateAAFor<AAExecutionDomain>(IRPosition::function(*F));
4530     if (!DisableOpenMPOptDeglobalization)
4531       A.getOrCreateAAFor<AAHeapToStack>(IRPosition::function(*F));
4532 
4533     for (auto &I : instructions(*F)) {
4534       if (auto *LI = dyn_cast<LoadInst>(&I)) {
4535         bool UsedAssumedInformation = false;
4536         A.getAssumedSimplified(IRPosition::value(*LI), /* AA */ nullptr,
4537                                UsedAssumedInformation);
4538       } else if (auto *SI = dyn_cast<StoreInst>(&I)) {
4539         A.getOrCreateAAFor<AAIsDead>(IRPosition::value(*SI));
4540       }
4541     }
4542   }
4543 }
4544 
4545 const char AAICVTracker::ID = 0;
4546 const char AAKernelInfo::ID = 0;
4547 const char AAExecutionDomain::ID = 0;
4548 const char AAHeapToShared::ID = 0;
4549 const char AAFoldRuntimeCall::ID = 0;
4550 
4551 AAICVTracker &AAICVTracker::createForPosition(const IRPosition &IRP,
4552                                               Attributor &A) {
4553   AAICVTracker *AA = nullptr;
4554   switch (IRP.getPositionKind()) {
4555   case IRPosition::IRP_INVALID:
4556   case IRPosition::IRP_FLOAT:
4557   case IRPosition::IRP_ARGUMENT:
4558   case IRPosition::IRP_CALL_SITE_ARGUMENT:
4559     llvm_unreachable("ICVTracker can only be created for function position!");
4560   case IRPosition::IRP_RETURNED:
4561     AA = new (A.Allocator) AAICVTrackerFunctionReturned(IRP, A);
4562     break;
4563   case IRPosition::IRP_CALL_SITE_RETURNED:
4564     AA = new (A.Allocator) AAICVTrackerCallSiteReturned(IRP, A);
4565     break;
4566   case IRPosition::IRP_CALL_SITE:
4567     AA = new (A.Allocator) AAICVTrackerCallSite(IRP, A);
4568     break;
4569   case IRPosition::IRP_FUNCTION:
4570     AA = new (A.Allocator) AAICVTrackerFunction(IRP, A);
4571     break;
4572   }
4573 
4574   return *AA;
4575 }
4576 
4577 AAExecutionDomain &AAExecutionDomain::createForPosition(const IRPosition &IRP,
4578                                                         Attributor &A) {
4579   AAExecutionDomainFunction *AA = nullptr;
4580   switch (IRP.getPositionKind()) {
4581   case IRPosition::IRP_INVALID:
4582   case IRPosition::IRP_FLOAT:
4583   case IRPosition::IRP_ARGUMENT:
4584   case IRPosition::IRP_CALL_SITE_ARGUMENT:
4585   case IRPosition::IRP_RETURNED:
4586   case IRPosition::IRP_CALL_SITE_RETURNED:
4587   case IRPosition::IRP_CALL_SITE:
4588     llvm_unreachable(
4589         "AAExecutionDomain can only be created for function position!");
4590   case IRPosition::IRP_FUNCTION:
4591     AA = new (A.Allocator) AAExecutionDomainFunction(IRP, A);
4592     break;
4593   }
4594 
4595   return *AA;
4596 }
4597 
4598 AAHeapToShared &AAHeapToShared::createForPosition(const IRPosition &IRP,
4599                                                   Attributor &A) {
4600   AAHeapToSharedFunction *AA = nullptr;
4601   switch (IRP.getPositionKind()) {
4602   case IRPosition::IRP_INVALID:
4603   case IRPosition::IRP_FLOAT:
4604   case IRPosition::IRP_ARGUMENT:
4605   case IRPosition::IRP_CALL_SITE_ARGUMENT:
4606   case IRPosition::IRP_RETURNED:
4607   case IRPosition::IRP_CALL_SITE_RETURNED:
4608   case IRPosition::IRP_CALL_SITE:
4609     llvm_unreachable(
4610         "AAHeapToShared can only be created for function position!");
4611   case IRPosition::IRP_FUNCTION:
4612     AA = new (A.Allocator) AAHeapToSharedFunction(IRP, A);
4613     break;
4614   }
4615 
4616   return *AA;
4617 }
4618 
4619 AAKernelInfo &AAKernelInfo::createForPosition(const IRPosition &IRP,
4620                                               Attributor &A) {
4621   AAKernelInfo *AA = nullptr;
4622   switch (IRP.getPositionKind()) {
4623   case IRPosition::IRP_INVALID:
4624   case IRPosition::IRP_FLOAT:
4625   case IRPosition::IRP_ARGUMENT:
4626   case IRPosition::IRP_RETURNED:
4627   case IRPosition::IRP_CALL_SITE_RETURNED:
4628   case IRPosition::IRP_CALL_SITE_ARGUMENT:
4629     llvm_unreachable("KernelInfo can only be created for function position!");
4630   case IRPosition::IRP_CALL_SITE:
4631     AA = new (A.Allocator) AAKernelInfoCallSite(IRP, A);
4632     break;
4633   case IRPosition::IRP_FUNCTION:
4634     AA = new (A.Allocator) AAKernelInfoFunction(IRP, A);
4635     break;
4636   }
4637 
4638   return *AA;
4639 }
4640 
4641 AAFoldRuntimeCall &AAFoldRuntimeCall::createForPosition(const IRPosition &IRP,
4642                                                         Attributor &A) {
4643   AAFoldRuntimeCall *AA = nullptr;
4644   switch (IRP.getPositionKind()) {
4645   case IRPosition::IRP_INVALID:
4646   case IRPosition::IRP_FLOAT:
4647   case IRPosition::IRP_ARGUMENT:
4648   case IRPosition::IRP_RETURNED:
4649   case IRPosition::IRP_FUNCTION:
4650   case IRPosition::IRP_CALL_SITE:
4651   case IRPosition::IRP_CALL_SITE_ARGUMENT:
4652     llvm_unreachable("KernelInfo can only be created for call site position!");
4653   case IRPosition::IRP_CALL_SITE_RETURNED:
4654     AA = new (A.Allocator) AAFoldRuntimeCallCallSiteReturned(IRP, A);
4655     break;
4656   }
4657 
4658   return *AA;
4659 }
4660 
4661 PreservedAnalyses OpenMPOptPass::run(Module &M, ModuleAnalysisManager &AM) {
4662   if (!containsOpenMP(M))
4663     return PreservedAnalyses::all();
4664   if (DisableOpenMPOptimizations)
4665     return PreservedAnalyses::all();
4666 
4667   FunctionAnalysisManager &FAM =
4668       AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
4669   KernelSet Kernels = getDeviceKernels(M);
4670 
4671   auto IsCalled = [&](Function &F) {
4672     if (Kernels.contains(&F))
4673       return true;
4674     for (const User *U : F.users())
4675       if (!isa<BlockAddress>(U))
4676         return true;
4677     return false;
4678   };
4679 
4680   auto EmitRemark = [&](Function &F) {
4681     auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(F);
4682     ORE.emit([&]() {
4683       OptimizationRemarkAnalysis ORA(DEBUG_TYPE, "OMP140", &F);
4684       return ORA << "Could not internalize function. "
4685                  << "Some optimizations may not be possible. [OMP140]";
4686     });
4687   };
4688 
4689   // Create internal copies of each function if this is a kernel Module. This
4690   // allows iterprocedural passes to see every call edge.
4691   DenseMap<Function *, Function *> InternalizedMap;
4692   if (isOpenMPDevice(M)) {
4693     SmallPtrSet<Function *, 16> InternalizeFns;
4694     for (Function &F : M)
4695       if (!F.isDeclaration() && !Kernels.contains(&F) && IsCalled(F) &&
4696           !DisableInternalization) {
4697         if (Attributor::isInternalizable(F)) {
4698           InternalizeFns.insert(&F);
4699         } else if (!F.hasLocalLinkage() && !F.hasFnAttribute(Attribute::Cold)) {
4700           EmitRemark(F);
4701         }
4702       }
4703 
4704     Attributor::internalizeFunctions(InternalizeFns, InternalizedMap);
4705   }
4706 
4707   // Look at every function in the Module unless it was internalized.
4708   SmallVector<Function *, 16> SCC;
4709   for (Function &F : M)
4710     if (!F.isDeclaration() && !InternalizedMap.lookup(&F))
4711       SCC.push_back(&F);
4712 
4713   if (SCC.empty())
4714     return PreservedAnalyses::all();
4715 
4716   AnalysisGetter AG(FAM);
4717 
4718   auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & {
4719     return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F);
4720   };
4721 
4722   BumpPtrAllocator Allocator;
4723   CallGraphUpdater CGUpdater;
4724 
4725   SetVector<Function *> Functions(SCC.begin(), SCC.end());
4726   OMPInformationCache InfoCache(M, AG, Allocator, /*CGSCC*/ Functions, Kernels);
4727 
4728   unsigned MaxFixpointIterations =
4729       (isOpenMPDevice(M)) ? SetFixpointIterations : 32;
4730   Attributor A(Functions, InfoCache, CGUpdater, nullptr, true, false,
4731                MaxFixpointIterations, OREGetter, DEBUG_TYPE);
4732 
4733   OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
4734   bool Changed = OMPOpt.run(true);
4735 
4736   // Optionally inline device functions for potentially better performance.
4737   if (AlwaysInlineDeviceFunctions && isOpenMPDevice(M))
4738     for (Function &F : M)
4739       if (!F.isDeclaration() && !Kernels.contains(&F) &&
4740           !F.hasFnAttribute(Attribute::NoInline))
4741         F.addFnAttr(Attribute::AlwaysInline);
4742 
4743   if (PrintModuleAfterOptimizations)
4744     LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt Module Pass:\n" << M);
4745 
4746   if (Changed)
4747     return PreservedAnalyses::none();
4748 
4749   return PreservedAnalyses::all();
4750 }
4751 
4752 PreservedAnalyses OpenMPOptCGSCCPass::run(LazyCallGraph::SCC &C,
4753                                           CGSCCAnalysisManager &AM,
4754                                           LazyCallGraph &CG,
4755                                           CGSCCUpdateResult &UR) {
4756   if (!containsOpenMP(*C.begin()->getFunction().getParent()))
4757     return PreservedAnalyses::all();
4758   if (DisableOpenMPOptimizations)
4759     return PreservedAnalyses::all();
4760 
4761   SmallVector<Function *, 16> SCC;
4762   // If there are kernels in the module, we have to run on all SCC's.
4763   for (LazyCallGraph::Node &N : C) {
4764     Function *Fn = &N.getFunction();
4765     SCC.push_back(Fn);
4766   }
4767 
4768   if (SCC.empty())
4769     return PreservedAnalyses::all();
4770 
4771   Module &M = *C.begin()->getFunction().getParent();
4772 
4773   KernelSet Kernels = getDeviceKernels(M);
4774 
4775   FunctionAnalysisManager &FAM =
4776       AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager();
4777 
4778   AnalysisGetter AG(FAM);
4779 
4780   auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & {
4781     return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F);
4782   };
4783 
4784   BumpPtrAllocator Allocator;
4785   CallGraphUpdater CGUpdater;
4786   CGUpdater.initialize(CG, C, AM, UR);
4787 
4788   SetVector<Function *> Functions(SCC.begin(), SCC.end());
4789   OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG, Allocator,
4790                                 /*CGSCC*/ Functions, Kernels);
4791 
4792   unsigned MaxFixpointIterations =
4793       (isOpenMPDevice(M)) ? SetFixpointIterations : 32;
4794   Attributor A(Functions, InfoCache, CGUpdater, nullptr, false, true,
4795                MaxFixpointIterations, OREGetter, DEBUG_TYPE);
4796 
4797   OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
4798   bool Changed = OMPOpt.run(false);
4799 
4800   if (PrintModuleAfterOptimizations)
4801     LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt CGSCC Pass:\n" << M);
4802 
4803   if (Changed)
4804     return PreservedAnalyses::none();
4805 
4806   return PreservedAnalyses::all();
4807 }
4808 
4809 namespace {
4810 
4811 struct OpenMPOptCGSCCLegacyPass : public CallGraphSCCPass {
4812   CallGraphUpdater CGUpdater;
4813   static char ID;
4814 
4815   OpenMPOptCGSCCLegacyPass() : CallGraphSCCPass(ID) {
4816     initializeOpenMPOptCGSCCLegacyPassPass(*PassRegistry::getPassRegistry());
4817   }
4818 
4819   void getAnalysisUsage(AnalysisUsage &AU) const override {
4820     CallGraphSCCPass::getAnalysisUsage(AU);
4821   }
4822 
4823   bool runOnSCC(CallGraphSCC &CGSCC) override {
4824     if (!containsOpenMP(CGSCC.getCallGraph().getModule()))
4825       return false;
4826     if (DisableOpenMPOptimizations || skipSCC(CGSCC))
4827       return false;
4828 
4829     SmallVector<Function *, 16> SCC;
4830     // If there are kernels in the module, we have to run on all SCC's.
4831     for (CallGraphNode *CGN : CGSCC) {
4832       Function *Fn = CGN->getFunction();
4833       if (!Fn || Fn->isDeclaration())
4834         continue;
4835       SCC.push_back(Fn);
4836     }
4837 
4838     if (SCC.empty())
4839       return false;
4840 
4841     Module &M = CGSCC.getCallGraph().getModule();
4842     KernelSet Kernels = getDeviceKernels(M);
4843 
4844     CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph();
4845     CGUpdater.initialize(CG, CGSCC);
4846 
4847     // Maintain a map of functions to avoid rebuilding the ORE
4848     DenseMap<Function *, std::unique_ptr<OptimizationRemarkEmitter>> OREMap;
4849     auto OREGetter = [&OREMap](Function *F) -> OptimizationRemarkEmitter & {
4850       std::unique_ptr<OptimizationRemarkEmitter> &ORE = OREMap[F];
4851       if (!ORE)
4852         ORE = std::make_unique<OptimizationRemarkEmitter>(F);
4853       return *ORE;
4854     };
4855 
4856     AnalysisGetter AG;
4857     SetVector<Function *> Functions(SCC.begin(), SCC.end());
4858     BumpPtrAllocator Allocator;
4859     OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG,
4860                                   Allocator,
4861                                   /*CGSCC*/ Functions, Kernels);
4862 
4863     unsigned MaxFixpointIterations =
4864         (isOpenMPDevice(M)) ? SetFixpointIterations : 32;
4865     Attributor A(Functions, InfoCache, CGUpdater, nullptr, false, true,
4866                  MaxFixpointIterations, OREGetter, DEBUG_TYPE);
4867 
4868     OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
4869     bool Result = OMPOpt.run(false);
4870 
4871     if (PrintModuleAfterOptimizations)
4872       LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt CGSCC Pass:\n" << M);
4873 
4874     return Result;
4875   }
4876 
4877   bool doFinalization(CallGraph &CG) override { return CGUpdater.finalize(); }
4878 };
4879 
4880 } // end anonymous namespace
4881 
4882 KernelSet llvm::omp::getDeviceKernels(Module &M) {
4883   // TODO: Create a more cross-platform way of determining device kernels.
4884   NamedMDNode *MD = M.getOrInsertNamedMetadata("nvvm.annotations");
4885   KernelSet Kernels;
4886 
4887   if (!MD)
4888     return Kernels;
4889 
4890   for (auto *Op : MD->operands()) {
4891     if (Op->getNumOperands() < 2)
4892       continue;
4893     MDString *KindID = dyn_cast<MDString>(Op->getOperand(1));
4894     if (!KindID || KindID->getString() != "kernel")
4895       continue;
4896 
4897     Function *KernelFn =
4898         mdconst::dyn_extract_or_null<Function>(Op->getOperand(0));
4899     if (!KernelFn)
4900       continue;
4901 
4902     ++NumOpenMPTargetRegionKernels;
4903 
4904     Kernels.insert(KernelFn);
4905   }
4906 
4907   return Kernels;
4908 }
4909 
4910 bool llvm::omp::containsOpenMP(Module &M) {
4911   Metadata *MD = M.getModuleFlag("openmp");
4912   if (!MD)
4913     return false;
4914 
4915   return true;
4916 }
4917 
4918 bool llvm::omp::isOpenMPDevice(Module &M) {
4919   Metadata *MD = M.getModuleFlag("openmp-device");
4920   if (!MD)
4921     return false;
4922 
4923   return true;
4924 }
4925 
4926 char OpenMPOptCGSCCLegacyPass::ID = 0;
4927 
4928 INITIALIZE_PASS_BEGIN(OpenMPOptCGSCCLegacyPass, "openmp-opt-cgscc",
4929                       "OpenMP specific optimizations", false, false)
4930 INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass)
4931 INITIALIZE_PASS_END(OpenMPOptCGSCCLegacyPass, "openmp-opt-cgscc",
4932                     "OpenMP specific optimizations", false, false)
4933 
4934 Pass *llvm::createOpenMPOptCGSCCLegacyPass() {
4935   return new OpenMPOptCGSCCLegacyPass();
4936 }
4937