19548b74aSJohannes Doerfert //===-- IPO/OpenMPOpt.cpp - Collection of OpenMP specific optimizations ---===//
29548b74aSJohannes Doerfert //
39548b74aSJohannes Doerfert // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
49548b74aSJohannes Doerfert // See https://llvm.org/LICENSE.txt for license information.
59548b74aSJohannes Doerfert // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
69548b74aSJohannes Doerfert //
79548b74aSJohannes Doerfert //===----------------------------------------------------------------------===//
89548b74aSJohannes Doerfert //
99548b74aSJohannes Doerfert // OpenMP specific optimizations:
109548b74aSJohannes Doerfert //
119548b74aSJohannes Doerfert // - Deduplication of runtime calls, e.g., omp_get_thread_num.
12ca1560daSJoseph Huber // - Replacing globalized device memory with stack memory.
13ca1560daSJoseph Huber // - Replacing globalized device memory with shared memory.
14b910a109SJoseph Huber // - Parallel region merging.
15b910a109SJoseph Huber // - Transforming generic-mode device kernels to SPMD mode.
16b910a109SJoseph Huber // - Specializing the state machine for generic-mode device kernels.
179548b74aSJohannes Doerfert //
189548b74aSJohannes Doerfert //===----------------------------------------------------------------------===//
199548b74aSJohannes Doerfert 
209548b74aSJohannes Doerfert #include "llvm/Transforms/IPO/OpenMPOpt.h"
219548b74aSJohannes Doerfert 
229548b74aSJohannes Doerfert #include "llvm/ADT/EnumeratedArray.h"
2318283125SJoseph Huber #include "llvm/ADT/PostOrderIterator.h"
249548b74aSJohannes Doerfert #include "llvm/ADT/Statistic.h"
259548b74aSJohannes Doerfert #include "llvm/Analysis/CallGraph.h"
269548b74aSJohannes Doerfert #include "llvm/Analysis/CallGraphSCCPass.h"
274d4ea9acSHuber, Joseph #include "llvm/Analysis/OptimizationRemarkEmitter.h"
283a6bfcf2SGiorgis Georgakoudis #include "llvm/Analysis/ValueTracking.h"
299548b74aSJohannes Doerfert #include "llvm/Frontend/OpenMP/OMPConstants.h"
30e28936f6SJohannes Doerfert #include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
31d9659bf6SJohannes Doerfert #include "llvm/IR/Assumptions.h"
32d9659bf6SJohannes Doerfert #include "llvm/IR/DiagnosticInfo.h"
33514c033dSJohannes Doerfert #include "llvm/IR/GlobalValue.h"
34d9659bf6SJohannes Doerfert #include "llvm/IR/Instruction.h"
3568abc3d2SJoseph Huber #include "llvm/IR/IntrinsicInst.h"
369548b74aSJohannes Doerfert #include "llvm/InitializePasses.h"
379548b74aSJohannes Doerfert #include "llvm/Support/CommandLine.h"
389548b74aSJohannes Doerfert #include "llvm/Transforms/IPO.h"
397cfd267cSsstefan1 #include "llvm/Transforms/IPO/Attributor.h"
403a6bfcf2SGiorgis Georgakoudis #include "llvm/Transforms/Utils/BasicBlockUtils.h"
419548b74aSJohannes Doerfert #include "llvm/Transforms/Utils/CallGraphUpdater.h"
4297517055SGiorgis Georgakoudis #include "llvm/Transforms/Utils/CodeExtractor.h"
439548b74aSJohannes Doerfert 
449548b74aSJohannes Doerfert using namespace llvm;
459548b74aSJohannes Doerfert using namespace omp;
469548b74aSJohannes Doerfert 
479548b74aSJohannes Doerfert #define DEBUG_TYPE "openmp-opt"
489548b74aSJohannes Doerfert 
499548b74aSJohannes Doerfert static cl::opt<bool> DisableOpenMPOptimizations(
509548b74aSJohannes Doerfert     "openmp-opt-disable", cl::ZeroOrMore,
519548b74aSJohannes Doerfert     cl::desc("Disable OpenMP specific optimizations."), cl::Hidden,
529548b74aSJohannes Doerfert     cl::init(false));
539548b74aSJohannes Doerfert 
543a6bfcf2SGiorgis Georgakoudis static cl::opt<bool> EnableParallelRegionMerging(
553a6bfcf2SGiorgis Georgakoudis     "openmp-opt-enable-merging", cl::ZeroOrMore,
563a6bfcf2SGiorgis Georgakoudis     cl::desc("Enable the OpenMP region merging optimization."), cl::Hidden,
573a6bfcf2SGiorgis Georgakoudis     cl::init(false));
583a6bfcf2SGiorgis Georgakoudis 
594a668604SJoseph Huber static cl::opt<bool>
604a668604SJoseph Huber     DisableInternalization("openmp-opt-disable-internalization", cl::ZeroOrMore,
614a668604SJoseph Huber                            cl::desc("Disable function internalization."),
624a668604SJoseph Huber                            cl::Hidden, cl::init(false));
634a668604SJoseph Huber 
640f426935Ssstefan1 static cl::opt<bool> PrintICVValues("openmp-print-icv-values", cl::init(false),
650f426935Ssstefan1                                     cl::Hidden);
66e8039ad4SJohannes Doerfert static cl::opt<bool> PrintOpenMPKernels("openmp-print-gpu-kernels",
67e8039ad4SJohannes Doerfert                                         cl::init(false), cl::Hidden);
680f426935Ssstefan1 
69496f8e5bSHamilton Tobon Mosquera static cl::opt<bool> HideMemoryTransferLatency(
70496f8e5bSHamilton Tobon Mosquera     "openmp-hide-memory-transfer-latency",
71496f8e5bSHamilton Tobon Mosquera     cl::desc("[WIP] Tries to hide the latency of host to device memory"
72496f8e5bSHamilton Tobon Mosquera              " transfers"),
73496f8e5bSHamilton Tobon Mosquera     cl::Hidden, cl::init(false));
74496f8e5bSHamilton Tobon Mosquera 
75cd0dd8ecSJoseph Huber static cl::opt<bool> DisableOpenMPOptDeglobalization(
76cd0dd8ecSJoseph Huber     "openmp-opt-disable-deglobalization", cl::ZeroOrMore,
77cd0dd8ecSJoseph Huber     cl::desc("Disable OpenMP optimizations involving deglobalization."),
78cd0dd8ecSJoseph Huber     cl::Hidden, cl::init(false));
79cd0dd8ecSJoseph Huber 
80cd0dd8ecSJoseph Huber static cl::opt<bool> DisableOpenMPOptSPMDization(
81cd0dd8ecSJoseph Huber     "openmp-opt-disable-spmdization", cl::ZeroOrMore,
82cd0dd8ecSJoseph Huber     cl::desc("Disable OpenMP optimizations involving SPMD-ization."),
83cd0dd8ecSJoseph Huber     cl::Hidden, cl::init(false));
84cd0dd8ecSJoseph Huber 
85cd0dd8ecSJoseph Huber static cl::opt<bool> DisableOpenMPOptFolding(
86cd0dd8ecSJoseph Huber     "openmp-opt-disable-folding", cl::ZeroOrMore,
87cd0dd8ecSJoseph Huber     cl::desc("Disable OpenMP optimizations involving folding."), cl::Hidden,
88cd0dd8ecSJoseph Huber     cl::init(false));
89cd0dd8ecSJoseph Huber 
90cd0dd8ecSJoseph Huber static cl::opt<bool> DisableOpenMPOptStateMachineRewrite(
91cd0dd8ecSJoseph Huber     "openmp-opt-disable-state-machine-rewrite", cl::ZeroOrMore,
92cd0dd8ecSJoseph Huber     cl::desc("Disable OpenMP optimizations that replace the state machine."),
93cd0dd8ecSJoseph Huber     cl::Hidden, cl::init(false));
94cd0dd8ecSJoseph Huber 
959548b74aSJohannes Doerfert STATISTIC(NumOpenMPRuntimeCallsDeduplicated,
969548b74aSJohannes Doerfert           "Number of OpenMP runtime calls deduplicated");
9755eb714aSRoman Lebedev STATISTIC(NumOpenMPParallelRegionsDeleted,
9855eb714aSRoman Lebedev           "Number of OpenMP parallel regions deleted");
999548b74aSJohannes Doerfert STATISTIC(NumOpenMPRuntimeFunctionsIdentified,
1009548b74aSJohannes Doerfert           "Number of OpenMP runtime functions identified");
1019548b74aSJohannes Doerfert STATISTIC(NumOpenMPRuntimeFunctionUsesIdentified,
1029548b74aSJohannes Doerfert           "Number of OpenMP runtime function uses identified");
103e8039ad4SJohannes Doerfert STATISTIC(NumOpenMPTargetRegionKernels,
104e8039ad4SJohannes Doerfert           "Number of OpenMP target region entry points (=kernels) identified");
105514c033dSJohannes Doerfert STATISTIC(NumOpenMPTargetRegionKernelsSPMD,
106514c033dSJohannes Doerfert           "Number of OpenMP target region entry points (=kernels) executed in "
107514c033dSJohannes Doerfert           "SPMD-mode instead of generic-mode");
108d9659bf6SJohannes Doerfert STATISTIC(NumOpenMPTargetRegionKernelsWithoutStateMachine,
109d9659bf6SJohannes Doerfert           "Number of OpenMP target region entry points (=kernels) executed in "
110d9659bf6SJohannes Doerfert           "generic-mode without a state machines");
111d9659bf6SJohannes Doerfert STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback,
112d9659bf6SJohannes Doerfert           "Number of OpenMP target region entry points (=kernels) executed in "
113d9659bf6SJohannes Doerfert           "generic-mode with customized state machines with fallback");
114d9659bf6SJohannes Doerfert STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback,
115d9659bf6SJohannes Doerfert           "Number of OpenMP target region entry points (=kernels) executed in "
116d9659bf6SJohannes Doerfert           "generic-mode with customized state machines without fallback");
1175b0581aeSJohannes Doerfert STATISTIC(
1185b0581aeSJohannes Doerfert     NumOpenMPParallelRegionsReplacedInGPUStateMachine,
1195b0581aeSJohannes Doerfert     "Number of OpenMP parallel regions replaced with ID in GPU state machines");
1203a6bfcf2SGiorgis Georgakoudis STATISTIC(NumOpenMPParallelRegionsMerged,
1213a6bfcf2SGiorgis Georgakoudis           "Number of OpenMP parallel regions merged");
1226fc51c9fSJoseph Huber STATISTIC(NumBytesMovedToSharedMemory,
1236fc51c9fSJoseph Huber           "Amount of memory pushed to shared memory");
1249548b74aSJohannes Doerfert 
125263c4a3cSrathod-sahaab #if !defined(NDEBUG)
1269548b74aSJohannes Doerfert static constexpr auto TAG = "[" DEBUG_TYPE "]";
127a50c0b0dSMikael Holmen #endif
1289548b74aSJohannes Doerfert 
1299548b74aSJohannes Doerfert namespace {
1309548b74aSJohannes Doerfert 
1316fc51c9fSJoseph Huber enum class AddressSpace : unsigned {
1326fc51c9fSJoseph Huber   Generic = 0,
1336fc51c9fSJoseph Huber   Global = 1,
1346fc51c9fSJoseph Huber   Shared = 3,
1356fc51c9fSJoseph Huber   Constant = 4,
1366fc51c9fSJoseph Huber   Local = 5,
1376fc51c9fSJoseph Huber };
1386fc51c9fSJoseph Huber 
1396fc51c9fSJoseph Huber struct AAHeapToShared;
1406fc51c9fSJoseph Huber 
141b8235d2bSsstefan1 struct AAICVTracker;
142b8235d2bSsstefan1 
1437cfd267cSsstefan1 /// OpenMP specific information. For now, stores RFIs and ICVs also needed for
1447cfd267cSsstefan1 /// Attributor runs.
1457cfd267cSsstefan1 struct OMPInformationCache : public InformationCache {
1467cfd267cSsstefan1   OMPInformationCache(Module &M, AnalysisGetter &AG,
147624d34afSJohannes Doerfert                       BumpPtrAllocator &Allocator, SetVector<Function *> &CGSCC,
148e8039ad4SJohannes Doerfert                       SmallPtrSetImpl<Kernel> &Kernels)
149624d34afSJohannes Doerfert       : InformationCache(M, AG, Allocator, &CGSCC), OMPBuilder(M),
150624d34afSJohannes Doerfert         Kernels(Kernels) {
151624d34afSJohannes Doerfert 
15261238d26Ssstefan1     OMPBuilder.initialize();
1539548b74aSJohannes Doerfert     initializeRuntimeFunctions();
1540f426935Ssstefan1     initializeInternalControlVars();
1559548b74aSJohannes Doerfert   }
1569548b74aSJohannes Doerfert 
1570f426935Ssstefan1   /// Generic information that describes an internal control variable.
1580f426935Ssstefan1   struct InternalControlVarInfo {
1590f426935Ssstefan1     /// The kind, as described by InternalControlVar enum.
1600f426935Ssstefan1     InternalControlVar Kind;
1610f426935Ssstefan1 
1620f426935Ssstefan1     /// The name of the ICV.
1630f426935Ssstefan1     StringRef Name;
1640f426935Ssstefan1 
1650f426935Ssstefan1     /// Environment variable associated with this ICV.
1660f426935Ssstefan1     StringRef EnvVarName;
1670f426935Ssstefan1 
1680f426935Ssstefan1     /// Initial value kind.
1690f426935Ssstefan1     ICVInitValue InitKind;
1700f426935Ssstefan1 
1710f426935Ssstefan1     /// Initial value.
1720f426935Ssstefan1     ConstantInt *InitValue;
1730f426935Ssstefan1 
1740f426935Ssstefan1     /// Setter RTL function associated with this ICV.
1750f426935Ssstefan1     RuntimeFunction Setter;
1760f426935Ssstefan1 
1770f426935Ssstefan1     /// Getter RTL function associated with this ICV.
1780f426935Ssstefan1     RuntimeFunction Getter;
1790f426935Ssstefan1 
1800f426935Ssstefan1     /// RTL Function corresponding to the override clause of this ICV
1810f426935Ssstefan1     RuntimeFunction Clause;
1820f426935Ssstefan1   };
1830f426935Ssstefan1 
1849548b74aSJohannes Doerfert   /// Generic information that describes a runtime function
1859548b74aSJohannes Doerfert   struct RuntimeFunctionInfo {
1868855fec3SJohannes Doerfert 
1879548b74aSJohannes Doerfert     /// The kind, as described by the RuntimeFunction enum.
1889548b74aSJohannes Doerfert     RuntimeFunction Kind;
1899548b74aSJohannes Doerfert 
1909548b74aSJohannes Doerfert     /// The name of the function.
1919548b74aSJohannes Doerfert     StringRef Name;
1929548b74aSJohannes Doerfert 
1939548b74aSJohannes Doerfert     /// Flag to indicate a variadic function.
1949548b74aSJohannes Doerfert     bool IsVarArg;
1959548b74aSJohannes Doerfert 
1969548b74aSJohannes Doerfert     /// The return type of the function.
1979548b74aSJohannes Doerfert     Type *ReturnType;
1989548b74aSJohannes Doerfert 
1999548b74aSJohannes Doerfert     /// The argument types of the function.
2009548b74aSJohannes Doerfert     SmallVector<Type *, 8> ArgumentTypes;
2019548b74aSJohannes Doerfert 
2029548b74aSJohannes Doerfert     /// The declaration if available.
203f09f4b26SJohannes Doerfert     Function *Declaration = nullptr;
2049548b74aSJohannes Doerfert 
2059548b74aSJohannes Doerfert     /// Uses of this runtime function per function containing the use.
2068855fec3SJohannes Doerfert     using UseVector = SmallVector<Use *, 16>;
2078855fec3SJohannes Doerfert 
208b8235d2bSsstefan1     /// Clear UsesMap for runtime function.
209b8235d2bSsstefan1     void clearUsesMap() { UsesMap.clear(); }
210b8235d2bSsstefan1 
21154bd3751SJohannes Doerfert     /// Boolean conversion that is true if the runtime function was found.
21254bd3751SJohannes Doerfert     operator bool() const { return Declaration; }
21354bd3751SJohannes Doerfert 
2148855fec3SJohannes Doerfert     /// Return the vector of uses in function \p F.
2158855fec3SJohannes Doerfert     UseVector &getOrCreateUseVector(Function *F) {
216b8235d2bSsstefan1       std::shared_ptr<UseVector> &UV = UsesMap[F];
2178855fec3SJohannes Doerfert       if (!UV)
218b8235d2bSsstefan1         UV = std::make_shared<UseVector>();
2198855fec3SJohannes Doerfert       return *UV;
2208855fec3SJohannes Doerfert     }
2218855fec3SJohannes Doerfert 
2228855fec3SJohannes Doerfert     /// Return the vector of uses in function \p F or `nullptr` if there are
2238855fec3SJohannes Doerfert     /// none.
2248855fec3SJohannes Doerfert     const UseVector *getUseVector(Function &F) const {
22595e57072SDavid Blaikie       auto I = UsesMap.find(&F);
22695e57072SDavid Blaikie       if (I != UsesMap.end())
22795e57072SDavid Blaikie         return I->second.get();
22895e57072SDavid Blaikie       return nullptr;
2298855fec3SJohannes Doerfert     }
2308855fec3SJohannes Doerfert 
2318855fec3SJohannes Doerfert     /// Return how many functions contain uses of this runtime function.
2328855fec3SJohannes Doerfert     size_t getNumFunctionsWithUses() const { return UsesMap.size(); }
2339548b74aSJohannes Doerfert 
2349548b74aSJohannes Doerfert     /// Return the number of arguments (or the minimal number for variadic
2359548b74aSJohannes Doerfert     /// functions).
2369548b74aSJohannes Doerfert     size_t getNumArgs() const { return ArgumentTypes.size(); }
2379548b74aSJohannes Doerfert 
2389548b74aSJohannes Doerfert     /// Run the callback \p CB on each use and forget the use if the result is
2399548b74aSJohannes Doerfert     /// true. The callback will be fed the function in which the use was
2409548b74aSJohannes Doerfert     /// encountered as second argument.
241624d34afSJohannes Doerfert     void foreachUse(SmallVectorImpl<Function *> &SCC,
242624d34afSJohannes Doerfert                     function_ref<bool(Use &, Function &)> CB) {
243624d34afSJohannes Doerfert       for (Function *F : SCC)
244624d34afSJohannes Doerfert         foreachUse(CB, F);
245e099c7b6Ssstefan1     }
246e099c7b6Ssstefan1 
247e099c7b6Ssstefan1     /// Run the callback \p CB on each use within the function \p F and forget
248e099c7b6Ssstefan1     /// the use if the result is true.
249624d34afSJohannes Doerfert     void foreachUse(function_ref<bool(Use &, Function &)> CB, Function *F) {
2508855fec3SJohannes Doerfert       SmallVector<unsigned, 8> ToBeDeleted;
2519548b74aSJohannes Doerfert       ToBeDeleted.clear();
252e099c7b6Ssstefan1 
2538855fec3SJohannes Doerfert       unsigned Idx = 0;
254624d34afSJohannes Doerfert       UseVector &UV = getOrCreateUseVector(F);
255e099c7b6Ssstefan1 
2568855fec3SJohannes Doerfert       for (Use *U : UV) {
257e099c7b6Ssstefan1         if (CB(*U, *F))
2588855fec3SJohannes Doerfert           ToBeDeleted.push_back(Idx);
2598855fec3SJohannes Doerfert         ++Idx;
2608855fec3SJohannes Doerfert       }
2618855fec3SJohannes Doerfert 
2628855fec3SJohannes Doerfert       // Remove the to-be-deleted indices in reverse order as prior
263b726c557SJohannes Doerfert       // modifications will not modify the smaller indices.
2648855fec3SJohannes Doerfert       while (!ToBeDeleted.empty()) {
2658855fec3SJohannes Doerfert         unsigned Idx = ToBeDeleted.pop_back_val();
2668855fec3SJohannes Doerfert         UV[Idx] = UV.back();
2678855fec3SJohannes Doerfert         UV.pop_back();
2689548b74aSJohannes Doerfert       }
2699548b74aSJohannes Doerfert     }
2708855fec3SJohannes Doerfert 
2718855fec3SJohannes Doerfert   private:
2728855fec3SJohannes Doerfert     /// Map from functions to all uses of this runtime function contained in
2738855fec3SJohannes Doerfert     /// them.
274b8235d2bSsstefan1     DenseMap<Function *, std::shared_ptr<UseVector>> UsesMap;
275d9659bf6SJohannes Doerfert 
276d9659bf6SJohannes Doerfert   public:
277d9659bf6SJohannes Doerfert     /// Iterators for the uses of this runtime function.
278d9659bf6SJohannes Doerfert     decltype(UsesMap)::iterator begin() { return UsesMap.begin(); }
279d9659bf6SJohannes Doerfert     decltype(UsesMap)::iterator end() { return UsesMap.end(); }
2809548b74aSJohannes Doerfert   };
2819548b74aSJohannes Doerfert 
2827cfd267cSsstefan1   /// An OpenMP-IR-Builder instance
2837cfd267cSsstefan1   OpenMPIRBuilder OMPBuilder;
2847cfd267cSsstefan1 
2857cfd267cSsstefan1   /// Map from runtime function kind to the runtime function description.
2867cfd267cSsstefan1   EnumeratedArray<RuntimeFunctionInfo, RuntimeFunction,
2877cfd267cSsstefan1                   RuntimeFunction::OMPRTL___last>
2887cfd267cSsstefan1       RFIs;
2897cfd267cSsstefan1 
290d9659bf6SJohannes Doerfert   /// Map from function declarations/definitions to their runtime enum type.
291d9659bf6SJohannes Doerfert   DenseMap<Function *, RuntimeFunction> RuntimeFunctionIDMap;
292d9659bf6SJohannes Doerfert 
2930f426935Ssstefan1   /// Map from ICV kind to the ICV description.
2940f426935Ssstefan1   EnumeratedArray<InternalControlVarInfo, InternalControlVar,
2950f426935Ssstefan1                   InternalControlVar::ICV___last>
2960f426935Ssstefan1       ICVs;
2970f426935Ssstefan1 
2980f426935Ssstefan1   /// Helper to initialize all internal control variable information for those
2990f426935Ssstefan1   /// defined in OMPKinds.def.
3000f426935Ssstefan1   void initializeInternalControlVars() {
3010f426935Ssstefan1 #define ICV_RT_SET(_Name, RTL)                                                 \
3020f426935Ssstefan1   {                                                                            \
3030f426935Ssstefan1     auto &ICV = ICVs[_Name];                                                   \
3040f426935Ssstefan1     ICV.Setter = RTL;                                                          \
3050f426935Ssstefan1   }
3060f426935Ssstefan1 #define ICV_RT_GET(Name, RTL)                                                  \
3070f426935Ssstefan1   {                                                                            \
3080f426935Ssstefan1     auto &ICV = ICVs[Name];                                                    \
3090f426935Ssstefan1     ICV.Getter = RTL;                                                          \
3100f426935Ssstefan1   }
3110f426935Ssstefan1 #define ICV_DATA_ENV(Enum, _Name, _EnvVarName, Init)                           \
3120f426935Ssstefan1   {                                                                            \
3130f426935Ssstefan1     auto &ICV = ICVs[Enum];                                                    \
3140f426935Ssstefan1     ICV.Name = _Name;                                                          \
3150f426935Ssstefan1     ICV.Kind = Enum;                                                           \
3160f426935Ssstefan1     ICV.InitKind = Init;                                                       \
3170f426935Ssstefan1     ICV.EnvVarName = _EnvVarName;                                              \
3180f426935Ssstefan1     switch (ICV.InitKind) {                                                    \
319951e43f3Ssstefan1     case ICV_IMPLEMENTATION_DEFINED:                                           \
3200f426935Ssstefan1       ICV.InitValue = nullptr;                                                 \
3210f426935Ssstefan1       break;                                                                   \
322951e43f3Ssstefan1     case ICV_ZERO:                                                             \
3236aab27baSsstefan1       ICV.InitValue = ConstantInt::get(                                        \
3246aab27baSsstefan1           Type::getInt32Ty(OMPBuilder.Int32->getContext()), 0);                \
3250f426935Ssstefan1       break;                                                                   \
326951e43f3Ssstefan1     case ICV_FALSE:                                                            \
3276aab27baSsstefan1       ICV.InitValue = ConstantInt::getFalse(OMPBuilder.Int1->getContext());    \
3280f426935Ssstefan1       break;                                                                   \
329951e43f3Ssstefan1     case ICV_LAST:                                                             \
3300f426935Ssstefan1       break;                                                                   \
3310f426935Ssstefan1     }                                                                          \
3320f426935Ssstefan1   }
3330f426935Ssstefan1 #include "llvm/Frontend/OpenMP/OMPKinds.def"
3340f426935Ssstefan1   }
3350f426935Ssstefan1 
3367cfd267cSsstefan1   /// Returns true if the function declaration \p F matches the runtime
3377cfd267cSsstefan1   /// function types, that is, return type \p RTFRetType, and argument types
3387cfd267cSsstefan1   /// \p RTFArgTypes.
3397cfd267cSsstefan1   static bool declMatchesRTFTypes(Function *F, Type *RTFRetType,
3407cfd267cSsstefan1                                   SmallVector<Type *, 8> &RTFArgTypes) {
3417cfd267cSsstefan1     // TODO: We should output information to the user (under debug output
3427cfd267cSsstefan1     //       and via remarks).
3437cfd267cSsstefan1 
3447cfd267cSsstefan1     if (!F)
3457cfd267cSsstefan1       return false;
3467cfd267cSsstefan1     if (F->getReturnType() != RTFRetType)
3477cfd267cSsstefan1       return false;
3487cfd267cSsstefan1     if (F->arg_size() != RTFArgTypes.size())
3497cfd267cSsstefan1       return false;
3507cfd267cSsstefan1 
3517cfd267cSsstefan1     auto RTFTyIt = RTFArgTypes.begin();
3527cfd267cSsstefan1     for (Argument &Arg : F->args()) {
3537cfd267cSsstefan1       if (Arg.getType() != *RTFTyIt)
3547cfd267cSsstefan1         return false;
3557cfd267cSsstefan1 
3567cfd267cSsstefan1       ++RTFTyIt;
3577cfd267cSsstefan1     }
3587cfd267cSsstefan1 
3597cfd267cSsstefan1     return true;
3607cfd267cSsstefan1   }
3617cfd267cSsstefan1 
362b726c557SJohannes Doerfert   // Helper to collect all uses of the declaration in the UsesMap.
363b8235d2bSsstefan1   unsigned collectUses(RuntimeFunctionInfo &RFI, bool CollectStats = true) {
3647cfd267cSsstefan1     unsigned NumUses = 0;
3657cfd267cSsstefan1     if (!RFI.Declaration)
3667cfd267cSsstefan1       return NumUses;
3677cfd267cSsstefan1     OMPBuilder.addAttributes(RFI.Kind, *RFI.Declaration);
3687cfd267cSsstefan1 
369b8235d2bSsstefan1     if (CollectStats) {
3707cfd267cSsstefan1       NumOpenMPRuntimeFunctionsIdentified += 1;
3717cfd267cSsstefan1       NumOpenMPRuntimeFunctionUsesIdentified += RFI.Declaration->getNumUses();
372b8235d2bSsstefan1     }
3737cfd267cSsstefan1 
3747cfd267cSsstefan1     // TODO: We directly convert uses into proper calls and unknown uses.
3757cfd267cSsstefan1     for (Use &U : RFI.Declaration->uses()) {
3767cfd267cSsstefan1       if (Instruction *UserI = dyn_cast<Instruction>(U.getUser())) {
3777cfd267cSsstefan1         if (ModuleSlice.count(UserI->getFunction())) {
3787cfd267cSsstefan1           RFI.getOrCreateUseVector(UserI->getFunction()).push_back(&U);
3797cfd267cSsstefan1           ++NumUses;
3807cfd267cSsstefan1         }
3817cfd267cSsstefan1       } else {
3827cfd267cSsstefan1         RFI.getOrCreateUseVector(nullptr).push_back(&U);
3837cfd267cSsstefan1         ++NumUses;
3847cfd267cSsstefan1       }
3857cfd267cSsstefan1     }
3867cfd267cSsstefan1     return NumUses;
387b8235d2bSsstefan1   }
3887cfd267cSsstefan1 
38997517055SGiorgis Georgakoudis   // Helper function to recollect uses of a runtime function.
39097517055SGiorgis Georgakoudis   void recollectUsesForFunction(RuntimeFunction RTF) {
39197517055SGiorgis Georgakoudis     auto &RFI = RFIs[RTF];
392b8235d2bSsstefan1     RFI.clearUsesMap();
393b8235d2bSsstefan1     collectUses(RFI, /*CollectStats*/ false);
394b8235d2bSsstefan1   }
39597517055SGiorgis Georgakoudis 
39697517055SGiorgis Georgakoudis   // Helper function to recollect uses of all runtime functions.
39797517055SGiorgis Georgakoudis   void recollectUses() {
39897517055SGiorgis Georgakoudis     for (int Idx = 0; Idx < RFIs.size(); ++Idx)
39997517055SGiorgis Georgakoudis       recollectUsesForFunction(static_cast<RuntimeFunction>(Idx));
400b8235d2bSsstefan1   }
401b8235d2bSsstefan1 
402b8235d2bSsstefan1   /// Helper to initialize all runtime function information for those defined
403b8235d2bSsstefan1   /// in OpenMPKinds.def.
404b8235d2bSsstefan1   void initializeRuntimeFunctions() {
4057cfd267cSsstefan1     Module &M = *((*ModuleSlice.begin())->getParent());
4067cfd267cSsstefan1 
4076aab27baSsstefan1     // Helper macros for handling __VA_ARGS__ in OMP_RTL
4086aab27baSsstefan1 #define OMP_TYPE(VarName, ...)                                                 \
4096aab27baSsstefan1   Type *VarName = OMPBuilder.VarName;                                          \
4106aab27baSsstefan1   (void)VarName;
4116aab27baSsstefan1 
4126aab27baSsstefan1 #define OMP_ARRAY_TYPE(VarName, ...)                                           \
4136aab27baSsstefan1   ArrayType *VarName##Ty = OMPBuilder.VarName##Ty;                             \
4146aab27baSsstefan1   (void)VarName##Ty;                                                           \
4156aab27baSsstefan1   PointerType *VarName##PtrTy = OMPBuilder.VarName##PtrTy;                     \
4166aab27baSsstefan1   (void)VarName##PtrTy;
4176aab27baSsstefan1 
4186aab27baSsstefan1 #define OMP_FUNCTION_TYPE(VarName, ...)                                        \
4196aab27baSsstefan1   FunctionType *VarName = OMPBuilder.VarName;                                  \
4206aab27baSsstefan1   (void)VarName;                                                               \
4216aab27baSsstefan1   PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr;                         \
4226aab27baSsstefan1   (void)VarName##Ptr;
4236aab27baSsstefan1 
4246aab27baSsstefan1 #define OMP_STRUCT_TYPE(VarName, ...)                                          \
4256aab27baSsstefan1   StructType *VarName = OMPBuilder.VarName;                                    \
4266aab27baSsstefan1   (void)VarName;                                                               \
4276aab27baSsstefan1   PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr;                         \
4286aab27baSsstefan1   (void)VarName##Ptr;
4296aab27baSsstefan1 
4307cfd267cSsstefan1 #define OMP_RTL(_Enum, _Name, _IsVarArg, _ReturnType, ...)                     \
4317cfd267cSsstefan1   {                                                                            \
4327cfd267cSsstefan1     SmallVector<Type *, 8> ArgsTypes({__VA_ARGS__});                           \
4337cfd267cSsstefan1     Function *F = M.getFunction(_Name);                                        \
434eef6601bSJoseph Huber     RTLFunctions.insert(F);                                                    \
4356aab27baSsstefan1     if (declMatchesRTFTypes(F, OMPBuilder._ReturnType, ArgsTypes)) {           \
436d9659bf6SJohannes Doerfert       RuntimeFunctionIDMap[F] = _Enum;                                         \
43716206d17SJoseph Huber       F->removeFnAttr(Attribute::NoInline);                                    \
4387cfd267cSsstefan1       auto &RFI = RFIs[_Enum];                                                 \
4397cfd267cSsstefan1       RFI.Kind = _Enum;                                                        \
4407cfd267cSsstefan1       RFI.Name = _Name;                                                        \
4417cfd267cSsstefan1       RFI.IsVarArg = _IsVarArg;                                                \
4426aab27baSsstefan1       RFI.ReturnType = OMPBuilder._ReturnType;                                 \
4437cfd267cSsstefan1       RFI.ArgumentTypes = std::move(ArgsTypes);                                \
4447cfd267cSsstefan1       RFI.Declaration = F;                                                     \
445b8235d2bSsstefan1       unsigned NumUses = collectUses(RFI);                                     \
4467cfd267cSsstefan1       (void)NumUses;                                                           \
4477cfd267cSsstefan1       LLVM_DEBUG({                                                             \
4487cfd267cSsstefan1         dbgs() << TAG << RFI.Name << (RFI.Declaration ? "" : " not")           \
4497cfd267cSsstefan1                << " found\n";                                                  \
4507cfd267cSsstefan1         if (RFI.Declaration)                                                   \
4517cfd267cSsstefan1           dbgs() << TAG << "-> got " << NumUses << " uses in "                 \
4527cfd267cSsstefan1                  << RFI.getNumFunctionsWithUses()                              \
4537cfd267cSsstefan1                  << " different functions.\n";                                 \
4547cfd267cSsstefan1       });                                                                      \
4557cfd267cSsstefan1     }                                                                          \
4567cfd267cSsstefan1   }
4577cfd267cSsstefan1 #include "llvm/Frontend/OpenMP/OMPKinds.def"
4587cfd267cSsstefan1 
4597cfd267cSsstefan1     // TODO: We should attach the attributes defined in OMPKinds.def.
4607cfd267cSsstefan1   }
461e8039ad4SJohannes Doerfert 
462e8039ad4SJohannes Doerfert   /// Collection of known kernels (\see Kernel) in the module.
463e8039ad4SJohannes Doerfert   SmallPtrSetImpl<Kernel> &Kernels;
464eef6601bSJoseph Huber 
465eef6601bSJoseph Huber   /// Collection of known OpenMP runtime functions..
466eef6601bSJoseph Huber   DenseSet<const Function *> RTLFunctions;
4677cfd267cSsstefan1 };
4687cfd267cSsstefan1 
469d9659bf6SJohannes Doerfert template <typename Ty, bool InsertInvalidates = true>
4701a7f7790SShilei Tian struct BooleanStateWithSetVector : public BooleanState {
4711a7f7790SShilei Tian   bool contains(const Ty &Elem) const { return Set.contains(Elem); }
4721a7f7790SShilei Tian   bool insert(const Ty &Elem) {
473d9659bf6SJohannes Doerfert     if (InsertInvalidates)
474d9659bf6SJohannes Doerfert       BooleanState::indicatePessimisticFixpoint();
475d9659bf6SJohannes Doerfert     return Set.insert(Elem);
476d9659bf6SJohannes Doerfert   }
477d9659bf6SJohannes Doerfert 
4781a7f7790SShilei Tian   const Ty &operator[](int Idx) const { return Set[Idx]; }
4791a7f7790SShilei Tian   bool operator==(const BooleanStateWithSetVector &RHS) const {
480d9659bf6SJohannes Doerfert     return BooleanState::operator==(RHS) && Set == RHS.Set;
481d9659bf6SJohannes Doerfert   }
4821a7f7790SShilei Tian   bool operator!=(const BooleanStateWithSetVector &RHS) const {
483d9659bf6SJohannes Doerfert     return !(*this == RHS);
484d9659bf6SJohannes Doerfert   }
485d9659bf6SJohannes Doerfert 
486d9659bf6SJohannes Doerfert   bool empty() const { return Set.empty(); }
487d9659bf6SJohannes Doerfert   size_t size() const { return Set.size(); }
488d9659bf6SJohannes Doerfert 
489d9659bf6SJohannes Doerfert   /// "Clamp" this state with \p RHS.
4901a7f7790SShilei Tian   BooleanStateWithSetVector &operator^=(const BooleanStateWithSetVector &RHS) {
491d9659bf6SJohannes Doerfert     BooleanState::operator^=(RHS);
492d9659bf6SJohannes Doerfert     Set.insert(RHS.Set.begin(), RHS.Set.end());
493d9659bf6SJohannes Doerfert     return *this;
494d9659bf6SJohannes Doerfert   }
495d9659bf6SJohannes Doerfert 
496d9659bf6SJohannes Doerfert private:
497d9659bf6SJohannes Doerfert   /// A set to keep track of elements.
4981a7f7790SShilei Tian   SetVector<Ty> Set;
499d9659bf6SJohannes Doerfert 
500d9659bf6SJohannes Doerfert public:
501d9659bf6SJohannes Doerfert   typename decltype(Set)::iterator begin() { return Set.begin(); }
502d9659bf6SJohannes Doerfert   typename decltype(Set)::iterator end() { return Set.end(); }
503d9659bf6SJohannes Doerfert   typename decltype(Set)::const_iterator begin() const { return Set.begin(); }
504d9659bf6SJohannes Doerfert   typename decltype(Set)::const_iterator end() const { return Set.end(); }
505d9659bf6SJohannes Doerfert };
506d9659bf6SJohannes Doerfert 
5071a7f7790SShilei Tian template <typename Ty, bool InsertInvalidates = true>
5081a7f7790SShilei Tian using BooleanStateWithPtrSetVector =
5091a7f7790SShilei Tian     BooleanStateWithSetVector<Ty *, InsertInvalidates>;
5101a7f7790SShilei Tian 
511d9659bf6SJohannes Doerfert struct KernelInfoState : AbstractState {
512d9659bf6SJohannes Doerfert   /// Flag to track if we reached a fixpoint.
513d9659bf6SJohannes Doerfert   bool IsAtFixpoint = false;
514d9659bf6SJohannes Doerfert 
515d9659bf6SJohannes Doerfert   /// The parallel regions (identified by the outlined parallel functions) that
516d9659bf6SJohannes Doerfert   /// can be reached from the associated function.
517d9659bf6SJohannes Doerfert   BooleanStateWithPtrSetVector<Function, /* InsertInvalidates */ false>
518d9659bf6SJohannes Doerfert       ReachedKnownParallelRegions;
519d9659bf6SJohannes Doerfert 
520d9659bf6SJohannes Doerfert   /// State to track what parallel region we might reach.
521d9659bf6SJohannes Doerfert   BooleanStateWithPtrSetVector<CallBase> ReachedUnknownParallelRegions;
522d9659bf6SJohannes Doerfert 
523514c033dSJohannes Doerfert   /// State to track if we are in SPMD-mode, assumed or know, and why we decided
524e8439ec8SGiorgis Georgakoudis   /// we cannot be. If it is assumed, then RequiresFullRuntime should also be
525e8439ec8SGiorgis Georgakoudis   /// false.
52629a3e3ddSGiorgis Georgakoudis   BooleanStateWithPtrSetVector<Instruction, false> SPMDCompatibilityTracker;
527514c033dSJohannes Doerfert 
528d9659bf6SJohannes Doerfert   /// The __kmpc_target_init call in this kernel, if any. If we find more than
529d9659bf6SJohannes Doerfert   /// one we abort as the kernel is malformed.
530d9659bf6SJohannes Doerfert   CallBase *KernelInitCB = nullptr;
531d9659bf6SJohannes Doerfert 
532d9659bf6SJohannes Doerfert   /// The __kmpc_target_deinit call in this kernel, if any. If we find more than
533d9659bf6SJohannes Doerfert   /// one we abort as the kernel is malformed.
534d9659bf6SJohannes Doerfert   CallBase *KernelDeinitCB = nullptr;
535d9659bf6SJohannes Doerfert 
536ca662297SShilei Tian   /// Flag to indicate if the associated function is a kernel entry.
537ca662297SShilei Tian   bool IsKernelEntry = false;
538ca662297SShilei Tian 
539ca662297SShilei Tian   /// State to track what kernel entries can reach the associated function.
540ca662297SShilei Tian   BooleanStateWithPtrSetVector<Function, false> ReachingKernelEntries;
541ca662297SShilei Tian 
542e97e0a4fSShilei Tian   /// State to indicate if we can track parallel level of the associated
543e97e0a4fSShilei Tian   /// function. We will give up tracking if we encounter unknown caller or the
544e97e0a4fSShilei Tian   /// caller is __kmpc_parallel_51.
545e97e0a4fSShilei Tian   BooleanStateWithSetVector<uint8_t> ParallelLevels;
546e97e0a4fSShilei Tian 
547d9659bf6SJohannes Doerfert   /// Abstract State interface
548d9659bf6SJohannes Doerfert   ///{
549d9659bf6SJohannes Doerfert 
550d9659bf6SJohannes Doerfert   KernelInfoState() {}
551d9659bf6SJohannes Doerfert   KernelInfoState(bool BestState) {
552d9659bf6SJohannes Doerfert     if (!BestState)
553d9659bf6SJohannes Doerfert       indicatePessimisticFixpoint();
554d9659bf6SJohannes Doerfert   }
555d9659bf6SJohannes Doerfert 
556d9659bf6SJohannes Doerfert   /// See AbstractState::isValidState(...)
557d9659bf6SJohannes Doerfert   bool isValidState() const override { return true; }
558d9659bf6SJohannes Doerfert 
559d9659bf6SJohannes Doerfert   /// See AbstractState::isAtFixpoint(...)
560d9659bf6SJohannes Doerfert   bool isAtFixpoint() const override { return IsAtFixpoint; }
561d9659bf6SJohannes Doerfert 
562d9659bf6SJohannes Doerfert   /// See AbstractState::indicatePessimisticFixpoint(...)
563d9659bf6SJohannes Doerfert   ChangeStatus indicatePessimisticFixpoint() override {
564d9659bf6SJohannes Doerfert     IsAtFixpoint = true;
565514c033dSJohannes Doerfert     SPMDCompatibilityTracker.indicatePessimisticFixpoint();
566d9659bf6SJohannes Doerfert     ReachedUnknownParallelRegions.indicatePessimisticFixpoint();
567d9659bf6SJohannes Doerfert     return ChangeStatus::CHANGED;
568d9659bf6SJohannes Doerfert   }
569d9659bf6SJohannes Doerfert 
570d9659bf6SJohannes Doerfert   /// See AbstractState::indicateOptimisticFixpoint(...)
571d9659bf6SJohannes Doerfert   ChangeStatus indicateOptimisticFixpoint() override {
572d9659bf6SJohannes Doerfert     IsAtFixpoint = true;
573d9659bf6SJohannes Doerfert     return ChangeStatus::UNCHANGED;
574d9659bf6SJohannes Doerfert   }
575d9659bf6SJohannes Doerfert 
576d9659bf6SJohannes Doerfert   /// Return the assumed state
577d9659bf6SJohannes Doerfert   KernelInfoState &getAssumed() { return *this; }
578d9659bf6SJohannes Doerfert   const KernelInfoState &getAssumed() const { return *this; }
579d9659bf6SJohannes Doerfert 
580d9659bf6SJohannes Doerfert   bool operator==(const KernelInfoState &RHS) const {
581514c033dSJohannes Doerfert     if (SPMDCompatibilityTracker != RHS.SPMDCompatibilityTracker)
582514c033dSJohannes Doerfert       return false;
583d9659bf6SJohannes Doerfert     if (ReachedKnownParallelRegions != RHS.ReachedKnownParallelRegions)
584d9659bf6SJohannes Doerfert       return false;
585d9659bf6SJohannes Doerfert     if (ReachedUnknownParallelRegions != RHS.ReachedUnknownParallelRegions)
586d9659bf6SJohannes Doerfert       return false;
587ca662297SShilei Tian     if (ReachingKernelEntries != RHS.ReachingKernelEntries)
588ca662297SShilei Tian       return false;
589d9659bf6SJohannes Doerfert     return true;
590d9659bf6SJohannes Doerfert   }
591d9659bf6SJohannes Doerfert 
592d9659bf6SJohannes Doerfert   /// Return empty set as the best state of potential values.
593d9659bf6SJohannes Doerfert   static KernelInfoState getBestState() { return KernelInfoState(true); }
594d9659bf6SJohannes Doerfert 
595d9659bf6SJohannes Doerfert   static KernelInfoState getBestState(KernelInfoState &KIS) {
596d9659bf6SJohannes Doerfert     return getBestState();
597d9659bf6SJohannes Doerfert   }
598d9659bf6SJohannes Doerfert 
599d9659bf6SJohannes Doerfert   /// Return full set as the worst state of potential values.
600d9659bf6SJohannes Doerfert   static KernelInfoState getWorstState() { return KernelInfoState(false); }
601d9659bf6SJohannes Doerfert 
602d9659bf6SJohannes Doerfert   /// "Clamp" this state with \p KIS.
603d9659bf6SJohannes Doerfert   KernelInfoState operator^=(const KernelInfoState &KIS) {
604d9659bf6SJohannes Doerfert     // Do not merge two different _init and _deinit call sites.
605d9659bf6SJohannes Doerfert     if (KIS.KernelInitCB) {
606d9659bf6SJohannes Doerfert       if (KernelInitCB && KernelInitCB != KIS.KernelInitCB)
607d9659bf6SJohannes Doerfert         indicatePessimisticFixpoint();
608d9659bf6SJohannes Doerfert       KernelInitCB = KIS.KernelInitCB;
609d9659bf6SJohannes Doerfert     }
610d9659bf6SJohannes Doerfert     if (KIS.KernelDeinitCB) {
611d9659bf6SJohannes Doerfert       if (KernelDeinitCB && KernelDeinitCB != KIS.KernelDeinitCB)
612d9659bf6SJohannes Doerfert         indicatePessimisticFixpoint();
613d9659bf6SJohannes Doerfert       KernelDeinitCB = KIS.KernelDeinitCB;
614d9659bf6SJohannes Doerfert     }
615514c033dSJohannes Doerfert     SPMDCompatibilityTracker ^= KIS.SPMDCompatibilityTracker;
616d9659bf6SJohannes Doerfert     ReachedKnownParallelRegions ^= KIS.ReachedKnownParallelRegions;
617d9659bf6SJohannes Doerfert     ReachedUnknownParallelRegions ^= KIS.ReachedUnknownParallelRegions;
618d9659bf6SJohannes Doerfert     return *this;
619d9659bf6SJohannes Doerfert   }
620d9659bf6SJohannes Doerfert 
621d9659bf6SJohannes Doerfert   KernelInfoState operator&=(const KernelInfoState &KIS) {
622d9659bf6SJohannes Doerfert     return (*this ^= KIS);
623d9659bf6SJohannes Doerfert   }
624d9659bf6SJohannes Doerfert 
625d9659bf6SJohannes Doerfert   ///}
626d9659bf6SJohannes Doerfert };
627d9659bf6SJohannes Doerfert 
6288931add6SHamilton Tobon Mosquera /// Used to map the values physically (in the IR) stored in an offload
6298931add6SHamilton Tobon Mosquera /// array, to a vector in memory.
6308931add6SHamilton Tobon Mosquera struct OffloadArray {
6318931add6SHamilton Tobon Mosquera   /// Physical array (in the IR).
6328931add6SHamilton Tobon Mosquera   AllocaInst *Array = nullptr;
6338931add6SHamilton Tobon Mosquera   /// Mapped values.
6348931add6SHamilton Tobon Mosquera   SmallVector<Value *, 8> StoredValues;
6358931add6SHamilton Tobon Mosquera   /// Last stores made in the offload array.
6368931add6SHamilton Tobon Mosquera   SmallVector<StoreInst *, 8> LastAccesses;
6378931add6SHamilton Tobon Mosquera 
6388931add6SHamilton Tobon Mosquera   OffloadArray() = default;
6398931add6SHamilton Tobon Mosquera 
6408931add6SHamilton Tobon Mosquera   /// Initializes the OffloadArray with the values stored in \p Array before
6418931add6SHamilton Tobon Mosquera   /// instruction \p Before is reached. Returns false if the initialization
6428931add6SHamilton Tobon Mosquera   /// fails.
6438931add6SHamilton Tobon Mosquera   /// This MUST be used immediately after the construction of the object.
6448931add6SHamilton Tobon Mosquera   bool initialize(AllocaInst &Array, Instruction &Before) {
6458931add6SHamilton Tobon Mosquera     if (!Array.getAllocatedType()->isArrayTy())
6468931add6SHamilton Tobon Mosquera       return false;
6478931add6SHamilton Tobon Mosquera 
6488931add6SHamilton Tobon Mosquera     if (!getValues(Array, Before))
6498931add6SHamilton Tobon Mosquera       return false;
6508931add6SHamilton Tobon Mosquera 
6518931add6SHamilton Tobon Mosquera     this->Array = &Array;
6528931add6SHamilton Tobon Mosquera     return true;
6538931add6SHamilton Tobon Mosquera   }
6548931add6SHamilton Tobon Mosquera 
655da8bec47SJoseph Huber   static const unsigned DeviceIDArgNum = 1;
656da8bec47SJoseph Huber   static const unsigned BasePtrsArgNum = 3;
657da8bec47SJoseph Huber   static const unsigned PtrsArgNum = 4;
658da8bec47SJoseph Huber   static const unsigned SizesArgNum = 5;
6591d3d9b9cSHamilton Tobon Mosquera 
6608931add6SHamilton Tobon Mosquera private:
6618931add6SHamilton Tobon Mosquera   /// Traverses the BasicBlock where \p Array is, collecting the stores made to
6628931add6SHamilton Tobon Mosquera   /// \p Array, leaving StoredValues with the values stored before the
6638931add6SHamilton Tobon Mosquera   /// instruction \p Before is reached.
6648931add6SHamilton Tobon Mosquera   bool getValues(AllocaInst &Array, Instruction &Before) {
6658931add6SHamilton Tobon Mosquera     // Initialize container.
666d08d490aSJohannes Doerfert     const uint64_t NumValues = Array.getAllocatedType()->getArrayNumElements();
6678931add6SHamilton Tobon Mosquera     StoredValues.assign(NumValues, nullptr);
6688931add6SHamilton Tobon Mosquera     LastAccesses.assign(NumValues, nullptr);
6698931add6SHamilton Tobon Mosquera 
6708931add6SHamilton Tobon Mosquera     // TODO: This assumes the instruction \p Before is in the same
6718931add6SHamilton Tobon Mosquera     //  BasicBlock as Array. Make it general, for any control flow graph.
6728931add6SHamilton Tobon Mosquera     BasicBlock *BB = Array.getParent();
6738931add6SHamilton Tobon Mosquera     if (BB != Before.getParent())
6748931add6SHamilton Tobon Mosquera       return false;
6758931add6SHamilton Tobon Mosquera 
6768931add6SHamilton Tobon Mosquera     const DataLayout &DL = Array.getModule()->getDataLayout();
6778931add6SHamilton Tobon Mosquera     const unsigned int PointerSize = DL.getPointerSize();
6788931add6SHamilton Tobon Mosquera 
6798931add6SHamilton Tobon Mosquera     for (Instruction &I : *BB) {
6808931add6SHamilton Tobon Mosquera       if (&I == &Before)
6818931add6SHamilton Tobon Mosquera         break;
6828931add6SHamilton Tobon Mosquera 
6838931add6SHamilton Tobon Mosquera       if (!isa<StoreInst>(&I))
6848931add6SHamilton Tobon Mosquera         continue;
6858931add6SHamilton Tobon Mosquera 
6868931add6SHamilton Tobon Mosquera       auto *S = cast<StoreInst>(&I);
6878931add6SHamilton Tobon Mosquera       int64_t Offset = -1;
688d08d490aSJohannes Doerfert       auto *Dst =
689d08d490aSJohannes Doerfert           GetPointerBaseWithConstantOffset(S->getPointerOperand(), Offset, DL);
6908931add6SHamilton Tobon Mosquera       if (Dst == &Array) {
6918931add6SHamilton Tobon Mosquera         int64_t Idx = Offset / PointerSize;
6928931add6SHamilton Tobon Mosquera         StoredValues[Idx] = getUnderlyingObject(S->getValueOperand());
6938931add6SHamilton Tobon Mosquera         LastAccesses[Idx] = S;
6948931add6SHamilton Tobon Mosquera       }
6958931add6SHamilton Tobon Mosquera     }
6968931add6SHamilton Tobon Mosquera 
6978931add6SHamilton Tobon Mosquera     return isFilled();
6988931add6SHamilton Tobon Mosquera   }
6998931add6SHamilton Tobon Mosquera 
7008931add6SHamilton Tobon Mosquera   /// Returns true if all values in StoredValues and
7018931add6SHamilton Tobon Mosquera   /// LastAccesses are not nullptrs.
7028931add6SHamilton Tobon Mosquera   bool isFilled() {
7038931add6SHamilton Tobon Mosquera     const unsigned NumValues = StoredValues.size();
7048931add6SHamilton Tobon Mosquera     for (unsigned I = 0; I < NumValues; ++I) {
7058931add6SHamilton Tobon Mosquera       if (!StoredValues[I] || !LastAccesses[I])
7068931add6SHamilton Tobon Mosquera         return false;
7078931add6SHamilton Tobon Mosquera     }
7088931add6SHamilton Tobon Mosquera 
7098931add6SHamilton Tobon Mosquera     return true;
7108931add6SHamilton Tobon Mosquera   }
7118931add6SHamilton Tobon Mosquera };
7128931add6SHamilton Tobon Mosquera 
7137cfd267cSsstefan1 struct OpenMPOpt {
7147cfd267cSsstefan1 
7157cfd267cSsstefan1   using OptimizationRemarkGetter =
7167cfd267cSsstefan1       function_ref<OptimizationRemarkEmitter &(Function *)>;
7177cfd267cSsstefan1 
7187cfd267cSsstefan1   OpenMPOpt(SmallVectorImpl<Function *> &SCC, CallGraphUpdater &CGUpdater,
7197cfd267cSsstefan1             OptimizationRemarkGetter OREGetter,
720b8235d2bSsstefan1             OMPInformationCache &OMPInfoCache, Attributor &A)
72177b79d79SMehdi Amini       : M(*(*SCC.begin())->getParent()), SCC(SCC), CGUpdater(CGUpdater),
722b8235d2bSsstefan1         OREGetter(OREGetter), OMPInfoCache(OMPInfoCache), A(A) {}
7237cfd267cSsstefan1 
724a2281419SJoseph Huber   /// Check if any remarks are enabled for openmp-opt
725a2281419SJoseph Huber   bool remarksEnabled() {
726a2281419SJoseph Huber     auto &Ctx = M.getContext();
727a2281419SJoseph Huber     return Ctx.getDiagHandlerPtr()->isAnyRemarkEnabled(DEBUG_TYPE);
728a2281419SJoseph Huber   }
729a2281419SJoseph Huber 
7309548b74aSJohannes Doerfert   /// Run all OpenMP optimizations on the underlying SCC/ModuleSlice.
731b2ad63d3SJoseph Huber   bool run(bool IsModulePass) {
73254bd3751SJohannes Doerfert     if (SCC.empty())
73354bd3751SJohannes Doerfert       return false;
73454bd3751SJohannes Doerfert 
7359548b74aSJohannes Doerfert     bool Changed = false;
7369548b74aSJohannes Doerfert 
7379548b74aSJohannes Doerfert     LLVM_DEBUG(dbgs() << TAG << "Run on SCC with " << SCC.size()
73877b79d79SMehdi Amini                       << " functions in a slice with "
73977b79d79SMehdi Amini                       << OMPInfoCache.ModuleSlice.size() << " functions\n");
7409548b74aSJohannes Doerfert 
741b2ad63d3SJoseph Huber     if (IsModulePass) {
742d9659bf6SJohannes Doerfert       Changed |= runAttributor(IsModulePass);
74318283125SJoseph Huber 
7446fc51c9fSJoseph Huber       // Recollect uses, in case Attributor deleted any.
7456fc51c9fSJoseph Huber       OMPInfoCache.recollectUses();
7466fc51c9fSJoseph Huber 
747be2b5696SJohannes Doerfert       // TODO: This should be folded into buildCustomStateMachine.
748be2b5696SJohannes Doerfert       Changed |= rewriteDeviceCodeStateMachine();
749be2b5696SJohannes Doerfert 
750b2ad63d3SJoseph Huber       if (remarksEnabled())
751b2ad63d3SJoseph Huber         analysisGlobalization();
752b2ad63d3SJoseph Huber     } else {
753e8039ad4SJohannes Doerfert       if (PrintICVValues)
754e8039ad4SJohannes Doerfert         printICVs();
755e8039ad4SJohannes Doerfert       if (PrintOpenMPKernels)
756e8039ad4SJohannes Doerfert         printKernels();
757e8039ad4SJohannes Doerfert 
758d9659bf6SJohannes Doerfert       Changed |= runAttributor(IsModulePass);
759e8039ad4SJohannes Doerfert 
760e8039ad4SJohannes Doerfert       // Recollect uses, in case Attributor deleted any.
761e8039ad4SJohannes Doerfert       OMPInfoCache.recollectUses();
762e8039ad4SJohannes Doerfert 
763e8039ad4SJohannes Doerfert       Changed |= deleteParallelRegions();
764d9659bf6SJohannes Doerfert 
765496f8e5bSHamilton Tobon Mosquera       if (HideMemoryTransferLatency)
766496f8e5bSHamilton Tobon Mosquera         Changed |= hideMemTransfersLatency();
7673a6bfcf2SGiorgis Georgakoudis       Changed |= deduplicateRuntimeCalls();
7683a6bfcf2SGiorgis Georgakoudis       if (EnableParallelRegionMerging) {
7693a6bfcf2SGiorgis Georgakoudis         if (mergeParallelRegions()) {
7703a6bfcf2SGiorgis Georgakoudis           deduplicateRuntimeCalls();
7713a6bfcf2SGiorgis Georgakoudis           Changed = true;
7723a6bfcf2SGiorgis Georgakoudis         }
7733a6bfcf2SGiorgis Georgakoudis       }
774b2ad63d3SJoseph Huber     }
775e8039ad4SJohannes Doerfert 
776e8039ad4SJohannes Doerfert     return Changed;
777e8039ad4SJohannes Doerfert   }
778e8039ad4SJohannes Doerfert 
7790f426935Ssstefan1   /// Print initial ICV values for testing.
7800f426935Ssstefan1   /// FIXME: This should be done from the Attributor once it is added.
781e8039ad4SJohannes Doerfert   void printICVs() const {
782cb9cfa0dSsstefan1     InternalControlVar ICVs[] = {ICV_nthreads, ICV_active_levels, ICV_cancel,
783cb9cfa0dSsstefan1                                  ICV_proc_bind};
7840f426935Ssstefan1 
7850f426935Ssstefan1     for (Function *F : OMPInfoCache.ModuleSlice) {
7860f426935Ssstefan1       for (auto ICV : ICVs) {
7870f426935Ssstefan1         auto ICVInfo = OMPInfoCache.ICVs[ICV];
7882db182ffSJoseph Huber         auto Remark = [&](OptimizationRemarkAnalysis ORA) {
7892db182ffSJoseph Huber           return ORA << "OpenMP ICV " << ore::NV("OpenMPICV", ICVInfo.Name)
7900f426935Ssstefan1                      << " Value: "
7910f426935Ssstefan1                      << (ICVInfo.InitValue
79261cdaf66SSimon Pilgrim                              ? toString(ICVInfo.InitValue->getValue(), 10, true)
7930f426935Ssstefan1                              : "IMPLEMENTATION_DEFINED");
7940f426935Ssstefan1         };
7950f426935Ssstefan1 
7962db182ffSJoseph Huber         emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPICVTracker", Remark);
7970f426935Ssstefan1       }
7980f426935Ssstefan1     }
7990f426935Ssstefan1   }
8000f426935Ssstefan1 
801e8039ad4SJohannes Doerfert   /// Print OpenMP GPU kernels for testing.
802e8039ad4SJohannes Doerfert   void printKernels() const {
803e8039ad4SJohannes Doerfert     for (Function *F : SCC) {
804e8039ad4SJohannes Doerfert       if (!OMPInfoCache.Kernels.count(F))
805e8039ad4SJohannes Doerfert         continue;
806b8235d2bSsstefan1 
8072db182ffSJoseph Huber       auto Remark = [&](OptimizationRemarkAnalysis ORA) {
8082db182ffSJoseph Huber         return ORA << "OpenMP GPU kernel "
809e8039ad4SJohannes Doerfert                    << ore::NV("OpenMPGPUKernel", F->getName()) << "\n";
810e8039ad4SJohannes Doerfert       };
811b8235d2bSsstefan1 
8122db182ffSJoseph Huber       emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPGPU", Remark);
813e8039ad4SJohannes Doerfert     }
8149548b74aSJohannes Doerfert   }
8159548b74aSJohannes Doerfert 
8167cfd267cSsstefan1   /// Return the call if \p U is a callee use in a regular call. If \p RFI is
8177cfd267cSsstefan1   /// given it has to be the callee or a nullptr is returned.
8187cfd267cSsstefan1   static CallInst *getCallIfRegularCall(
8197cfd267cSsstefan1       Use &U, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
8207cfd267cSsstefan1     CallInst *CI = dyn_cast<CallInst>(U.getUser());
8217cfd267cSsstefan1     if (CI && CI->isCallee(&U) && !CI->hasOperandBundles() &&
822c4b1fe05SJohannes Doerfert         (!RFI ||
823c4b1fe05SJohannes Doerfert          (RFI->Declaration && CI->getCalledFunction() == RFI->Declaration)))
8247cfd267cSsstefan1       return CI;
8257cfd267cSsstefan1     return nullptr;
8267cfd267cSsstefan1   }
8277cfd267cSsstefan1 
8287cfd267cSsstefan1   /// Return the call if \p V is a regular call. If \p RFI is given it has to be
8297cfd267cSsstefan1   /// the callee or a nullptr is returned.
8307cfd267cSsstefan1   static CallInst *getCallIfRegularCall(
8317cfd267cSsstefan1       Value &V, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
8327cfd267cSsstefan1     CallInst *CI = dyn_cast<CallInst>(&V);
8337cfd267cSsstefan1     if (CI && !CI->hasOperandBundles() &&
834c4b1fe05SJohannes Doerfert         (!RFI ||
835c4b1fe05SJohannes Doerfert          (RFI->Declaration && CI->getCalledFunction() == RFI->Declaration)))
8367cfd267cSsstefan1       return CI;
8377cfd267cSsstefan1     return nullptr;
8387cfd267cSsstefan1   }
8397cfd267cSsstefan1 
8409548b74aSJohannes Doerfert private:
8413a6bfcf2SGiorgis Georgakoudis   /// Merge parallel regions when it is safe.
8423a6bfcf2SGiorgis Georgakoudis   bool mergeParallelRegions() {
8433a6bfcf2SGiorgis Georgakoudis     const unsigned CallbackCalleeOperand = 2;
8443a6bfcf2SGiorgis Georgakoudis     const unsigned CallbackFirstArgOperand = 3;
8453a6bfcf2SGiorgis Georgakoudis     using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
8463a6bfcf2SGiorgis Georgakoudis 
8473a6bfcf2SGiorgis Georgakoudis     // Check if there are any __kmpc_fork_call calls to merge.
8483a6bfcf2SGiorgis Georgakoudis     OMPInformationCache::RuntimeFunctionInfo &RFI =
8493a6bfcf2SGiorgis Georgakoudis         OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
8503a6bfcf2SGiorgis Georgakoudis 
8513a6bfcf2SGiorgis Georgakoudis     if (!RFI.Declaration)
8523a6bfcf2SGiorgis Georgakoudis       return false;
8533a6bfcf2SGiorgis Georgakoudis 
85497517055SGiorgis Georgakoudis     // Unmergable calls that prevent merging a parallel region.
85597517055SGiorgis Georgakoudis     OMPInformationCache::RuntimeFunctionInfo UnmergableCallsInfo[] = {
85697517055SGiorgis Georgakoudis         OMPInfoCache.RFIs[OMPRTL___kmpc_push_proc_bind],
85797517055SGiorgis Georgakoudis         OMPInfoCache.RFIs[OMPRTL___kmpc_push_num_threads],
85897517055SGiorgis Georgakoudis     };
8593a6bfcf2SGiorgis Georgakoudis 
8603a6bfcf2SGiorgis Georgakoudis     bool Changed = false;
8613a6bfcf2SGiorgis Georgakoudis     LoopInfo *LI = nullptr;
8623a6bfcf2SGiorgis Georgakoudis     DominatorTree *DT = nullptr;
8633a6bfcf2SGiorgis Georgakoudis 
8643a6bfcf2SGiorgis Georgakoudis     SmallDenseMap<BasicBlock *, SmallPtrSet<Instruction *, 4>> BB2PRMap;
8653a6bfcf2SGiorgis Georgakoudis 
8663a6bfcf2SGiorgis Georgakoudis     BasicBlock *StartBB = nullptr, *EndBB = nullptr;
8673a6bfcf2SGiorgis Georgakoudis     auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
8683a6bfcf2SGiorgis Georgakoudis                          BasicBlock &ContinuationIP) {
8693a6bfcf2SGiorgis Georgakoudis       BasicBlock *CGStartBB = CodeGenIP.getBlock();
8703a6bfcf2SGiorgis Georgakoudis       BasicBlock *CGEndBB =
8713a6bfcf2SGiorgis Georgakoudis           SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
8723a6bfcf2SGiorgis Georgakoudis       assert(StartBB != nullptr && "StartBB should not be null");
8733a6bfcf2SGiorgis Georgakoudis       CGStartBB->getTerminator()->setSuccessor(0, StartBB);
8743a6bfcf2SGiorgis Georgakoudis       assert(EndBB != nullptr && "EndBB should not be null");
8753a6bfcf2SGiorgis Georgakoudis       EndBB->getTerminator()->setSuccessor(0, CGEndBB);
8763a6bfcf2SGiorgis Georgakoudis     };
8773a6bfcf2SGiorgis Georgakoudis 
878240dd924SAlex Zinenko     auto PrivCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, Value &,
879240dd924SAlex Zinenko                       Value &Inner, Value *&ReplacementValue) -> InsertPointTy {
880240dd924SAlex Zinenko       ReplacementValue = &Inner;
8813a6bfcf2SGiorgis Georgakoudis       return CodeGenIP;
8823a6bfcf2SGiorgis Georgakoudis     };
8833a6bfcf2SGiorgis Georgakoudis 
8843a6bfcf2SGiorgis Georgakoudis     auto FiniCB = [&](InsertPointTy CodeGenIP) {};
8853a6bfcf2SGiorgis Georgakoudis 
88697517055SGiorgis Georgakoudis     /// Create a sequential execution region within a merged parallel region,
88797517055SGiorgis Georgakoudis     /// encapsulated in a master construct with a barrier for synchronization.
88897517055SGiorgis Georgakoudis     auto CreateSequentialRegion = [&](Function *OuterFn,
88997517055SGiorgis Georgakoudis                                       BasicBlock *OuterPredBB,
89097517055SGiorgis Georgakoudis                                       Instruction *SeqStartI,
89197517055SGiorgis Georgakoudis                                       Instruction *SeqEndI) {
89297517055SGiorgis Georgakoudis       // Isolate the instructions of the sequential region to a separate
89397517055SGiorgis Georgakoudis       // block.
89497517055SGiorgis Georgakoudis       BasicBlock *ParentBB = SeqStartI->getParent();
89597517055SGiorgis Georgakoudis       BasicBlock *SeqEndBB =
89697517055SGiorgis Georgakoudis           SplitBlock(ParentBB, SeqEndI->getNextNode(), DT, LI);
89797517055SGiorgis Georgakoudis       BasicBlock *SeqAfterBB =
89897517055SGiorgis Georgakoudis           SplitBlock(SeqEndBB, &*SeqEndBB->getFirstInsertionPt(), DT, LI);
89997517055SGiorgis Georgakoudis       BasicBlock *SeqStartBB =
90097517055SGiorgis Georgakoudis           SplitBlock(ParentBB, SeqStartI, DT, LI, nullptr, "seq.par.merged");
90197517055SGiorgis Georgakoudis 
90297517055SGiorgis Georgakoudis       assert(ParentBB->getUniqueSuccessor() == SeqStartBB &&
90397517055SGiorgis Georgakoudis              "Expected a different CFG");
90497517055SGiorgis Georgakoudis       const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc();
90597517055SGiorgis Georgakoudis       ParentBB->getTerminator()->eraseFromParent();
90697517055SGiorgis Georgakoudis 
90797517055SGiorgis Georgakoudis       auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
90897517055SGiorgis Georgakoudis                            BasicBlock &ContinuationIP) {
90997517055SGiorgis Georgakoudis         BasicBlock *CGStartBB = CodeGenIP.getBlock();
91097517055SGiorgis Georgakoudis         BasicBlock *CGEndBB =
91197517055SGiorgis Georgakoudis             SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
91297517055SGiorgis Georgakoudis         assert(SeqStartBB != nullptr && "SeqStartBB should not be null");
91397517055SGiorgis Georgakoudis         CGStartBB->getTerminator()->setSuccessor(0, SeqStartBB);
91497517055SGiorgis Georgakoudis         assert(SeqEndBB != nullptr && "SeqEndBB should not be null");
91597517055SGiorgis Georgakoudis         SeqEndBB->getTerminator()->setSuccessor(0, CGEndBB);
91697517055SGiorgis Georgakoudis       };
91797517055SGiorgis Georgakoudis       auto FiniCB = [&](InsertPointTy CodeGenIP) {};
91897517055SGiorgis Georgakoudis 
91997517055SGiorgis Georgakoudis       // Find outputs from the sequential region to outside users and
92097517055SGiorgis Georgakoudis       // broadcast their values to them.
92197517055SGiorgis Georgakoudis       for (Instruction &I : *SeqStartBB) {
92297517055SGiorgis Georgakoudis         SmallPtrSet<Instruction *, 4> OutsideUsers;
92397517055SGiorgis Georgakoudis         for (User *Usr : I.users()) {
92497517055SGiorgis Georgakoudis           Instruction &UsrI = *cast<Instruction>(Usr);
92597517055SGiorgis Georgakoudis           // Ignore outputs to LT intrinsics, code extraction for the merged
92697517055SGiorgis Georgakoudis           // parallel region will fix them.
92797517055SGiorgis Georgakoudis           if (UsrI.isLifetimeStartOrEnd())
92897517055SGiorgis Georgakoudis             continue;
92997517055SGiorgis Georgakoudis 
93097517055SGiorgis Georgakoudis           if (UsrI.getParent() != SeqStartBB)
93197517055SGiorgis Georgakoudis             OutsideUsers.insert(&UsrI);
93297517055SGiorgis Georgakoudis         }
93397517055SGiorgis Georgakoudis 
93497517055SGiorgis Georgakoudis         if (OutsideUsers.empty())
93597517055SGiorgis Georgakoudis           continue;
93697517055SGiorgis Georgakoudis 
93797517055SGiorgis Georgakoudis         // Emit an alloca in the outer region to store the broadcasted
93897517055SGiorgis Georgakoudis         // value.
93997517055SGiorgis Georgakoudis         const DataLayout &DL = M.getDataLayout();
94097517055SGiorgis Georgakoudis         AllocaInst *AllocaI = new AllocaInst(
94197517055SGiorgis Georgakoudis             I.getType(), DL.getAllocaAddrSpace(), nullptr,
94297517055SGiorgis Georgakoudis             I.getName() + ".seq.output.alloc", &OuterFn->front().front());
94397517055SGiorgis Georgakoudis 
94497517055SGiorgis Georgakoudis         // Emit a store instruction in the sequential BB to update the
94597517055SGiorgis Georgakoudis         // value.
94697517055SGiorgis Georgakoudis         new StoreInst(&I, AllocaI, SeqStartBB->getTerminator());
94797517055SGiorgis Georgakoudis 
94897517055SGiorgis Georgakoudis         // Emit a load instruction and replace the use of the output value
94997517055SGiorgis Georgakoudis         // with it.
95097517055SGiorgis Georgakoudis         for (Instruction *UsrI : OutsideUsers) {
9515b70c12fSJohannes Doerfert           LoadInst *LoadI = new LoadInst(
9525b70c12fSJohannes Doerfert               I.getType(), AllocaI, I.getName() + ".seq.output.load", UsrI);
95397517055SGiorgis Georgakoudis           UsrI->replaceUsesOfWith(&I, LoadI);
95497517055SGiorgis Georgakoudis         }
95597517055SGiorgis Georgakoudis       }
95697517055SGiorgis Georgakoudis 
95797517055SGiorgis Georgakoudis       OpenMPIRBuilder::LocationDescription Loc(
95897517055SGiorgis Georgakoudis           InsertPointTy(ParentBB, ParentBB->end()), DL);
95997517055SGiorgis Georgakoudis       InsertPointTy SeqAfterIP =
96097517055SGiorgis Georgakoudis           OMPInfoCache.OMPBuilder.createMaster(Loc, BodyGenCB, FiniCB);
96197517055SGiorgis Georgakoudis 
96297517055SGiorgis Georgakoudis       OMPInfoCache.OMPBuilder.createBarrier(SeqAfterIP, OMPD_parallel);
96397517055SGiorgis Georgakoudis 
96497517055SGiorgis Georgakoudis       BranchInst::Create(SeqAfterBB, SeqAfterIP.getBlock());
96597517055SGiorgis Georgakoudis 
96697517055SGiorgis Georgakoudis       LLVM_DEBUG(dbgs() << TAG << "After sequential inlining " << *OuterFn
96797517055SGiorgis Georgakoudis                         << "\n");
96897517055SGiorgis Georgakoudis     };
96997517055SGiorgis Georgakoudis 
9703a6bfcf2SGiorgis Georgakoudis     // Helper to merge the __kmpc_fork_call calls in MergableCIs. They are all
9713a6bfcf2SGiorgis Georgakoudis     // contained in BB and only separated by instructions that can be
9723a6bfcf2SGiorgis Georgakoudis     // redundantly executed in parallel. The block BB is split before the first
9733a6bfcf2SGiorgis Georgakoudis     // call (in MergableCIs) and after the last so the entire region we merge
9743a6bfcf2SGiorgis Georgakoudis     // into a single parallel region is contained in a single basic block
9753a6bfcf2SGiorgis Georgakoudis     // without any other instructions. We use the OpenMPIRBuilder to outline
9763a6bfcf2SGiorgis Georgakoudis     // that block and call the resulting function via __kmpc_fork_call.
9773a6bfcf2SGiorgis Georgakoudis     auto Merge = [&](SmallVectorImpl<CallInst *> &MergableCIs, BasicBlock *BB) {
9783a6bfcf2SGiorgis Georgakoudis       // TODO: Change the interface to allow single CIs expanded, e.g, to
9793a6bfcf2SGiorgis Georgakoudis       // include an outer loop.
9803a6bfcf2SGiorgis Georgakoudis       assert(MergableCIs.size() > 1 && "Assumed multiple mergable CIs");
9813a6bfcf2SGiorgis Georgakoudis 
9823a6bfcf2SGiorgis Georgakoudis       auto Remark = [&](OptimizationRemark OR) {
983eef6601bSJoseph Huber         OR << "Parallel region merged with parallel region"
984eef6601bSJoseph Huber            << (MergableCIs.size() > 2 ? "s" : "") << " at ";
98523b0ab2aSKazu Hirata         for (auto *CI : llvm::drop_begin(MergableCIs)) {
9863a6bfcf2SGiorgis Georgakoudis           OR << ore::NV("OpenMPParallelMerge", CI->getDebugLoc());
9873a6bfcf2SGiorgis Georgakoudis           if (CI != MergableCIs.back())
9883a6bfcf2SGiorgis Georgakoudis             OR << ", ";
9893a6bfcf2SGiorgis Georgakoudis         }
990eef6601bSJoseph Huber         return OR << ".";
9913a6bfcf2SGiorgis Georgakoudis       };
9923a6bfcf2SGiorgis Georgakoudis 
9932c31d5ebSJoseph Huber       emitRemark<OptimizationRemark>(MergableCIs.front(), "OMP150", Remark);
9943a6bfcf2SGiorgis Georgakoudis 
9953a6bfcf2SGiorgis Georgakoudis       Function *OriginalFn = BB->getParent();
9963a6bfcf2SGiorgis Georgakoudis       LLVM_DEBUG(dbgs() << TAG << "Merge " << MergableCIs.size()
9973a6bfcf2SGiorgis Georgakoudis                         << " parallel regions in " << OriginalFn->getName()
9983a6bfcf2SGiorgis Georgakoudis                         << "\n");
9993a6bfcf2SGiorgis Georgakoudis 
10003a6bfcf2SGiorgis Georgakoudis       // Isolate the calls to merge in a separate block.
10013a6bfcf2SGiorgis Georgakoudis       EndBB = SplitBlock(BB, MergableCIs.back()->getNextNode(), DT, LI);
10023a6bfcf2SGiorgis Georgakoudis       BasicBlock *AfterBB =
10033a6bfcf2SGiorgis Georgakoudis           SplitBlock(EndBB, &*EndBB->getFirstInsertionPt(), DT, LI);
10043a6bfcf2SGiorgis Georgakoudis       StartBB = SplitBlock(BB, MergableCIs.front(), DT, LI, nullptr,
10053a6bfcf2SGiorgis Georgakoudis                            "omp.par.merged");
10063a6bfcf2SGiorgis Georgakoudis 
10073a6bfcf2SGiorgis Georgakoudis       assert(BB->getUniqueSuccessor() == StartBB && "Expected a different CFG");
10083a6bfcf2SGiorgis Georgakoudis       const DebugLoc DL = BB->getTerminator()->getDebugLoc();
10093a6bfcf2SGiorgis Georgakoudis       BB->getTerminator()->eraseFromParent();
10103a6bfcf2SGiorgis Georgakoudis 
101197517055SGiorgis Georgakoudis       // Create sequential regions for sequential instructions that are
101297517055SGiorgis Georgakoudis       // in-between mergable parallel regions.
101397517055SGiorgis Georgakoudis       for (auto *It = MergableCIs.begin(), *End = MergableCIs.end() - 1;
101497517055SGiorgis Georgakoudis            It != End; ++It) {
101597517055SGiorgis Georgakoudis         Instruction *ForkCI = *It;
101697517055SGiorgis Georgakoudis         Instruction *NextForkCI = *(It + 1);
101797517055SGiorgis Georgakoudis 
101897517055SGiorgis Georgakoudis         // Continue if there are not in-between instructions.
101997517055SGiorgis Georgakoudis         if (ForkCI->getNextNode() == NextForkCI)
102097517055SGiorgis Georgakoudis           continue;
102197517055SGiorgis Georgakoudis 
102297517055SGiorgis Georgakoudis         CreateSequentialRegion(OriginalFn, BB, ForkCI->getNextNode(),
102397517055SGiorgis Georgakoudis                                NextForkCI->getPrevNode());
102497517055SGiorgis Georgakoudis       }
102597517055SGiorgis Georgakoudis 
10263a6bfcf2SGiorgis Georgakoudis       OpenMPIRBuilder::LocationDescription Loc(InsertPointTy(BB, BB->end()),
10273a6bfcf2SGiorgis Georgakoudis                                                DL);
10283a6bfcf2SGiorgis Georgakoudis       IRBuilder<>::InsertPoint AllocaIP(
10293a6bfcf2SGiorgis Georgakoudis           &OriginalFn->getEntryBlock(),
10303a6bfcf2SGiorgis Georgakoudis           OriginalFn->getEntryBlock().getFirstInsertionPt());
10313a6bfcf2SGiorgis Georgakoudis       // Create the merged parallel region with default proc binding, to
10323a6bfcf2SGiorgis Georgakoudis       // avoid overriding binding settings, and without explicit cancellation.
1033e5dba2d7SMichael Kruse       InsertPointTy AfterIP = OMPInfoCache.OMPBuilder.createParallel(
10343a6bfcf2SGiorgis Georgakoudis           Loc, AllocaIP, BodyGenCB, PrivCB, FiniCB, nullptr, nullptr,
10353a6bfcf2SGiorgis Georgakoudis           OMP_PROC_BIND_default, /* IsCancellable */ false);
10363a6bfcf2SGiorgis Georgakoudis       BranchInst::Create(AfterBB, AfterIP.getBlock());
10373a6bfcf2SGiorgis Georgakoudis 
10383a6bfcf2SGiorgis Georgakoudis       // Perform the actual outlining.
1039b1191206SMichael Kruse       OMPInfoCache.OMPBuilder.finalize(OriginalFn,
1040b1191206SMichael Kruse                                        /* AllowExtractorSinking */ true);
10413a6bfcf2SGiorgis Georgakoudis 
10423a6bfcf2SGiorgis Georgakoudis       Function *OutlinedFn = MergableCIs.front()->getCaller();
10433a6bfcf2SGiorgis Georgakoudis 
10443a6bfcf2SGiorgis Georgakoudis       // Replace the __kmpc_fork_call calls with direct calls to the outlined
10453a6bfcf2SGiorgis Georgakoudis       // callbacks.
10463a6bfcf2SGiorgis Georgakoudis       SmallVector<Value *, 8> Args;
10473a6bfcf2SGiorgis Georgakoudis       for (auto *CI : MergableCIs) {
10483a6bfcf2SGiorgis Georgakoudis         Value *Callee =
10493a6bfcf2SGiorgis Georgakoudis             CI->getArgOperand(CallbackCalleeOperand)->stripPointerCasts();
10503a6bfcf2SGiorgis Georgakoudis         FunctionType *FT =
10513a6bfcf2SGiorgis Georgakoudis             cast<FunctionType>(Callee->getType()->getPointerElementType());
10523a6bfcf2SGiorgis Georgakoudis         Args.clear();
10533a6bfcf2SGiorgis Georgakoudis         Args.push_back(OutlinedFn->getArg(0));
10543a6bfcf2SGiorgis Georgakoudis         Args.push_back(OutlinedFn->getArg(1));
10553a6bfcf2SGiorgis Georgakoudis         for (unsigned U = CallbackFirstArgOperand, E = CI->getNumArgOperands();
10563a6bfcf2SGiorgis Georgakoudis              U < E; ++U)
10573a6bfcf2SGiorgis Georgakoudis           Args.push_back(CI->getArgOperand(U));
10583a6bfcf2SGiorgis Georgakoudis 
10593a6bfcf2SGiorgis Georgakoudis         CallInst *NewCI = CallInst::Create(FT, Callee, Args, "", CI);
10603a6bfcf2SGiorgis Georgakoudis         if (CI->getDebugLoc())
10613a6bfcf2SGiorgis Georgakoudis           NewCI->setDebugLoc(CI->getDebugLoc());
10623a6bfcf2SGiorgis Georgakoudis 
10633a6bfcf2SGiorgis Georgakoudis         // Forward parameter attributes from the callback to the callee.
10643a6bfcf2SGiorgis Georgakoudis         for (unsigned U = CallbackFirstArgOperand, E = CI->getNumArgOperands();
10653a6bfcf2SGiorgis Georgakoudis              U < E; ++U)
1066*80ea2bb5SArthur Eubanks           for (const Attribute &A : CI->getAttributes().getParamAttrs(U))
10673a6bfcf2SGiorgis Georgakoudis             NewCI->addParamAttr(
10683a6bfcf2SGiorgis Georgakoudis                 U - (CallbackFirstArgOperand - CallbackCalleeOperand), A);
10693a6bfcf2SGiorgis Georgakoudis 
10703a6bfcf2SGiorgis Georgakoudis         // Emit an explicit barrier to replace the implicit fork-join barrier.
10713a6bfcf2SGiorgis Georgakoudis         if (CI != MergableCIs.back()) {
10723a6bfcf2SGiorgis Georgakoudis           // TODO: Remove barrier if the merged parallel region includes the
10733a6bfcf2SGiorgis Georgakoudis           // 'nowait' clause.
1074e5dba2d7SMichael Kruse           OMPInfoCache.OMPBuilder.createBarrier(
10753a6bfcf2SGiorgis Georgakoudis               InsertPointTy(NewCI->getParent(),
10763a6bfcf2SGiorgis Georgakoudis                             NewCI->getNextNode()->getIterator()),
10773a6bfcf2SGiorgis Georgakoudis               OMPD_parallel);
10783a6bfcf2SGiorgis Georgakoudis         }
10793a6bfcf2SGiorgis Georgakoudis 
10803a6bfcf2SGiorgis Georgakoudis         CI->eraseFromParent();
10813a6bfcf2SGiorgis Georgakoudis       }
10823a6bfcf2SGiorgis Georgakoudis 
10833a6bfcf2SGiorgis Georgakoudis       assert(OutlinedFn != OriginalFn && "Outlining failed");
10847fea561eSArthur Eubanks       CGUpdater.registerOutlinedFunction(*OriginalFn, *OutlinedFn);
10853a6bfcf2SGiorgis Georgakoudis       CGUpdater.reanalyzeFunction(*OriginalFn);
10863a6bfcf2SGiorgis Georgakoudis 
10873a6bfcf2SGiorgis Georgakoudis       NumOpenMPParallelRegionsMerged += MergableCIs.size();
10883a6bfcf2SGiorgis Georgakoudis 
10893a6bfcf2SGiorgis Georgakoudis       return true;
10903a6bfcf2SGiorgis Georgakoudis     };
10913a6bfcf2SGiorgis Georgakoudis 
10923a6bfcf2SGiorgis Georgakoudis     // Helper function that identifes sequences of
10933a6bfcf2SGiorgis Georgakoudis     // __kmpc_fork_call uses in a basic block.
10943a6bfcf2SGiorgis Georgakoudis     auto DetectPRsCB = [&](Use &U, Function &F) {
10953a6bfcf2SGiorgis Georgakoudis       CallInst *CI = getCallIfRegularCall(U, &RFI);
10963a6bfcf2SGiorgis Georgakoudis       BB2PRMap[CI->getParent()].insert(CI);
10973a6bfcf2SGiorgis Georgakoudis 
10983a6bfcf2SGiorgis Georgakoudis       return false;
10993a6bfcf2SGiorgis Georgakoudis     };
11003a6bfcf2SGiorgis Georgakoudis 
11013a6bfcf2SGiorgis Georgakoudis     BB2PRMap.clear();
11023a6bfcf2SGiorgis Georgakoudis     RFI.foreachUse(SCC, DetectPRsCB);
11033a6bfcf2SGiorgis Georgakoudis     SmallVector<SmallVector<CallInst *, 4>, 4> MergableCIsVector;
11043a6bfcf2SGiorgis Georgakoudis     // Find mergable parallel regions within a basic block that are
11053a6bfcf2SGiorgis Georgakoudis     // safe to merge, that is any in-between instructions can safely
11063a6bfcf2SGiorgis Georgakoudis     // execute in parallel after merging.
11073a6bfcf2SGiorgis Georgakoudis     // TODO: support merging across basic-blocks.
11083a6bfcf2SGiorgis Georgakoudis     for (auto &It : BB2PRMap) {
11093a6bfcf2SGiorgis Georgakoudis       auto &CIs = It.getSecond();
11103a6bfcf2SGiorgis Georgakoudis       if (CIs.size() < 2)
11113a6bfcf2SGiorgis Georgakoudis         continue;
11123a6bfcf2SGiorgis Georgakoudis 
11133a6bfcf2SGiorgis Georgakoudis       BasicBlock *BB = It.getFirst();
11143a6bfcf2SGiorgis Georgakoudis       SmallVector<CallInst *, 4> MergableCIs;
11153a6bfcf2SGiorgis Georgakoudis 
111697517055SGiorgis Georgakoudis       /// Returns true if the instruction is mergable, false otherwise.
111797517055SGiorgis Georgakoudis       /// A terminator instruction is unmergable by definition since merging
111897517055SGiorgis Georgakoudis       /// works within a BB. Instructions before the mergable region are
111997517055SGiorgis Georgakoudis       /// mergable if they are not calls to OpenMP runtime functions that may
112097517055SGiorgis Georgakoudis       /// set different execution parameters for subsequent parallel regions.
112197517055SGiorgis Georgakoudis       /// Instructions in-between parallel regions are mergable if they are not
112297517055SGiorgis Georgakoudis       /// calls to any non-intrinsic function since that may call a non-mergable
112397517055SGiorgis Georgakoudis       /// OpenMP runtime function.
112497517055SGiorgis Georgakoudis       auto IsMergable = [&](Instruction &I, bool IsBeforeMergableRegion) {
112597517055SGiorgis Georgakoudis         // We do not merge across BBs, hence return false (unmergable) if the
112697517055SGiorgis Georgakoudis         // instruction is a terminator.
112797517055SGiorgis Georgakoudis         if (I.isTerminator())
112897517055SGiorgis Georgakoudis           return false;
112997517055SGiorgis Georgakoudis 
113097517055SGiorgis Georgakoudis         if (!isa<CallInst>(&I))
113197517055SGiorgis Georgakoudis           return true;
113297517055SGiorgis Georgakoudis 
113397517055SGiorgis Georgakoudis         CallInst *CI = cast<CallInst>(&I);
113497517055SGiorgis Georgakoudis         if (IsBeforeMergableRegion) {
113597517055SGiorgis Georgakoudis           Function *CalledFunction = CI->getCalledFunction();
113697517055SGiorgis Georgakoudis           if (!CalledFunction)
113797517055SGiorgis Georgakoudis             return false;
113897517055SGiorgis Georgakoudis           // Return false (unmergable) if the call before the parallel
113997517055SGiorgis Georgakoudis           // region calls an explicit affinity (proc_bind) or number of
114097517055SGiorgis Georgakoudis           // threads (num_threads) compiler-generated function. Those settings
114197517055SGiorgis Georgakoudis           // may be incompatible with following parallel regions.
114297517055SGiorgis Georgakoudis           // TODO: ICV tracking to detect compatibility.
114397517055SGiorgis Georgakoudis           for (const auto &RFI : UnmergableCallsInfo) {
114497517055SGiorgis Georgakoudis             if (CalledFunction == RFI.Declaration)
114597517055SGiorgis Georgakoudis               return false;
114697517055SGiorgis Georgakoudis           }
114797517055SGiorgis Georgakoudis         } else {
114897517055SGiorgis Georgakoudis           // Return false (unmergable) if there is a call instruction
114997517055SGiorgis Georgakoudis           // in-between parallel regions when it is not an intrinsic. It
115097517055SGiorgis Georgakoudis           // may call an unmergable OpenMP runtime function in its callpath.
115197517055SGiorgis Georgakoudis           // TODO: Keep track of possible OpenMP calls in the callpath.
115297517055SGiorgis Georgakoudis           if (!isa<IntrinsicInst>(CI))
115397517055SGiorgis Georgakoudis             return false;
115497517055SGiorgis Georgakoudis         }
115597517055SGiorgis Georgakoudis 
115697517055SGiorgis Georgakoudis         return true;
115797517055SGiorgis Georgakoudis       };
11583a6bfcf2SGiorgis Georgakoudis       // Find maximal number of parallel region CIs that are safe to merge.
115997517055SGiorgis Georgakoudis       for (auto It = BB->begin(), End = BB->end(); It != End;) {
116097517055SGiorgis Georgakoudis         Instruction &I = *It;
116197517055SGiorgis Georgakoudis         ++It;
116297517055SGiorgis Georgakoudis 
11633a6bfcf2SGiorgis Georgakoudis         if (CIs.count(&I)) {
11643a6bfcf2SGiorgis Georgakoudis           MergableCIs.push_back(cast<CallInst>(&I));
11653a6bfcf2SGiorgis Georgakoudis           continue;
11663a6bfcf2SGiorgis Georgakoudis         }
11673a6bfcf2SGiorgis Georgakoudis 
116897517055SGiorgis Georgakoudis         // Continue expanding if the instruction is mergable.
116997517055SGiorgis Georgakoudis         if (IsMergable(I, MergableCIs.empty()))
11703a6bfcf2SGiorgis Georgakoudis           continue;
11713a6bfcf2SGiorgis Georgakoudis 
117297517055SGiorgis Georgakoudis         // Forward the instruction iterator to skip the next parallel region
117397517055SGiorgis Georgakoudis         // since there is an unmergable instruction which can affect it.
117497517055SGiorgis Georgakoudis         for (; It != End; ++It) {
117597517055SGiorgis Georgakoudis           Instruction &SkipI = *It;
117697517055SGiorgis Georgakoudis           if (CIs.count(&SkipI)) {
117797517055SGiorgis Georgakoudis             LLVM_DEBUG(dbgs() << TAG << "Skip parallel region " << SkipI
117897517055SGiorgis Georgakoudis                               << " due to " << I << "\n");
117997517055SGiorgis Georgakoudis             ++It;
118097517055SGiorgis Georgakoudis             break;
118197517055SGiorgis Georgakoudis           }
118297517055SGiorgis Georgakoudis         }
118397517055SGiorgis Georgakoudis 
118497517055SGiorgis Georgakoudis         // Store mergable regions found.
11853a6bfcf2SGiorgis Georgakoudis         if (MergableCIs.size() > 1) {
11863a6bfcf2SGiorgis Georgakoudis           MergableCIsVector.push_back(MergableCIs);
11873a6bfcf2SGiorgis Georgakoudis           LLVM_DEBUG(dbgs() << TAG << "Found " << MergableCIs.size()
11883a6bfcf2SGiorgis Georgakoudis                             << " parallel regions in block " << BB->getName()
11893a6bfcf2SGiorgis Georgakoudis                             << " of function " << BB->getParent()->getName()
11903a6bfcf2SGiorgis Georgakoudis                             << "\n";);
11913a6bfcf2SGiorgis Georgakoudis         }
11923a6bfcf2SGiorgis Georgakoudis 
11933a6bfcf2SGiorgis Georgakoudis         MergableCIs.clear();
11943a6bfcf2SGiorgis Georgakoudis       }
11953a6bfcf2SGiorgis Georgakoudis 
11963a6bfcf2SGiorgis Georgakoudis       if (!MergableCIsVector.empty()) {
11973a6bfcf2SGiorgis Georgakoudis         Changed = true;
11983a6bfcf2SGiorgis Georgakoudis 
11993a6bfcf2SGiorgis Georgakoudis         for (auto &MergableCIs : MergableCIsVector)
12003a6bfcf2SGiorgis Georgakoudis           Merge(MergableCIs, BB);
1201b2ad63d3SJoseph Huber         MergableCIsVector.clear();
12023a6bfcf2SGiorgis Georgakoudis       }
12033a6bfcf2SGiorgis Georgakoudis     }
12043a6bfcf2SGiorgis Georgakoudis 
12053a6bfcf2SGiorgis Georgakoudis     if (Changed) {
120697517055SGiorgis Georgakoudis       /// Re-collect use for fork calls, emitted barrier calls, and
120797517055SGiorgis Georgakoudis       /// any emitted master/end_master calls.
120897517055SGiorgis Georgakoudis       OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_fork_call);
120997517055SGiorgis Georgakoudis       OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_barrier);
121097517055SGiorgis Georgakoudis       OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_master);
121197517055SGiorgis Georgakoudis       OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_end_master);
12123a6bfcf2SGiorgis Georgakoudis     }
12133a6bfcf2SGiorgis Georgakoudis 
12143a6bfcf2SGiorgis Georgakoudis     return Changed;
12153a6bfcf2SGiorgis Georgakoudis   }
12163a6bfcf2SGiorgis Georgakoudis 
12179d38f98dSJohannes Doerfert   /// Try to delete parallel regions if possible.
1218e565db49SJohannes Doerfert   bool deleteParallelRegions() {
1219e565db49SJohannes Doerfert     const unsigned CallbackCalleeOperand = 2;
1220e565db49SJohannes Doerfert 
12217cfd267cSsstefan1     OMPInformationCache::RuntimeFunctionInfo &RFI =
12227cfd267cSsstefan1         OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
12237cfd267cSsstefan1 
1224e565db49SJohannes Doerfert     if (!RFI.Declaration)
1225e565db49SJohannes Doerfert       return false;
1226e565db49SJohannes Doerfert 
1227e565db49SJohannes Doerfert     bool Changed = false;
1228e565db49SJohannes Doerfert     auto DeleteCallCB = [&](Use &U, Function &) {
1229e565db49SJohannes Doerfert       CallInst *CI = getCallIfRegularCall(U);
1230e565db49SJohannes Doerfert       if (!CI)
1231e565db49SJohannes Doerfert         return false;
1232e565db49SJohannes Doerfert       auto *Fn = dyn_cast<Function>(
1233e565db49SJohannes Doerfert           CI->getArgOperand(CallbackCalleeOperand)->stripPointerCasts());
1234e565db49SJohannes Doerfert       if (!Fn)
1235e565db49SJohannes Doerfert         return false;
1236e565db49SJohannes Doerfert       if (!Fn->onlyReadsMemory())
1237e565db49SJohannes Doerfert         return false;
1238e565db49SJohannes Doerfert       if (!Fn->hasFnAttribute(Attribute::WillReturn))
1239e565db49SJohannes Doerfert         return false;
1240e565db49SJohannes Doerfert 
1241e565db49SJohannes Doerfert       LLVM_DEBUG(dbgs() << TAG << "Delete read-only parallel region in "
1242e565db49SJohannes Doerfert                         << CI->getCaller()->getName() << "\n");
12434d4ea9acSHuber, Joseph 
12444d4ea9acSHuber, Joseph       auto Remark = [&](OptimizationRemark OR) {
1245eef6601bSJoseph Huber         return OR << "Removing parallel region with no side-effects.";
12464d4ea9acSHuber, Joseph       };
12472c31d5ebSJoseph Huber       emitRemark<OptimizationRemark>(CI, "OMP160", Remark);
12484d4ea9acSHuber, Joseph 
1249e565db49SJohannes Doerfert       CGUpdater.removeCallSite(*CI);
1250e565db49SJohannes Doerfert       CI->eraseFromParent();
1251e565db49SJohannes Doerfert       Changed = true;
125255eb714aSRoman Lebedev       ++NumOpenMPParallelRegionsDeleted;
1253e565db49SJohannes Doerfert       return true;
1254e565db49SJohannes Doerfert     };
1255e565db49SJohannes Doerfert 
1256624d34afSJohannes Doerfert     RFI.foreachUse(SCC, DeleteCallCB);
1257e565db49SJohannes Doerfert 
1258e565db49SJohannes Doerfert     return Changed;
1259e565db49SJohannes Doerfert   }
1260e565db49SJohannes Doerfert 
1261b726c557SJohannes Doerfert   /// Try to eliminate runtime calls by reusing existing ones.
12629548b74aSJohannes Doerfert   bool deduplicateRuntimeCalls() {
12639548b74aSJohannes Doerfert     bool Changed = false;
12649548b74aSJohannes Doerfert 
1265e28936f6SJohannes Doerfert     RuntimeFunction DeduplicableRuntimeCallIDs[] = {
1266e28936f6SJohannes Doerfert         OMPRTL_omp_get_num_threads,
1267e28936f6SJohannes Doerfert         OMPRTL_omp_in_parallel,
1268e28936f6SJohannes Doerfert         OMPRTL_omp_get_cancellation,
1269e28936f6SJohannes Doerfert         OMPRTL_omp_get_thread_limit,
1270e28936f6SJohannes Doerfert         OMPRTL_omp_get_supported_active_levels,
1271e28936f6SJohannes Doerfert         OMPRTL_omp_get_level,
1272e28936f6SJohannes Doerfert         OMPRTL_omp_get_ancestor_thread_num,
1273e28936f6SJohannes Doerfert         OMPRTL_omp_get_team_size,
1274e28936f6SJohannes Doerfert         OMPRTL_omp_get_active_level,
1275e28936f6SJohannes Doerfert         OMPRTL_omp_in_final,
1276e28936f6SJohannes Doerfert         OMPRTL_omp_get_proc_bind,
1277e28936f6SJohannes Doerfert         OMPRTL_omp_get_num_places,
1278e28936f6SJohannes Doerfert         OMPRTL_omp_get_num_procs,
1279e28936f6SJohannes Doerfert         OMPRTL_omp_get_place_num,
1280e28936f6SJohannes Doerfert         OMPRTL_omp_get_partition_num_places,
1281e28936f6SJohannes Doerfert         OMPRTL_omp_get_partition_place_nums};
1282e28936f6SJohannes Doerfert 
1283bc93c2d7SMarek Kurdej     // Global-tid is handled separately.
12849548b74aSJohannes Doerfert     SmallSetVector<Value *, 16> GTIdArgs;
12859548b74aSJohannes Doerfert     collectGlobalThreadIdArguments(GTIdArgs);
12869548b74aSJohannes Doerfert     LLVM_DEBUG(dbgs() << TAG << "Found " << GTIdArgs.size()
12879548b74aSJohannes Doerfert                       << " global thread ID arguments\n");
12889548b74aSJohannes Doerfert 
12899548b74aSJohannes Doerfert     for (Function *F : SCC) {
1290e28936f6SJohannes Doerfert       for (auto DeduplicableRuntimeCallID : DeduplicableRuntimeCallIDs)
12914e29d256Sserge-sans-paille         Changed |= deduplicateRuntimeCalls(
12924e29d256Sserge-sans-paille             *F, OMPInfoCache.RFIs[DeduplicableRuntimeCallID]);
1293e28936f6SJohannes Doerfert 
1294e28936f6SJohannes Doerfert       // __kmpc_global_thread_num is special as we can replace it with an
1295e28936f6SJohannes Doerfert       // argument in enough cases to make it worth trying.
12969548b74aSJohannes Doerfert       Value *GTIdArg = nullptr;
12979548b74aSJohannes Doerfert       for (Argument &Arg : F->args())
12989548b74aSJohannes Doerfert         if (GTIdArgs.count(&Arg)) {
12999548b74aSJohannes Doerfert           GTIdArg = &Arg;
13009548b74aSJohannes Doerfert           break;
13019548b74aSJohannes Doerfert         }
13029548b74aSJohannes Doerfert       Changed |= deduplicateRuntimeCalls(
13037cfd267cSsstefan1           *F, OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num], GTIdArg);
13049548b74aSJohannes Doerfert     }
13059548b74aSJohannes Doerfert 
13069548b74aSJohannes Doerfert     return Changed;
13079548b74aSJohannes Doerfert   }
13089548b74aSJohannes Doerfert 
1309496f8e5bSHamilton Tobon Mosquera   /// Tries to hide the latency of runtime calls that involve host to
1310496f8e5bSHamilton Tobon Mosquera   /// device memory transfers by splitting them into their "issue" and "wait"
1311496f8e5bSHamilton Tobon Mosquera   /// versions. The "issue" is moved upwards as much as possible. The "wait" is
1312496f8e5bSHamilton Tobon Mosquera   /// moved downards as much as possible. The "issue" issues the memory transfer
1313496f8e5bSHamilton Tobon Mosquera   /// asynchronously, returning a handle. The "wait" waits in the returned
1314496f8e5bSHamilton Tobon Mosquera   /// handle for the memory transfer to finish.
1315496f8e5bSHamilton Tobon Mosquera   bool hideMemTransfersLatency() {
1316496f8e5bSHamilton Tobon Mosquera     auto &RFI = OMPInfoCache.RFIs[OMPRTL___tgt_target_data_begin_mapper];
1317496f8e5bSHamilton Tobon Mosquera     bool Changed = false;
1318496f8e5bSHamilton Tobon Mosquera     auto SplitMemTransfers = [&](Use &U, Function &Decl) {
1319496f8e5bSHamilton Tobon Mosquera       auto *RTCall = getCallIfRegularCall(U, &RFI);
1320496f8e5bSHamilton Tobon Mosquera       if (!RTCall)
1321496f8e5bSHamilton Tobon Mosquera         return false;
1322496f8e5bSHamilton Tobon Mosquera 
13238931add6SHamilton Tobon Mosquera       OffloadArray OffloadArrays[3];
13248931add6SHamilton Tobon Mosquera       if (!getValuesInOffloadArrays(*RTCall, OffloadArrays))
13258931add6SHamilton Tobon Mosquera         return false;
13268931add6SHamilton Tobon Mosquera 
13278931add6SHamilton Tobon Mosquera       LLVM_DEBUG(dumpValuesInOffloadArrays(OffloadArrays));
13288931add6SHamilton Tobon Mosquera 
1329bd2fa181SHamilton Tobon Mosquera       // TODO: Check if can be moved upwards.
1330bd2fa181SHamilton Tobon Mosquera       bool WasSplit = false;
1331bd2fa181SHamilton Tobon Mosquera       Instruction *WaitMovementPoint = canBeMovedDownwards(*RTCall);
1332bd2fa181SHamilton Tobon Mosquera       if (WaitMovementPoint)
1333bd2fa181SHamilton Tobon Mosquera         WasSplit = splitTargetDataBeginRTC(*RTCall, *WaitMovementPoint);
1334bd2fa181SHamilton Tobon Mosquera 
1335496f8e5bSHamilton Tobon Mosquera       Changed |= WasSplit;
1336496f8e5bSHamilton Tobon Mosquera       return WasSplit;
1337496f8e5bSHamilton Tobon Mosquera     };
1338496f8e5bSHamilton Tobon Mosquera     RFI.foreachUse(SCC, SplitMemTransfers);
1339496f8e5bSHamilton Tobon Mosquera 
1340496f8e5bSHamilton Tobon Mosquera     return Changed;
1341496f8e5bSHamilton Tobon Mosquera   }
1342496f8e5bSHamilton Tobon Mosquera 
1343a2281419SJoseph Huber   void analysisGlobalization() {
13446fc51c9fSJoseph Huber     auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
134582453e75SJoseph Huber 
134682453e75SJoseph Huber     auto CheckGlobalization = [&](Use &U, Function &Decl) {
1347a2281419SJoseph Huber       if (CallInst *CI = getCallIfRegularCall(U, &RFI)) {
134844feacc7SJoseph Huber         auto Remark = [&](OptimizationRemarkMissed ORM) {
134944feacc7SJoseph Huber           return ORM
1350a2281419SJoseph Huber                  << "Found thread data sharing on the GPU. "
1351a2281419SJoseph Huber                  << "Expect degraded performance due to data globalization.";
1352a2281419SJoseph Huber         };
13532c31d5ebSJoseph Huber         emitRemark<OptimizationRemarkMissed>(CI, "OMP112", Remark);
1354a2281419SJoseph Huber       }
1355a2281419SJoseph Huber 
1356a2281419SJoseph Huber       return false;
1357a2281419SJoseph Huber     };
1358a2281419SJoseph Huber 
135982453e75SJoseph Huber     RFI.foreachUse(SCC, CheckGlobalization);
136082453e75SJoseph Huber   }
1361a2281419SJoseph Huber 
13628931add6SHamilton Tobon Mosquera   /// Maps the values stored in the offload arrays passed as arguments to
13638931add6SHamilton Tobon Mosquera   /// \p RuntimeCall into the offload arrays in \p OAs.
13648931add6SHamilton Tobon Mosquera   bool getValuesInOffloadArrays(CallInst &RuntimeCall,
13658931add6SHamilton Tobon Mosquera                                 MutableArrayRef<OffloadArray> OAs) {
13668931add6SHamilton Tobon Mosquera     assert(OAs.size() == 3 && "Need space for three offload arrays!");
13678931add6SHamilton Tobon Mosquera 
13688931add6SHamilton Tobon Mosquera     // A runtime call that involves memory offloading looks something like:
13698931add6SHamilton Tobon Mosquera     // call void @__tgt_target_data_begin_mapper(arg0, arg1,
13708931add6SHamilton Tobon Mosquera     //   i8** %offload_baseptrs, i8** %offload_ptrs, i64* %offload_sizes,
13718931add6SHamilton Tobon Mosquera     // ...)
13728931add6SHamilton Tobon Mosquera     // So, the idea is to access the allocas that allocate space for these
13738931add6SHamilton Tobon Mosquera     // offload arrays, offload_baseptrs, offload_ptrs, offload_sizes.
13748931add6SHamilton Tobon Mosquera     // Therefore:
13758931add6SHamilton Tobon Mosquera     // i8** %offload_baseptrs.
13761d3d9b9cSHamilton Tobon Mosquera     Value *BasePtrsArg =
13771d3d9b9cSHamilton Tobon Mosquera         RuntimeCall.getArgOperand(OffloadArray::BasePtrsArgNum);
13788931add6SHamilton Tobon Mosquera     // i8** %offload_ptrs.
13791d3d9b9cSHamilton Tobon Mosquera     Value *PtrsArg = RuntimeCall.getArgOperand(OffloadArray::PtrsArgNum);
13808931add6SHamilton Tobon Mosquera     // i8** %offload_sizes.
13811d3d9b9cSHamilton Tobon Mosquera     Value *SizesArg = RuntimeCall.getArgOperand(OffloadArray::SizesArgNum);
13828931add6SHamilton Tobon Mosquera 
13838931add6SHamilton Tobon Mosquera     // Get values stored in **offload_baseptrs.
13848931add6SHamilton Tobon Mosquera     auto *V = getUnderlyingObject(BasePtrsArg);
13858931add6SHamilton Tobon Mosquera     if (!isa<AllocaInst>(V))
13868931add6SHamilton Tobon Mosquera       return false;
13878931add6SHamilton Tobon Mosquera     auto *BasePtrsArray = cast<AllocaInst>(V);
13888931add6SHamilton Tobon Mosquera     if (!OAs[0].initialize(*BasePtrsArray, RuntimeCall))
13898931add6SHamilton Tobon Mosquera       return false;
13908931add6SHamilton Tobon Mosquera 
13918931add6SHamilton Tobon Mosquera     // Get values stored in **offload_baseptrs.
13928931add6SHamilton Tobon Mosquera     V = getUnderlyingObject(PtrsArg);
13938931add6SHamilton Tobon Mosquera     if (!isa<AllocaInst>(V))
13948931add6SHamilton Tobon Mosquera       return false;
13958931add6SHamilton Tobon Mosquera     auto *PtrsArray = cast<AllocaInst>(V);
13968931add6SHamilton Tobon Mosquera     if (!OAs[1].initialize(*PtrsArray, RuntimeCall))
13978931add6SHamilton Tobon Mosquera       return false;
13988931add6SHamilton Tobon Mosquera 
13998931add6SHamilton Tobon Mosquera     // Get values stored in **offload_sizes.
14008931add6SHamilton Tobon Mosquera     V = getUnderlyingObject(SizesArg);
14018931add6SHamilton Tobon Mosquera     // If it's a [constant] global array don't analyze it.
14028931add6SHamilton Tobon Mosquera     if (isa<GlobalValue>(V))
14038931add6SHamilton Tobon Mosquera       return isa<Constant>(V);
14048931add6SHamilton Tobon Mosquera     if (!isa<AllocaInst>(V))
14058931add6SHamilton Tobon Mosquera       return false;
14068931add6SHamilton Tobon Mosquera 
14078931add6SHamilton Tobon Mosquera     auto *SizesArray = cast<AllocaInst>(V);
14088931add6SHamilton Tobon Mosquera     if (!OAs[2].initialize(*SizesArray, RuntimeCall))
14098931add6SHamilton Tobon Mosquera       return false;
14108931add6SHamilton Tobon Mosquera 
14118931add6SHamilton Tobon Mosquera     return true;
14128931add6SHamilton Tobon Mosquera   }
14138931add6SHamilton Tobon Mosquera 
14148931add6SHamilton Tobon Mosquera   /// Prints the values in the OffloadArrays \p OAs using LLVM_DEBUG.
14158931add6SHamilton Tobon Mosquera   /// For now this is a way to test that the function getValuesInOffloadArrays
14168931add6SHamilton Tobon Mosquera   /// is working properly.
14178931add6SHamilton Tobon Mosquera   /// TODO: Move this to a unittest when unittests are available for OpenMPOpt.
14188931add6SHamilton Tobon Mosquera   void dumpValuesInOffloadArrays(ArrayRef<OffloadArray> OAs) {
14198931add6SHamilton Tobon Mosquera     assert(OAs.size() == 3 && "There are three offload arrays to debug!");
14208931add6SHamilton Tobon Mosquera 
14218931add6SHamilton Tobon Mosquera     LLVM_DEBUG(dbgs() << TAG << " Successfully got offload values:\n");
14228931add6SHamilton Tobon Mosquera     std::string ValuesStr;
14238931add6SHamilton Tobon Mosquera     raw_string_ostream Printer(ValuesStr);
14248931add6SHamilton Tobon Mosquera     std::string Separator = " --- ";
14258931add6SHamilton Tobon Mosquera 
14268931add6SHamilton Tobon Mosquera     for (auto *BP : OAs[0].StoredValues) {
14278931add6SHamilton Tobon Mosquera       BP->print(Printer);
14288931add6SHamilton Tobon Mosquera       Printer << Separator;
14298931add6SHamilton Tobon Mosquera     }
14308931add6SHamilton Tobon Mosquera     LLVM_DEBUG(dbgs() << "\t\toffload_baseptrs: " << Printer.str() << "\n");
14318931add6SHamilton Tobon Mosquera     ValuesStr.clear();
14328931add6SHamilton Tobon Mosquera 
14338931add6SHamilton Tobon Mosquera     for (auto *P : OAs[1].StoredValues) {
14348931add6SHamilton Tobon Mosquera       P->print(Printer);
14358931add6SHamilton Tobon Mosquera       Printer << Separator;
14368931add6SHamilton Tobon Mosquera     }
14378931add6SHamilton Tobon Mosquera     LLVM_DEBUG(dbgs() << "\t\toffload_ptrs: " << Printer.str() << "\n");
14388931add6SHamilton Tobon Mosquera     ValuesStr.clear();
14398931add6SHamilton Tobon Mosquera 
14408931add6SHamilton Tobon Mosquera     for (auto *S : OAs[2].StoredValues) {
14418931add6SHamilton Tobon Mosquera       S->print(Printer);
14428931add6SHamilton Tobon Mosquera       Printer << Separator;
14438931add6SHamilton Tobon Mosquera     }
14448931add6SHamilton Tobon Mosquera     LLVM_DEBUG(dbgs() << "\t\toffload_sizes: " << Printer.str() << "\n");
14458931add6SHamilton Tobon Mosquera   }
14468931add6SHamilton Tobon Mosquera 
1447bd2fa181SHamilton Tobon Mosquera   /// Returns the instruction where the "wait" counterpart \p RuntimeCall can be
1448bd2fa181SHamilton Tobon Mosquera   /// moved. Returns nullptr if the movement is not possible, or not worth it.
1449bd2fa181SHamilton Tobon Mosquera   Instruction *canBeMovedDownwards(CallInst &RuntimeCall) {
1450bd2fa181SHamilton Tobon Mosquera     // FIXME: This traverses only the BasicBlock where RuntimeCall is.
1451bd2fa181SHamilton Tobon Mosquera     //  Make it traverse the CFG.
1452bd2fa181SHamilton Tobon Mosquera 
1453bd2fa181SHamilton Tobon Mosquera     Instruction *CurrentI = &RuntimeCall;
1454bd2fa181SHamilton Tobon Mosquera     bool IsWorthIt = false;
1455bd2fa181SHamilton Tobon Mosquera     while ((CurrentI = CurrentI->getNextNode())) {
1456bd2fa181SHamilton Tobon Mosquera 
1457bd2fa181SHamilton Tobon Mosquera       // TODO: Once we detect the regions to be offloaded we should use the
1458bd2fa181SHamilton Tobon Mosquera       //  alias analysis manager to check if CurrentI may modify one of
1459bd2fa181SHamilton Tobon Mosquera       //  the offloaded regions.
1460bd2fa181SHamilton Tobon Mosquera       if (CurrentI->mayHaveSideEffects() || CurrentI->mayReadFromMemory()) {
1461bd2fa181SHamilton Tobon Mosquera         if (IsWorthIt)
1462bd2fa181SHamilton Tobon Mosquera           return CurrentI;
1463bd2fa181SHamilton Tobon Mosquera 
1464bd2fa181SHamilton Tobon Mosquera         return nullptr;
1465bd2fa181SHamilton Tobon Mosquera       }
1466bd2fa181SHamilton Tobon Mosquera 
1467bd2fa181SHamilton Tobon Mosquera       // FIXME: For now if we move it over anything without side effect
1468bd2fa181SHamilton Tobon Mosquera       //  is worth it.
1469bd2fa181SHamilton Tobon Mosquera       IsWorthIt = true;
1470bd2fa181SHamilton Tobon Mosquera     }
1471bd2fa181SHamilton Tobon Mosquera 
1472bd2fa181SHamilton Tobon Mosquera     // Return end of BasicBlock.
1473bd2fa181SHamilton Tobon Mosquera     return RuntimeCall.getParent()->getTerminator();
1474bd2fa181SHamilton Tobon Mosquera   }
1475bd2fa181SHamilton Tobon Mosquera 
1476496f8e5bSHamilton Tobon Mosquera   /// Splits \p RuntimeCall into its "issue" and "wait" counterparts.
1477bd2fa181SHamilton Tobon Mosquera   bool splitTargetDataBeginRTC(CallInst &RuntimeCall,
1478bd2fa181SHamilton Tobon Mosquera                                Instruction &WaitMovementPoint) {
1479bd31abc1SHamilton Tobon Mosquera     // Create stack allocated handle (__tgt_async_info) at the beginning of the
1480bd31abc1SHamilton Tobon Mosquera     // function. Used for storing information of the async transfer, allowing to
1481bd31abc1SHamilton Tobon Mosquera     // wait on it later.
1482496f8e5bSHamilton Tobon Mosquera     auto &IRBuilder = OMPInfoCache.OMPBuilder;
1483bd31abc1SHamilton Tobon Mosquera     auto *F = RuntimeCall.getCaller();
1484bd31abc1SHamilton Tobon Mosquera     Instruction *FirstInst = &(F->getEntryBlock().front());
1485bd31abc1SHamilton Tobon Mosquera     AllocaInst *Handle = new AllocaInst(
1486bd31abc1SHamilton Tobon Mosquera         IRBuilder.AsyncInfo, F->getAddressSpace(), "handle", FirstInst);
1487bd31abc1SHamilton Tobon Mosquera 
1488496f8e5bSHamilton Tobon Mosquera     // Add "issue" runtime call declaration:
1489496f8e5bSHamilton Tobon Mosquera     // declare %struct.tgt_async_info @__tgt_target_data_begin_issue(i64, i32,
1490496f8e5bSHamilton Tobon Mosquera     //   i8**, i8**, i64*, i64*)
1491496f8e5bSHamilton Tobon Mosquera     FunctionCallee IssueDecl = IRBuilder.getOrCreateRuntimeFunction(
1492496f8e5bSHamilton Tobon Mosquera         M, OMPRTL___tgt_target_data_begin_mapper_issue);
1493496f8e5bSHamilton Tobon Mosquera 
1494496f8e5bSHamilton Tobon Mosquera     // Change RuntimeCall call site for its asynchronous version.
149597e55cfeSJoseph Huber     SmallVector<Value *, 16> Args;
1496bd2fa181SHamilton Tobon Mosquera     for (auto &Arg : RuntimeCall.args())
1497496f8e5bSHamilton Tobon Mosquera       Args.push_back(Arg.get());
1498bd31abc1SHamilton Tobon Mosquera     Args.push_back(Handle);
1499496f8e5bSHamilton Tobon Mosquera 
1500496f8e5bSHamilton Tobon Mosquera     CallInst *IssueCallsite =
1501bd31abc1SHamilton Tobon Mosquera         CallInst::Create(IssueDecl, Args, /*NameStr=*/"", &RuntimeCall);
1502bd2fa181SHamilton Tobon Mosquera     RuntimeCall.eraseFromParent();
1503496f8e5bSHamilton Tobon Mosquera 
1504496f8e5bSHamilton Tobon Mosquera     // Add "wait" runtime call declaration:
1505496f8e5bSHamilton Tobon Mosquera     // declare void @__tgt_target_data_begin_wait(i64, %struct.__tgt_async_info)
1506496f8e5bSHamilton Tobon Mosquera     FunctionCallee WaitDecl = IRBuilder.getOrCreateRuntimeFunction(
1507496f8e5bSHamilton Tobon Mosquera         M, OMPRTL___tgt_target_data_begin_mapper_wait);
1508496f8e5bSHamilton Tobon Mosquera 
1509496f8e5bSHamilton Tobon Mosquera     Value *WaitParams[2] = {
1510da8bec47SJoseph Huber         IssueCallsite->getArgOperand(
1511da8bec47SJoseph Huber             OffloadArray::DeviceIDArgNum), // device_id.
1512bd31abc1SHamilton Tobon Mosquera         Handle                             // handle to wait on.
1513496f8e5bSHamilton Tobon Mosquera     };
1514bd2fa181SHamilton Tobon Mosquera     CallInst::Create(WaitDecl, WaitParams, /*NameStr=*/"", &WaitMovementPoint);
1515496f8e5bSHamilton Tobon Mosquera 
1516496f8e5bSHamilton Tobon Mosquera     return true;
1517496f8e5bSHamilton Tobon Mosquera   }
1518496f8e5bSHamilton Tobon Mosquera 
1519dc3b5b00SJohannes Doerfert   static Value *combinedIdentStruct(Value *CurrentIdent, Value *NextIdent,
1520dc3b5b00SJohannes Doerfert                                     bool GlobalOnly, bool &SingleChoice) {
1521dc3b5b00SJohannes Doerfert     if (CurrentIdent == NextIdent)
1522dc3b5b00SJohannes Doerfert       return CurrentIdent;
1523dc3b5b00SJohannes Doerfert 
1524396b7253SJohannes Doerfert     // TODO: Figure out how to actually combine multiple debug locations. For
1525dc3b5b00SJohannes Doerfert     //       now we just keep an existing one if there is a single choice.
1526dc3b5b00SJohannes Doerfert     if (!GlobalOnly || isa<GlobalValue>(NextIdent)) {
1527dc3b5b00SJohannes Doerfert       SingleChoice = !CurrentIdent;
1528dc3b5b00SJohannes Doerfert       return NextIdent;
1529dc3b5b00SJohannes Doerfert     }
1530396b7253SJohannes Doerfert     return nullptr;
1531396b7253SJohannes Doerfert   }
1532396b7253SJohannes Doerfert 
1533396b7253SJohannes Doerfert   /// Return an `struct ident_t*` value that represents the ones used in the
1534396b7253SJohannes Doerfert   /// calls of \p RFI inside of \p F. If \p GlobalOnly is true, we will not
1535396b7253SJohannes Doerfert   /// return a local `struct ident_t*`. For now, if we cannot find a suitable
1536396b7253SJohannes Doerfert   /// return value we create one from scratch. We also do not yet combine
1537396b7253SJohannes Doerfert   /// information, e.g., the source locations, see combinedIdentStruct.
15387cfd267cSsstefan1   Value *
15397cfd267cSsstefan1   getCombinedIdentFromCallUsesIn(OMPInformationCache::RuntimeFunctionInfo &RFI,
15407cfd267cSsstefan1                                  Function &F, bool GlobalOnly) {
1541dc3b5b00SJohannes Doerfert     bool SingleChoice = true;
1542396b7253SJohannes Doerfert     Value *Ident = nullptr;
1543396b7253SJohannes Doerfert     auto CombineIdentStruct = [&](Use &U, Function &Caller) {
1544396b7253SJohannes Doerfert       CallInst *CI = getCallIfRegularCall(U, &RFI);
1545396b7253SJohannes Doerfert       if (!CI || &F != &Caller)
1546396b7253SJohannes Doerfert         return false;
1547396b7253SJohannes Doerfert       Ident = combinedIdentStruct(Ident, CI->getArgOperand(0),
1548dc3b5b00SJohannes Doerfert                                   /* GlobalOnly */ true, SingleChoice);
1549396b7253SJohannes Doerfert       return false;
1550396b7253SJohannes Doerfert     };
1551624d34afSJohannes Doerfert     RFI.foreachUse(SCC, CombineIdentStruct);
1552396b7253SJohannes Doerfert 
1553dc3b5b00SJohannes Doerfert     if (!Ident || !SingleChoice) {
1554396b7253SJohannes Doerfert       // The IRBuilder uses the insertion block to get to the module, this is
1555396b7253SJohannes Doerfert       // unfortunate but we work around it for now.
15567cfd267cSsstefan1       if (!OMPInfoCache.OMPBuilder.getInsertionPoint().getBlock())
15577cfd267cSsstefan1         OMPInfoCache.OMPBuilder.updateToLocation(OpenMPIRBuilder::InsertPointTy(
1558396b7253SJohannes Doerfert             &F.getEntryBlock(), F.getEntryBlock().begin()));
1559396b7253SJohannes Doerfert       // Create a fallback location if non was found.
1560396b7253SJohannes Doerfert       // TODO: Use the debug locations of the calls instead.
15617cfd267cSsstefan1       Constant *Loc = OMPInfoCache.OMPBuilder.getOrCreateDefaultSrcLocStr();
15627cfd267cSsstefan1       Ident = OMPInfoCache.OMPBuilder.getOrCreateIdent(Loc);
1563396b7253SJohannes Doerfert     }
1564396b7253SJohannes Doerfert     return Ident;
1565396b7253SJohannes Doerfert   }
1566396b7253SJohannes Doerfert 
1567b726c557SJohannes Doerfert   /// Try to eliminate calls of \p RFI in \p F by reusing an existing one or
15689548b74aSJohannes Doerfert   /// \p ReplVal if given.
15697cfd267cSsstefan1   bool deduplicateRuntimeCalls(Function &F,
15707cfd267cSsstefan1                                OMPInformationCache::RuntimeFunctionInfo &RFI,
15719548b74aSJohannes Doerfert                                Value *ReplVal = nullptr) {
15728855fec3SJohannes Doerfert     auto *UV = RFI.getUseVector(F);
15738855fec3SJohannes Doerfert     if (!UV || UV->size() + (ReplVal != nullptr) < 2)
1574b1fbf438SRoman Lebedev       return false;
1575b1fbf438SRoman Lebedev 
15767cfd267cSsstefan1     LLVM_DEBUG(
15777cfd267cSsstefan1         dbgs() << TAG << "Deduplicate " << UV->size() << " uses of " << RFI.Name
15787cfd267cSsstefan1                << (ReplVal ? " with an existing value\n" : "\n") << "\n");
15797cfd267cSsstefan1 
1580ab3da5ddSMichael Liao     assert((!ReplVal || (isa<Argument>(ReplVal) &&
1581ab3da5ddSMichael Liao                          cast<Argument>(ReplVal)->getParent() == &F)) &&
15829548b74aSJohannes Doerfert            "Unexpected replacement value!");
1583396b7253SJohannes Doerfert 
1584396b7253SJohannes Doerfert     // TODO: Use dominance to find a good position instead.
15856aab27baSsstefan1     auto CanBeMoved = [this](CallBase &CB) {
1586396b7253SJohannes Doerfert       unsigned NumArgs = CB.getNumArgOperands();
1587396b7253SJohannes Doerfert       if (NumArgs == 0)
1588396b7253SJohannes Doerfert         return true;
15896aab27baSsstefan1       if (CB.getArgOperand(0)->getType() != OMPInfoCache.OMPBuilder.IdentPtr)
1590396b7253SJohannes Doerfert         return false;
1591396b7253SJohannes Doerfert       for (unsigned u = 1; u < NumArgs; ++u)
1592396b7253SJohannes Doerfert         if (isa<Instruction>(CB.getArgOperand(u)))
1593396b7253SJohannes Doerfert           return false;
1594396b7253SJohannes Doerfert       return true;
1595396b7253SJohannes Doerfert     };
1596396b7253SJohannes Doerfert 
15979548b74aSJohannes Doerfert     if (!ReplVal) {
15988855fec3SJohannes Doerfert       for (Use *U : *UV)
15999548b74aSJohannes Doerfert         if (CallInst *CI = getCallIfRegularCall(*U, &RFI)) {
1600396b7253SJohannes Doerfert           if (!CanBeMoved(*CI))
1601396b7253SJohannes Doerfert             continue;
16024d4ea9acSHuber, Joseph 
1603f97de4cbSGiorgis Georgakoudis           // If the function is a kernel, dedup will move
1604f97de4cbSGiorgis Georgakoudis           // the runtime call right after the kernel init callsite. Otherwise,
1605f97de4cbSGiorgis Georgakoudis           // it will move it to the beginning of the caller function.
1606f97de4cbSGiorgis Georgakoudis           if (isKernel(F)) {
1607f97de4cbSGiorgis Georgakoudis             auto &KernelInitRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
1608f97de4cbSGiorgis Georgakoudis             auto *KernelInitUV = KernelInitRFI.getUseVector(F);
1609f97de4cbSGiorgis Georgakoudis 
1610f97de4cbSGiorgis Georgakoudis             if (KernelInitUV->empty())
1611f97de4cbSGiorgis Georgakoudis               continue;
1612f97de4cbSGiorgis Georgakoudis 
1613f97de4cbSGiorgis Georgakoudis             assert(KernelInitUV->size() == 1 &&
1614f97de4cbSGiorgis Georgakoudis                    "Expected a single __kmpc_target_init in kernel\n");
1615f97de4cbSGiorgis Georgakoudis 
1616f97de4cbSGiorgis Georgakoudis             CallInst *KernelInitCI =
1617f97de4cbSGiorgis Georgakoudis                 getCallIfRegularCall(*KernelInitUV->front(), &KernelInitRFI);
1618f97de4cbSGiorgis Georgakoudis             assert(KernelInitCI &&
1619f97de4cbSGiorgis Georgakoudis                    "Expected a call to __kmpc_target_init in kernel\n");
1620f97de4cbSGiorgis Georgakoudis 
1621f97de4cbSGiorgis Georgakoudis             CI->moveAfter(KernelInitCI);
1622f97de4cbSGiorgis Georgakoudis           } else
16239548b74aSJohannes Doerfert             CI->moveBefore(&*F.getEntryBlock().getFirstInsertionPt());
16249548b74aSJohannes Doerfert           ReplVal = CI;
16259548b74aSJohannes Doerfert           break;
16269548b74aSJohannes Doerfert         }
16279548b74aSJohannes Doerfert       if (!ReplVal)
16289548b74aSJohannes Doerfert         return false;
16299548b74aSJohannes Doerfert     }
16309548b74aSJohannes Doerfert 
1631396b7253SJohannes Doerfert     // If we use a call as a replacement value we need to make sure the ident is
1632396b7253SJohannes Doerfert     // valid at the new location. For now we just pick a global one, either
1633396b7253SJohannes Doerfert     // existing and used by one of the calls, or created from scratch.
1634396b7253SJohannes Doerfert     if (CallBase *CI = dyn_cast<CallBase>(ReplVal)) {
1635396b7253SJohannes Doerfert       if (CI->getNumArgOperands() > 0 &&
16366aab27baSsstefan1           CI->getArgOperand(0)->getType() == OMPInfoCache.OMPBuilder.IdentPtr) {
1637396b7253SJohannes Doerfert         Value *Ident = getCombinedIdentFromCallUsesIn(RFI, F,
1638396b7253SJohannes Doerfert                                                       /* GlobalOnly */ true);
1639396b7253SJohannes Doerfert         CI->setArgOperand(0, Ident);
1640396b7253SJohannes Doerfert       }
1641396b7253SJohannes Doerfert     }
1642396b7253SJohannes Doerfert 
16439548b74aSJohannes Doerfert     bool Changed = false;
16449548b74aSJohannes Doerfert     auto ReplaceAndDeleteCB = [&](Use &U, Function &Caller) {
16459548b74aSJohannes Doerfert       CallInst *CI = getCallIfRegularCall(U, &RFI);
16469548b74aSJohannes Doerfert       if (!CI || CI == ReplVal || &F != &Caller)
16479548b74aSJohannes Doerfert         return false;
16489548b74aSJohannes Doerfert       assert(CI->getCaller() == &F && "Unexpected call!");
16494d4ea9acSHuber, Joseph 
16504d4ea9acSHuber, Joseph       auto Remark = [&](OptimizationRemark OR) {
16514d4ea9acSHuber, Joseph         return OR << "OpenMP runtime call "
1652eef6601bSJoseph Huber                   << ore::NV("OpenMPOptRuntime", RFI.Name) << " deduplicated.";
16534d4ea9acSHuber, Joseph       };
1654eef6601bSJoseph Huber       if (CI->getDebugLoc())
16552c31d5ebSJoseph Huber         emitRemark<OptimizationRemark>(CI, "OMP170", Remark);
1656eef6601bSJoseph Huber       else
16572c31d5ebSJoseph Huber         emitRemark<OptimizationRemark>(&F, "OMP170", Remark);
16584d4ea9acSHuber, Joseph 
16599548b74aSJohannes Doerfert       CGUpdater.removeCallSite(*CI);
16609548b74aSJohannes Doerfert       CI->replaceAllUsesWith(ReplVal);
16619548b74aSJohannes Doerfert       CI->eraseFromParent();
16629548b74aSJohannes Doerfert       ++NumOpenMPRuntimeCallsDeduplicated;
16639548b74aSJohannes Doerfert       Changed = true;
16649548b74aSJohannes Doerfert       return true;
16659548b74aSJohannes Doerfert     };
1666624d34afSJohannes Doerfert     RFI.foreachUse(SCC, ReplaceAndDeleteCB);
16679548b74aSJohannes Doerfert 
16689548b74aSJohannes Doerfert     return Changed;
16699548b74aSJohannes Doerfert   }
16709548b74aSJohannes Doerfert 
16719548b74aSJohannes Doerfert   /// Collect arguments that represent the global thread id in \p GTIdArgs.
16729548b74aSJohannes Doerfert   void collectGlobalThreadIdArguments(SmallSetVector<Value *, 16> &GTIdArgs) {
16739548b74aSJohannes Doerfert     // TODO: Below we basically perform a fixpoint iteration with a pessimistic
16749548b74aSJohannes Doerfert     //       initialization. We could define an AbstractAttribute instead and
16759548b74aSJohannes Doerfert     //       run the Attributor here once it can be run as an SCC pass.
16769548b74aSJohannes Doerfert 
16779548b74aSJohannes Doerfert     // Helper to check the argument \p ArgNo at all call sites of \p F for
16789548b74aSJohannes Doerfert     // a GTId.
16799548b74aSJohannes Doerfert     auto CallArgOpIsGTId = [&](Function &F, unsigned ArgNo, CallInst &RefCI) {
16809548b74aSJohannes Doerfert       if (!F.hasLocalLinkage())
16819548b74aSJohannes Doerfert         return false;
16829548b74aSJohannes Doerfert       for (Use &U : F.uses()) {
16839548b74aSJohannes Doerfert         if (CallInst *CI = getCallIfRegularCall(U)) {
16849548b74aSJohannes Doerfert           Value *ArgOp = CI->getArgOperand(ArgNo);
16859548b74aSJohannes Doerfert           if (CI == &RefCI || GTIdArgs.count(ArgOp) ||
16867cfd267cSsstefan1               getCallIfRegularCall(
16877cfd267cSsstefan1                   *ArgOp, &OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num]))
16889548b74aSJohannes Doerfert             continue;
16899548b74aSJohannes Doerfert         }
16909548b74aSJohannes Doerfert         return false;
16919548b74aSJohannes Doerfert       }
16929548b74aSJohannes Doerfert       return true;
16939548b74aSJohannes Doerfert     };
16949548b74aSJohannes Doerfert 
16959548b74aSJohannes Doerfert     // Helper to identify uses of a GTId as GTId arguments.
16969548b74aSJohannes Doerfert     auto AddUserArgs = [&](Value &GTId) {
16979548b74aSJohannes Doerfert       for (Use &U : GTId.uses())
16989548b74aSJohannes Doerfert         if (CallInst *CI = dyn_cast<CallInst>(U.getUser()))
16999548b74aSJohannes Doerfert           if (CI->isArgOperand(&U))
17009548b74aSJohannes Doerfert             if (Function *Callee = CI->getCalledFunction())
17019548b74aSJohannes Doerfert               if (CallArgOpIsGTId(*Callee, U.getOperandNo(), *CI))
17029548b74aSJohannes Doerfert                 GTIdArgs.insert(Callee->getArg(U.getOperandNo()));
17039548b74aSJohannes Doerfert     };
17049548b74aSJohannes Doerfert 
17059548b74aSJohannes Doerfert     // The argument users of __kmpc_global_thread_num calls are GTIds.
17067cfd267cSsstefan1     OMPInformationCache::RuntimeFunctionInfo &GlobThreadNumRFI =
17077cfd267cSsstefan1         OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num];
17087cfd267cSsstefan1 
1709624d34afSJohannes Doerfert     GlobThreadNumRFI.foreachUse(SCC, [&](Use &U, Function &F) {
17108855fec3SJohannes Doerfert       if (CallInst *CI = getCallIfRegularCall(U, &GlobThreadNumRFI))
17119548b74aSJohannes Doerfert         AddUserArgs(*CI);
17128855fec3SJohannes Doerfert       return false;
17138855fec3SJohannes Doerfert     });
17149548b74aSJohannes Doerfert 
17159548b74aSJohannes Doerfert     // Transitively search for more arguments by looking at the users of the
17169548b74aSJohannes Doerfert     // ones we know already. During the search the GTIdArgs vector is extended
17179548b74aSJohannes Doerfert     // so we cannot cache the size nor can we use a range based for.
17189548b74aSJohannes Doerfert     for (unsigned u = 0; u < GTIdArgs.size(); ++u)
17199548b74aSJohannes Doerfert       AddUserArgs(*GTIdArgs[u]);
17209548b74aSJohannes Doerfert   }
17219548b74aSJohannes Doerfert 
17225b0581aeSJohannes Doerfert   /// Kernel (=GPU) optimizations and utility functions
17235b0581aeSJohannes Doerfert   ///
17245b0581aeSJohannes Doerfert   ///{{
17255b0581aeSJohannes Doerfert 
17265b0581aeSJohannes Doerfert   /// Check if \p F is a kernel, hence entry point for target offloading.
17275b0581aeSJohannes Doerfert   bool isKernel(Function &F) { return OMPInfoCache.Kernels.count(&F); }
17285b0581aeSJohannes Doerfert 
17295b0581aeSJohannes Doerfert   /// Cache to remember the unique kernel for a function.
17305b0581aeSJohannes Doerfert   DenseMap<Function *, Optional<Kernel>> UniqueKernelMap;
17315b0581aeSJohannes Doerfert 
17325b0581aeSJohannes Doerfert   /// Find the unique kernel that will execute \p F, if any.
17335b0581aeSJohannes Doerfert   Kernel getUniqueKernelFor(Function &F);
17345b0581aeSJohannes Doerfert 
17355b0581aeSJohannes Doerfert   /// Find the unique kernel that will execute \p I, if any.
17365b0581aeSJohannes Doerfert   Kernel getUniqueKernelFor(Instruction &I) {
17375b0581aeSJohannes Doerfert     return getUniqueKernelFor(*I.getFunction());
17385b0581aeSJohannes Doerfert   }
17395b0581aeSJohannes Doerfert 
17405b0581aeSJohannes Doerfert   /// Rewrite the device (=GPU) code state machine create in non-SPMD mode in
17415b0581aeSJohannes Doerfert   /// the cases we can avoid taking the address of a function.
17425b0581aeSJohannes Doerfert   bool rewriteDeviceCodeStateMachine();
17435b0581aeSJohannes Doerfert 
17445b0581aeSJohannes Doerfert   ///
17455b0581aeSJohannes Doerfert   ///}}
17465b0581aeSJohannes Doerfert 
17474d4ea9acSHuber, Joseph   /// Emit a remark generically
17484d4ea9acSHuber, Joseph   ///
17494d4ea9acSHuber, Joseph   /// This template function can be used to generically emit a remark. The
17504d4ea9acSHuber, Joseph   /// RemarkKind should be one of the following:
17514d4ea9acSHuber, Joseph   ///   - OptimizationRemark to indicate a successful optimization attempt
17524d4ea9acSHuber, Joseph   ///   - OptimizationRemarkMissed to report a failed optimization attempt
17534d4ea9acSHuber, Joseph   ///   - OptimizationRemarkAnalysis to provide additional information about an
17544d4ea9acSHuber, Joseph   ///     optimization attempt
17554d4ea9acSHuber, Joseph   ///
17564d4ea9acSHuber, Joseph   /// The remark is built using a callback function provided by the caller that
17574d4ea9acSHuber, Joseph   /// takes a RemarkKind as input and returns a RemarkKind.
17582db182ffSJoseph Huber   template <typename RemarkKind, typename RemarkCallBack>
17592db182ffSJoseph Huber   void emitRemark(Instruction *I, StringRef RemarkName,
1760e8039ad4SJohannes Doerfert                   RemarkCallBack &&RemarkCB) const {
17612db182ffSJoseph Huber     Function *F = I->getParent()->getParent();
17624d4ea9acSHuber, Joseph     auto &ORE = OREGetter(F);
17634d4ea9acSHuber, Joseph 
17642c31d5ebSJoseph Huber     if (RemarkName.startswith("OMP"))
17652c31d5ebSJoseph Huber       ORE.emit([&]() {
17662c31d5ebSJoseph Huber         return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I))
17672c31d5ebSJoseph Huber                << " [" << RemarkName << "]";
17682c31d5ebSJoseph Huber       });
17692c31d5ebSJoseph Huber     else
17702c31d5ebSJoseph Huber       ORE.emit(
17712c31d5ebSJoseph Huber           [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I)); });
17724d4ea9acSHuber, Joseph   }
17734d4ea9acSHuber, Joseph 
17742db182ffSJoseph Huber   /// Emit a remark on a function.
17752db182ffSJoseph Huber   template <typename RemarkKind, typename RemarkCallBack>
17762db182ffSJoseph Huber   void emitRemark(Function *F, StringRef RemarkName,
17772db182ffSJoseph Huber                   RemarkCallBack &&RemarkCB) const {
17780f426935Ssstefan1     auto &ORE = OREGetter(F);
17790f426935Ssstefan1 
17802c31d5ebSJoseph Huber     if (RemarkName.startswith("OMP"))
17812c31d5ebSJoseph Huber       ORE.emit([&]() {
17822c31d5ebSJoseph Huber         return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F))
17832c31d5ebSJoseph Huber                << " [" << RemarkName << "]";
17842c31d5ebSJoseph Huber       });
17852c31d5ebSJoseph Huber     else
17862c31d5ebSJoseph Huber       ORE.emit(
17872c31d5ebSJoseph Huber           [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F)); });
17880f426935Ssstefan1   }
17890f426935Ssstefan1 
179058725c12SJoseph Huber   /// RAII struct to temporarily change an RTL function's linkage to external.
179158725c12SJoseph Huber   /// This prevents it from being mistakenly removed by other optimizations.
179258725c12SJoseph Huber   struct ExternalizationRAII {
179358725c12SJoseph Huber     ExternalizationRAII(OMPInformationCache &OMPInfoCache,
179458725c12SJoseph Huber                         RuntimeFunction RFKind)
1795e757a3b0SJoseph Huber         : Declaration(OMPInfoCache.RFIs[RFKind].Declaration) {
179658725c12SJoseph Huber       if (!Declaration)
179758725c12SJoseph Huber         return;
179858725c12SJoseph Huber 
179958725c12SJoseph Huber       LinkageType = Declaration->getLinkage();
180058725c12SJoseph Huber       Declaration->setLinkage(GlobalValue::ExternalLinkage);
180158725c12SJoseph Huber     }
180258725c12SJoseph Huber 
180358725c12SJoseph Huber     ~ExternalizationRAII() {
180458725c12SJoseph Huber       if (!Declaration)
180558725c12SJoseph Huber         return;
180658725c12SJoseph Huber 
180758725c12SJoseph Huber       Declaration->setLinkage(LinkageType);
180858725c12SJoseph Huber     }
180958725c12SJoseph Huber 
181058725c12SJoseph Huber     Function *Declaration;
181158725c12SJoseph Huber     GlobalValue::LinkageTypes LinkageType;
181258725c12SJoseph Huber   };
181358725c12SJoseph Huber 
1814b726c557SJohannes Doerfert   /// The underlying module.
18159548b74aSJohannes Doerfert   Module &M;
18169548b74aSJohannes Doerfert 
18179548b74aSJohannes Doerfert   /// The SCC we are operating on.
1818ee17263aSJohannes Doerfert   SmallVectorImpl<Function *> &SCC;
18199548b74aSJohannes Doerfert 
18209548b74aSJohannes Doerfert   /// Callback to update the call graph, the first argument is a removed call,
18219548b74aSJohannes Doerfert   /// the second an optional replacement call.
18229548b74aSJohannes Doerfert   CallGraphUpdater &CGUpdater;
18239548b74aSJohannes Doerfert 
18244d4ea9acSHuber, Joseph   /// Callback to get an OptimizationRemarkEmitter from a Function *
18254d4ea9acSHuber, Joseph   OptimizationRemarkGetter OREGetter;
18264d4ea9acSHuber, Joseph 
18277cfd267cSsstefan1   /// OpenMP-specific information cache. Also Used for Attributor runs.
18287cfd267cSsstefan1   OMPInformationCache &OMPInfoCache;
1829b8235d2bSsstefan1 
1830b8235d2bSsstefan1   /// Attributor instance.
1831b8235d2bSsstefan1   Attributor &A;
1832b8235d2bSsstefan1 
1833b8235d2bSsstefan1   /// Helper function to run Attributor on SCC.
1834d9659bf6SJohannes Doerfert   bool runAttributor(bool IsModulePass) {
1835b8235d2bSsstefan1     if (SCC.empty())
1836b8235d2bSsstefan1       return false;
1837b8235d2bSsstefan1 
183858725c12SJoseph Huber     // Temporarily make these function have external linkage so the Attributor
183958725c12SJoseph Huber     // doesn't remove them when we try to look them up later.
184058725c12SJoseph Huber     ExternalizationRAII Parallel(OMPInfoCache, OMPRTL___kmpc_kernel_parallel);
184158725c12SJoseph Huber     ExternalizationRAII EndParallel(OMPInfoCache,
184258725c12SJoseph Huber                                     OMPRTL___kmpc_kernel_end_parallel);
184358725c12SJoseph Huber     ExternalizationRAII BarrierSPMD(OMPInfoCache,
184458725c12SJoseph Huber                                     OMPRTL___kmpc_barrier_simple_spmd);
184558725c12SJoseph Huber 
1846d9659bf6SJohannes Doerfert     registerAAs(IsModulePass);
1847b8235d2bSsstefan1 
1848b8235d2bSsstefan1     ChangeStatus Changed = A.run();
1849b8235d2bSsstefan1 
1850b8235d2bSsstefan1     LLVM_DEBUG(dbgs() << "[Attributor] Done with " << SCC.size()
1851b8235d2bSsstefan1                       << " functions, result: " << Changed << ".\n");
1852b8235d2bSsstefan1 
1853b8235d2bSsstefan1     return Changed == ChangeStatus::CHANGED;
1854b8235d2bSsstefan1   }
1855b8235d2bSsstefan1 
18565ab6aeddSJose M Monsalve Diaz   void registerFoldRuntimeCall(RuntimeFunction RF);
18575ab6aeddSJose M Monsalve Diaz 
1858b8235d2bSsstefan1   /// Populate the Attributor with abstract attribute opportunities in the
1859b8235d2bSsstefan1   /// function.
1860d9659bf6SJohannes Doerfert   void registerAAs(bool IsModulePass);
1861b8235d2bSsstefan1 };
1862b8235d2bSsstefan1 
18635b0581aeSJohannes Doerfert Kernel OpenMPOpt::getUniqueKernelFor(Function &F) {
18645b0581aeSJohannes Doerfert   if (!OMPInfoCache.ModuleSlice.count(&F))
18655b0581aeSJohannes Doerfert     return nullptr;
18665b0581aeSJohannes Doerfert 
18675b0581aeSJohannes Doerfert   // Use a scope to keep the lifetime of the CachedKernel short.
18685b0581aeSJohannes Doerfert   {
18695b0581aeSJohannes Doerfert     Optional<Kernel> &CachedKernel = UniqueKernelMap[&F];
18705b0581aeSJohannes Doerfert     if (CachedKernel)
18715b0581aeSJohannes Doerfert       return *CachedKernel;
18725b0581aeSJohannes Doerfert 
18735b0581aeSJohannes Doerfert     // TODO: We should use an AA to create an (optimistic and callback
18745b0581aeSJohannes Doerfert     //       call-aware) call graph. For now we stick to simple patterns that
18755b0581aeSJohannes Doerfert     //       are less powerful, basically the worst fixpoint.
18765b0581aeSJohannes Doerfert     if (isKernel(F)) {
18775b0581aeSJohannes Doerfert       CachedKernel = Kernel(&F);
18785b0581aeSJohannes Doerfert       return *CachedKernel;
18795b0581aeSJohannes Doerfert     }
18805b0581aeSJohannes Doerfert 
18815b0581aeSJohannes Doerfert     CachedKernel = nullptr;
1882994bb6ebSJohannes Doerfert     if (!F.hasLocalLinkage()) {
1883994bb6ebSJohannes Doerfert 
1884994bb6ebSJohannes Doerfert       // See https://openmp.llvm.org/remarks/OptimizationRemarks.html
18852db182ffSJoseph Huber       auto Remark = [&](OptimizationRemarkAnalysis ORA) {
1886eef6601bSJoseph Huber         return ORA << "Potentially unknown OpenMP target region caller.";
1887994bb6ebSJohannes Doerfert       };
18882db182ffSJoseph Huber       emitRemark<OptimizationRemarkAnalysis>(&F, "OMP100", Remark);
1889994bb6ebSJohannes Doerfert 
18905b0581aeSJohannes Doerfert       return nullptr;
18915b0581aeSJohannes Doerfert     }
1892994bb6ebSJohannes Doerfert   }
18935b0581aeSJohannes Doerfert 
18945b0581aeSJohannes Doerfert   auto GetUniqueKernelForUse = [&](const Use &U) -> Kernel {
18955b0581aeSJohannes Doerfert     if (auto *Cmp = dyn_cast<ICmpInst>(U.getUser())) {
18965b0581aeSJohannes Doerfert       // Allow use in equality comparisons.
18975b0581aeSJohannes Doerfert       if (Cmp->isEquality())
18985b0581aeSJohannes Doerfert         return getUniqueKernelFor(*Cmp);
18995b0581aeSJohannes Doerfert       return nullptr;
19005b0581aeSJohannes Doerfert     }
19015b0581aeSJohannes Doerfert     if (auto *CB = dyn_cast<CallBase>(U.getUser())) {
19025b0581aeSJohannes Doerfert       // Allow direct calls.
19035b0581aeSJohannes Doerfert       if (CB->isCallee(&U))
19045b0581aeSJohannes Doerfert         return getUniqueKernelFor(*CB);
1905a2dbfb6bSGiorgis Georgakoudis 
1906a2dbfb6bSGiorgis Georgakoudis       OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
1907a2dbfb6bSGiorgis Georgakoudis           OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
1908a2dbfb6bSGiorgis Georgakoudis       // Allow the use in __kmpc_parallel_51 calls.
1909a2dbfb6bSGiorgis Georgakoudis       if (OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI))
19105b0581aeSJohannes Doerfert         return getUniqueKernelFor(*CB);
19115b0581aeSJohannes Doerfert       return nullptr;
19125b0581aeSJohannes Doerfert     }
19135b0581aeSJohannes Doerfert     // Disallow every other use.
19145b0581aeSJohannes Doerfert     return nullptr;
19155b0581aeSJohannes Doerfert   };
19165b0581aeSJohannes Doerfert 
19175b0581aeSJohannes Doerfert   // TODO: In the future we want to track more than just a unique kernel.
19185b0581aeSJohannes Doerfert   SmallPtrSet<Kernel, 2> PotentialKernels;
19198d8ce85bSsstefan1   OMPInformationCache::foreachUse(F, [&](const Use &U) {
19205b0581aeSJohannes Doerfert     PotentialKernels.insert(GetUniqueKernelForUse(U));
19215b0581aeSJohannes Doerfert   });
19225b0581aeSJohannes Doerfert 
19235b0581aeSJohannes Doerfert   Kernel K = nullptr;
19245b0581aeSJohannes Doerfert   if (PotentialKernels.size() == 1)
19255b0581aeSJohannes Doerfert     K = *PotentialKernels.begin();
19265b0581aeSJohannes Doerfert 
19275b0581aeSJohannes Doerfert   // Cache the result.
19285b0581aeSJohannes Doerfert   UniqueKernelMap[&F] = K;
19295b0581aeSJohannes Doerfert 
19305b0581aeSJohannes Doerfert   return K;
19315b0581aeSJohannes Doerfert }
19325b0581aeSJohannes Doerfert 
19335b0581aeSJohannes Doerfert bool OpenMPOpt::rewriteDeviceCodeStateMachine() {
1934a2dbfb6bSGiorgis Georgakoudis   OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
1935a2dbfb6bSGiorgis Georgakoudis       OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
19365b0581aeSJohannes Doerfert 
19375b0581aeSJohannes Doerfert   bool Changed = false;
1938a2dbfb6bSGiorgis Georgakoudis   if (!KernelParallelRFI)
19395b0581aeSJohannes Doerfert     return Changed;
19405b0581aeSJohannes Doerfert 
1941cd0dd8ecSJoseph Huber   // If we have disabled state machine changes, exit
1942cd0dd8ecSJoseph Huber   if (DisableOpenMPOptStateMachineRewrite)
1943cd0dd8ecSJoseph Huber     return Changed;
1944cd0dd8ecSJoseph Huber 
19455b0581aeSJohannes Doerfert   for (Function *F : SCC) {
19465b0581aeSJohannes Doerfert 
1947a2dbfb6bSGiorgis Georgakoudis     // Check if the function is a use in a __kmpc_parallel_51 call at
19485b0581aeSJohannes Doerfert     // all.
19495b0581aeSJohannes Doerfert     bool UnknownUse = false;
1950a2dbfb6bSGiorgis Georgakoudis     bool KernelParallelUse = false;
19515b0581aeSJohannes Doerfert     unsigned NumDirectCalls = 0;
19525b0581aeSJohannes Doerfert 
19535b0581aeSJohannes Doerfert     SmallVector<Use *, 2> ToBeReplacedStateMachineUses;
19548d8ce85bSsstefan1     OMPInformationCache::foreachUse(*F, [&](Use &U) {
19555b0581aeSJohannes Doerfert       if (auto *CB = dyn_cast<CallBase>(U.getUser()))
19565b0581aeSJohannes Doerfert         if (CB->isCallee(&U)) {
19575b0581aeSJohannes Doerfert           ++NumDirectCalls;
19585b0581aeSJohannes Doerfert           return;
19595b0581aeSJohannes Doerfert         }
19605b0581aeSJohannes Doerfert 
196181db6144SMichael Liao       if (isa<ICmpInst>(U.getUser())) {
19625b0581aeSJohannes Doerfert         ToBeReplacedStateMachineUses.push_back(&U);
19635b0581aeSJohannes Doerfert         return;
19645b0581aeSJohannes Doerfert       }
1965a2dbfb6bSGiorgis Georgakoudis 
1966a2dbfb6bSGiorgis Georgakoudis       // Find wrapper functions that represent parallel kernels.
1967a2dbfb6bSGiorgis Georgakoudis       CallInst *CI =
1968a2dbfb6bSGiorgis Georgakoudis           OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI);
1969a2dbfb6bSGiorgis Georgakoudis       const unsigned int WrapperFunctionArgNo = 6;
1970a2dbfb6bSGiorgis Georgakoudis       if (!KernelParallelUse && CI &&
1971a2dbfb6bSGiorgis Georgakoudis           CI->getArgOperandNo(&U) == WrapperFunctionArgNo) {
1972a2dbfb6bSGiorgis Georgakoudis         KernelParallelUse = true;
19735b0581aeSJohannes Doerfert         ToBeReplacedStateMachineUses.push_back(&U);
19745b0581aeSJohannes Doerfert         return;
19755b0581aeSJohannes Doerfert       }
19765b0581aeSJohannes Doerfert       UnknownUse = true;
19775b0581aeSJohannes Doerfert     });
19785b0581aeSJohannes Doerfert 
1979a2dbfb6bSGiorgis Georgakoudis     // Do not emit a remark if we haven't seen a __kmpc_parallel_51
1980fec1f210SJohannes Doerfert     // use.
1981a2dbfb6bSGiorgis Georgakoudis     if (!KernelParallelUse)
19825b0581aeSJohannes Doerfert       continue;
19835b0581aeSJohannes Doerfert 
1984fec1f210SJohannes Doerfert     // If this ever hits, we should investigate.
1985fec1f210SJohannes Doerfert     // TODO: Checking the number of uses is not a necessary restriction and
1986fec1f210SJohannes Doerfert     // should be lifted.
1987fec1f210SJohannes Doerfert     if (UnknownUse || NumDirectCalls != 1 ||
1988d9659bf6SJohannes Doerfert         ToBeReplacedStateMachineUses.size() > 2) {
19892db182ffSJoseph Huber       auto Remark = [&](OptimizationRemarkAnalysis ORA) {
19902db182ffSJoseph Huber         return ORA << "Parallel region is used in "
1991fec1f210SJohannes Doerfert                    << (UnknownUse ? "unknown" : "unexpected")
1992eef6601bSJoseph Huber                    << " ways. Will not attempt to rewrite the state machine.";
1993fec1f210SJohannes Doerfert       };
19942c31d5ebSJoseph Huber       emitRemark<OptimizationRemarkAnalysis>(F, "OMP101", Remark);
19955b0581aeSJohannes Doerfert       continue;
1996fec1f210SJohannes Doerfert     }
19975b0581aeSJohannes Doerfert 
1998a2dbfb6bSGiorgis Georgakoudis     // Even if we have __kmpc_parallel_51 calls, we (for now) give
19995b0581aeSJohannes Doerfert     // up if the function is not called from a unique kernel.
20005b0581aeSJohannes Doerfert     Kernel K = getUniqueKernelFor(*F);
2001fec1f210SJohannes Doerfert     if (!K) {
20022db182ffSJoseph Huber       auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2003eef6601bSJoseph Huber         return ORA << "Parallel region is not called from a unique kernel. "
2004eef6601bSJoseph Huber                       "Will not attempt to rewrite the state machine.";
2005fec1f210SJohannes Doerfert       };
20062c31d5ebSJoseph Huber       emitRemark<OptimizationRemarkAnalysis>(F, "OMP102", Remark);
20075b0581aeSJohannes Doerfert       continue;
2008fec1f210SJohannes Doerfert     }
20095b0581aeSJohannes Doerfert 
20105b0581aeSJohannes Doerfert     // We now know F is a parallel body function called only from the kernel K.
20115b0581aeSJohannes Doerfert     // We also identified the state machine uses in which we replace the
20125b0581aeSJohannes Doerfert     // function pointer by a new global symbol for identification purposes. This
20135b0581aeSJohannes Doerfert     // ensures only direct calls to the function are left.
20145b0581aeSJohannes Doerfert 
20155b0581aeSJohannes Doerfert     Module &M = *F->getParent();
20165b0581aeSJohannes Doerfert     Type *Int8Ty = Type::getInt8Ty(M.getContext());
20175b0581aeSJohannes Doerfert 
20185b0581aeSJohannes Doerfert     auto *ID = new GlobalVariable(
20195b0581aeSJohannes Doerfert         M, Int8Ty, /* isConstant */ true, GlobalValue::PrivateLinkage,
20205b0581aeSJohannes Doerfert         UndefValue::get(Int8Ty), F->getName() + ".ID");
20215b0581aeSJohannes Doerfert 
20225b0581aeSJohannes Doerfert     for (Use *U : ToBeReplacedStateMachineUses)
20235b0581aeSJohannes Doerfert       U->set(ConstantExpr::getBitCast(ID, U->get()->getType()));
20245b0581aeSJohannes Doerfert 
20255b0581aeSJohannes Doerfert     ++NumOpenMPParallelRegionsReplacedInGPUStateMachine;
20265b0581aeSJohannes Doerfert 
20275b0581aeSJohannes Doerfert     Changed = true;
20285b0581aeSJohannes Doerfert   }
20295b0581aeSJohannes Doerfert 
20305b0581aeSJohannes Doerfert   return Changed;
20315b0581aeSJohannes Doerfert }
20325b0581aeSJohannes Doerfert 
2033b8235d2bSsstefan1 /// Abstract Attribute for tracking ICV values.
2034b8235d2bSsstefan1 struct AAICVTracker : public StateWrapper<BooleanState, AbstractAttribute> {
2035b8235d2bSsstefan1   using Base = StateWrapper<BooleanState, AbstractAttribute>;
2036b8235d2bSsstefan1   AAICVTracker(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
2037b8235d2bSsstefan1 
20385dfd7cc4Ssstefan1   void initialize(Attributor &A) override {
20395dfd7cc4Ssstefan1     Function *F = getAnchorScope();
20405dfd7cc4Ssstefan1     if (!F || !A.isFunctionIPOAmendable(*F))
20415dfd7cc4Ssstefan1       indicatePessimisticFixpoint();
20425dfd7cc4Ssstefan1   }
20435dfd7cc4Ssstefan1 
2044b8235d2bSsstefan1   /// Returns true if value is assumed to be tracked.
2045b8235d2bSsstefan1   bool isAssumedTracked() const { return getAssumed(); }
2046b8235d2bSsstefan1 
2047b8235d2bSsstefan1   /// Returns true if value is known to be tracked.
2048b8235d2bSsstefan1   bool isKnownTracked() const { return getAssumed(); }
2049b8235d2bSsstefan1 
2050b8235d2bSsstefan1   /// Create an abstract attribute biew for the position \p IRP.
2051b8235d2bSsstefan1   static AAICVTracker &createForPosition(const IRPosition &IRP, Attributor &A);
2052b8235d2bSsstefan1 
2053b8235d2bSsstefan1   /// Return the value with which \p I can be replaced for specific \p ICV.
20545dfd7cc4Ssstefan1   virtual Optional<Value *> getReplacementValue(InternalControlVar ICV,
20555dfd7cc4Ssstefan1                                                 const Instruction *I,
20565dfd7cc4Ssstefan1                                                 Attributor &A) const {
20575dfd7cc4Ssstefan1     return None;
20585dfd7cc4Ssstefan1   }
20595dfd7cc4Ssstefan1 
20605dfd7cc4Ssstefan1   /// Return an assumed unique ICV value if a single candidate is found. If
20615dfd7cc4Ssstefan1   /// there cannot be one, return a nullptr. If it is not clear yet, return the
20625dfd7cc4Ssstefan1   /// Optional::NoneType.
20635dfd7cc4Ssstefan1   virtual Optional<Value *>
20645dfd7cc4Ssstefan1   getUniqueReplacementValue(InternalControlVar ICV) const = 0;
20655dfd7cc4Ssstefan1 
20665dfd7cc4Ssstefan1   // Currently only nthreads is being tracked.
20675dfd7cc4Ssstefan1   // this array will only grow with time.
20685dfd7cc4Ssstefan1   InternalControlVar TrackableICVs[1] = {ICV_nthreads};
2069b8235d2bSsstefan1 
2070b8235d2bSsstefan1   /// See AbstractAttribute::getName()
2071b8235d2bSsstefan1   const std::string getName() const override { return "AAICVTracker"; }
2072b8235d2bSsstefan1 
2073233af895SLuofan Chen   /// See AbstractAttribute::getIdAddr()
2074233af895SLuofan Chen   const char *getIdAddr() const override { return &ID; }
2075233af895SLuofan Chen 
2076233af895SLuofan Chen   /// This function should return true if the type of the \p AA is AAICVTracker
2077233af895SLuofan Chen   static bool classof(const AbstractAttribute *AA) {
2078233af895SLuofan Chen     return (AA->getIdAddr() == &ID);
2079233af895SLuofan Chen   }
2080233af895SLuofan Chen 
2081b8235d2bSsstefan1   static const char ID;
2082b8235d2bSsstefan1 };
2083b8235d2bSsstefan1 
2084b8235d2bSsstefan1 struct AAICVTrackerFunction : public AAICVTracker {
2085b8235d2bSsstefan1   AAICVTrackerFunction(const IRPosition &IRP, Attributor &A)
2086b8235d2bSsstefan1       : AAICVTracker(IRP, A) {}
2087b8235d2bSsstefan1 
2088b8235d2bSsstefan1   // FIXME: come up with better string.
20895dfd7cc4Ssstefan1   const std::string getAsStr() const override { return "ICVTrackerFunction"; }
2090b8235d2bSsstefan1 
2091b8235d2bSsstefan1   // FIXME: come up with some stats.
2092b8235d2bSsstefan1   void trackStatistics() const override {}
2093b8235d2bSsstefan1 
20945dfd7cc4Ssstefan1   /// We don't manifest anything for this AA.
2095b8235d2bSsstefan1   ChangeStatus manifest(Attributor &A) override {
20965dfd7cc4Ssstefan1     return ChangeStatus::UNCHANGED;
2097b8235d2bSsstefan1   }
2098b8235d2bSsstefan1 
2099b8235d2bSsstefan1   // Map of ICV to their values at specific program point.
21005dfd7cc4Ssstefan1   EnumeratedArray<DenseMap<Instruction *, Value *>, InternalControlVar,
2101b8235d2bSsstefan1                   InternalControlVar::ICV___last>
21025dfd7cc4Ssstefan1       ICVReplacementValuesMap;
2103b8235d2bSsstefan1 
2104b8235d2bSsstefan1   ChangeStatus updateImpl(Attributor &A) override {
2105b8235d2bSsstefan1     ChangeStatus HasChanged = ChangeStatus::UNCHANGED;
2106b8235d2bSsstefan1 
2107b8235d2bSsstefan1     Function *F = getAnchorScope();
2108b8235d2bSsstefan1 
2109b8235d2bSsstefan1     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2110b8235d2bSsstefan1 
2111b8235d2bSsstefan1     for (InternalControlVar ICV : TrackableICVs) {
2112b8235d2bSsstefan1       auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
2113b8235d2bSsstefan1 
21145dfd7cc4Ssstefan1       auto &ValuesMap = ICVReplacementValuesMap[ICV];
2115b8235d2bSsstefan1       auto TrackValues = [&](Use &U, Function &) {
2116b8235d2bSsstefan1         CallInst *CI = OpenMPOpt::getCallIfRegularCall(U);
2117b8235d2bSsstefan1         if (!CI)
2118b8235d2bSsstefan1           return false;
2119b8235d2bSsstefan1 
2120b8235d2bSsstefan1         // FIXME: handle setters with more that 1 arguments.
2121b8235d2bSsstefan1         /// Track new value.
21225dfd7cc4Ssstefan1         if (ValuesMap.insert(std::make_pair(CI, CI->getArgOperand(0))).second)
2123b8235d2bSsstefan1           HasChanged = ChangeStatus::CHANGED;
2124b8235d2bSsstefan1 
2125b8235d2bSsstefan1         return false;
2126b8235d2bSsstefan1       };
2127b8235d2bSsstefan1 
21285dfd7cc4Ssstefan1       auto CallCheck = [&](Instruction &I) {
21295dfd7cc4Ssstefan1         Optional<Value *> ReplVal = getValueForCall(A, &I, ICV);
21305dfd7cc4Ssstefan1         if (ReplVal.hasValue() &&
21315dfd7cc4Ssstefan1             ValuesMap.insert(std::make_pair(&I, *ReplVal)).second)
21325dfd7cc4Ssstefan1           HasChanged = ChangeStatus::CHANGED;
21335dfd7cc4Ssstefan1 
21345dfd7cc4Ssstefan1         return true;
21355dfd7cc4Ssstefan1       };
21365dfd7cc4Ssstefan1 
21375dfd7cc4Ssstefan1       // Track all changes of an ICV.
2138b8235d2bSsstefan1       SetterRFI.foreachUse(TrackValues, F);
21395dfd7cc4Ssstefan1 
2140792aac98SJohannes Doerfert       bool UsedAssumedInformation = false;
21415dfd7cc4Ssstefan1       A.checkForAllInstructions(CallCheck, *this, {Instruction::Call},
2142792aac98SJohannes Doerfert                                 UsedAssumedInformation,
21435dfd7cc4Ssstefan1                                 /* CheckBBLivenessOnly */ true);
21445dfd7cc4Ssstefan1 
21455dfd7cc4Ssstefan1       /// TODO: Figure out a way to avoid adding entry in
21465dfd7cc4Ssstefan1       /// ICVReplacementValuesMap
21475dfd7cc4Ssstefan1       Instruction *Entry = &F->getEntryBlock().front();
21485dfd7cc4Ssstefan1       if (HasChanged == ChangeStatus::CHANGED && !ValuesMap.count(Entry))
21495dfd7cc4Ssstefan1         ValuesMap.insert(std::make_pair(Entry, nullptr));
2150b8235d2bSsstefan1     }
2151b8235d2bSsstefan1 
2152b8235d2bSsstefan1     return HasChanged;
2153b8235d2bSsstefan1   }
2154b8235d2bSsstefan1 
21555dfd7cc4Ssstefan1   /// Hepler to check if \p I is a call and get the value for it if it is
21565dfd7cc4Ssstefan1   /// unique.
21575dfd7cc4Ssstefan1   Optional<Value *> getValueForCall(Attributor &A, const Instruction *I,
21585dfd7cc4Ssstefan1                                     InternalControlVar &ICV) const {
2159b8235d2bSsstefan1 
21605dfd7cc4Ssstefan1     const auto *CB = dyn_cast<CallBase>(I);
2161dcaec812SJohannes Doerfert     if (!CB || CB->hasFnAttr("no_openmp") ||
2162dcaec812SJohannes Doerfert         CB->hasFnAttr("no_openmp_routines"))
21635dfd7cc4Ssstefan1       return None;
21645dfd7cc4Ssstefan1 
2165b8235d2bSsstefan1     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2166b8235d2bSsstefan1     auto &GetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Getter];
21675dfd7cc4Ssstefan1     auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
21685dfd7cc4Ssstefan1     Function *CalledFunction = CB->getCalledFunction();
2169b8235d2bSsstefan1 
21704eef14f9SWei Wang     // Indirect call, assume ICV changes.
21714eef14f9SWei Wang     if (CalledFunction == nullptr)
21724eef14f9SWei Wang       return nullptr;
21735dfd7cc4Ssstefan1     if (CalledFunction == GetterRFI.Declaration)
21745dfd7cc4Ssstefan1       return None;
21755dfd7cc4Ssstefan1     if (CalledFunction == SetterRFI.Declaration) {
21765dfd7cc4Ssstefan1       if (ICVReplacementValuesMap[ICV].count(I))
21775dfd7cc4Ssstefan1         return ICVReplacementValuesMap[ICV].lookup(I);
21785dfd7cc4Ssstefan1 
21795dfd7cc4Ssstefan1       return nullptr;
21805dfd7cc4Ssstefan1     }
21815dfd7cc4Ssstefan1 
21825dfd7cc4Ssstefan1     // Since we don't know, assume it changes the ICV.
21835dfd7cc4Ssstefan1     if (CalledFunction->isDeclaration())
21845dfd7cc4Ssstefan1       return nullptr;
21855dfd7cc4Ssstefan1 
21865b70c12fSJohannes Doerfert     const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
21875b70c12fSJohannes Doerfert         *this, IRPosition::callsite_returned(*CB), DepClassTy::REQUIRED);
21885dfd7cc4Ssstefan1 
21895dfd7cc4Ssstefan1     if (ICVTrackingAA.isAssumedTracked())
21905dfd7cc4Ssstefan1       return ICVTrackingAA.getUniqueReplacementValue(ICV);
21915dfd7cc4Ssstefan1 
21925dfd7cc4Ssstefan1     // If we don't know, assume it changes.
21935dfd7cc4Ssstefan1     return nullptr;
21945dfd7cc4Ssstefan1   }
21955dfd7cc4Ssstefan1 
21965dfd7cc4Ssstefan1   // We don't check unique value for a function, so return None.
21975dfd7cc4Ssstefan1   Optional<Value *>
21985dfd7cc4Ssstefan1   getUniqueReplacementValue(InternalControlVar ICV) const override {
21995dfd7cc4Ssstefan1     return None;
22005dfd7cc4Ssstefan1   }
22015dfd7cc4Ssstefan1 
22025dfd7cc4Ssstefan1   /// Return the value with which \p I can be replaced for specific \p ICV.
22035dfd7cc4Ssstefan1   Optional<Value *> getReplacementValue(InternalControlVar ICV,
22045dfd7cc4Ssstefan1                                         const Instruction *I,
22055dfd7cc4Ssstefan1                                         Attributor &A) const override {
22065dfd7cc4Ssstefan1     const auto &ValuesMap = ICVReplacementValuesMap[ICV];
22075dfd7cc4Ssstefan1     if (ValuesMap.count(I))
22085dfd7cc4Ssstefan1       return ValuesMap.lookup(I);
22095dfd7cc4Ssstefan1 
22105dfd7cc4Ssstefan1     SmallVector<const Instruction *, 16> Worklist;
22115dfd7cc4Ssstefan1     SmallPtrSet<const Instruction *, 16> Visited;
22125dfd7cc4Ssstefan1     Worklist.push_back(I);
22135dfd7cc4Ssstefan1 
22145dfd7cc4Ssstefan1     Optional<Value *> ReplVal;
22155dfd7cc4Ssstefan1 
22165dfd7cc4Ssstefan1     while (!Worklist.empty()) {
22175dfd7cc4Ssstefan1       const Instruction *CurrInst = Worklist.pop_back_val();
22185dfd7cc4Ssstefan1       if (!Visited.insert(CurrInst).second)
2219b8235d2bSsstefan1         continue;
2220b8235d2bSsstefan1 
22215dfd7cc4Ssstefan1       const BasicBlock *CurrBB = CurrInst->getParent();
22225dfd7cc4Ssstefan1 
22235dfd7cc4Ssstefan1       // Go up and look for all potential setters/calls that might change the
22245dfd7cc4Ssstefan1       // ICV.
22255dfd7cc4Ssstefan1       while ((CurrInst = CurrInst->getPrevNode())) {
22265dfd7cc4Ssstefan1         if (ValuesMap.count(CurrInst)) {
22275dfd7cc4Ssstefan1           Optional<Value *> NewReplVal = ValuesMap.lookup(CurrInst);
22285dfd7cc4Ssstefan1           // Unknown value, track new.
22295dfd7cc4Ssstefan1           if (!ReplVal.hasValue()) {
22305dfd7cc4Ssstefan1             ReplVal = NewReplVal;
22315dfd7cc4Ssstefan1             break;
22325dfd7cc4Ssstefan1           }
22335dfd7cc4Ssstefan1 
22345dfd7cc4Ssstefan1           // If we found a new value, we can't know the icv value anymore.
22355dfd7cc4Ssstefan1           if (NewReplVal.hasValue())
22365dfd7cc4Ssstefan1             if (ReplVal != NewReplVal)
2237b8235d2bSsstefan1               return nullptr;
2238b8235d2bSsstefan1 
22395dfd7cc4Ssstefan1           break;
2240b8235d2bSsstefan1         }
2241b8235d2bSsstefan1 
22425dfd7cc4Ssstefan1         Optional<Value *> NewReplVal = getValueForCall(A, CurrInst, ICV);
22435dfd7cc4Ssstefan1         if (!NewReplVal.hasValue())
22445dfd7cc4Ssstefan1           continue;
22455dfd7cc4Ssstefan1 
22465dfd7cc4Ssstefan1         // Unknown value, track new.
22475dfd7cc4Ssstefan1         if (!ReplVal.hasValue()) {
22485dfd7cc4Ssstefan1           ReplVal = NewReplVal;
22495dfd7cc4Ssstefan1           break;
2250b8235d2bSsstefan1         }
2251b8235d2bSsstefan1 
22525dfd7cc4Ssstefan1         // if (NewReplVal.hasValue())
22535dfd7cc4Ssstefan1         // We found a new value, we can't know the icv value anymore.
22545dfd7cc4Ssstefan1         if (ReplVal != NewReplVal)
2255b8235d2bSsstefan1           return nullptr;
2256b8235d2bSsstefan1       }
22575dfd7cc4Ssstefan1 
22585dfd7cc4Ssstefan1       // If we are in the same BB and we have a value, we are done.
22595dfd7cc4Ssstefan1       if (CurrBB == I->getParent() && ReplVal.hasValue())
22605dfd7cc4Ssstefan1         return ReplVal;
22615dfd7cc4Ssstefan1 
22625dfd7cc4Ssstefan1       // Go through all predecessors and add terminators for analysis.
22635dfd7cc4Ssstefan1       for (const BasicBlock *Pred : predecessors(CurrBB))
22645dfd7cc4Ssstefan1         if (const Instruction *Terminator = Pred->getTerminator())
22655dfd7cc4Ssstefan1           Worklist.push_back(Terminator);
22665dfd7cc4Ssstefan1     }
22675dfd7cc4Ssstefan1 
22685dfd7cc4Ssstefan1     return ReplVal;
22695dfd7cc4Ssstefan1   }
22705dfd7cc4Ssstefan1 };
22715dfd7cc4Ssstefan1 
22725dfd7cc4Ssstefan1 struct AAICVTrackerFunctionReturned : AAICVTracker {
22735dfd7cc4Ssstefan1   AAICVTrackerFunctionReturned(const IRPosition &IRP, Attributor &A)
22745dfd7cc4Ssstefan1       : AAICVTracker(IRP, A) {}
22755dfd7cc4Ssstefan1 
22765dfd7cc4Ssstefan1   // FIXME: come up with better string.
22775dfd7cc4Ssstefan1   const std::string getAsStr() const override {
22785dfd7cc4Ssstefan1     return "ICVTrackerFunctionReturned";
22795dfd7cc4Ssstefan1   }
22805dfd7cc4Ssstefan1 
22815dfd7cc4Ssstefan1   // FIXME: come up with some stats.
22825dfd7cc4Ssstefan1   void trackStatistics() const override {}
22835dfd7cc4Ssstefan1 
22845dfd7cc4Ssstefan1   /// We don't manifest anything for this AA.
22855dfd7cc4Ssstefan1   ChangeStatus manifest(Attributor &A) override {
22865dfd7cc4Ssstefan1     return ChangeStatus::UNCHANGED;
22875dfd7cc4Ssstefan1   }
22885dfd7cc4Ssstefan1 
22895dfd7cc4Ssstefan1   // Map of ICV to their values at specific program point.
22905dfd7cc4Ssstefan1   EnumeratedArray<Optional<Value *>, InternalControlVar,
22915dfd7cc4Ssstefan1                   InternalControlVar::ICV___last>
22925dfd7cc4Ssstefan1       ICVReplacementValuesMap;
22935dfd7cc4Ssstefan1 
22945dfd7cc4Ssstefan1   /// Return the value with which \p I can be replaced for specific \p ICV.
22955dfd7cc4Ssstefan1   Optional<Value *>
22965dfd7cc4Ssstefan1   getUniqueReplacementValue(InternalControlVar ICV) const override {
22975dfd7cc4Ssstefan1     return ICVReplacementValuesMap[ICV];
22985dfd7cc4Ssstefan1   }
22995dfd7cc4Ssstefan1 
23005dfd7cc4Ssstefan1   ChangeStatus updateImpl(Attributor &A) override {
23015dfd7cc4Ssstefan1     ChangeStatus Changed = ChangeStatus::UNCHANGED;
23025dfd7cc4Ssstefan1     const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
23035b70c12fSJohannes Doerfert         *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
23045dfd7cc4Ssstefan1 
23055dfd7cc4Ssstefan1     if (!ICVTrackingAA.isAssumedTracked())
23065dfd7cc4Ssstefan1       return indicatePessimisticFixpoint();
23075dfd7cc4Ssstefan1 
23085dfd7cc4Ssstefan1     for (InternalControlVar ICV : TrackableICVs) {
23095dfd7cc4Ssstefan1       Optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
23105dfd7cc4Ssstefan1       Optional<Value *> UniqueICVValue;
23115dfd7cc4Ssstefan1 
23125dfd7cc4Ssstefan1       auto CheckReturnInst = [&](Instruction &I) {
23135dfd7cc4Ssstefan1         Optional<Value *> NewReplVal =
23145dfd7cc4Ssstefan1             ICVTrackingAA.getReplacementValue(ICV, &I, A);
23155dfd7cc4Ssstefan1 
23165dfd7cc4Ssstefan1         // If we found a second ICV value there is no unique returned value.
23175dfd7cc4Ssstefan1         if (UniqueICVValue.hasValue() && UniqueICVValue != NewReplVal)
23185dfd7cc4Ssstefan1           return false;
23195dfd7cc4Ssstefan1 
23205dfd7cc4Ssstefan1         UniqueICVValue = NewReplVal;
23215dfd7cc4Ssstefan1 
23225dfd7cc4Ssstefan1         return true;
23235dfd7cc4Ssstefan1       };
23245dfd7cc4Ssstefan1 
2325792aac98SJohannes Doerfert       bool UsedAssumedInformation = false;
23265dfd7cc4Ssstefan1       if (!A.checkForAllInstructions(CheckReturnInst, *this, {Instruction::Ret},
2327792aac98SJohannes Doerfert                                      UsedAssumedInformation,
23285dfd7cc4Ssstefan1                                      /* CheckBBLivenessOnly */ true))
23295dfd7cc4Ssstefan1         UniqueICVValue = nullptr;
23305dfd7cc4Ssstefan1 
23315dfd7cc4Ssstefan1       if (UniqueICVValue == ReplVal)
23325dfd7cc4Ssstefan1         continue;
23335dfd7cc4Ssstefan1 
23345dfd7cc4Ssstefan1       ReplVal = UniqueICVValue;
23355dfd7cc4Ssstefan1       Changed = ChangeStatus::CHANGED;
23365dfd7cc4Ssstefan1     }
23375dfd7cc4Ssstefan1 
23385dfd7cc4Ssstefan1     return Changed;
23395dfd7cc4Ssstefan1   }
23405dfd7cc4Ssstefan1 };
23415dfd7cc4Ssstefan1 
23425dfd7cc4Ssstefan1 struct AAICVTrackerCallSite : AAICVTracker {
23435dfd7cc4Ssstefan1   AAICVTrackerCallSite(const IRPosition &IRP, Attributor &A)
23445dfd7cc4Ssstefan1       : AAICVTracker(IRP, A) {}
23455dfd7cc4Ssstefan1 
23465dfd7cc4Ssstefan1   void initialize(Attributor &A) override {
23475dfd7cc4Ssstefan1     Function *F = getAnchorScope();
23485dfd7cc4Ssstefan1     if (!F || !A.isFunctionIPOAmendable(*F))
23495dfd7cc4Ssstefan1       indicatePessimisticFixpoint();
23505dfd7cc4Ssstefan1 
23515dfd7cc4Ssstefan1     // We only initialize this AA for getters, so we need to know which ICV it
23525dfd7cc4Ssstefan1     // gets.
23535dfd7cc4Ssstefan1     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
23545dfd7cc4Ssstefan1     for (InternalControlVar ICV : TrackableICVs) {
23555dfd7cc4Ssstefan1       auto ICVInfo = OMPInfoCache.ICVs[ICV];
23565dfd7cc4Ssstefan1       auto &Getter = OMPInfoCache.RFIs[ICVInfo.Getter];
23575dfd7cc4Ssstefan1       if (Getter.Declaration == getAssociatedFunction()) {
23585dfd7cc4Ssstefan1         AssociatedICV = ICVInfo.Kind;
23595dfd7cc4Ssstefan1         return;
23605dfd7cc4Ssstefan1       }
23615dfd7cc4Ssstefan1     }
23625dfd7cc4Ssstefan1 
23635dfd7cc4Ssstefan1     /// Unknown ICV.
23645dfd7cc4Ssstefan1     indicatePessimisticFixpoint();
23655dfd7cc4Ssstefan1   }
23665dfd7cc4Ssstefan1 
23675dfd7cc4Ssstefan1   ChangeStatus manifest(Attributor &A) override {
23685dfd7cc4Ssstefan1     if (!ReplVal.hasValue() || !ReplVal.getValue())
23695dfd7cc4Ssstefan1       return ChangeStatus::UNCHANGED;
23705dfd7cc4Ssstefan1 
23715dfd7cc4Ssstefan1     A.changeValueAfterManifest(*getCtxI(), **ReplVal);
23725dfd7cc4Ssstefan1     A.deleteAfterManifest(*getCtxI());
23735dfd7cc4Ssstefan1 
23745dfd7cc4Ssstefan1     return ChangeStatus::CHANGED;
23755dfd7cc4Ssstefan1   }
23765dfd7cc4Ssstefan1 
23775dfd7cc4Ssstefan1   // FIXME: come up with better string.
23785dfd7cc4Ssstefan1   const std::string getAsStr() const override { return "ICVTrackerCallSite"; }
23795dfd7cc4Ssstefan1 
23805dfd7cc4Ssstefan1   // FIXME: come up with some stats.
23815dfd7cc4Ssstefan1   void trackStatistics() const override {}
23825dfd7cc4Ssstefan1 
23835dfd7cc4Ssstefan1   InternalControlVar AssociatedICV;
23845dfd7cc4Ssstefan1   Optional<Value *> ReplVal;
23855dfd7cc4Ssstefan1 
23865dfd7cc4Ssstefan1   ChangeStatus updateImpl(Attributor &A) override {
23875dfd7cc4Ssstefan1     const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
23885b70c12fSJohannes Doerfert         *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
23895dfd7cc4Ssstefan1 
23905dfd7cc4Ssstefan1     // We don't have any information, so we assume it changes the ICV.
23915dfd7cc4Ssstefan1     if (!ICVTrackingAA.isAssumedTracked())
23925dfd7cc4Ssstefan1       return indicatePessimisticFixpoint();
23935dfd7cc4Ssstefan1 
23945dfd7cc4Ssstefan1     Optional<Value *> NewReplVal =
23955dfd7cc4Ssstefan1         ICVTrackingAA.getReplacementValue(AssociatedICV, getCtxI(), A);
23965dfd7cc4Ssstefan1 
23975dfd7cc4Ssstefan1     if (ReplVal == NewReplVal)
23985dfd7cc4Ssstefan1       return ChangeStatus::UNCHANGED;
23995dfd7cc4Ssstefan1 
24005dfd7cc4Ssstefan1     ReplVal = NewReplVal;
24015dfd7cc4Ssstefan1     return ChangeStatus::CHANGED;
24025dfd7cc4Ssstefan1   }
24035dfd7cc4Ssstefan1 
24045dfd7cc4Ssstefan1   // Return the value with which associated value can be replaced for specific
24055dfd7cc4Ssstefan1   // \p ICV.
24065dfd7cc4Ssstefan1   Optional<Value *>
24075dfd7cc4Ssstefan1   getUniqueReplacementValue(InternalControlVar ICV) const override {
24085dfd7cc4Ssstefan1     return ReplVal;
24095dfd7cc4Ssstefan1   }
24105dfd7cc4Ssstefan1 };
24115dfd7cc4Ssstefan1 
24125dfd7cc4Ssstefan1 struct AAICVTrackerCallSiteReturned : AAICVTracker {
24135dfd7cc4Ssstefan1   AAICVTrackerCallSiteReturned(const IRPosition &IRP, Attributor &A)
24145dfd7cc4Ssstefan1       : AAICVTracker(IRP, A) {}
24155dfd7cc4Ssstefan1 
24165dfd7cc4Ssstefan1   // FIXME: come up with better string.
24175dfd7cc4Ssstefan1   const std::string getAsStr() const override {
24185dfd7cc4Ssstefan1     return "ICVTrackerCallSiteReturned";
24195dfd7cc4Ssstefan1   }
24205dfd7cc4Ssstefan1 
24215dfd7cc4Ssstefan1   // FIXME: come up with some stats.
24225dfd7cc4Ssstefan1   void trackStatistics() const override {}
24235dfd7cc4Ssstefan1 
24245dfd7cc4Ssstefan1   /// We don't manifest anything for this AA.
24255dfd7cc4Ssstefan1   ChangeStatus manifest(Attributor &A) override {
24265dfd7cc4Ssstefan1     return ChangeStatus::UNCHANGED;
24275dfd7cc4Ssstefan1   }
24285dfd7cc4Ssstefan1 
24295dfd7cc4Ssstefan1   // Map of ICV to their values at specific program point.
24305dfd7cc4Ssstefan1   EnumeratedArray<Optional<Value *>, InternalControlVar,
24315dfd7cc4Ssstefan1                   InternalControlVar::ICV___last>
24325dfd7cc4Ssstefan1       ICVReplacementValuesMap;
24335dfd7cc4Ssstefan1 
24345dfd7cc4Ssstefan1   /// Return the value with which associated value can be replaced for specific
24355dfd7cc4Ssstefan1   /// \p ICV.
24365dfd7cc4Ssstefan1   Optional<Value *>
24375dfd7cc4Ssstefan1   getUniqueReplacementValue(InternalControlVar ICV) const override {
24385dfd7cc4Ssstefan1     return ICVReplacementValuesMap[ICV];
24395dfd7cc4Ssstefan1   }
24405dfd7cc4Ssstefan1 
24415dfd7cc4Ssstefan1   ChangeStatus updateImpl(Attributor &A) override {
24425dfd7cc4Ssstefan1     ChangeStatus Changed = ChangeStatus::UNCHANGED;
24435dfd7cc4Ssstefan1     const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
24445b70c12fSJohannes Doerfert         *this, IRPosition::returned(*getAssociatedFunction()),
24455b70c12fSJohannes Doerfert         DepClassTy::REQUIRED);
24465dfd7cc4Ssstefan1 
24475dfd7cc4Ssstefan1     // We don't have any information, so we assume it changes the ICV.
24485dfd7cc4Ssstefan1     if (!ICVTrackingAA.isAssumedTracked())
24495dfd7cc4Ssstefan1       return indicatePessimisticFixpoint();
24505dfd7cc4Ssstefan1 
24515dfd7cc4Ssstefan1     for (InternalControlVar ICV : TrackableICVs) {
24525dfd7cc4Ssstefan1       Optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
24535dfd7cc4Ssstefan1       Optional<Value *> NewReplVal =
24545dfd7cc4Ssstefan1           ICVTrackingAA.getUniqueReplacementValue(ICV);
24555dfd7cc4Ssstefan1 
24565dfd7cc4Ssstefan1       if (ReplVal == NewReplVal)
24575dfd7cc4Ssstefan1         continue;
24585dfd7cc4Ssstefan1 
24595dfd7cc4Ssstefan1       ReplVal = NewReplVal;
24605dfd7cc4Ssstefan1       Changed = ChangeStatus::CHANGED;
24615dfd7cc4Ssstefan1     }
24625dfd7cc4Ssstefan1     return Changed;
24635dfd7cc4Ssstefan1   }
24649548b74aSJohannes Doerfert };
246518283125SJoseph Huber 
246618283125SJoseph Huber struct AAExecutionDomainFunction : public AAExecutionDomain {
246718283125SJoseph Huber   AAExecutionDomainFunction(const IRPosition &IRP, Attributor &A)
246818283125SJoseph Huber       : AAExecutionDomain(IRP, A) {}
246918283125SJoseph Huber 
247018283125SJoseph Huber   const std::string getAsStr() const override {
247118283125SJoseph Huber     return "[AAExecutionDomain] " + std::to_string(SingleThreadedBBs.size()) +
247218283125SJoseph Huber            "/" + std::to_string(NumBBs) + " BBs thread 0 only.";
247318283125SJoseph Huber   }
247418283125SJoseph Huber 
247518283125SJoseph Huber   /// See AbstractAttribute::trackStatistics().
247618283125SJoseph Huber   void trackStatistics() const override {}
247718283125SJoseph Huber 
247818283125SJoseph Huber   void initialize(Attributor &A) override {
247918283125SJoseph Huber     Function *F = getAnchorScope();
248018283125SJoseph Huber     for (const auto &BB : *F)
248118283125SJoseph Huber       SingleThreadedBBs.insert(&BB);
248218283125SJoseph Huber     NumBBs = SingleThreadedBBs.size();
248318283125SJoseph Huber   }
248418283125SJoseph Huber 
248518283125SJoseph Huber   ChangeStatus manifest(Attributor &A) override {
248618283125SJoseph Huber     LLVM_DEBUG({
248718283125SJoseph Huber       for (const BasicBlock *BB : SingleThreadedBBs)
248818283125SJoseph Huber         dbgs() << TAG << " Basic block @" << getAnchorScope()->getName() << " "
248918283125SJoseph Huber                << BB->getName() << " is executed by a single thread.\n";
249018283125SJoseph Huber     });
249118283125SJoseph Huber     return ChangeStatus::UNCHANGED;
249218283125SJoseph Huber   }
249318283125SJoseph Huber 
249418283125SJoseph Huber   ChangeStatus updateImpl(Attributor &A) override;
249518283125SJoseph Huber 
249618283125SJoseph Huber   /// Check if an instruction is executed by a single thread.
24979a23e673SJohannes Doerfert   bool isExecutedByInitialThreadOnly(const Instruction &I) const override {
24989a23e673SJohannes Doerfert     return isExecutedByInitialThreadOnly(*I.getParent());
249918283125SJoseph Huber   }
250018283125SJoseph Huber 
25019a23e673SJohannes Doerfert   bool isExecutedByInitialThreadOnly(const BasicBlock &BB) const override {
25021cfdcae6SJoseph Huber     return isValidState() && SingleThreadedBBs.contains(&BB);
250318283125SJoseph Huber   }
250418283125SJoseph Huber 
250518283125SJoseph Huber   /// Set of basic blocks that are executed by a single thread.
250618283125SJoseph Huber   DenseSet<const BasicBlock *> SingleThreadedBBs;
250718283125SJoseph Huber 
250818283125SJoseph Huber   /// Total number of basic blocks in this function.
250918283125SJoseph Huber   long unsigned NumBBs;
251018283125SJoseph Huber };
251118283125SJoseph Huber 
251218283125SJoseph Huber ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) {
251318283125SJoseph Huber   Function *F = getAnchorScope();
251418283125SJoseph Huber   ReversePostOrderTraversal<Function *> RPOT(F);
251518283125SJoseph Huber   auto NumSingleThreadedBBs = SingleThreadedBBs.size();
251618283125SJoseph Huber 
251718283125SJoseph Huber   bool AllCallSitesKnown;
251818283125SJoseph Huber   auto PredForCallSite = [&](AbstractCallSite ACS) {
251918283125SJoseph Huber     const auto &ExecutionDomainAA = A.getAAFor<AAExecutionDomain>(
252018283125SJoseph Huber         *this, IRPosition::function(*ACS.getInstruction()->getFunction()),
252118283125SJoseph Huber         DepClassTy::REQUIRED);
25221cfdcae6SJoseph Huber     return ACS.isDirectCall() &&
25231cfdcae6SJoseph Huber            ExecutionDomainAA.isExecutedByInitialThreadOnly(
25249a23e673SJohannes Doerfert                *ACS.getInstruction());
252518283125SJoseph Huber   };
252618283125SJoseph Huber 
252718283125SJoseph Huber   if (!A.checkForAllCallSites(PredForCallSite, *this,
252818283125SJoseph Huber                               /* RequiresAllCallSites */ true,
252918283125SJoseph Huber                               AllCallSitesKnown))
253018283125SJoseph Huber     SingleThreadedBBs.erase(&F->getEntryBlock());
253118283125SJoseph Huber 
2532e2cfbfccSJohannes Doerfert   auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2533e2cfbfccSJohannes Doerfert   auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
2534e2cfbfccSJohannes Doerfert 
2535e2cfbfccSJohannes Doerfert   // Check if the edge into the successor block compares the __kmpc_target_init
2536e2cfbfccSJohannes Doerfert   // result with -1. If we are in non-SPMD-mode that signals only the main
2537e2cfbfccSJohannes Doerfert   // thread will execute the edge.
25386fc51c9fSJoseph Huber   auto IsInitialThreadOnly = [&](BranchInst *Edge, BasicBlock *SuccessorBB) {
253918283125SJoseph Huber     if (!Edge || !Edge->isConditional())
254018283125SJoseph Huber       return false;
254118283125SJoseph Huber     if (Edge->getSuccessor(0) != SuccessorBB)
254218283125SJoseph Huber       return false;
254318283125SJoseph Huber 
254418283125SJoseph Huber     auto *Cmp = dyn_cast<CmpInst>(Edge->getCondition());
254518283125SJoseph Huber     if (!Cmp || !Cmp->isTrueWhenEqual() || !Cmp->isEquality())
254618283125SJoseph Huber       return false;
254718283125SJoseph Huber 
254818283125SJoseph Huber     ConstantInt *C = dyn_cast<ConstantInt>(Cmp->getOperand(1));
2549e2cfbfccSJohannes Doerfert     if (!C)
255018283125SJoseph Huber       return false;
255118283125SJoseph Huber 
2552e2cfbfccSJohannes Doerfert     // Match:  -1 == __kmpc_target_init (for non-SPMD kernels only!)
2553e2cfbfccSJohannes Doerfert     if (C->isAllOnesValue()) {
2554e2cfbfccSJohannes Doerfert       auto *CB = dyn_cast<CallBase>(Cmp->getOperand(0));
2555c4b1fe05SJohannes Doerfert       CB = CB ? OpenMPOpt::getCallIfRegularCall(*CB, &RFI) : nullptr;
2556c4b1fe05SJohannes Doerfert       if (!CB)
2557e2cfbfccSJohannes Doerfert         return false;
2558e2cfbfccSJohannes Doerfert       const int InitIsSPMDArgNo = 1;
2559e2cfbfccSJohannes Doerfert       auto *IsSPMDModeCI =
2560e2cfbfccSJohannes Doerfert           dyn_cast<ConstantInt>(CB->getOperand(InitIsSPMDArgNo));
2561e2cfbfccSJohannes Doerfert       return IsSPMDModeCI && IsSPMDModeCI->isZero();
2562e2cfbfccSJohannes Doerfert     }
256318283125SJoseph Huber 
256418283125SJoseph Huber     return false;
256518283125SJoseph Huber   };
256618283125SJoseph Huber 
256718283125SJoseph Huber   // Merge all the predecessor states into the current basic block. A basic
256818283125SJoseph Huber   // block is executed by a single thread if all of its predecessors are.
256918283125SJoseph Huber   auto MergePredecessorStates = [&](BasicBlock *BB) {
257018283125SJoseph Huber     if (pred_begin(BB) == pred_end(BB))
257118283125SJoseph Huber       return SingleThreadedBBs.contains(BB);
257218283125SJoseph Huber 
25736fc51c9fSJoseph Huber     bool IsInitialThread = true;
257418283125SJoseph Huber     for (auto PredBB = pred_begin(BB), PredEndBB = pred_end(BB);
257518283125SJoseph Huber          PredBB != PredEndBB; ++PredBB) {
25766fc51c9fSJoseph Huber       if (!IsInitialThreadOnly(dyn_cast<BranchInst>((*PredBB)->getTerminator()),
257718283125SJoseph Huber                                BB))
25786fc51c9fSJoseph Huber         IsInitialThread &= SingleThreadedBBs.contains(*PredBB);
257918283125SJoseph Huber     }
258018283125SJoseph Huber 
25816fc51c9fSJoseph Huber     return IsInitialThread;
258218283125SJoseph Huber   };
258318283125SJoseph Huber 
258418283125SJoseph Huber   for (auto *BB : RPOT) {
258518283125SJoseph Huber     if (!MergePredecessorStates(BB))
258618283125SJoseph Huber       SingleThreadedBBs.erase(BB);
258718283125SJoseph Huber   }
258818283125SJoseph Huber 
258918283125SJoseph Huber   return (NumSingleThreadedBBs == SingleThreadedBBs.size())
259018283125SJoseph Huber              ? ChangeStatus::UNCHANGED
259118283125SJoseph Huber              : ChangeStatus::CHANGED;
259218283125SJoseph Huber }
259318283125SJoseph Huber 
25946fc51c9fSJoseph Huber /// Try to replace memory allocation calls called by a single thread with a
25956fc51c9fSJoseph Huber /// static buffer of shared memory.
25966fc51c9fSJoseph Huber struct AAHeapToShared : public StateWrapper<BooleanState, AbstractAttribute> {
25976fc51c9fSJoseph Huber   using Base = StateWrapper<BooleanState, AbstractAttribute>;
25986fc51c9fSJoseph Huber   AAHeapToShared(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
25996fc51c9fSJoseph Huber 
26006fc51c9fSJoseph Huber   /// Create an abstract attribute view for the position \p IRP.
26016fc51c9fSJoseph Huber   static AAHeapToShared &createForPosition(const IRPosition &IRP,
26026fc51c9fSJoseph Huber                                            Attributor &A);
26036fc51c9fSJoseph Huber 
2604f8c40ed8SGiorgis Georgakoudis   /// Returns true if HeapToShared conversion is assumed to be possible.
2605f8c40ed8SGiorgis Georgakoudis   virtual bool isAssumedHeapToShared(CallBase &CB) const = 0;
2606f8c40ed8SGiorgis Georgakoudis 
2607f8c40ed8SGiorgis Georgakoudis   /// Returns true if HeapToShared conversion is assumed and the CB is a
2608f8c40ed8SGiorgis Georgakoudis   /// callsite to a free operation to be removed.
2609f8c40ed8SGiorgis Georgakoudis   virtual bool isAssumedHeapToSharedRemovedFree(CallBase &CB) const = 0;
2610f8c40ed8SGiorgis Georgakoudis 
26116fc51c9fSJoseph Huber   /// See AbstractAttribute::getName().
26126fc51c9fSJoseph Huber   const std::string getName() const override { return "AAHeapToShared"; }
26136fc51c9fSJoseph Huber 
26146fc51c9fSJoseph Huber   /// See AbstractAttribute::getIdAddr().
26156fc51c9fSJoseph Huber   const char *getIdAddr() const override { return &ID; }
26166fc51c9fSJoseph Huber 
26176fc51c9fSJoseph Huber   /// This function should return true if the type of the \p AA is
26186fc51c9fSJoseph Huber   /// AAHeapToShared.
26196fc51c9fSJoseph Huber   static bool classof(const AbstractAttribute *AA) {
26206fc51c9fSJoseph Huber     return (AA->getIdAddr() == &ID);
26216fc51c9fSJoseph Huber   }
26226fc51c9fSJoseph Huber 
26236fc51c9fSJoseph Huber   /// Unique ID (due to the unique address)
26246fc51c9fSJoseph Huber   static const char ID;
26256fc51c9fSJoseph Huber };
26266fc51c9fSJoseph Huber 
26276fc51c9fSJoseph Huber struct AAHeapToSharedFunction : public AAHeapToShared {
26286fc51c9fSJoseph Huber   AAHeapToSharedFunction(const IRPosition &IRP, Attributor &A)
26296fc51c9fSJoseph Huber       : AAHeapToShared(IRP, A) {}
26306fc51c9fSJoseph Huber 
26316fc51c9fSJoseph Huber   const std::string getAsStr() const override {
26326fc51c9fSJoseph Huber     return "[AAHeapToShared] " + std::to_string(MallocCalls.size()) +
26336fc51c9fSJoseph Huber            " malloc calls eligible.";
26346fc51c9fSJoseph Huber   }
26356fc51c9fSJoseph Huber 
26366fc51c9fSJoseph Huber   /// See AbstractAttribute::trackStatistics().
26376fc51c9fSJoseph Huber   void trackStatistics() const override {}
26386fc51c9fSJoseph Huber 
2639f8c40ed8SGiorgis Georgakoudis   /// This functions finds free calls that will be removed by the
2640f8c40ed8SGiorgis Georgakoudis   /// HeapToShared transformation.
2641f8c40ed8SGiorgis Georgakoudis   void findPotentialRemovedFreeCalls(Attributor &A) {
2642f8c40ed8SGiorgis Georgakoudis     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2643f8c40ed8SGiorgis Georgakoudis     auto &FreeRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];
2644f8c40ed8SGiorgis Georgakoudis 
2645f8c40ed8SGiorgis Georgakoudis     PotentialRemovedFreeCalls.clear();
2646f8c40ed8SGiorgis Georgakoudis     // Update free call users of found malloc calls.
2647f8c40ed8SGiorgis Georgakoudis     for (CallBase *CB : MallocCalls) {
2648f8c40ed8SGiorgis Georgakoudis       SmallVector<CallBase *, 4> FreeCalls;
2649f8c40ed8SGiorgis Georgakoudis       for (auto *U : CB->users()) {
2650f8c40ed8SGiorgis Georgakoudis         CallBase *C = dyn_cast<CallBase>(U);
2651f8c40ed8SGiorgis Georgakoudis         if (C && C->getCalledFunction() == FreeRFI.Declaration)
2652f8c40ed8SGiorgis Georgakoudis           FreeCalls.push_back(C);
2653f8c40ed8SGiorgis Georgakoudis       }
2654f8c40ed8SGiorgis Georgakoudis 
2655f8c40ed8SGiorgis Georgakoudis       if (FreeCalls.size() != 1)
2656f8c40ed8SGiorgis Georgakoudis         continue;
2657f8c40ed8SGiorgis Georgakoudis 
2658f8c40ed8SGiorgis Georgakoudis       PotentialRemovedFreeCalls.insert(FreeCalls.front());
2659f8c40ed8SGiorgis Georgakoudis     }
2660f8c40ed8SGiorgis Georgakoudis   }
2661f8c40ed8SGiorgis Georgakoudis 
26626fc51c9fSJoseph Huber   void initialize(Attributor &A) override {
26636fc51c9fSJoseph Huber     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
26646fc51c9fSJoseph Huber     auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
26656fc51c9fSJoseph Huber 
26666fc51c9fSJoseph Huber     for (User *U : RFI.Declaration->users())
26676fc51c9fSJoseph Huber       if (CallBase *CB = dyn_cast<CallBase>(U))
26686fc51c9fSJoseph Huber         MallocCalls.insert(CB);
2669f8c40ed8SGiorgis Georgakoudis 
2670f8c40ed8SGiorgis Georgakoudis     findPotentialRemovedFreeCalls(A);
2671f8c40ed8SGiorgis Georgakoudis   }
2672f8c40ed8SGiorgis Georgakoudis 
2673eaab880eSGiorgis Georgakoudis   bool isAssumedHeapToShared(CallBase &CB) const override {
2674f8c40ed8SGiorgis Georgakoudis     return isValidState() && MallocCalls.count(&CB);
2675f8c40ed8SGiorgis Georgakoudis   }
2676f8c40ed8SGiorgis Georgakoudis 
2677eaab880eSGiorgis Georgakoudis   bool isAssumedHeapToSharedRemovedFree(CallBase &CB) const override {
2678f8c40ed8SGiorgis Georgakoudis     return isValidState() && PotentialRemovedFreeCalls.count(&CB);
26796fc51c9fSJoseph Huber   }
26806fc51c9fSJoseph Huber 
26816fc51c9fSJoseph Huber   ChangeStatus manifest(Attributor &A) override {
26826fc51c9fSJoseph Huber     if (MallocCalls.empty())
26836fc51c9fSJoseph Huber       return ChangeStatus::UNCHANGED;
26846fc51c9fSJoseph Huber 
26856fc51c9fSJoseph Huber     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
26866fc51c9fSJoseph Huber     auto &FreeCall = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];
26876fc51c9fSJoseph Huber 
26886fc51c9fSJoseph Huber     Function *F = getAnchorScope();
26896fc51c9fSJoseph Huber     auto *HS = A.lookupAAFor<AAHeapToStack>(IRPosition::function(*F), this,
26906fc51c9fSJoseph Huber                                             DepClassTy::OPTIONAL);
26916fc51c9fSJoseph Huber 
26926fc51c9fSJoseph Huber     ChangeStatus Changed = ChangeStatus::UNCHANGED;
26936fc51c9fSJoseph Huber     for (CallBase *CB : MallocCalls) {
26946fc51c9fSJoseph Huber       // Skip replacing this if HeapToStack has already claimed it.
2695c1c1fe93SJohannes Doerfert       if (HS && HS->isAssumedHeapToStack(*CB))
26966fc51c9fSJoseph Huber         continue;
26976fc51c9fSJoseph Huber 
26986fc51c9fSJoseph Huber       // Find the unique free call to remove it.
26996fc51c9fSJoseph Huber       SmallVector<CallBase *, 4> FreeCalls;
27006fc51c9fSJoseph Huber       for (auto *U : CB->users()) {
27016fc51c9fSJoseph Huber         CallBase *C = dyn_cast<CallBase>(U);
27026fc51c9fSJoseph Huber         if (C && C->getCalledFunction() == FreeCall.Declaration)
27036fc51c9fSJoseph Huber           FreeCalls.push_back(C);
27046fc51c9fSJoseph Huber       }
27056fc51c9fSJoseph Huber       if (FreeCalls.size() != 1)
27066fc51c9fSJoseph Huber         continue;
27076fc51c9fSJoseph Huber 
27086fc51c9fSJoseph Huber       ConstantInt *AllocSize = dyn_cast<ConstantInt>(CB->getArgOperand(0));
27096fc51c9fSJoseph Huber 
27106fc51c9fSJoseph Huber       LLVM_DEBUG(dbgs() << TAG << "Replace globalization call in "
27116fc51c9fSJoseph Huber                         << CB->getCaller()->getName() << " with "
27126fc51c9fSJoseph Huber                         << AllocSize->getZExtValue()
27136fc51c9fSJoseph Huber                         << " bytes of shared memory\n");
27146fc51c9fSJoseph Huber 
27156fc51c9fSJoseph Huber       // Create a new shared memory buffer of the same size as the allocation
27166fc51c9fSJoseph Huber       // and replace all the uses of the original allocation with it.
27176fc51c9fSJoseph Huber       Module *M = CB->getModule();
27186fc51c9fSJoseph Huber       Type *Int8Ty = Type::getInt8Ty(M->getContext());
27196fc51c9fSJoseph Huber       Type *Int8ArrTy = ArrayType::get(Int8Ty, AllocSize->getZExtValue());
27206fc51c9fSJoseph Huber       auto *SharedMem = new GlobalVariable(
27216fc51c9fSJoseph Huber           *M, Int8ArrTy, /* IsConstant */ false, GlobalValue::InternalLinkage,
27226fc51c9fSJoseph Huber           UndefValue::get(Int8ArrTy), CB->getName(), nullptr,
27236fc51c9fSJoseph Huber           GlobalValue::NotThreadLocal,
27246fc51c9fSJoseph Huber           static_cast<unsigned>(AddressSpace::Shared));
27256fc51c9fSJoseph Huber       auto *NewBuffer =
27266fc51c9fSJoseph Huber           ConstantExpr::getPointerCast(SharedMem, Int8Ty->getPointerTo());
27276fc51c9fSJoseph Huber 
272830e36c9bSJoseph Huber       auto Remark = [&](OptimizationRemark OR) {
272930e36c9bSJoseph Huber         return OR << "Replaced globalized variable with "
273030e36c9bSJoseph Huber                   << ore::NV("SharedMemory", AllocSize->getZExtValue())
273130e36c9bSJoseph Huber                   << ((AllocSize->getZExtValue() != 1) ? " bytes " : " byte ")
2732eef6601bSJoseph Huber                   << "of shared memory.";
273330e36c9bSJoseph Huber       };
27342c31d5ebSJoseph Huber       A.emitRemark<OptimizationRemark>(CB, "OMP111", Remark);
273530e36c9bSJoseph Huber 
27366fc51c9fSJoseph Huber       SharedMem->setAlignment(MaybeAlign(32));
27376fc51c9fSJoseph Huber 
27386fc51c9fSJoseph Huber       A.changeValueAfterManifest(*CB, *NewBuffer);
27396fc51c9fSJoseph Huber       A.deleteAfterManifest(*CB);
27406fc51c9fSJoseph Huber       A.deleteAfterManifest(*FreeCalls.front());
27416fc51c9fSJoseph Huber 
27426fc51c9fSJoseph Huber       NumBytesMovedToSharedMemory += AllocSize->getZExtValue();
27436fc51c9fSJoseph Huber       Changed = ChangeStatus::CHANGED;
27446fc51c9fSJoseph Huber     }
27456fc51c9fSJoseph Huber 
27466fc51c9fSJoseph Huber     return Changed;
27476fc51c9fSJoseph Huber   }
27486fc51c9fSJoseph Huber 
27496fc51c9fSJoseph Huber   ChangeStatus updateImpl(Attributor &A) override {
27506fc51c9fSJoseph Huber     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
27516fc51c9fSJoseph Huber     auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
27526fc51c9fSJoseph Huber     Function *F = getAnchorScope();
27536fc51c9fSJoseph Huber 
27546fc51c9fSJoseph Huber     auto NumMallocCalls = MallocCalls.size();
27556fc51c9fSJoseph Huber 
27566fc51c9fSJoseph Huber     // Only consider malloc calls executed by a single thread with a constant.
27576fc51c9fSJoseph Huber     for (User *U : RFI.Declaration->users()) {
27586fc51c9fSJoseph Huber       const auto &ED = A.getAAFor<AAExecutionDomain>(
27596fc51c9fSJoseph Huber           *this, IRPosition::function(*F), DepClassTy::REQUIRED);
27606fc51c9fSJoseph Huber       if (CallBase *CB = dyn_cast<CallBase>(U))
27616fc51c9fSJoseph Huber         if (!dyn_cast<ConstantInt>(CB->getArgOperand(0)) ||
27626fc51c9fSJoseph Huber             !ED.isExecutedByInitialThreadOnly(*CB))
27636fc51c9fSJoseph Huber           MallocCalls.erase(CB);
27646fc51c9fSJoseph Huber     }
27656fc51c9fSJoseph Huber 
2766f8c40ed8SGiorgis Georgakoudis     findPotentialRemovedFreeCalls(A);
2767f8c40ed8SGiorgis Georgakoudis 
27686fc51c9fSJoseph Huber     if (NumMallocCalls != MallocCalls.size())
27696fc51c9fSJoseph Huber       return ChangeStatus::CHANGED;
27706fc51c9fSJoseph Huber 
27716fc51c9fSJoseph Huber     return ChangeStatus::UNCHANGED;
27726fc51c9fSJoseph Huber   }
27736fc51c9fSJoseph Huber 
27746fc51c9fSJoseph Huber   /// Collection of all malloc calls in a function.
27756fc51c9fSJoseph Huber   SmallPtrSet<CallBase *, 4> MallocCalls;
2776f8c40ed8SGiorgis Georgakoudis   /// Collection of potentially removed free calls in a function.
2777f8c40ed8SGiorgis Georgakoudis   SmallPtrSet<CallBase *, 4> PotentialRemovedFreeCalls;
27786fc51c9fSJoseph Huber };
27796fc51c9fSJoseph Huber 
2780d9659bf6SJohannes Doerfert struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> {
2781d9659bf6SJohannes Doerfert   using Base = StateWrapper<KernelInfoState, AbstractAttribute>;
2782d9659bf6SJohannes Doerfert   AAKernelInfo(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
2783d9659bf6SJohannes Doerfert 
2784d9659bf6SJohannes Doerfert   /// Statistics are tracked as part of manifest for now.
2785d9659bf6SJohannes Doerfert   void trackStatistics() const override {}
2786d9659bf6SJohannes Doerfert 
2787d9659bf6SJohannes Doerfert   /// See AbstractAttribute::getAsStr()
2788d9659bf6SJohannes Doerfert   const std::string getAsStr() const override {
2789d9659bf6SJohannes Doerfert     if (!isValidState())
2790d9659bf6SJohannes Doerfert       return "<invalid>";
2791514c033dSJohannes Doerfert     return std::string(SPMDCompatibilityTracker.isAssumed() ? "SPMD"
2792514c033dSJohannes Doerfert                                                             : "generic") +
2793514c033dSJohannes Doerfert            std::string(SPMDCompatibilityTracker.isAtFixpoint() ? " [FIX]"
2794514c033dSJohannes Doerfert                                                                : "") +
2795d9659bf6SJohannes Doerfert            std::string(" #PRs: ") +
2796d9659bf6SJohannes Doerfert            std::to_string(ReachedKnownParallelRegions.size()) +
2797d9659bf6SJohannes Doerfert            ", #Unknown PRs: " +
2798d9659bf6SJohannes Doerfert            std::to_string(ReachedUnknownParallelRegions.size());
2799d9659bf6SJohannes Doerfert   }
2800d9659bf6SJohannes Doerfert 
2801d9659bf6SJohannes Doerfert   /// Create an abstract attribute biew for the position \p IRP.
2802d9659bf6SJohannes Doerfert   static AAKernelInfo &createForPosition(const IRPosition &IRP, Attributor &A);
2803d9659bf6SJohannes Doerfert 
2804d9659bf6SJohannes Doerfert   /// See AbstractAttribute::getName()
2805d9659bf6SJohannes Doerfert   const std::string getName() const override { return "AAKernelInfo"; }
2806d9659bf6SJohannes Doerfert 
2807d9659bf6SJohannes Doerfert   /// See AbstractAttribute::getIdAddr()
2808d9659bf6SJohannes Doerfert   const char *getIdAddr() const override { return &ID; }
2809d9659bf6SJohannes Doerfert 
2810d9659bf6SJohannes Doerfert   /// This function should return true if the type of the \p AA is AAKernelInfo
2811d9659bf6SJohannes Doerfert   static bool classof(const AbstractAttribute *AA) {
2812d9659bf6SJohannes Doerfert     return (AA->getIdAddr() == &ID);
2813d9659bf6SJohannes Doerfert   }
2814d9659bf6SJohannes Doerfert 
2815d9659bf6SJohannes Doerfert   static const char ID;
2816d9659bf6SJohannes Doerfert };
2817d9659bf6SJohannes Doerfert 
2818d9659bf6SJohannes Doerfert /// The function kernel info abstract attribute, basically, what can we say
2819d9659bf6SJohannes Doerfert /// about a function with regards to the KernelInfoState.
2820d9659bf6SJohannes Doerfert struct AAKernelInfoFunction : AAKernelInfo {
2821d9659bf6SJohannes Doerfert   AAKernelInfoFunction(const IRPosition &IRP, Attributor &A)
2822d9659bf6SJohannes Doerfert       : AAKernelInfo(IRP, A) {}
2823d9659bf6SJohannes Doerfert 
282429a3e3ddSGiorgis Georgakoudis   SmallPtrSet<Instruction *, 4> GuardedInstructions;
282529a3e3ddSGiorgis Georgakoudis 
282629a3e3ddSGiorgis Georgakoudis   SmallPtrSetImpl<Instruction *> &getGuardedInstructions() {
282729a3e3ddSGiorgis Georgakoudis     return GuardedInstructions;
282829a3e3ddSGiorgis Georgakoudis   }
282929a3e3ddSGiorgis Georgakoudis 
2830d9659bf6SJohannes Doerfert   /// See AbstractAttribute::initialize(...).
2831d9659bf6SJohannes Doerfert   void initialize(Attributor &A) override {
2832d9659bf6SJohannes Doerfert     // This is a high-level transform that might change the constant arguments
2833d9659bf6SJohannes Doerfert     // of the init and dinit calls. We need to tell the Attributor about this
2834d9659bf6SJohannes Doerfert     // to avoid other parts using the current constant value for simpliication.
2835d9659bf6SJohannes Doerfert     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2836d9659bf6SJohannes Doerfert 
2837d9659bf6SJohannes Doerfert     Function *Fn = getAnchorScope();
2838d9659bf6SJohannes Doerfert     if (!OMPInfoCache.Kernels.count(Fn))
2839d9659bf6SJohannes Doerfert       return;
2840d9659bf6SJohannes Doerfert 
2841ca662297SShilei Tian     // Add itself to the reaching kernel and set IsKernelEntry.
2842ca662297SShilei Tian     ReachingKernelEntries.insert(Fn);
2843ca662297SShilei Tian     IsKernelEntry = true;
2844ca662297SShilei Tian 
2845d9659bf6SJohannes Doerfert     OMPInformationCache::RuntimeFunctionInfo &InitRFI =
2846d9659bf6SJohannes Doerfert         OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
2847d9659bf6SJohannes Doerfert     OMPInformationCache::RuntimeFunctionInfo &DeinitRFI =
2848d9659bf6SJohannes Doerfert         OMPInfoCache.RFIs[OMPRTL___kmpc_target_deinit];
2849d9659bf6SJohannes Doerfert 
2850d9659bf6SJohannes Doerfert     // For kernels we perform more initialization work, first we find the init
2851d9659bf6SJohannes Doerfert     // and deinit calls.
2852d9659bf6SJohannes Doerfert     auto StoreCallBase = [](Use &U,
2853d9659bf6SJohannes Doerfert                             OMPInformationCache::RuntimeFunctionInfo &RFI,
2854d9659bf6SJohannes Doerfert                             CallBase *&Storage) {
2855d9659bf6SJohannes Doerfert       CallBase *CB = OpenMPOpt::getCallIfRegularCall(U, &RFI);
2856d9659bf6SJohannes Doerfert       assert(CB &&
2857d9659bf6SJohannes Doerfert              "Unexpected use of __kmpc_target_init or __kmpc_target_deinit!");
2858d9659bf6SJohannes Doerfert       assert(!Storage &&
2859d9659bf6SJohannes Doerfert              "Multiple uses of __kmpc_target_init or __kmpc_target_deinit!");
2860d9659bf6SJohannes Doerfert       Storage = CB;
2861d9659bf6SJohannes Doerfert       return false;
2862d9659bf6SJohannes Doerfert     };
2863d9659bf6SJohannes Doerfert     InitRFI.foreachUse(
2864d9659bf6SJohannes Doerfert         [&](Use &U, Function &) {
2865d9659bf6SJohannes Doerfert           StoreCallBase(U, InitRFI, KernelInitCB);
2866d9659bf6SJohannes Doerfert           return false;
2867d9659bf6SJohannes Doerfert         },
2868d9659bf6SJohannes Doerfert         Fn);
2869d9659bf6SJohannes Doerfert     DeinitRFI.foreachUse(
2870d9659bf6SJohannes Doerfert         [&](Use &U, Function &) {
2871d9659bf6SJohannes Doerfert           StoreCallBase(U, DeinitRFI, KernelDeinitCB);
2872d9659bf6SJohannes Doerfert           return false;
2873d9659bf6SJohannes Doerfert         },
2874d9659bf6SJohannes Doerfert         Fn);
2875d9659bf6SJohannes Doerfert 
2876d9659bf6SJohannes Doerfert     assert((KernelInitCB && KernelDeinitCB) &&
2877d9659bf6SJohannes Doerfert            "Kernel without __kmpc_target_init or __kmpc_target_deinit!");
2878d9659bf6SJohannes Doerfert 
2879514c033dSJohannes Doerfert     // For kernels we might need to initialize/finalize the IsSPMD state and
2880514c033dSJohannes Doerfert     // we need to register a simplification callback so that the Attributor
2881514c033dSJohannes Doerfert     // knows the constant arguments to __kmpc_target_init and
2882d9659bf6SJohannes Doerfert     // __kmpc_target_deinit might actually change.
2883d9659bf6SJohannes Doerfert 
2884d9659bf6SJohannes Doerfert     Attributor::SimplifictionCallbackTy StateMachineSimplifyCB =
2885d9659bf6SJohannes Doerfert         [&](const IRPosition &IRP, const AbstractAttribute *AA,
2886d9659bf6SJohannes Doerfert             bool &UsedAssumedInformation) -> Optional<Value *> {
2887d9659bf6SJohannes Doerfert       // IRP represents the "use generic state machine" argument of an
2888d9659bf6SJohannes Doerfert       // __kmpc_target_init call. We will answer this one with the internal
2889d9659bf6SJohannes Doerfert       // state. As long as we are not in an invalid state, we will create a
2890d9659bf6SJohannes Doerfert       // custom state machine so the value should be a `i1 false`. If we are
2891d9659bf6SJohannes Doerfert       // in an invalid state, we won't change the value that is in the IR.
2892d9659bf6SJohannes Doerfert       if (!isValidState())
2893d9659bf6SJohannes Doerfert         return nullptr;
2894e0c5d83aSJohannes Doerfert       // If we have disabled state machine rewrites, don't make a custom one.
2895e0c5d83aSJohannes Doerfert       if (DisableOpenMPOptStateMachineRewrite)
2896e0c5d83aSJohannes Doerfert         return nullptr;
2897d9659bf6SJohannes Doerfert       if (AA)
2898d9659bf6SJohannes Doerfert         A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
2899d9659bf6SJohannes Doerfert       UsedAssumedInformation = !isAtFixpoint();
2900d9659bf6SJohannes Doerfert       auto *FalseVal =
2901d9659bf6SJohannes Doerfert           ConstantInt::getBool(IRP.getAnchorValue().getContext(), 0);
2902d9659bf6SJohannes Doerfert       return FalseVal;
2903d9659bf6SJohannes Doerfert     };
2904d9659bf6SJohannes Doerfert 
2905514c033dSJohannes Doerfert     Attributor::SimplifictionCallbackTy IsSPMDModeSimplifyCB =
2906514c033dSJohannes Doerfert         [&](const IRPosition &IRP, const AbstractAttribute *AA,
2907514c033dSJohannes Doerfert             bool &UsedAssumedInformation) -> Optional<Value *> {
2908514c033dSJohannes Doerfert       // IRP represents the "SPMDCompatibilityTracker" argument of an
2909514c033dSJohannes Doerfert       // __kmpc_target_init or
2910514c033dSJohannes Doerfert       // __kmpc_target_deinit call. We will answer this one with the internal
2911514c033dSJohannes Doerfert       // state.
291297387fdfSJohannes Doerfert       if (!SPMDCompatibilityTracker.isValidState())
2913514c033dSJohannes Doerfert         return nullptr;
2914514c033dSJohannes Doerfert       if (!SPMDCompatibilityTracker.isAtFixpoint()) {
2915514c033dSJohannes Doerfert         if (AA)
2916514c033dSJohannes Doerfert           A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
2917514c033dSJohannes Doerfert         UsedAssumedInformation = true;
2918514c033dSJohannes Doerfert       } else {
2919514c033dSJohannes Doerfert         UsedAssumedInformation = false;
2920514c033dSJohannes Doerfert       }
2921514c033dSJohannes Doerfert       auto *Val = ConstantInt::getBool(IRP.getAnchorValue().getContext(),
2922514c033dSJohannes Doerfert                                        SPMDCompatibilityTracker.isAssumed());
2923514c033dSJohannes Doerfert       return Val;
2924514c033dSJohannes Doerfert     };
2925514c033dSJohannes Doerfert 
2926e8439ec8SGiorgis Georgakoudis     Attributor::SimplifictionCallbackTy IsGenericModeSimplifyCB =
2927e8439ec8SGiorgis Georgakoudis         [&](const IRPosition &IRP, const AbstractAttribute *AA,
2928e8439ec8SGiorgis Georgakoudis             bool &UsedAssumedInformation) -> Optional<Value *> {
2929e8439ec8SGiorgis Georgakoudis       // IRP represents the "RequiresFullRuntime" argument of an
2930e8439ec8SGiorgis Georgakoudis       // __kmpc_target_init or __kmpc_target_deinit call. We will answer this
2931e8439ec8SGiorgis Georgakoudis       // one with the internal state of the SPMDCompatibilityTracker, so if
2932e8439ec8SGiorgis Georgakoudis       // generic then true, if SPMD then false.
2933e8439ec8SGiorgis Georgakoudis       if (!SPMDCompatibilityTracker.isValidState())
2934e8439ec8SGiorgis Georgakoudis         return nullptr;
2935e8439ec8SGiorgis Georgakoudis       if (!SPMDCompatibilityTracker.isAtFixpoint()) {
2936e8439ec8SGiorgis Georgakoudis         if (AA)
2937e8439ec8SGiorgis Georgakoudis           A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
2938e8439ec8SGiorgis Georgakoudis         UsedAssumedInformation = true;
2939e8439ec8SGiorgis Georgakoudis       } else {
2940e8439ec8SGiorgis Georgakoudis         UsedAssumedInformation = false;
2941e8439ec8SGiorgis Georgakoudis       }
2942e8439ec8SGiorgis Georgakoudis       auto *Val = ConstantInt::getBool(IRP.getAnchorValue().getContext(),
2943e8439ec8SGiorgis Georgakoudis                                        !SPMDCompatibilityTracker.isAssumed());
2944e8439ec8SGiorgis Georgakoudis       return Val;
2945e8439ec8SGiorgis Georgakoudis     };
2946e8439ec8SGiorgis Georgakoudis 
2947514c033dSJohannes Doerfert     constexpr const int InitIsSPMDArgNo = 1;
2948514c033dSJohannes Doerfert     constexpr const int DeinitIsSPMDArgNo = 1;
2949d9659bf6SJohannes Doerfert     constexpr const int InitUseStateMachineArgNo = 2;
2950e8439ec8SGiorgis Georgakoudis     constexpr const int InitRequiresFullRuntimeArgNo = 3;
2951e8439ec8SGiorgis Georgakoudis     constexpr const int DeinitRequiresFullRuntimeArgNo = 2;
2952d9659bf6SJohannes Doerfert     A.registerSimplificationCallback(
2953d9659bf6SJohannes Doerfert         IRPosition::callsite_argument(*KernelInitCB, InitUseStateMachineArgNo),
2954d9659bf6SJohannes Doerfert         StateMachineSimplifyCB);
2955514c033dSJohannes Doerfert     A.registerSimplificationCallback(
2956514c033dSJohannes Doerfert         IRPosition::callsite_argument(*KernelInitCB, InitIsSPMDArgNo),
2957514c033dSJohannes Doerfert         IsSPMDModeSimplifyCB);
2958514c033dSJohannes Doerfert     A.registerSimplificationCallback(
2959514c033dSJohannes Doerfert         IRPosition::callsite_argument(*KernelDeinitCB, DeinitIsSPMDArgNo),
2960514c033dSJohannes Doerfert         IsSPMDModeSimplifyCB);
2961e8439ec8SGiorgis Georgakoudis     A.registerSimplificationCallback(
2962e8439ec8SGiorgis Georgakoudis         IRPosition::callsite_argument(*KernelInitCB,
2963e8439ec8SGiorgis Georgakoudis                                       InitRequiresFullRuntimeArgNo),
2964e8439ec8SGiorgis Georgakoudis         IsGenericModeSimplifyCB);
2965e8439ec8SGiorgis Georgakoudis     A.registerSimplificationCallback(
2966e8439ec8SGiorgis Georgakoudis         IRPosition::callsite_argument(*KernelDeinitCB,
2967e8439ec8SGiorgis Georgakoudis                                       DeinitRequiresFullRuntimeArgNo),
2968e8439ec8SGiorgis Georgakoudis         IsGenericModeSimplifyCB);
2969514c033dSJohannes Doerfert 
2970514c033dSJohannes Doerfert     // Check if we know we are in SPMD-mode already.
2971514c033dSJohannes Doerfert     ConstantInt *IsSPMDArg =
2972514c033dSJohannes Doerfert         dyn_cast<ConstantInt>(KernelInitCB->getArgOperand(InitIsSPMDArgNo));
2973514c033dSJohannes Doerfert     if (IsSPMDArg && !IsSPMDArg->isZero())
2974514c033dSJohannes Doerfert       SPMDCompatibilityTracker.indicateOptimisticFixpoint();
297560e643feSGiorgis Georgakoudis     // This is a generic region but SPMDization is disabled so stop tracking.
297660e643feSGiorgis Georgakoudis     else if (DisableOpenMPOptSPMDization)
297760e643feSGiorgis Georgakoudis       SPMDCompatibilityTracker.indicatePessimisticFixpoint();
2978d9659bf6SJohannes Doerfert   }
2979d9659bf6SJohannes Doerfert 
2980d9659bf6SJohannes Doerfert   /// Modify the IR based on the KernelInfoState as the fixpoint iteration is
2981d9659bf6SJohannes Doerfert   /// finished now.
2982d9659bf6SJohannes Doerfert   ChangeStatus manifest(Attributor &A) override {
2983d9659bf6SJohannes Doerfert     // If we are not looking at a kernel with __kmpc_target_init and
2984d9659bf6SJohannes Doerfert     // __kmpc_target_deinit call we cannot actually manifest the information.
2985d9659bf6SJohannes Doerfert     if (!KernelInitCB || !KernelDeinitCB)
2986d9659bf6SJohannes Doerfert       return ChangeStatus::UNCHANGED;
2987d9659bf6SJohannes Doerfert 
2988514c033dSJohannes Doerfert     // Known SPMD-mode kernels need no manifest changes.
2989514c033dSJohannes Doerfert     if (SPMDCompatibilityTracker.isKnown())
2990514c033dSJohannes Doerfert       return ChangeStatus::UNCHANGED;
2991514c033dSJohannes Doerfert 
2992514c033dSJohannes Doerfert     // If we can we change the execution mode to SPMD-mode otherwise we build a
2993514c033dSJohannes Doerfert     // custom state machine.
2994514c033dSJohannes Doerfert     if (!changeToSPMDMode(A))
2995d9659bf6SJohannes Doerfert       buildCustomStateMachine(A);
2996d9659bf6SJohannes Doerfert 
2997d9659bf6SJohannes Doerfert     return ChangeStatus::CHANGED;
2998d9659bf6SJohannes Doerfert   }
2999d9659bf6SJohannes Doerfert 
3000514c033dSJohannes Doerfert   bool changeToSPMDMode(Attributor &A) {
3001eef6601bSJoseph Huber     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3002eef6601bSJoseph Huber 
3003514c033dSJohannes Doerfert     if (!SPMDCompatibilityTracker.isAssumed()) {
3004514c033dSJohannes Doerfert       for (Instruction *NonCompatibleI : SPMDCompatibilityTracker) {
3005514c033dSJohannes Doerfert         if (!NonCompatibleI)
3006514c033dSJohannes Doerfert           continue;
3007eef6601bSJoseph Huber 
3008eef6601bSJoseph Huber         // Skip diagnostics on calls to known OpenMP runtime functions for now.
3009eef6601bSJoseph Huber         if (auto *CB = dyn_cast<CallBase>(NonCompatibleI))
3010eef6601bSJoseph Huber           if (OMPInfoCache.RTLFunctions.contains(CB->getCalledFunction()))
3011eef6601bSJoseph Huber             continue;
3012eef6601bSJoseph Huber 
3013514c033dSJohannes Doerfert         auto Remark = [&](OptimizationRemarkAnalysis ORA) {
3014eef6601bSJoseph Huber           ORA << "Value has potential side effects preventing SPMD-mode "
3015eef6601bSJoseph Huber                  "execution";
3016eef6601bSJoseph Huber           if (isa<CallBase>(NonCompatibleI)) {
3017eef6601bSJoseph Huber             ORA << ". Add `__attribute__((assume(\"ompx_spmd_amenable\")))` to "
3018eef6601bSJoseph Huber                    "the called function to override";
3019514c033dSJohannes Doerfert           }
3020514c033dSJohannes Doerfert           return ORA << ".";
3021514c033dSJohannes Doerfert         };
30222c31d5ebSJoseph Huber         A.emitRemark<OptimizationRemarkAnalysis>(NonCompatibleI, "OMP121",
30232c31d5ebSJoseph Huber                                                  Remark);
3024514c033dSJohannes Doerfert 
3025514c033dSJohannes Doerfert         LLVM_DEBUG(dbgs() << TAG << "SPMD-incompatible side-effect: "
3026514c033dSJohannes Doerfert                           << *NonCompatibleI << "\n");
3027514c033dSJohannes Doerfert       }
3028514c033dSJohannes Doerfert 
3029514c033dSJohannes Doerfert       return false;
3030514c033dSJohannes Doerfert     }
3031514c033dSJohannes Doerfert 
303229a3e3ddSGiorgis Georgakoudis     auto CreateGuardedRegion = [&](Instruction *RegionStartI,
303329a3e3ddSGiorgis Georgakoudis                                    Instruction *RegionEndI) {
303429a3e3ddSGiorgis Georgakoudis       LoopInfo *LI = nullptr;
303529a3e3ddSGiorgis Georgakoudis       DominatorTree *DT = nullptr;
303629a3e3ddSGiorgis Georgakoudis       MemorySSAUpdater *MSU = nullptr;
303729a3e3ddSGiorgis Georgakoudis       using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
303829a3e3ddSGiorgis Georgakoudis 
303929a3e3ddSGiorgis Georgakoudis       BasicBlock *ParentBB = RegionStartI->getParent();
304029a3e3ddSGiorgis Georgakoudis       Function *Fn = ParentBB->getParent();
304129a3e3ddSGiorgis Georgakoudis       Module &M = *Fn->getParent();
304229a3e3ddSGiorgis Georgakoudis 
304329a3e3ddSGiorgis Georgakoudis       // Create all the blocks and logic.
304429a3e3ddSGiorgis Georgakoudis       // ParentBB:
304529a3e3ddSGiorgis Georgakoudis       //    goto RegionCheckTidBB
304629a3e3ddSGiorgis Georgakoudis       // RegionCheckTidBB:
304729a3e3ddSGiorgis Georgakoudis       //    Tid = __kmpc_hardware_thread_id()
304829a3e3ddSGiorgis Georgakoudis       //    if (Tid != 0)
304929a3e3ddSGiorgis Georgakoudis       //        goto RegionBarrierBB
305029a3e3ddSGiorgis Georgakoudis       // RegionStartBB:
305129a3e3ddSGiorgis Georgakoudis       //    <execute instructions guarded>
305229a3e3ddSGiorgis Georgakoudis       //    goto RegionEndBB
305329a3e3ddSGiorgis Georgakoudis       // RegionEndBB:
305429a3e3ddSGiorgis Georgakoudis       //    <store escaping values to shared mem>
305529a3e3ddSGiorgis Georgakoudis       //    goto RegionBarrierBB
305629a3e3ddSGiorgis Georgakoudis       //  RegionBarrierBB:
305729a3e3ddSGiorgis Georgakoudis       //    __kmpc_simple_barrier_spmd()
305829a3e3ddSGiorgis Georgakoudis       //    // second barrier is omitted if lacking escaping values.
305929a3e3ddSGiorgis Georgakoudis       //    <load escaping values from shared mem>
306029a3e3ddSGiorgis Georgakoudis       //    __kmpc_simple_barrier_spmd()
306129a3e3ddSGiorgis Georgakoudis       //    goto RegionExitBB
306229a3e3ddSGiorgis Georgakoudis       // RegionExitBB:
306329a3e3ddSGiorgis Georgakoudis       //    <execute rest of instructions>
306429a3e3ddSGiorgis Georgakoudis 
306529a3e3ddSGiorgis Georgakoudis       BasicBlock *RegionEndBB = SplitBlock(ParentBB, RegionEndI->getNextNode(),
306629a3e3ddSGiorgis Georgakoudis                                            DT, LI, MSU, "region.guarded.end");
306729a3e3ddSGiorgis Georgakoudis       BasicBlock *RegionBarrierBB =
306829a3e3ddSGiorgis Georgakoudis           SplitBlock(RegionEndBB, &*RegionEndBB->getFirstInsertionPt(), DT, LI,
306929a3e3ddSGiorgis Georgakoudis                      MSU, "region.barrier");
307029a3e3ddSGiorgis Georgakoudis       BasicBlock *RegionExitBB =
307129a3e3ddSGiorgis Georgakoudis           SplitBlock(RegionBarrierBB, &*RegionBarrierBB->getFirstInsertionPt(),
307229a3e3ddSGiorgis Georgakoudis                      DT, LI, MSU, "region.exit");
307329a3e3ddSGiorgis Georgakoudis       BasicBlock *RegionStartBB =
307429a3e3ddSGiorgis Georgakoudis           SplitBlock(ParentBB, RegionStartI, DT, LI, MSU, "region.guarded");
307529a3e3ddSGiorgis Georgakoudis 
307629a3e3ddSGiorgis Georgakoudis       assert(ParentBB->getUniqueSuccessor() == RegionStartBB &&
307729a3e3ddSGiorgis Georgakoudis              "Expected a different CFG");
307829a3e3ddSGiorgis Georgakoudis 
307929a3e3ddSGiorgis Georgakoudis       BasicBlock *RegionCheckTidBB = SplitBlock(
308029a3e3ddSGiorgis Georgakoudis           ParentBB, ParentBB->getTerminator(), DT, LI, MSU, "region.check.tid");
308129a3e3ddSGiorgis Georgakoudis 
308229a3e3ddSGiorgis Georgakoudis       // Register basic blocks with the Attributor.
308329a3e3ddSGiorgis Georgakoudis       A.registerManifestAddedBasicBlock(*RegionEndBB);
308429a3e3ddSGiorgis Georgakoudis       A.registerManifestAddedBasicBlock(*RegionBarrierBB);
308529a3e3ddSGiorgis Georgakoudis       A.registerManifestAddedBasicBlock(*RegionExitBB);
308629a3e3ddSGiorgis Georgakoudis       A.registerManifestAddedBasicBlock(*RegionStartBB);
308729a3e3ddSGiorgis Georgakoudis       A.registerManifestAddedBasicBlock(*RegionCheckTidBB);
308829a3e3ddSGiorgis Georgakoudis 
308929a3e3ddSGiorgis Georgakoudis       bool HasBroadcastValues = false;
309029a3e3ddSGiorgis Georgakoudis       // Find escaping outputs from the guarded region to outside users and
309129a3e3ddSGiorgis Georgakoudis       // broadcast their values to them.
309229a3e3ddSGiorgis Georgakoudis       for (Instruction &I : *RegionStartBB) {
309329a3e3ddSGiorgis Georgakoudis         SmallPtrSet<Instruction *, 4> OutsideUsers;
309429a3e3ddSGiorgis Georgakoudis         for (User *Usr : I.users()) {
309529a3e3ddSGiorgis Georgakoudis           Instruction &UsrI = *cast<Instruction>(Usr);
309629a3e3ddSGiorgis Georgakoudis           if (UsrI.getParent() != RegionStartBB)
309729a3e3ddSGiorgis Georgakoudis             OutsideUsers.insert(&UsrI);
309829a3e3ddSGiorgis Georgakoudis         }
309929a3e3ddSGiorgis Georgakoudis 
310029a3e3ddSGiorgis Georgakoudis         if (OutsideUsers.empty())
310129a3e3ddSGiorgis Georgakoudis           continue;
310229a3e3ddSGiorgis Georgakoudis 
310329a3e3ddSGiorgis Georgakoudis         HasBroadcastValues = true;
310429a3e3ddSGiorgis Georgakoudis 
310529a3e3ddSGiorgis Georgakoudis         // Emit a global variable in shared memory to store the broadcasted
310629a3e3ddSGiorgis Georgakoudis         // value.
310729a3e3ddSGiorgis Georgakoudis         auto *SharedMem = new GlobalVariable(
310829a3e3ddSGiorgis Georgakoudis             M, I.getType(), /* IsConstant */ false,
310929a3e3ddSGiorgis Georgakoudis             GlobalValue::InternalLinkage, UndefValue::get(I.getType()),
311029a3e3ddSGiorgis Georgakoudis             I.getName() + ".guarded.output.alloc", nullptr,
311129a3e3ddSGiorgis Georgakoudis             GlobalValue::NotThreadLocal,
311229a3e3ddSGiorgis Georgakoudis             static_cast<unsigned>(AddressSpace::Shared));
311329a3e3ddSGiorgis Georgakoudis 
311429a3e3ddSGiorgis Georgakoudis         // Emit a store instruction to update the value.
311529a3e3ddSGiorgis Georgakoudis         new StoreInst(&I, SharedMem, RegionEndBB->getTerminator());
311629a3e3ddSGiorgis Georgakoudis 
311729a3e3ddSGiorgis Georgakoudis         LoadInst *LoadI = new LoadInst(I.getType(), SharedMem,
311829a3e3ddSGiorgis Georgakoudis                                        I.getName() + ".guarded.output.load",
311929a3e3ddSGiorgis Georgakoudis                                        RegionBarrierBB->getTerminator());
312029a3e3ddSGiorgis Georgakoudis 
312129a3e3ddSGiorgis Georgakoudis         // Emit a load instruction and replace uses of the output value.
312229a3e3ddSGiorgis Georgakoudis         for (Instruction *UsrI : OutsideUsers) {
312329a3e3ddSGiorgis Georgakoudis           assert(UsrI->getParent() == RegionExitBB &&
312429a3e3ddSGiorgis Georgakoudis                  "Expected escaping users in exit region");
312529a3e3ddSGiorgis Georgakoudis           UsrI->replaceUsesOfWith(&I, LoadI);
312629a3e3ddSGiorgis Georgakoudis         }
312729a3e3ddSGiorgis Georgakoudis       }
312829a3e3ddSGiorgis Georgakoudis 
312929a3e3ddSGiorgis Georgakoudis       auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
313029a3e3ddSGiorgis Georgakoudis 
313129a3e3ddSGiorgis Georgakoudis       // Go to tid check BB in ParentBB.
313229a3e3ddSGiorgis Georgakoudis       const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc();
313329a3e3ddSGiorgis Georgakoudis       ParentBB->getTerminator()->eraseFromParent();
313429a3e3ddSGiorgis Georgakoudis       OpenMPIRBuilder::LocationDescription Loc(
313529a3e3ddSGiorgis Georgakoudis           InsertPointTy(ParentBB, ParentBB->end()), DL);
313629a3e3ddSGiorgis Georgakoudis       OMPInfoCache.OMPBuilder.updateToLocation(Loc);
313729a3e3ddSGiorgis Georgakoudis       auto *SrcLocStr = OMPInfoCache.OMPBuilder.getOrCreateSrcLocStr(Loc);
313829a3e3ddSGiorgis Georgakoudis       Value *Ident = OMPInfoCache.OMPBuilder.getOrCreateIdent(SrcLocStr);
313929a3e3ddSGiorgis Georgakoudis       BranchInst::Create(RegionCheckTidBB, ParentBB)->setDebugLoc(DL);
314029a3e3ddSGiorgis Georgakoudis 
314129a3e3ddSGiorgis Georgakoudis       // Add check for Tid in RegionCheckTidBB
314229a3e3ddSGiorgis Georgakoudis       RegionCheckTidBB->getTerminator()->eraseFromParent();
314329a3e3ddSGiorgis Georgakoudis       OpenMPIRBuilder::LocationDescription LocRegionCheckTid(
314429a3e3ddSGiorgis Georgakoudis           InsertPointTy(RegionCheckTidBB, RegionCheckTidBB->end()), DL);
314529a3e3ddSGiorgis Georgakoudis       OMPInfoCache.OMPBuilder.updateToLocation(LocRegionCheckTid);
314629a3e3ddSGiorgis Georgakoudis       FunctionCallee HardwareTidFn =
314729a3e3ddSGiorgis Georgakoudis           OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
314829a3e3ddSGiorgis Georgakoudis               M, OMPRTL___kmpc_get_hardware_thread_id_in_block);
314929a3e3ddSGiorgis Georgakoudis       Value *Tid =
315029a3e3ddSGiorgis Georgakoudis           OMPInfoCache.OMPBuilder.Builder.CreateCall(HardwareTidFn, {});
315129a3e3ddSGiorgis Georgakoudis       Value *TidCheck = OMPInfoCache.OMPBuilder.Builder.CreateIsNull(Tid);
315229a3e3ddSGiorgis Georgakoudis       OMPInfoCache.OMPBuilder.Builder
315329a3e3ddSGiorgis Georgakoudis           .CreateCondBr(TidCheck, RegionStartBB, RegionBarrierBB)
315429a3e3ddSGiorgis Georgakoudis           ->setDebugLoc(DL);
315529a3e3ddSGiorgis Georgakoudis 
315629a3e3ddSGiorgis Georgakoudis       // First barrier for synchronization, ensures main thread has updated
315729a3e3ddSGiorgis Georgakoudis       // values.
315829a3e3ddSGiorgis Georgakoudis       FunctionCallee BarrierFn =
315929a3e3ddSGiorgis Georgakoudis           OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
316029a3e3ddSGiorgis Georgakoudis               M, OMPRTL___kmpc_barrier_simple_spmd);
316129a3e3ddSGiorgis Georgakoudis       OMPInfoCache.OMPBuilder.updateToLocation(InsertPointTy(
316229a3e3ddSGiorgis Georgakoudis           RegionBarrierBB, RegionBarrierBB->getFirstInsertionPt()));
316329a3e3ddSGiorgis Georgakoudis       OMPInfoCache.OMPBuilder.Builder.CreateCall(BarrierFn, {Ident, Tid})
316429a3e3ddSGiorgis Georgakoudis           ->setDebugLoc(DL);
316529a3e3ddSGiorgis Georgakoudis 
316629a3e3ddSGiorgis Georgakoudis       // Second barrier ensures workers have read broadcast values.
316729a3e3ddSGiorgis Georgakoudis       if (HasBroadcastValues)
316829a3e3ddSGiorgis Georgakoudis         CallInst::Create(BarrierFn, {Ident, Tid}, "",
316929a3e3ddSGiorgis Georgakoudis                          RegionBarrierBB->getTerminator())
317029a3e3ddSGiorgis Georgakoudis             ->setDebugLoc(DL);
317129a3e3ddSGiorgis Georgakoudis     };
317229a3e3ddSGiorgis Georgakoudis 
317329a3e3ddSGiorgis Georgakoudis     SmallVector<std::pair<Instruction *, Instruction *>, 4> GuardedRegions;
317429a3e3ddSGiorgis Georgakoudis 
317529a3e3ddSGiorgis Georgakoudis     for (Instruction *GuardedI : SPMDCompatibilityTracker) {
317629a3e3ddSGiorgis Georgakoudis       BasicBlock *BB = GuardedI->getParent();
317729a3e3ddSGiorgis Georgakoudis       auto *CalleeAA = A.lookupAAFor<AAKernelInfo>(
317829a3e3ddSGiorgis Georgakoudis           IRPosition::function(*GuardedI->getFunction()), nullptr,
317929a3e3ddSGiorgis Georgakoudis           DepClassTy::NONE);
318029a3e3ddSGiorgis Georgakoudis       assert(CalleeAA != nullptr && "Expected Callee AAKernelInfo");
318129a3e3ddSGiorgis Georgakoudis       auto &CalleeAAFunction = *cast<AAKernelInfoFunction>(CalleeAA);
318229a3e3ddSGiorgis Georgakoudis       // Continue if instruction is already guarded.
318329a3e3ddSGiorgis Georgakoudis       if (CalleeAAFunction.getGuardedInstructions().contains(GuardedI))
318429a3e3ddSGiorgis Georgakoudis         continue;
318529a3e3ddSGiorgis Georgakoudis 
318629a3e3ddSGiorgis Georgakoudis       Instruction *GuardedRegionStart = nullptr, *GuardedRegionEnd = nullptr;
318729a3e3ddSGiorgis Georgakoudis       for (Instruction &I : *BB) {
318829a3e3ddSGiorgis Georgakoudis         // If instruction I needs to be guarded update the guarded region
318929a3e3ddSGiorgis Georgakoudis         // bounds.
319029a3e3ddSGiorgis Georgakoudis         if (SPMDCompatibilityTracker.contains(&I)) {
319129a3e3ddSGiorgis Georgakoudis           CalleeAAFunction.getGuardedInstructions().insert(&I);
319229a3e3ddSGiorgis Georgakoudis           if (GuardedRegionStart)
319329a3e3ddSGiorgis Georgakoudis             GuardedRegionEnd = &I;
319429a3e3ddSGiorgis Georgakoudis           else
319529a3e3ddSGiorgis Georgakoudis             GuardedRegionStart = GuardedRegionEnd = &I;
319629a3e3ddSGiorgis Georgakoudis 
319729a3e3ddSGiorgis Georgakoudis           continue;
319829a3e3ddSGiorgis Georgakoudis         }
319929a3e3ddSGiorgis Georgakoudis 
320029a3e3ddSGiorgis Georgakoudis         // Instruction I does not need guarding, store
320129a3e3ddSGiorgis Georgakoudis         // any region found and reset bounds.
320229a3e3ddSGiorgis Georgakoudis         if (GuardedRegionStart) {
320329a3e3ddSGiorgis Georgakoudis           GuardedRegions.push_back(
320429a3e3ddSGiorgis Georgakoudis               std::make_pair(GuardedRegionStart, GuardedRegionEnd));
320529a3e3ddSGiorgis Georgakoudis           GuardedRegionStart = nullptr;
320629a3e3ddSGiorgis Georgakoudis           GuardedRegionEnd = nullptr;
320729a3e3ddSGiorgis Georgakoudis         }
320829a3e3ddSGiorgis Georgakoudis       }
320929a3e3ddSGiorgis Georgakoudis     }
321029a3e3ddSGiorgis Georgakoudis 
321129a3e3ddSGiorgis Georgakoudis     for (auto &GR : GuardedRegions)
321229a3e3ddSGiorgis Georgakoudis       CreateGuardedRegion(GR.first, GR.second);
321329a3e3ddSGiorgis Georgakoudis 
3214514c033dSJohannes Doerfert     // Adjust the global exec mode flag that tells the runtime what mode this
3215514c033dSJohannes Doerfert     // kernel is executed in.
3216514c033dSJohannes Doerfert     Function *Kernel = getAnchorScope();
3217514c033dSJohannes Doerfert     GlobalVariable *ExecMode = Kernel->getParent()->getGlobalVariable(
3218514c033dSJohannes Doerfert         (Kernel->getName() + "_exec_mode").str());
3219514c033dSJohannes Doerfert     assert(ExecMode && "Kernel without exec mode?");
3220514c033dSJohannes Doerfert     assert(ExecMode->getInitializer() &&
3221514c033dSJohannes Doerfert            ExecMode->getInitializer()->isOneValue() &&
3222514c033dSJohannes Doerfert            "Initially non-SPMD kernel has SPMD exec mode!");
32237d576392SJoseph Huber 
32247d576392SJoseph Huber     // Set the global exec mode flag to indicate SPMD-Generic mode.
32257d576392SJoseph Huber     constexpr int SPMDGeneric = 2;
32267d576392SJoseph Huber     if (!ExecMode->getInitializer()->isZeroValue())
3227514c033dSJohannes Doerfert       ExecMode->setInitializer(
32287d576392SJoseph Huber           ConstantInt::get(ExecMode->getInitializer()->getType(), SPMDGeneric));
3229514c033dSJohannes Doerfert 
3230514c033dSJohannes Doerfert     // Next rewrite the init and deinit calls to indicate we use SPMD-mode now.
3231514c033dSJohannes Doerfert     const int InitIsSPMDArgNo = 1;
3232514c033dSJohannes Doerfert     const int DeinitIsSPMDArgNo = 1;
3233514c033dSJohannes Doerfert     const int InitUseStateMachineArgNo = 2;
3234e8439ec8SGiorgis Georgakoudis     const int InitRequiresFullRuntimeArgNo = 3;
3235e8439ec8SGiorgis Georgakoudis     const int DeinitRequiresFullRuntimeArgNo = 2;
3236514c033dSJohannes Doerfert 
3237514c033dSJohannes Doerfert     auto &Ctx = getAnchorValue().getContext();
3238514c033dSJohannes Doerfert     A.changeUseAfterManifest(KernelInitCB->getArgOperandUse(InitIsSPMDArgNo),
3239514c033dSJohannes Doerfert                              *ConstantInt::getBool(Ctx, 1));
3240514c033dSJohannes Doerfert     A.changeUseAfterManifest(
3241514c033dSJohannes Doerfert         KernelInitCB->getArgOperandUse(InitUseStateMachineArgNo),
3242514c033dSJohannes Doerfert         *ConstantInt::getBool(Ctx, 0));
3243514c033dSJohannes Doerfert     A.changeUseAfterManifest(
3244514c033dSJohannes Doerfert         KernelDeinitCB->getArgOperandUse(DeinitIsSPMDArgNo),
3245514c033dSJohannes Doerfert         *ConstantInt::getBool(Ctx, 1));
3246e8439ec8SGiorgis Georgakoudis     A.changeUseAfterManifest(
3247e8439ec8SGiorgis Georgakoudis         KernelInitCB->getArgOperandUse(InitRequiresFullRuntimeArgNo),
3248e8439ec8SGiorgis Georgakoudis         *ConstantInt::getBool(Ctx, 0));
3249e8439ec8SGiorgis Georgakoudis     A.changeUseAfterManifest(
3250e8439ec8SGiorgis Georgakoudis         KernelDeinitCB->getArgOperandUse(DeinitRequiresFullRuntimeArgNo),
3251e8439ec8SGiorgis Georgakoudis         *ConstantInt::getBool(Ctx, 0));
3252e8439ec8SGiorgis Georgakoudis 
3253514c033dSJohannes Doerfert     ++NumOpenMPTargetRegionKernelsSPMD;
3254514c033dSJohannes Doerfert 
3255514c033dSJohannes Doerfert     auto Remark = [&](OptimizationRemark OR) {
3256eef6601bSJoseph Huber       return OR << "Transformed generic-mode kernel to SPMD-mode.";
3257514c033dSJohannes Doerfert     };
32582c31d5ebSJoseph Huber     A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP120", Remark);
3259514c033dSJohannes Doerfert     return true;
3260514c033dSJohannes Doerfert   };
3261514c033dSJohannes Doerfert 
3262d9659bf6SJohannes Doerfert   ChangeStatus buildCustomStateMachine(Attributor &A) {
3263cd0dd8ecSJoseph Huber     // If we have disabled state machine rewrites, don't make a custom one
3264cd0dd8ecSJoseph Huber     if (DisableOpenMPOptStateMachineRewrite)
3265cd0dd8ecSJoseph Huber       return indicatePessimisticFixpoint();
3266cd0dd8ecSJoseph Huber 
3267d9659bf6SJohannes Doerfert     assert(ReachedKnownParallelRegions.isValidState() &&
3268d9659bf6SJohannes Doerfert            "Custom state machine with invalid parallel region states?");
3269d9659bf6SJohannes Doerfert 
3270d9659bf6SJohannes Doerfert     const int InitIsSPMDArgNo = 1;
3271d9659bf6SJohannes Doerfert     const int InitUseStateMachineArgNo = 2;
3272d9659bf6SJohannes Doerfert 
3273d9659bf6SJohannes Doerfert     // Check if the current configuration is non-SPMD and generic state machine.
3274d9659bf6SJohannes Doerfert     // If we already have SPMD mode or a custom state machine we do not need to
3275d9659bf6SJohannes Doerfert     // go any further. If it is anything but a constant something is weird and
3276d9659bf6SJohannes Doerfert     // we give up.
3277d9659bf6SJohannes Doerfert     ConstantInt *UseStateMachine = dyn_cast<ConstantInt>(
3278d9659bf6SJohannes Doerfert         KernelInitCB->getArgOperand(InitUseStateMachineArgNo));
3279d9659bf6SJohannes Doerfert     ConstantInt *IsSPMD =
3280d9659bf6SJohannes Doerfert         dyn_cast<ConstantInt>(KernelInitCB->getArgOperand(InitIsSPMDArgNo));
3281d9659bf6SJohannes Doerfert 
3282d9659bf6SJohannes Doerfert     // If we are stuck with generic mode, try to create a custom device (=GPU)
3283d9659bf6SJohannes Doerfert     // state machine which is specialized for the parallel regions that are
3284d9659bf6SJohannes Doerfert     // reachable by the kernel.
3285d9659bf6SJohannes Doerfert     if (!UseStateMachine || UseStateMachine->isZero() || !IsSPMD ||
3286d9659bf6SJohannes Doerfert         !IsSPMD->isZero())
3287d9659bf6SJohannes Doerfert       return ChangeStatus::UNCHANGED;
3288d9659bf6SJohannes Doerfert 
3289514c033dSJohannes Doerfert     // If not SPMD mode, indicate we use a custom state machine now.
3290d9659bf6SJohannes Doerfert     auto &Ctx = getAnchorValue().getContext();
3291d9659bf6SJohannes Doerfert     auto *FalseVal = ConstantInt::getBool(Ctx, 0);
3292d9659bf6SJohannes Doerfert     A.changeUseAfterManifest(
3293d9659bf6SJohannes Doerfert         KernelInitCB->getArgOperandUse(InitUseStateMachineArgNo), *FalseVal);
3294d9659bf6SJohannes Doerfert 
3295d9659bf6SJohannes Doerfert     // If we don't actually need a state machine we are done here. This can
3296d9659bf6SJohannes Doerfert     // happen if there simply are no parallel regions. In the resulting kernel
3297d9659bf6SJohannes Doerfert     // all worker threads will simply exit right away, leaving the main thread
3298d9659bf6SJohannes Doerfert     // to do the work alone.
3299d9659bf6SJohannes Doerfert     if (ReachedKnownParallelRegions.empty() &&
3300d9659bf6SJohannes Doerfert         ReachedUnknownParallelRegions.empty()) {
3301d9659bf6SJohannes Doerfert       ++NumOpenMPTargetRegionKernelsWithoutStateMachine;
3302d9659bf6SJohannes Doerfert 
3303d9659bf6SJohannes Doerfert       auto Remark = [&](OptimizationRemark OR) {
3304eef6601bSJoseph Huber         return OR << "Removing unused state machine from generic-mode kernel.";
3305d9659bf6SJohannes Doerfert       };
33062c31d5ebSJoseph Huber       A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP130", Remark);
3307d9659bf6SJohannes Doerfert 
3308d9659bf6SJohannes Doerfert       return ChangeStatus::CHANGED;
3309d9659bf6SJohannes Doerfert     }
3310d9659bf6SJohannes Doerfert 
3311d9659bf6SJohannes Doerfert     // Keep track in the statistics of our new shiny custom state machine.
3312d9659bf6SJohannes Doerfert     if (ReachedUnknownParallelRegions.empty()) {
3313d9659bf6SJohannes Doerfert       ++NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback;
3314d9659bf6SJohannes Doerfert 
3315d9659bf6SJohannes Doerfert       auto Remark = [&](OptimizationRemark OR) {
3316eef6601bSJoseph Huber         return OR << "Rewriting generic-mode kernel with a customized state "
3317eef6601bSJoseph Huber                      "machine.";
3318d9659bf6SJohannes Doerfert       };
33192c31d5ebSJoseph Huber       A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP131", Remark);
3320d9659bf6SJohannes Doerfert     } else {
3321d9659bf6SJohannes Doerfert       ++NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback;
3322d9659bf6SJohannes Doerfert 
3323eef6601bSJoseph Huber       auto Remark = [&](OptimizationRemarkAnalysis OR) {
3324d9659bf6SJohannes Doerfert         return OR << "Generic-mode kernel is executed with a customized state "
3325eef6601bSJoseph Huber                      "machine that requires a fallback.";
3326d9659bf6SJohannes Doerfert       };
33272c31d5ebSJoseph Huber       A.emitRemark<OptimizationRemarkAnalysis>(KernelInitCB, "OMP132", Remark);
3328d9659bf6SJohannes Doerfert 
3329d9659bf6SJohannes Doerfert       // Tell the user why we ended up with a fallback.
3330d9659bf6SJohannes Doerfert       for (CallBase *UnknownParallelRegionCB : ReachedUnknownParallelRegions) {
3331d9659bf6SJohannes Doerfert         if (!UnknownParallelRegionCB)
3332d9659bf6SJohannes Doerfert           continue;
3333d9659bf6SJohannes Doerfert         auto Remark = [&](OptimizationRemarkAnalysis ORA) {
3334eef6601bSJoseph Huber           return ORA << "Call may contain unknown parallel regions. Use "
3335eef6601bSJoseph Huber                      << "`__attribute__((assume(\"omp_no_parallelism\")))` to "
3336eef6601bSJoseph Huber                         "override.";
3337d9659bf6SJohannes Doerfert         };
33382c31d5ebSJoseph Huber         A.emitRemark<OptimizationRemarkAnalysis>(UnknownParallelRegionCB,
33392c31d5ebSJoseph Huber                                                  "OMP133", Remark);
3340d9659bf6SJohannes Doerfert       }
3341d9659bf6SJohannes Doerfert     }
3342d9659bf6SJohannes Doerfert 
3343d9659bf6SJohannes Doerfert     // Create all the blocks:
3344d9659bf6SJohannes Doerfert     //
3345d9659bf6SJohannes Doerfert     //                       InitCB = __kmpc_target_init(...)
3346d9659bf6SJohannes Doerfert     //                       bool IsWorker = InitCB >= 0;
3347d9659bf6SJohannes Doerfert     //                       if (IsWorker) {
3348d9659bf6SJohannes Doerfert     // SMBeginBB:               __kmpc_barrier_simple_spmd(...);
3349d9659bf6SJohannes Doerfert     //                         void *WorkFn;
3350d9659bf6SJohannes Doerfert     //                         bool Active = __kmpc_kernel_parallel(&WorkFn);
3351d9659bf6SJohannes Doerfert     //                         if (!WorkFn) return;
3352d9659bf6SJohannes Doerfert     // SMIsActiveCheckBB:       if (Active) {
3353d9659bf6SJohannes Doerfert     // SMIfCascadeCurrentBB:      if      (WorkFn == <ParFn0>)
3354d9659bf6SJohannes Doerfert     //                              ParFn0(...);
3355d9659bf6SJohannes Doerfert     // SMIfCascadeCurrentBB:      else if (WorkFn == <ParFn1>)
3356d9659bf6SJohannes Doerfert     //                              ParFn1(...);
3357d9659bf6SJohannes Doerfert     //                            ...
3358d9659bf6SJohannes Doerfert     // SMIfCascadeCurrentBB:      else
3359d9659bf6SJohannes Doerfert     //                              ((WorkFnTy*)WorkFn)(...);
3360d9659bf6SJohannes Doerfert     // SMEndParallelBB:           __kmpc_kernel_end_parallel(...);
3361d9659bf6SJohannes Doerfert     //                          }
3362d9659bf6SJohannes Doerfert     // SMDoneBB:                __kmpc_barrier_simple_spmd(...);
3363d9659bf6SJohannes Doerfert     //                          goto SMBeginBB;
3364d9659bf6SJohannes Doerfert     //                       }
3365d9659bf6SJohannes Doerfert     // UserCodeEntryBB:      // user code
3366d9659bf6SJohannes Doerfert     //                       __kmpc_target_deinit(...)
3367d9659bf6SJohannes Doerfert     //
3368d9659bf6SJohannes Doerfert     Function *Kernel = getAssociatedFunction();
3369d9659bf6SJohannes Doerfert     assert(Kernel && "Expected an associated function!");
3370d9659bf6SJohannes Doerfert 
3371d9659bf6SJohannes Doerfert     BasicBlock *InitBB = KernelInitCB->getParent();
3372d9659bf6SJohannes Doerfert     BasicBlock *UserCodeEntryBB = InitBB->splitBasicBlock(
3373d9659bf6SJohannes Doerfert         KernelInitCB->getNextNode(), "thread.user_code.check");
3374d9659bf6SJohannes Doerfert     BasicBlock *StateMachineBeginBB = BasicBlock::Create(
3375d9659bf6SJohannes Doerfert         Ctx, "worker_state_machine.begin", Kernel, UserCodeEntryBB);
3376d9659bf6SJohannes Doerfert     BasicBlock *StateMachineFinishedBB = BasicBlock::Create(
3377d9659bf6SJohannes Doerfert         Ctx, "worker_state_machine.finished", Kernel, UserCodeEntryBB);
3378d9659bf6SJohannes Doerfert     BasicBlock *StateMachineIsActiveCheckBB = BasicBlock::Create(
3379d9659bf6SJohannes Doerfert         Ctx, "worker_state_machine.is_active.check", Kernel, UserCodeEntryBB);
3380d9659bf6SJohannes Doerfert     BasicBlock *StateMachineIfCascadeCurrentBB =
3381d9659bf6SJohannes Doerfert         BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.check",
3382d9659bf6SJohannes Doerfert                            Kernel, UserCodeEntryBB);
3383d9659bf6SJohannes Doerfert     BasicBlock *StateMachineEndParallelBB =
3384d9659bf6SJohannes Doerfert         BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.end",
3385d9659bf6SJohannes Doerfert                            Kernel, UserCodeEntryBB);
3386d9659bf6SJohannes Doerfert     BasicBlock *StateMachineDoneBarrierBB = BasicBlock::Create(
3387d9659bf6SJohannes Doerfert         Ctx, "worker_state_machine.done.barrier", Kernel, UserCodeEntryBB);
33883f71b425SGiorgis Georgakoudis     A.registerManifestAddedBasicBlock(*InitBB);
33893f71b425SGiorgis Georgakoudis     A.registerManifestAddedBasicBlock(*UserCodeEntryBB);
33903f71b425SGiorgis Georgakoudis     A.registerManifestAddedBasicBlock(*StateMachineBeginBB);
33913f71b425SGiorgis Georgakoudis     A.registerManifestAddedBasicBlock(*StateMachineFinishedBB);
33923f71b425SGiorgis Georgakoudis     A.registerManifestAddedBasicBlock(*StateMachineIsActiveCheckBB);
33933f71b425SGiorgis Georgakoudis     A.registerManifestAddedBasicBlock(*StateMachineIfCascadeCurrentBB);
33943f71b425SGiorgis Georgakoudis     A.registerManifestAddedBasicBlock(*StateMachineEndParallelBB);
33953f71b425SGiorgis Georgakoudis     A.registerManifestAddedBasicBlock(*StateMachineDoneBarrierBB);
3396d9659bf6SJohannes Doerfert 
3397d9659bf6SJohannes Doerfert     const DebugLoc &DLoc = KernelInitCB->getDebugLoc();
3398d9659bf6SJohannes Doerfert     ReturnInst::Create(Ctx, StateMachineFinishedBB)->setDebugLoc(DLoc);
3399d9659bf6SJohannes Doerfert 
3400d9659bf6SJohannes Doerfert     InitBB->getTerminator()->eraseFromParent();
3401d9659bf6SJohannes Doerfert     Instruction *IsWorker =
3402d9659bf6SJohannes Doerfert         ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_NE, KernelInitCB,
3403d9659bf6SJohannes Doerfert                          ConstantInt::get(KernelInitCB->getType(), -1),
3404d9659bf6SJohannes Doerfert                          "thread.is_worker", InitBB);
3405d9659bf6SJohannes Doerfert     IsWorker->setDebugLoc(DLoc);
3406d9659bf6SJohannes Doerfert     BranchInst::Create(StateMachineBeginBB, UserCodeEntryBB, IsWorker, InitBB);
3407d9659bf6SJohannes Doerfert 
3408d9659bf6SJohannes Doerfert     // Create local storage for the work function pointer.
3409d9659bf6SJohannes Doerfert     Type *VoidPtrTy = Type::getInt8PtrTy(Ctx);
3410d9659bf6SJohannes Doerfert     AllocaInst *WorkFnAI = new AllocaInst(VoidPtrTy, 0, "worker.work_fn.addr",
3411d9659bf6SJohannes Doerfert                                           &Kernel->getEntryBlock().front());
3412d9659bf6SJohannes Doerfert     WorkFnAI->setDebugLoc(DLoc);
3413d9659bf6SJohannes Doerfert 
3414d9659bf6SJohannes Doerfert     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3415d9659bf6SJohannes Doerfert     OMPInfoCache.OMPBuilder.updateToLocation(
3416d9659bf6SJohannes Doerfert         OpenMPIRBuilder::LocationDescription(
3417d9659bf6SJohannes Doerfert             IRBuilder<>::InsertPoint(StateMachineBeginBB,
3418d9659bf6SJohannes Doerfert                                      StateMachineBeginBB->end()),
3419d9659bf6SJohannes Doerfert             DLoc));
3420d9659bf6SJohannes Doerfert 
3421d9659bf6SJohannes Doerfert     Value *Ident = KernelInitCB->getArgOperand(0);
3422d9659bf6SJohannes Doerfert     Value *GTid = KernelInitCB;
3423d9659bf6SJohannes Doerfert 
3424d9659bf6SJohannes Doerfert     Module &M = *Kernel->getParent();
3425d9659bf6SJohannes Doerfert     FunctionCallee BarrierFn =
3426d9659bf6SJohannes Doerfert         OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3427d9659bf6SJohannes Doerfert             M, OMPRTL___kmpc_barrier_simple_spmd);
3428d9659bf6SJohannes Doerfert     CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineBeginBB)
3429d9659bf6SJohannes Doerfert         ->setDebugLoc(DLoc);
3430d9659bf6SJohannes Doerfert 
3431d9659bf6SJohannes Doerfert     FunctionCallee KernelParallelFn =
3432d9659bf6SJohannes Doerfert         OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3433d9659bf6SJohannes Doerfert             M, OMPRTL___kmpc_kernel_parallel);
3434d9659bf6SJohannes Doerfert     Instruction *IsActiveWorker = CallInst::Create(
3435d9659bf6SJohannes Doerfert         KernelParallelFn, {WorkFnAI}, "worker.is_active", StateMachineBeginBB);
3436d9659bf6SJohannes Doerfert     IsActiveWorker->setDebugLoc(DLoc);
3437d9659bf6SJohannes Doerfert     Instruction *WorkFn = new LoadInst(VoidPtrTy, WorkFnAI, "worker.work_fn",
3438d9659bf6SJohannes Doerfert                                        StateMachineBeginBB);
3439d9659bf6SJohannes Doerfert     WorkFn->setDebugLoc(DLoc);
3440d9659bf6SJohannes Doerfert 
3441d9659bf6SJohannes Doerfert     FunctionType *ParallelRegionFnTy = FunctionType::get(
3442d9659bf6SJohannes Doerfert         Type::getVoidTy(Ctx), {Type::getInt16Ty(Ctx), Type::getInt32Ty(Ctx)},
3443d9659bf6SJohannes Doerfert         false);
3444d9659bf6SJohannes Doerfert     Value *WorkFnCast = BitCastInst::CreatePointerBitCastOrAddrSpaceCast(
3445d9659bf6SJohannes Doerfert         WorkFn, ParallelRegionFnTy->getPointerTo(), "worker.work_fn.addr_cast",
3446d9659bf6SJohannes Doerfert         StateMachineBeginBB);
3447d9659bf6SJohannes Doerfert 
3448d9659bf6SJohannes Doerfert     Instruction *IsDone =
3449d9659bf6SJohannes Doerfert         ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFn,
3450d9659bf6SJohannes Doerfert                          Constant::getNullValue(VoidPtrTy), "worker.is_done",
3451d9659bf6SJohannes Doerfert                          StateMachineBeginBB);
3452d9659bf6SJohannes Doerfert     IsDone->setDebugLoc(DLoc);
3453d9659bf6SJohannes Doerfert     BranchInst::Create(StateMachineFinishedBB, StateMachineIsActiveCheckBB,
3454d9659bf6SJohannes Doerfert                        IsDone, StateMachineBeginBB)
3455d9659bf6SJohannes Doerfert         ->setDebugLoc(DLoc);
3456d9659bf6SJohannes Doerfert 
3457d9659bf6SJohannes Doerfert     BranchInst::Create(StateMachineIfCascadeCurrentBB,
3458d9659bf6SJohannes Doerfert                        StateMachineDoneBarrierBB, IsActiveWorker,
3459d9659bf6SJohannes Doerfert                        StateMachineIsActiveCheckBB)
3460d9659bf6SJohannes Doerfert         ->setDebugLoc(DLoc);
3461d9659bf6SJohannes Doerfert 
3462d9659bf6SJohannes Doerfert     Value *ZeroArg =
3463d9659bf6SJohannes Doerfert         Constant::getNullValue(ParallelRegionFnTy->getParamType(0));
3464d9659bf6SJohannes Doerfert 
3465d9659bf6SJohannes Doerfert     // Now that we have most of the CFG skeleton it is time for the if-cascade
3466d9659bf6SJohannes Doerfert     // that checks the function pointer we got from the runtime against the
3467d9659bf6SJohannes Doerfert     // parallel regions we expect, if there are any.
3468d9659bf6SJohannes Doerfert     for (int i = 0, e = ReachedKnownParallelRegions.size(); i < e; ++i) {
3469d9659bf6SJohannes Doerfert       auto *ParallelRegion = ReachedKnownParallelRegions[i];
3470d9659bf6SJohannes Doerfert       BasicBlock *PRExecuteBB = BasicBlock::Create(
3471d9659bf6SJohannes Doerfert           Ctx, "worker_state_machine.parallel_region.execute", Kernel,
3472d9659bf6SJohannes Doerfert           StateMachineEndParallelBB);
3473d9659bf6SJohannes Doerfert       CallInst::Create(ParallelRegion, {ZeroArg, GTid}, "", PRExecuteBB)
3474d9659bf6SJohannes Doerfert           ->setDebugLoc(DLoc);
3475d9659bf6SJohannes Doerfert       BranchInst::Create(StateMachineEndParallelBB, PRExecuteBB)
3476d9659bf6SJohannes Doerfert           ->setDebugLoc(DLoc);
3477d9659bf6SJohannes Doerfert 
3478d9659bf6SJohannes Doerfert       BasicBlock *PRNextBB =
3479d9659bf6SJohannes Doerfert           BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.check",
3480d9659bf6SJohannes Doerfert                              Kernel, StateMachineEndParallelBB);
3481d9659bf6SJohannes Doerfert 
3482d9659bf6SJohannes Doerfert       // Check if we need to compare the pointer at all or if we can just
3483d9659bf6SJohannes Doerfert       // call the parallel region function.
3484d9659bf6SJohannes Doerfert       Value *IsPR;
3485d9659bf6SJohannes Doerfert       if (i + 1 < e || !ReachedUnknownParallelRegions.empty()) {
3486d9659bf6SJohannes Doerfert         Instruction *CmpI = ICmpInst::Create(
3487d9659bf6SJohannes Doerfert             ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFnCast, ParallelRegion,
3488d9659bf6SJohannes Doerfert             "worker.check_parallel_region", StateMachineIfCascadeCurrentBB);
3489d9659bf6SJohannes Doerfert         CmpI->setDebugLoc(DLoc);
3490d9659bf6SJohannes Doerfert         IsPR = CmpI;
3491d9659bf6SJohannes Doerfert       } else {
3492d9659bf6SJohannes Doerfert         IsPR = ConstantInt::getTrue(Ctx);
3493d9659bf6SJohannes Doerfert       }
3494d9659bf6SJohannes Doerfert 
3495d9659bf6SJohannes Doerfert       BranchInst::Create(PRExecuteBB, PRNextBB, IsPR,
3496d9659bf6SJohannes Doerfert                          StateMachineIfCascadeCurrentBB)
3497d9659bf6SJohannes Doerfert           ->setDebugLoc(DLoc);
3498d9659bf6SJohannes Doerfert       StateMachineIfCascadeCurrentBB = PRNextBB;
3499d9659bf6SJohannes Doerfert     }
3500d9659bf6SJohannes Doerfert 
3501d9659bf6SJohannes Doerfert     // At the end of the if-cascade we place the indirect function pointer call
3502d9659bf6SJohannes Doerfert     // in case we might need it, that is if there can be parallel regions we
3503d9659bf6SJohannes Doerfert     // have not handled in the if-cascade above.
3504d9659bf6SJohannes Doerfert     if (!ReachedUnknownParallelRegions.empty()) {
3505d9659bf6SJohannes Doerfert       StateMachineIfCascadeCurrentBB->setName(
3506d9659bf6SJohannes Doerfert           "worker_state_machine.parallel_region.fallback.execute");
3507d9659bf6SJohannes Doerfert       CallInst::Create(ParallelRegionFnTy, WorkFnCast, {ZeroArg, GTid}, "",
3508d9659bf6SJohannes Doerfert                        StateMachineIfCascadeCurrentBB)
3509d9659bf6SJohannes Doerfert           ->setDebugLoc(DLoc);
3510d9659bf6SJohannes Doerfert     }
3511d9659bf6SJohannes Doerfert     BranchInst::Create(StateMachineEndParallelBB,
3512d9659bf6SJohannes Doerfert                        StateMachineIfCascadeCurrentBB)
3513d9659bf6SJohannes Doerfert         ->setDebugLoc(DLoc);
3514d9659bf6SJohannes Doerfert 
3515d9659bf6SJohannes Doerfert     CallInst::Create(OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3516d9659bf6SJohannes Doerfert                          M, OMPRTL___kmpc_kernel_end_parallel),
3517d9659bf6SJohannes Doerfert                      {}, "", StateMachineEndParallelBB)
3518d9659bf6SJohannes Doerfert         ->setDebugLoc(DLoc);
3519d9659bf6SJohannes Doerfert     BranchInst::Create(StateMachineDoneBarrierBB, StateMachineEndParallelBB)
3520d9659bf6SJohannes Doerfert         ->setDebugLoc(DLoc);
3521d9659bf6SJohannes Doerfert 
3522d9659bf6SJohannes Doerfert     CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineDoneBarrierBB)
3523d9659bf6SJohannes Doerfert         ->setDebugLoc(DLoc);
3524d9659bf6SJohannes Doerfert     BranchInst::Create(StateMachineBeginBB, StateMachineDoneBarrierBB)
3525d9659bf6SJohannes Doerfert         ->setDebugLoc(DLoc);
3526d9659bf6SJohannes Doerfert 
3527d9659bf6SJohannes Doerfert     return ChangeStatus::CHANGED;
3528d9659bf6SJohannes Doerfert   }
3529d9659bf6SJohannes Doerfert 
3530d9659bf6SJohannes Doerfert   /// Fixpoint iteration update function. Will be called every time a dependence
3531d9659bf6SJohannes Doerfert   /// changed its state (and in the beginning).
3532d9659bf6SJohannes Doerfert   ChangeStatus updateImpl(Attributor &A) override {
3533d9659bf6SJohannes Doerfert     KernelInfoState StateBefore = getState();
3534d9659bf6SJohannes Doerfert 
3535514c033dSJohannes Doerfert     // Callback to check a read/write instruction.
3536514c033dSJohannes Doerfert     auto CheckRWInst = [&](Instruction &I) {
3537514c033dSJohannes Doerfert       // We handle calls later.
3538514c033dSJohannes Doerfert       if (isa<CallBase>(I))
3539514c033dSJohannes Doerfert         return true;
3540514c033dSJohannes Doerfert       // We only care about write effects.
3541514c033dSJohannes Doerfert       if (!I.mayWriteToMemory())
3542514c033dSJohannes Doerfert         return true;
3543514c033dSJohannes Doerfert       if (auto *SI = dyn_cast<StoreInst>(&I)) {
3544514c033dSJohannes Doerfert         SmallVector<const Value *> Objects;
3545514c033dSJohannes Doerfert         getUnderlyingObjects(SI->getPointerOperand(), Objects);
3546514c033dSJohannes Doerfert         if (llvm::all_of(Objects,
3547514c033dSJohannes Doerfert                          [](const Value *Obj) { return isa<AllocaInst>(Obj); }))
3548514c033dSJohannes Doerfert           return true;
354929a3e3ddSGiorgis Georgakoudis         // Check for AAHeapToStack moved objects which must not be guarded.
355029a3e3ddSGiorgis Georgakoudis         auto &HS = A.getAAFor<AAHeapToStack>(
355129a3e3ddSGiorgis Georgakoudis             *this, IRPosition::function(*I.getFunction()),
355229a3e3ddSGiorgis Georgakoudis             DepClassTy::REQUIRED);
355329a3e3ddSGiorgis Georgakoudis         if (llvm::all_of(Objects, [&HS](const Value *Obj) {
355429a3e3ddSGiorgis Georgakoudis               auto *CB = dyn_cast<CallBase>(Obj);
355529a3e3ddSGiorgis Georgakoudis               if (!CB)
355629a3e3ddSGiorgis Georgakoudis                 return false;
355729a3e3ddSGiorgis Georgakoudis               return HS.isAssumedHeapToStack(*CB);
355829a3e3ddSGiorgis Georgakoudis             })) {
355929a3e3ddSGiorgis Georgakoudis           return true;
3560514c033dSJohannes Doerfert         }
356129a3e3ddSGiorgis Georgakoudis       }
356229a3e3ddSGiorgis Georgakoudis 
356329a3e3ddSGiorgis Georgakoudis       // Insert instruction that needs guarding.
3564514c033dSJohannes Doerfert       SPMDCompatibilityTracker.insert(&I);
3565514c033dSJohannes Doerfert       return true;
3566514c033dSJohannes Doerfert     };
3567792aac98SJohannes Doerfert 
3568792aac98SJohannes Doerfert     bool UsedAssumedInformationInCheckRWInst = false;
356997387fdfSJohannes Doerfert     if (!SPMDCompatibilityTracker.isAtFixpoint())
3570792aac98SJohannes Doerfert       if (!A.checkForAllReadWriteInstructions(
3571792aac98SJohannes Doerfert               CheckRWInst, *this, UsedAssumedInformationInCheckRWInst))
3572514c033dSJohannes Doerfert         SPMDCompatibilityTracker.indicatePessimisticFixpoint();
3573514c033dSJohannes Doerfert 
3574e97e0a4fSShilei Tian     if (!IsKernelEntry) {
3575ca662297SShilei Tian       updateReachingKernelEntries(A);
3576e97e0a4fSShilei Tian       updateParallelLevels(A);
357729a3e3ddSGiorgis Georgakoudis 
357829a3e3ddSGiorgis Georgakoudis       if (!ParallelLevels.isValidState())
357929a3e3ddSGiorgis Georgakoudis         SPMDCompatibilityTracker.indicatePessimisticFixpoint();
3580e97e0a4fSShilei Tian     }
3581ca662297SShilei Tian 
3582d9659bf6SJohannes Doerfert     // Callback to check a call instruction.
358397387fdfSJohannes Doerfert     bool AllSPMDStatesWereFixed = true;
3584d9659bf6SJohannes Doerfert     auto CheckCallInst = [&](Instruction &I) {
3585d9659bf6SJohannes Doerfert       auto &CB = cast<CallBase>(I);
3586d9659bf6SJohannes Doerfert       auto &CBAA = A.getAAFor<AAKernelInfo>(
3587d9659bf6SJohannes Doerfert           *this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL);
3588d9659bf6SJohannes Doerfert       getState() ^= CBAA.getState();
358997387fdfSJohannes Doerfert       AllSPMDStatesWereFixed &= CBAA.SPMDCompatibilityTracker.isAtFixpoint();
3590d9659bf6SJohannes Doerfert       return true;
3591d9659bf6SJohannes Doerfert     };
3592d9659bf6SJohannes Doerfert 
3593792aac98SJohannes Doerfert     bool UsedAssumedInformationInCheckCallInst = false;
3594792aac98SJohannes Doerfert     if (!A.checkForAllCallLikeInstructions(
3595792aac98SJohannes Doerfert             CheckCallInst, *this, UsedAssumedInformationInCheckCallInst))
3596d9659bf6SJohannes Doerfert       return indicatePessimisticFixpoint();
3597d9659bf6SJohannes Doerfert 
359897387fdfSJohannes Doerfert     // If we haven't used any assumed information for the SPMD state we can fix
359997387fdfSJohannes Doerfert     // it.
360097387fdfSJohannes Doerfert     if (!UsedAssumedInformationInCheckRWInst &&
360197387fdfSJohannes Doerfert         !UsedAssumedInformationInCheckCallInst && AllSPMDStatesWereFixed)
360297387fdfSJohannes Doerfert       SPMDCompatibilityTracker.indicateOptimisticFixpoint();
360397387fdfSJohannes Doerfert 
3604d9659bf6SJohannes Doerfert     return StateBefore == getState() ? ChangeStatus::UNCHANGED
3605d9659bf6SJohannes Doerfert                                      : ChangeStatus::CHANGED;
3606d9659bf6SJohannes Doerfert   }
3607ca662297SShilei Tian 
3608ca662297SShilei Tian private:
3609ca662297SShilei Tian   /// Update info regarding reaching kernels.
3610ca662297SShilei Tian   void updateReachingKernelEntries(Attributor &A) {
3611ca662297SShilei Tian     auto PredCallSite = [&](AbstractCallSite ACS) {
3612ca662297SShilei Tian       Function *Caller = ACS.getInstruction()->getFunction();
3613ca662297SShilei Tian 
3614ca662297SShilei Tian       assert(Caller && "Caller is nullptr");
3615ca662297SShilei Tian 
3616d3454ee8SShilei Tian       auto &CAA = A.getOrCreateAAFor<AAKernelInfo>(
3617d3454ee8SShilei Tian           IRPosition::function(*Caller), this, DepClassTy::REQUIRED);
3618ca662297SShilei Tian       if (CAA.ReachingKernelEntries.isValidState()) {
3619ca662297SShilei Tian         ReachingKernelEntries ^= CAA.ReachingKernelEntries;
3620ca662297SShilei Tian         return true;
3621ca662297SShilei Tian       }
3622ca662297SShilei Tian 
3623ca662297SShilei Tian       // We lost track of the caller of the associated function, any kernel
3624ca662297SShilei Tian       // could reach now.
3625ca662297SShilei Tian       ReachingKernelEntries.indicatePessimisticFixpoint();
3626ca662297SShilei Tian 
3627ca662297SShilei Tian       return true;
3628ca662297SShilei Tian     };
3629ca662297SShilei Tian 
3630ca662297SShilei Tian     bool AllCallSitesKnown;
3631ca662297SShilei Tian     if (!A.checkForAllCallSites(PredCallSite, *this,
3632ca662297SShilei Tian                                 true /* RequireAllCallSites */,
3633ca662297SShilei Tian                                 AllCallSitesKnown))
3634ca662297SShilei Tian       ReachingKernelEntries.indicatePessimisticFixpoint();
3635ca662297SShilei Tian   }
3636e97e0a4fSShilei Tian 
3637e97e0a4fSShilei Tian   /// Update info regarding parallel levels.
3638e97e0a4fSShilei Tian   void updateParallelLevels(Attributor &A) {
3639e97e0a4fSShilei Tian     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3640e97e0a4fSShilei Tian     OMPInformationCache::RuntimeFunctionInfo &Parallel51RFI =
3641e97e0a4fSShilei Tian         OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
3642e97e0a4fSShilei Tian 
3643e97e0a4fSShilei Tian     auto PredCallSite = [&](AbstractCallSite ACS) {
3644e97e0a4fSShilei Tian       Function *Caller = ACS.getInstruction()->getFunction();
3645e97e0a4fSShilei Tian 
3646e97e0a4fSShilei Tian       assert(Caller && "Caller is nullptr");
3647e97e0a4fSShilei Tian 
3648e97e0a4fSShilei Tian       auto &CAA =
3649e97e0a4fSShilei Tian           A.getOrCreateAAFor<AAKernelInfo>(IRPosition::function(*Caller));
3650e97e0a4fSShilei Tian       if (CAA.ParallelLevels.isValidState()) {
3651e97e0a4fSShilei Tian         // Any function that is called by `__kmpc_parallel_51` will not be
3652e97e0a4fSShilei Tian         // folded as the parallel level in the function is updated. In order to
3653e97e0a4fSShilei Tian         // get it right, all the analysis would depend on the implentation. That
3654e97e0a4fSShilei Tian         // said, if in the future any change to the implementation, the analysis
3655e97e0a4fSShilei Tian         // could be wrong. As a consequence, we are just conservative here.
3656e97e0a4fSShilei Tian         if (Caller == Parallel51RFI.Declaration) {
3657e97e0a4fSShilei Tian           ParallelLevels.indicatePessimisticFixpoint();
3658e97e0a4fSShilei Tian           return true;
3659e97e0a4fSShilei Tian         }
3660e97e0a4fSShilei Tian 
3661e97e0a4fSShilei Tian         ParallelLevels ^= CAA.ParallelLevels;
3662e97e0a4fSShilei Tian 
3663e97e0a4fSShilei Tian         return true;
3664e97e0a4fSShilei Tian       }
3665e97e0a4fSShilei Tian 
3666e97e0a4fSShilei Tian       // We lost track of the caller of the associated function, any kernel
3667e97e0a4fSShilei Tian       // could reach now.
3668e97e0a4fSShilei Tian       ParallelLevels.indicatePessimisticFixpoint();
3669e97e0a4fSShilei Tian 
3670e97e0a4fSShilei Tian       return true;
3671e97e0a4fSShilei Tian     };
3672e97e0a4fSShilei Tian 
3673e97e0a4fSShilei Tian     bool AllCallSitesKnown = true;
3674e97e0a4fSShilei Tian     if (!A.checkForAllCallSites(PredCallSite, *this,
3675e97e0a4fSShilei Tian                                 true /* RequireAllCallSites */,
3676e97e0a4fSShilei Tian                                 AllCallSitesKnown))
3677e97e0a4fSShilei Tian       ParallelLevels.indicatePessimisticFixpoint();
3678e97e0a4fSShilei Tian   }
3679d9659bf6SJohannes Doerfert };
3680d9659bf6SJohannes Doerfert 
3681d9659bf6SJohannes Doerfert /// The call site kernel info abstract attribute, basically, what can we say
3682d9659bf6SJohannes Doerfert /// about a call site with regards to the KernelInfoState. For now this simply
3683d9659bf6SJohannes Doerfert /// forwards the information from the callee.
3684d9659bf6SJohannes Doerfert struct AAKernelInfoCallSite : AAKernelInfo {
3685d9659bf6SJohannes Doerfert   AAKernelInfoCallSite(const IRPosition &IRP, Attributor &A)
3686d9659bf6SJohannes Doerfert       : AAKernelInfo(IRP, A) {}
3687d9659bf6SJohannes Doerfert 
3688d9659bf6SJohannes Doerfert   /// See AbstractAttribute::initialize(...).
3689d9659bf6SJohannes Doerfert   void initialize(Attributor &A) override {
3690d9659bf6SJohannes Doerfert     AAKernelInfo::initialize(A);
3691d9659bf6SJohannes Doerfert 
3692d9659bf6SJohannes Doerfert     CallBase &CB = cast<CallBase>(getAssociatedValue());
3693d9659bf6SJohannes Doerfert     Function *Callee = getAssociatedFunction();
3694d9659bf6SJohannes Doerfert 
3695d9659bf6SJohannes Doerfert     // Helper to lookup an assumption string.
3696d9659bf6SJohannes Doerfert     auto HasAssumption = [](Function *Fn, StringRef AssumptionStr) {
3697d9659bf6SJohannes Doerfert       return Fn && hasAssumption(*Fn, AssumptionStr);
3698d9659bf6SJohannes Doerfert     };
3699d9659bf6SJohannes Doerfert 
3700514c033dSJohannes Doerfert     // Check for SPMD-mode assumptions.
3701514c033dSJohannes Doerfert     if (HasAssumption(Callee, "ompx_spmd_amenable"))
3702514c033dSJohannes Doerfert       SPMDCompatibilityTracker.indicateOptimisticFixpoint();
3703514c033dSJohannes Doerfert 
3704d9659bf6SJohannes Doerfert     // First weed out calls we do not care about, that is readonly/readnone
3705d9659bf6SJohannes Doerfert     // calls, intrinsics, and "no_openmp" calls. Neither of these can reach a
3706d9659bf6SJohannes Doerfert     // parallel region or anything else we are looking for.
3707d9659bf6SJohannes Doerfert     if (!CB.mayWriteToMemory() || isa<IntrinsicInst>(CB)) {
3708d9659bf6SJohannes Doerfert       indicateOptimisticFixpoint();
3709d9659bf6SJohannes Doerfert       return;
3710d9659bf6SJohannes Doerfert     }
3711d9659bf6SJohannes Doerfert 
3712d9659bf6SJohannes Doerfert     // Next we check if we know the callee. If it is a known OpenMP function
3713d9659bf6SJohannes Doerfert     // we will handle them explicitly in the switch below. If it is not, we
3714d9659bf6SJohannes Doerfert     // will use an AAKernelInfo object on the callee to gather information and
3715d9659bf6SJohannes Doerfert     // merge that into the current state. The latter happens in the updateImpl.
3716d9659bf6SJohannes Doerfert     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3717d9659bf6SJohannes Doerfert     const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
3718d9659bf6SJohannes Doerfert     if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
3719d9659bf6SJohannes Doerfert       // Unknown caller or declarations are not analyzable, we give up.
3720d9659bf6SJohannes Doerfert       if (!Callee || !A.isFunctionIPOAmendable(*Callee)) {
3721d9659bf6SJohannes Doerfert 
3722d9659bf6SJohannes Doerfert         // Unknown callees might contain parallel regions, except if they have
3723d9659bf6SJohannes Doerfert         // an appropriate assumption attached.
3724d9659bf6SJohannes Doerfert         if (!(HasAssumption(Callee, "omp_no_openmp") ||
3725d9659bf6SJohannes Doerfert               HasAssumption(Callee, "omp_no_parallelism")))
3726d9659bf6SJohannes Doerfert           ReachedUnknownParallelRegions.insert(&CB);
3727d9659bf6SJohannes Doerfert 
3728514c033dSJohannes Doerfert         // If SPMDCompatibilityTracker is not fixed, we need to give up on the
3729514c033dSJohannes Doerfert         // idea we can run something unknown in SPMD-mode.
373029a3e3ddSGiorgis Georgakoudis         if (!SPMDCompatibilityTracker.isAtFixpoint()) {
373129a3e3ddSGiorgis Georgakoudis           SPMDCompatibilityTracker.indicatePessimisticFixpoint();
3732514c033dSJohannes Doerfert           SPMDCompatibilityTracker.insert(&CB);
373329a3e3ddSGiorgis Georgakoudis         }
3734514c033dSJohannes Doerfert 
3735d9659bf6SJohannes Doerfert         // We have updated the state for this unknown call properly, there won't
3736d9659bf6SJohannes Doerfert         // be any change so we indicate a fixpoint.
3737d9659bf6SJohannes Doerfert         indicateOptimisticFixpoint();
3738d9659bf6SJohannes Doerfert       }
3739d9659bf6SJohannes Doerfert       // If the callee is known and can be used in IPO, we will update the state
3740d9659bf6SJohannes Doerfert       // based on the callee state in updateImpl.
3741d9659bf6SJohannes Doerfert       return;
3742d9659bf6SJohannes Doerfert     }
3743d9659bf6SJohannes Doerfert 
3744d9659bf6SJohannes Doerfert     const unsigned int WrapperFunctionArgNo = 6;
3745d9659bf6SJohannes Doerfert     RuntimeFunction RF = It->getSecond();
3746d9659bf6SJohannes Doerfert     switch (RF) {
3747514c033dSJohannes Doerfert     // All the functions we know are compatible with SPMD mode.
3748514c033dSJohannes Doerfert     case OMPRTL___kmpc_is_spmd_exec_mode:
3749514c033dSJohannes Doerfert     case OMPRTL___kmpc_for_static_fini:
3750514c033dSJohannes Doerfert     case OMPRTL___kmpc_global_thread_num:
37515ab6aeddSJose M Monsalve Diaz     case OMPRTL___kmpc_get_hardware_num_threads_in_block:
37525ab6aeddSJose M Monsalve Diaz     case OMPRTL___kmpc_get_hardware_num_blocks:
3753514c033dSJohannes Doerfert     case OMPRTL___kmpc_single:
3754514c033dSJohannes Doerfert     case OMPRTL___kmpc_end_single:
3755514c033dSJohannes Doerfert     case OMPRTL___kmpc_master:
3756514c033dSJohannes Doerfert     case OMPRTL___kmpc_end_master:
3757514c033dSJohannes Doerfert     case OMPRTL___kmpc_barrier:
3758514c033dSJohannes Doerfert       break;
3759514c033dSJohannes Doerfert     case OMPRTL___kmpc_for_static_init_4:
3760514c033dSJohannes Doerfert     case OMPRTL___kmpc_for_static_init_4u:
3761514c033dSJohannes Doerfert     case OMPRTL___kmpc_for_static_init_8:
3762514c033dSJohannes Doerfert     case OMPRTL___kmpc_for_static_init_8u: {
3763514c033dSJohannes Doerfert       // Check the schedule and allow static schedule in SPMD mode.
3764514c033dSJohannes Doerfert       unsigned ScheduleArgOpNo = 2;
3765514c033dSJohannes Doerfert       auto *ScheduleTypeCI =
3766514c033dSJohannes Doerfert           dyn_cast<ConstantInt>(CB.getArgOperand(ScheduleArgOpNo));
3767514c033dSJohannes Doerfert       unsigned ScheduleTypeVal =
3768514c033dSJohannes Doerfert           ScheduleTypeCI ? ScheduleTypeCI->getZExtValue() : 0;
3769514c033dSJohannes Doerfert       switch (OMPScheduleType(ScheduleTypeVal)) {
3770514c033dSJohannes Doerfert       case OMPScheduleType::Static:
3771514c033dSJohannes Doerfert       case OMPScheduleType::StaticChunked:
3772514c033dSJohannes Doerfert       case OMPScheduleType::Distribute:
3773514c033dSJohannes Doerfert       case OMPScheduleType::DistributeChunked:
3774514c033dSJohannes Doerfert         break;
3775514c033dSJohannes Doerfert       default:
377629a3e3ddSGiorgis Georgakoudis         SPMDCompatibilityTracker.indicatePessimisticFixpoint();
3777514c033dSJohannes Doerfert         SPMDCompatibilityTracker.insert(&CB);
3778514c033dSJohannes Doerfert         break;
3779514c033dSJohannes Doerfert       };
3780514c033dSJohannes Doerfert     } break;
3781d9659bf6SJohannes Doerfert     case OMPRTL___kmpc_target_init:
3782d9659bf6SJohannes Doerfert       KernelInitCB = &CB;
3783d9659bf6SJohannes Doerfert       break;
3784d9659bf6SJohannes Doerfert     case OMPRTL___kmpc_target_deinit:
3785d9659bf6SJohannes Doerfert       KernelDeinitCB = &CB;
3786d9659bf6SJohannes Doerfert       break;
3787d9659bf6SJohannes Doerfert     case OMPRTL___kmpc_parallel_51:
3788d9659bf6SJohannes Doerfert       if (auto *ParallelRegion = dyn_cast<Function>(
3789d9659bf6SJohannes Doerfert               CB.getArgOperand(WrapperFunctionArgNo)->stripPointerCasts())) {
3790d9659bf6SJohannes Doerfert         ReachedKnownParallelRegions.insert(ParallelRegion);
3791d9659bf6SJohannes Doerfert         break;
3792d9659bf6SJohannes Doerfert       }
3793d9659bf6SJohannes Doerfert       // The condition above should usually get the parallel region function
3794d9659bf6SJohannes Doerfert       // pointer and record it. In the off chance it doesn't we assume the
3795d9659bf6SJohannes Doerfert       // worst.
3796d9659bf6SJohannes Doerfert       ReachedUnknownParallelRegions.insert(&CB);
3797d9659bf6SJohannes Doerfert       break;
3798d9659bf6SJohannes Doerfert     case OMPRTL___kmpc_omp_task:
3799d9659bf6SJohannes Doerfert       // We do not look into tasks right now, just give up.
3800514c033dSJohannes Doerfert       SPMDCompatibilityTracker.insert(&CB);
3801d9659bf6SJohannes Doerfert       ReachedUnknownParallelRegions.insert(&CB);
380229a3e3ddSGiorgis Georgakoudis       indicatePessimisticFixpoint();
380329a3e3ddSGiorgis Georgakoudis       return;
3804f8c40ed8SGiorgis Georgakoudis     case OMPRTL___kmpc_alloc_shared:
3805f8c40ed8SGiorgis Georgakoudis     case OMPRTL___kmpc_free_shared:
3806f8c40ed8SGiorgis Georgakoudis       // Return without setting a fixpoint, to be resolved in updateImpl.
3807f8c40ed8SGiorgis Georgakoudis       return;
3808d9659bf6SJohannes Doerfert     default:
3809514c033dSJohannes Doerfert       // Unknown OpenMP runtime calls cannot be executed in SPMD-mode,
3810514c033dSJohannes Doerfert       // generally.
3811514c033dSJohannes Doerfert       SPMDCompatibilityTracker.insert(&CB);
381229a3e3ddSGiorgis Georgakoudis       indicatePessimisticFixpoint();
381329a3e3ddSGiorgis Georgakoudis       return;
3814d9659bf6SJohannes Doerfert     }
3815d9659bf6SJohannes Doerfert     // All other OpenMP runtime calls will not reach parallel regions so they
3816d9659bf6SJohannes Doerfert     // can be safely ignored for now. Since it is a known OpenMP runtime call we
3817d9659bf6SJohannes Doerfert     // have now modeled all effects and there is no need for any update.
3818d9659bf6SJohannes Doerfert     indicateOptimisticFixpoint();
3819d9659bf6SJohannes Doerfert   }
3820d9659bf6SJohannes Doerfert 
3821d9659bf6SJohannes Doerfert   ChangeStatus updateImpl(Attributor &A) override {
3822d9659bf6SJohannes Doerfert     // TODO: Once we have call site specific value information we can provide
3823d9659bf6SJohannes Doerfert     //       call site specific liveness information and then it makes
3824d9659bf6SJohannes Doerfert     //       sense to specialize attributes for call sites arguments instead of
3825d9659bf6SJohannes Doerfert     //       redirecting requests to the callee argument.
3826d9659bf6SJohannes Doerfert     Function *F = getAssociatedFunction();
3827f8c40ed8SGiorgis Georgakoudis 
3828f8c40ed8SGiorgis Georgakoudis     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3829f8c40ed8SGiorgis Georgakoudis     const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(F);
3830f8c40ed8SGiorgis Georgakoudis 
3831f8c40ed8SGiorgis Georgakoudis     // If F is not a runtime function, propagate the AAKernelInfo of the callee.
3832f8c40ed8SGiorgis Georgakoudis     if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
3833d9659bf6SJohannes Doerfert       const IRPosition &FnPos = IRPosition::function(*F);
3834d9659bf6SJohannes Doerfert       auto &FnAA = A.getAAFor<AAKernelInfo>(*this, FnPos, DepClassTy::REQUIRED);
3835d9659bf6SJohannes Doerfert       if (getState() == FnAA.getState())
3836d9659bf6SJohannes Doerfert         return ChangeStatus::UNCHANGED;
3837d9659bf6SJohannes Doerfert       getState() = FnAA.getState();
3838d9659bf6SJohannes Doerfert       return ChangeStatus::CHANGED;
3839d9659bf6SJohannes Doerfert     }
3840f8c40ed8SGiorgis Georgakoudis 
3841f8c40ed8SGiorgis Georgakoudis     // F is a runtime function that allocates or frees memory, check
3842f8c40ed8SGiorgis Georgakoudis     // AAHeapToStack and AAHeapToShared.
3843f8c40ed8SGiorgis Georgakoudis     KernelInfoState StateBefore = getState();
3844f8c40ed8SGiorgis Georgakoudis     assert((It->getSecond() == OMPRTL___kmpc_alloc_shared ||
3845f8c40ed8SGiorgis Georgakoudis             It->getSecond() == OMPRTL___kmpc_free_shared) &&
3846f8c40ed8SGiorgis Georgakoudis            "Expected a __kmpc_alloc_shared or __kmpc_free_shared runtime call");
3847f8c40ed8SGiorgis Georgakoudis 
3848f8c40ed8SGiorgis Georgakoudis     CallBase &CB = cast<CallBase>(getAssociatedValue());
3849f8c40ed8SGiorgis Georgakoudis 
3850f8c40ed8SGiorgis Georgakoudis     auto &HeapToStackAA = A.getAAFor<AAHeapToStack>(
3851f8c40ed8SGiorgis Georgakoudis         *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL);
3852f8c40ed8SGiorgis Georgakoudis     auto &HeapToSharedAA = A.getAAFor<AAHeapToShared>(
3853f8c40ed8SGiorgis Georgakoudis         *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL);
3854f8c40ed8SGiorgis Georgakoudis 
3855f8c40ed8SGiorgis Georgakoudis     RuntimeFunction RF = It->getSecond();
3856f8c40ed8SGiorgis Georgakoudis 
3857f8c40ed8SGiorgis Georgakoudis     switch (RF) {
3858f8c40ed8SGiorgis Georgakoudis     // If neither HeapToStack nor HeapToShared assume the call is removed,
3859f8c40ed8SGiorgis Georgakoudis     // assume SPMD incompatibility.
3860f8c40ed8SGiorgis Georgakoudis     case OMPRTL___kmpc_alloc_shared:
3861f8c40ed8SGiorgis Georgakoudis       if (!HeapToStackAA.isAssumedHeapToStack(CB) &&
3862f8c40ed8SGiorgis Georgakoudis           !HeapToSharedAA.isAssumedHeapToShared(CB))
3863f8c40ed8SGiorgis Georgakoudis         SPMDCompatibilityTracker.insert(&CB);
3864f8c40ed8SGiorgis Georgakoudis       break;
3865f8c40ed8SGiorgis Georgakoudis     case OMPRTL___kmpc_free_shared:
3866f8c40ed8SGiorgis Georgakoudis       if (!HeapToStackAA.isAssumedHeapToStackRemovedFree(CB) &&
3867f8c40ed8SGiorgis Georgakoudis           !HeapToSharedAA.isAssumedHeapToSharedRemovedFree(CB))
3868f8c40ed8SGiorgis Georgakoudis         SPMDCompatibilityTracker.insert(&CB);
3869f8c40ed8SGiorgis Georgakoudis       break;
3870f8c40ed8SGiorgis Georgakoudis     default:
3871f8c40ed8SGiorgis Georgakoudis       SPMDCompatibilityTracker.insert(&CB);
3872f8c40ed8SGiorgis Georgakoudis     }
3873f8c40ed8SGiorgis Georgakoudis 
3874f8c40ed8SGiorgis Georgakoudis     return StateBefore == getState() ? ChangeStatus::UNCHANGED
3875f8c40ed8SGiorgis Georgakoudis                                      : ChangeStatus::CHANGED;
3876f8c40ed8SGiorgis Georgakoudis   }
3877d9659bf6SJohannes Doerfert };
3878d9659bf6SJohannes Doerfert 
3879ca662297SShilei Tian struct AAFoldRuntimeCall
3880ca662297SShilei Tian     : public StateWrapper<BooleanState, AbstractAttribute> {
3881ca662297SShilei Tian   using Base = StateWrapper<BooleanState, AbstractAttribute>;
3882ca662297SShilei Tian 
3883ca662297SShilei Tian   AAFoldRuntimeCall(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
3884ca662297SShilei Tian 
3885ca662297SShilei Tian   /// Statistics are tracked as part of manifest for now.
3886ca662297SShilei Tian   void trackStatistics() const override {}
3887ca662297SShilei Tian 
3888ca662297SShilei Tian   /// Create an abstract attribute biew for the position \p IRP.
3889ca662297SShilei Tian   static AAFoldRuntimeCall &createForPosition(const IRPosition &IRP,
3890ca662297SShilei Tian                                               Attributor &A);
3891ca662297SShilei Tian 
3892ca662297SShilei Tian   /// See AbstractAttribute::getName()
3893ca662297SShilei Tian   const std::string getName() const override { return "AAFoldRuntimeCall"; }
3894ca662297SShilei Tian 
3895ca662297SShilei Tian   /// See AbstractAttribute::getIdAddr()
3896ca662297SShilei Tian   const char *getIdAddr() const override { return &ID; }
3897ca662297SShilei Tian 
3898ca662297SShilei Tian   /// This function should return true if the type of the \p AA is
3899ca662297SShilei Tian   /// AAFoldRuntimeCall
3900ca662297SShilei Tian   static bool classof(const AbstractAttribute *AA) {
3901ca662297SShilei Tian     return (AA->getIdAddr() == &ID);
3902ca662297SShilei Tian   }
3903ca662297SShilei Tian 
3904ca662297SShilei Tian   static const char ID;
3905ca662297SShilei Tian };
3906ca662297SShilei Tian 
3907ca662297SShilei Tian struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
3908ca662297SShilei Tian   AAFoldRuntimeCallCallSiteReturned(const IRPosition &IRP, Attributor &A)
3909ca662297SShilei Tian       : AAFoldRuntimeCall(IRP, A) {}
3910ca662297SShilei Tian 
3911ca662297SShilei Tian   /// See AbstractAttribute::getAsStr()
3912ca662297SShilei Tian   const std::string getAsStr() const override {
3913ca662297SShilei Tian     if (!isValidState())
3914ca662297SShilei Tian       return "<invalid>";
3915ca662297SShilei Tian 
3916ca662297SShilei Tian     std::string Str("simplified value: ");
3917ca662297SShilei Tian 
3918ca662297SShilei Tian     if (!SimplifiedValue.hasValue())
3919ca662297SShilei Tian       return Str + std::string("none");
3920ca662297SShilei Tian 
3921ca662297SShilei Tian     if (!SimplifiedValue.getValue())
3922ca662297SShilei Tian       return Str + std::string("nullptr");
3923ca662297SShilei Tian 
3924ca662297SShilei Tian     if (ConstantInt *CI = dyn_cast<ConstantInt>(SimplifiedValue.getValue()))
3925ca662297SShilei Tian       return Str + std::to_string(CI->getSExtValue());
3926ca662297SShilei Tian 
3927ca662297SShilei Tian     return Str + std::string("unknown");
3928ca662297SShilei Tian   }
3929ca662297SShilei Tian 
3930ca662297SShilei Tian   void initialize(Attributor &A) override {
3931cd0dd8ecSJoseph Huber     if (DisableOpenMPOptFolding)
3932cd0dd8ecSJoseph Huber       indicatePessimisticFixpoint();
3933cd0dd8ecSJoseph Huber 
3934ca662297SShilei Tian     Function *Callee = getAssociatedFunction();
3935ca662297SShilei Tian 
3936ca662297SShilei Tian     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3937ca662297SShilei Tian     const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
3938ca662297SShilei Tian     assert(It != OMPInfoCache.RuntimeFunctionIDMap.end() &&
3939ca662297SShilei Tian            "Expected a known OpenMP runtime function");
3940ca662297SShilei Tian 
3941ca662297SShilei Tian     RFKind = It->getSecond();
3942ca662297SShilei Tian 
3943ca662297SShilei Tian     CallBase &CB = cast<CallBase>(getAssociatedValue());
3944ca662297SShilei Tian     A.registerSimplificationCallback(
3945ca662297SShilei Tian         IRPosition::callsite_returned(CB),
3946ca662297SShilei Tian         [&](const IRPosition &IRP, const AbstractAttribute *AA,
3947ca662297SShilei Tian             bool &UsedAssumedInformation) -> Optional<Value *> {
3948ca662297SShilei Tian           assert((isValidState() || (SimplifiedValue.hasValue() &&
3949ca662297SShilei Tian                                      SimplifiedValue.getValue() == nullptr)) &&
3950ca662297SShilei Tian                  "Unexpected invalid state!");
3951ca662297SShilei Tian 
3952ca662297SShilei Tian           if (!isAtFixpoint()) {
3953ca662297SShilei Tian             UsedAssumedInformation = true;
3954ca662297SShilei Tian             if (AA)
3955ca662297SShilei Tian               A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
3956ca662297SShilei Tian           }
3957ca662297SShilei Tian           return SimplifiedValue;
3958ca662297SShilei Tian         });
3959ca662297SShilei Tian   }
3960ca662297SShilei Tian 
3961ca662297SShilei Tian   ChangeStatus updateImpl(Attributor &A) override {
3962ca662297SShilei Tian     ChangeStatus Changed = ChangeStatus::UNCHANGED;
3963ca662297SShilei Tian     switch (RFKind) {
3964ca662297SShilei Tian     case OMPRTL___kmpc_is_spmd_exec_mode:
3965c23da666SShilei Tian       Changed |= foldIsSPMDExecMode(A);
3966ca662297SShilei Tian       break;
3967196fe994SJoseph Huber     case OMPRTL___kmpc_is_generic_main_thread_id:
3968196fe994SJoseph Huber       Changed |= foldIsGenericMainThread(A);
3969196fe994SJoseph Huber       break;
3970e97e0a4fSShilei Tian     case OMPRTL___kmpc_parallel_level:
3971e97e0a4fSShilei Tian       Changed |= foldParallelLevel(A);
3972e97e0a4fSShilei Tian       break;
39735ab6aeddSJose M Monsalve Diaz     case OMPRTL___kmpc_get_hardware_num_threads_in_block:
39745ab6aeddSJose M Monsalve Diaz       Changed = Changed | foldKernelFnAttribute(A, "omp_target_thread_limit");
39755ab6aeddSJose M Monsalve Diaz       break;
39765ab6aeddSJose M Monsalve Diaz     case OMPRTL___kmpc_get_hardware_num_blocks:
39775ab6aeddSJose M Monsalve Diaz       Changed = Changed | foldKernelFnAttribute(A, "omp_target_num_teams");
39785ab6aeddSJose M Monsalve Diaz       break;
3979ca662297SShilei Tian     default:
3980ca662297SShilei Tian       llvm_unreachable("Unhandled OpenMP runtime function!");
3981ca662297SShilei Tian     }
3982ca662297SShilei Tian 
3983ca662297SShilei Tian     return Changed;
3984ca662297SShilei Tian   }
3985ca662297SShilei Tian 
3986ca662297SShilei Tian   ChangeStatus manifest(Attributor &A) override {
3987ca662297SShilei Tian     ChangeStatus Changed = ChangeStatus::UNCHANGED;
3988ca662297SShilei Tian 
3989ca662297SShilei Tian     if (SimplifiedValue.hasValue() && SimplifiedValue.getValue()) {
3990ca662297SShilei Tian       Instruction &CB = *getCtxI();
3991ca662297SShilei Tian       A.changeValueAfterManifest(CB, **SimplifiedValue);
3992ca662297SShilei Tian       A.deleteAfterManifest(CB);
3993196fe994SJoseph Huber 
3994196fe994SJoseph Huber       LLVM_DEBUG(dbgs() << TAG << "Folding runtime call: " << CB << " with "
3995196fe994SJoseph Huber                         << **SimplifiedValue << "\n");
3996196fe994SJoseph Huber 
3997ca662297SShilei Tian       Changed = ChangeStatus::CHANGED;
3998ca662297SShilei Tian     }
3999ca662297SShilei Tian 
4000ca662297SShilei Tian     return Changed;
4001ca662297SShilei Tian   }
4002ca662297SShilei Tian 
4003ca662297SShilei Tian   ChangeStatus indicatePessimisticFixpoint() override {
4004ca662297SShilei Tian     SimplifiedValue = nullptr;
4005ca662297SShilei Tian     return AAFoldRuntimeCall::indicatePessimisticFixpoint();
4006ca662297SShilei Tian   }
4007ca662297SShilei Tian 
4008ca662297SShilei Tian private:
4009ca662297SShilei Tian   /// Fold __kmpc_is_spmd_exec_mode into a constant if possible.
4010ca662297SShilei Tian   ChangeStatus foldIsSPMDExecMode(Attributor &A) {
4011ca662297SShilei Tian     Optional<Value *> SimplifiedValueBefore = SimplifiedValue;
4012ca662297SShilei Tian 
4013ca662297SShilei Tian     unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
4014ca662297SShilei Tian     unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
4015ca662297SShilei Tian     auto &CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
4016ca662297SShilei Tian         *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
4017ca662297SShilei Tian 
4018ca662297SShilei Tian     if (!CallerKernelInfoAA.ReachingKernelEntries.isValidState())
4019ca662297SShilei Tian       return indicatePessimisticFixpoint();
4020ca662297SShilei Tian 
4021ca662297SShilei Tian     for (Kernel K : CallerKernelInfoAA.ReachingKernelEntries) {
4022ca662297SShilei Tian       auto &AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
4023ca662297SShilei Tian                                           DepClassTy::REQUIRED);
4024ca662297SShilei Tian 
4025ca662297SShilei Tian       if (!AA.isValidState()) {
4026ca662297SShilei Tian         SimplifiedValue = nullptr;
4027ca662297SShilei Tian         return indicatePessimisticFixpoint();
4028ca662297SShilei Tian       }
4029ca662297SShilei Tian 
4030ca662297SShilei Tian       if (AA.SPMDCompatibilityTracker.isAssumed()) {
4031ca662297SShilei Tian         if (AA.SPMDCompatibilityTracker.isAtFixpoint())
4032ca662297SShilei Tian           ++KnownSPMDCount;
4033ca662297SShilei Tian         else
4034ca662297SShilei Tian           ++AssumedSPMDCount;
4035ca662297SShilei Tian       } else {
4036ca662297SShilei Tian         if (AA.SPMDCompatibilityTracker.isAtFixpoint())
4037ca662297SShilei Tian           ++KnownNonSPMDCount;
4038ca662297SShilei Tian         else
4039ca662297SShilei Tian           ++AssumedNonSPMDCount;
4040ca662297SShilei Tian       }
4041ca662297SShilei Tian     }
4042ca662297SShilei Tian 
4043ae69f468SShilei Tian     if ((AssumedSPMDCount + KnownSPMDCount) &&
4044ae69f468SShilei Tian         (AssumedNonSPMDCount + KnownNonSPMDCount))
4045ca662297SShilei Tian       return indicatePessimisticFixpoint();
4046ca662297SShilei Tian 
4047ca662297SShilei Tian     auto &Ctx = getAnchorValue().getContext();
4048ca662297SShilei Tian     if (KnownSPMDCount || AssumedSPMDCount) {
4049ca662297SShilei Tian       assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 &&
4050ca662297SShilei Tian              "Expected only SPMD kernels!");
4051ca662297SShilei Tian       // All reaching kernels are in SPMD mode. Update all function calls to
4052ca662297SShilei Tian       // __kmpc_is_spmd_exec_mode to 1.
4053ca662297SShilei Tian       SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), true);
4054d3454ee8SShilei Tian     } else if (KnownNonSPMDCount || AssumedNonSPMDCount) {
4055ca662297SShilei Tian       assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 &&
4056ca662297SShilei Tian              "Expected only non-SPMD kernels!");
4057ca662297SShilei Tian       // All reaching kernels are in non-SPMD mode. Update all function
4058ca662297SShilei Tian       // calls to __kmpc_is_spmd_exec_mode to 0.
4059ca662297SShilei Tian       SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), false);
4060d3454ee8SShilei Tian     } else {
4061d3454ee8SShilei Tian       // We have empty reaching kernels, therefore we cannot tell if the
4062d3454ee8SShilei Tian       // associated call site can be folded. At this moment, SimplifiedValue
4063d3454ee8SShilei Tian       // must be none.
4064d3454ee8SShilei Tian       assert(!SimplifiedValue.hasValue() && "SimplifiedValue should be none");
4065ca662297SShilei Tian     }
4066ca662297SShilei Tian 
4067ca662297SShilei Tian     return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
4068ca662297SShilei Tian                                                     : ChangeStatus::CHANGED;
4069ca662297SShilei Tian   }
4070ca662297SShilei Tian 
4071196fe994SJoseph Huber   /// Fold __kmpc_is_generic_main_thread_id into a constant if possible.
4072196fe994SJoseph Huber   ChangeStatus foldIsGenericMainThread(Attributor &A) {
4073196fe994SJoseph Huber     Optional<Value *> SimplifiedValueBefore = SimplifiedValue;
4074196fe994SJoseph Huber 
4075196fe994SJoseph Huber     CallBase &CB = cast<CallBase>(getAssociatedValue());
4076196fe994SJoseph Huber     Function *F = CB.getFunction();
4077196fe994SJoseph Huber     const auto &ExecutionDomainAA = A.getAAFor<AAExecutionDomain>(
4078196fe994SJoseph Huber         *this, IRPosition::function(*F), DepClassTy::REQUIRED);
4079196fe994SJoseph Huber 
4080196fe994SJoseph Huber     if (!ExecutionDomainAA.isValidState())
4081196fe994SJoseph Huber       return indicatePessimisticFixpoint();
4082196fe994SJoseph Huber 
4083196fe994SJoseph Huber     auto &Ctx = getAnchorValue().getContext();
4084196fe994SJoseph Huber     if (ExecutionDomainAA.isExecutedByInitialThreadOnly(CB))
4085196fe994SJoseph Huber       SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), true);
4086196fe994SJoseph Huber     else
4087196fe994SJoseph Huber       return indicatePessimisticFixpoint();
4088196fe994SJoseph Huber 
4089196fe994SJoseph Huber     return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
4090196fe994SJoseph Huber                                                     : ChangeStatus::CHANGED;
4091196fe994SJoseph Huber   }
4092196fe994SJoseph Huber 
4093e97e0a4fSShilei Tian   /// Fold __kmpc_parallel_level into a constant if possible.
4094e97e0a4fSShilei Tian   ChangeStatus foldParallelLevel(Attributor &A) {
4095e97e0a4fSShilei Tian     Optional<Value *> SimplifiedValueBefore = SimplifiedValue;
4096e97e0a4fSShilei Tian 
4097e97e0a4fSShilei Tian     auto &CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
4098e97e0a4fSShilei Tian         *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
4099e97e0a4fSShilei Tian 
4100e97e0a4fSShilei Tian     if (!CallerKernelInfoAA.ParallelLevels.isValidState())
4101e97e0a4fSShilei Tian       return indicatePessimisticFixpoint();
4102e97e0a4fSShilei Tian 
4103e97e0a4fSShilei Tian     if (!CallerKernelInfoAA.ReachingKernelEntries.isValidState())
4104e97e0a4fSShilei Tian       return indicatePessimisticFixpoint();
4105e97e0a4fSShilei Tian 
4106e97e0a4fSShilei Tian     if (CallerKernelInfoAA.ReachingKernelEntries.empty()) {
4107e97e0a4fSShilei Tian       assert(!SimplifiedValue.hasValue() &&
4108e97e0a4fSShilei Tian              "SimplifiedValue should keep none at this point");
4109e97e0a4fSShilei Tian       return ChangeStatus::UNCHANGED;
4110e97e0a4fSShilei Tian     }
4111e97e0a4fSShilei Tian 
4112e97e0a4fSShilei Tian     unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
4113e97e0a4fSShilei Tian     unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
4114e97e0a4fSShilei Tian     for (Kernel K : CallerKernelInfoAA.ReachingKernelEntries) {
4115e97e0a4fSShilei Tian       auto &AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
4116e97e0a4fSShilei Tian                                           DepClassTy::REQUIRED);
4117e97e0a4fSShilei Tian       if (!AA.SPMDCompatibilityTracker.isValidState())
4118e97e0a4fSShilei Tian         return indicatePessimisticFixpoint();
4119e97e0a4fSShilei Tian 
4120e97e0a4fSShilei Tian       if (AA.SPMDCompatibilityTracker.isAssumed()) {
4121e97e0a4fSShilei Tian         if (AA.SPMDCompatibilityTracker.isAtFixpoint())
4122e97e0a4fSShilei Tian           ++KnownSPMDCount;
4123e97e0a4fSShilei Tian         else
4124e97e0a4fSShilei Tian           ++AssumedSPMDCount;
4125e97e0a4fSShilei Tian       } else {
4126e97e0a4fSShilei Tian         if (AA.SPMDCompatibilityTracker.isAtFixpoint())
4127e97e0a4fSShilei Tian           ++KnownNonSPMDCount;
4128e97e0a4fSShilei Tian         else
4129e97e0a4fSShilei Tian           ++AssumedNonSPMDCount;
4130e97e0a4fSShilei Tian       }
4131e97e0a4fSShilei Tian     }
4132e97e0a4fSShilei Tian 
4133e97e0a4fSShilei Tian     if ((AssumedSPMDCount + KnownSPMDCount) &&
4134e97e0a4fSShilei Tian         (AssumedNonSPMDCount + KnownNonSPMDCount))
4135e97e0a4fSShilei Tian       return indicatePessimisticFixpoint();
4136e97e0a4fSShilei Tian 
4137e97e0a4fSShilei Tian     auto &Ctx = getAnchorValue().getContext();
4138e97e0a4fSShilei Tian     // If the caller can only be reached by SPMD kernel entries, the parallel
4139e97e0a4fSShilei Tian     // level is 1. Similarly, if the caller can only be reached by non-SPMD
4140e97e0a4fSShilei Tian     // kernel entries, it is 0.
4141e97e0a4fSShilei Tian     if (AssumedSPMDCount || KnownSPMDCount) {
4142e97e0a4fSShilei Tian       assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 &&
4143e97e0a4fSShilei Tian              "Expected only SPMD kernels!");
4144e97e0a4fSShilei Tian       SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 1);
4145e97e0a4fSShilei Tian     } else {
4146e97e0a4fSShilei Tian       assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 &&
4147e97e0a4fSShilei Tian              "Expected only non-SPMD kernels!");
4148e97e0a4fSShilei Tian       SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 0);
4149e97e0a4fSShilei Tian     }
41505ab6aeddSJose M Monsalve Diaz     return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
41515ab6aeddSJose M Monsalve Diaz                                                     : ChangeStatus::CHANGED;
41525ab6aeddSJose M Monsalve Diaz   }
4153e97e0a4fSShilei Tian 
41545ab6aeddSJose M Monsalve Diaz   ChangeStatus foldKernelFnAttribute(Attributor &A, llvm::StringRef Attr) {
41555ab6aeddSJose M Monsalve Diaz     // Specialize only if all the calls agree with the attribute constant value
41565ab6aeddSJose M Monsalve Diaz     int32_t CurrentAttrValue = -1;
41575ab6aeddSJose M Monsalve Diaz     Optional<Value *> SimplifiedValueBefore = SimplifiedValue;
41585ab6aeddSJose M Monsalve Diaz 
41595ab6aeddSJose M Monsalve Diaz     auto &CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
41605ab6aeddSJose M Monsalve Diaz         *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
41615ab6aeddSJose M Monsalve Diaz 
41625ab6aeddSJose M Monsalve Diaz     if (!CallerKernelInfoAA.ReachingKernelEntries.isValidState())
41635ab6aeddSJose M Monsalve Diaz       return indicatePessimisticFixpoint();
41645ab6aeddSJose M Monsalve Diaz 
41655ab6aeddSJose M Monsalve Diaz     // Iterate over the kernels that reach this function
41665ab6aeddSJose M Monsalve Diaz     for (Kernel K : CallerKernelInfoAA.ReachingKernelEntries) {
41675ab6aeddSJose M Monsalve Diaz       int32_t NextAttrVal = -1;
41685ab6aeddSJose M Monsalve Diaz       if (K->hasFnAttribute(Attr))
41695ab6aeddSJose M Monsalve Diaz         NextAttrVal =
41705ab6aeddSJose M Monsalve Diaz             std::stoi(K->getFnAttribute(Attr).getValueAsString().str());
41715ab6aeddSJose M Monsalve Diaz 
41725ab6aeddSJose M Monsalve Diaz       if (NextAttrVal == -1 ||
41735ab6aeddSJose M Monsalve Diaz           (CurrentAttrValue != -1 && CurrentAttrValue != NextAttrVal))
41745ab6aeddSJose M Monsalve Diaz         return indicatePessimisticFixpoint();
41755ab6aeddSJose M Monsalve Diaz       CurrentAttrValue = NextAttrVal;
41765ab6aeddSJose M Monsalve Diaz     }
41775ab6aeddSJose M Monsalve Diaz 
41785ab6aeddSJose M Monsalve Diaz     if (CurrentAttrValue != -1) {
41795ab6aeddSJose M Monsalve Diaz       auto &Ctx = getAnchorValue().getContext();
41805ab6aeddSJose M Monsalve Diaz       SimplifiedValue =
41815ab6aeddSJose M Monsalve Diaz           ConstantInt::get(Type::getInt32Ty(Ctx), CurrentAttrValue);
41825ab6aeddSJose M Monsalve Diaz     }
4183e97e0a4fSShilei Tian     return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
4184e97e0a4fSShilei Tian                                                     : ChangeStatus::CHANGED;
4185e97e0a4fSShilei Tian   }
4186e97e0a4fSShilei Tian 
4187ca662297SShilei Tian   /// An optional value the associated value is assumed to fold to. That is, we
4188ca662297SShilei Tian   /// assume the associated value (which is a call) can be replaced by this
4189ca662297SShilei Tian   /// simplified value.
4190ca662297SShilei Tian   Optional<Value *> SimplifiedValue;
4191ca662297SShilei Tian 
4192ca662297SShilei Tian   /// The runtime function kind of the callee of the associated call site.
4193ca662297SShilei Tian   RuntimeFunction RFKind;
4194ca662297SShilei Tian };
4195ca662297SShilei Tian 
41969548b74aSJohannes Doerfert } // namespace
41979548b74aSJohannes Doerfert 
41985ab6aeddSJose M Monsalve Diaz /// Register folding callsite
41995ab6aeddSJose M Monsalve Diaz void OpenMPOpt::registerFoldRuntimeCall(RuntimeFunction RF) {
42005ab6aeddSJose M Monsalve Diaz   auto &RFI = OMPInfoCache.RFIs[RF];
42015ab6aeddSJose M Monsalve Diaz   RFI.foreachUse(SCC, [&](Use &U, Function &F) {
42025ab6aeddSJose M Monsalve Diaz     CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &RFI);
42035ab6aeddSJose M Monsalve Diaz     if (!CI)
42045ab6aeddSJose M Monsalve Diaz       return false;
42055ab6aeddSJose M Monsalve Diaz     A.getOrCreateAAFor<AAFoldRuntimeCall>(
42065ab6aeddSJose M Monsalve Diaz         IRPosition::callsite_returned(*CI), /* QueryingAA */ nullptr,
42075ab6aeddSJose M Monsalve Diaz         DepClassTy::NONE, /* ForceUpdate */ false,
42085ab6aeddSJose M Monsalve Diaz         /* UpdateAfterInit */ false);
42095ab6aeddSJose M Monsalve Diaz     return false;
42105ab6aeddSJose M Monsalve Diaz   });
42115ab6aeddSJose M Monsalve Diaz }
42125ab6aeddSJose M Monsalve Diaz 
4213d9659bf6SJohannes Doerfert void OpenMPOpt::registerAAs(bool IsModulePass) {
4214d9659bf6SJohannes Doerfert   if (SCC.empty())
4215d9659bf6SJohannes Doerfert 
4216d9659bf6SJohannes Doerfert     return;
4217d9659bf6SJohannes Doerfert   if (IsModulePass) {
4218d9659bf6SJohannes Doerfert     // Ensure we create the AAKernelInfo AAs first and without triggering an
4219d9659bf6SJohannes Doerfert     // update. This will make sure we register all value simplification
4220d9659bf6SJohannes Doerfert     // callbacks before any other AA has the chance to create an AAValueSimplify
4221d9659bf6SJohannes Doerfert     // or similar.
4222d9659bf6SJohannes Doerfert     for (Function *Kernel : OMPInfoCache.Kernels)
4223d9659bf6SJohannes Doerfert       A.getOrCreateAAFor<AAKernelInfo>(
4224d9659bf6SJohannes Doerfert           IRPosition::function(*Kernel), /* QueryingAA */ nullptr,
4225d9659bf6SJohannes Doerfert           DepClassTy::NONE, /* ForceUpdate */ false,
4226d9659bf6SJohannes Doerfert           /* UpdateAfterInit */ false);
4227ca662297SShilei Tian 
4228196fe994SJoseph Huber 
42295ab6aeddSJose M Monsalve Diaz     registerFoldRuntimeCall(OMPRTL___kmpc_is_generic_main_thread_id);
42305ab6aeddSJose M Monsalve Diaz     registerFoldRuntimeCall(OMPRTL___kmpc_is_spmd_exec_mode);
42315ab6aeddSJose M Monsalve Diaz     registerFoldRuntimeCall(OMPRTL___kmpc_parallel_level);
42325ab6aeddSJose M Monsalve Diaz     registerFoldRuntimeCall(OMPRTL___kmpc_get_hardware_num_threads_in_block);
42335ab6aeddSJose M Monsalve Diaz     registerFoldRuntimeCall(OMPRTL___kmpc_get_hardware_num_blocks);
4234d9659bf6SJohannes Doerfert   }
4235d9659bf6SJohannes Doerfert 
4236d9659bf6SJohannes Doerfert   // Create CallSite AA for all Getters.
4237d9659bf6SJohannes Doerfert   for (int Idx = 0; Idx < OMPInfoCache.ICVs.size() - 1; ++Idx) {
4238d9659bf6SJohannes Doerfert     auto ICVInfo = OMPInfoCache.ICVs[static_cast<InternalControlVar>(Idx)];
4239d9659bf6SJohannes Doerfert 
4240d9659bf6SJohannes Doerfert     auto &GetterRFI = OMPInfoCache.RFIs[ICVInfo.Getter];
4241d9659bf6SJohannes Doerfert 
4242d9659bf6SJohannes Doerfert     auto CreateAA = [&](Use &U, Function &Caller) {
4243d9659bf6SJohannes Doerfert       CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &GetterRFI);
4244d9659bf6SJohannes Doerfert       if (!CI)
4245d9659bf6SJohannes Doerfert         return false;
4246d9659bf6SJohannes Doerfert 
4247d9659bf6SJohannes Doerfert       auto &CB = cast<CallBase>(*CI);
4248d9659bf6SJohannes Doerfert 
4249d9659bf6SJohannes Doerfert       IRPosition CBPos = IRPosition::callsite_function(CB);
4250d9659bf6SJohannes Doerfert       A.getOrCreateAAFor<AAICVTracker>(CBPos);
4251d9659bf6SJohannes Doerfert       return false;
4252d9659bf6SJohannes Doerfert     };
4253d9659bf6SJohannes Doerfert 
4254d9659bf6SJohannes Doerfert     GetterRFI.foreachUse(SCC, CreateAA);
4255d9659bf6SJohannes Doerfert   }
4256d9659bf6SJohannes Doerfert   auto &GlobalizationRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
4257d9659bf6SJohannes Doerfert   auto CreateAA = [&](Use &U, Function &F) {
4258d9659bf6SJohannes Doerfert     A.getOrCreateAAFor<AAHeapToShared>(IRPosition::function(F));
4259d9659bf6SJohannes Doerfert     return false;
4260d9659bf6SJohannes Doerfert   };
4261cd0dd8ecSJoseph Huber   if (!DisableOpenMPOptDeglobalization)
4262d9659bf6SJohannes Doerfert     GlobalizationRFI.foreachUse(SCC, CreateAA);
4263d9659bf6SJohannes Doerfert 
4264d9659bf6SJohannes Doerfert   // Create an ExecutionDomain AA for every function and a HeapToStack AA for
4265d9659bf6SJohannes Doerfert   // every function if there is a device kernel.
426670b75f62SJohannes Doerfert   if (!isOpenMPDevice(M))
426770b75f62SJohannes Doerfert     return;
426870b75f62SJohannes Doerfert 
4269d9659bf6SJohannes Doerfert   for (auto *F : SCC) {
427070b75f62SJohannes Doerfert     if (F->isDeclaration())
427170b75f62SJohannes Doerfert       continue;
427270b75f62SJohannes Doerfert 
4273d9659bf6SJohannes Doerfert     A.getOrCreateAAFor<AAExecutionDomain>(IRPosition::function(*F));
4274cd0dd8ecSJoseph Huber     if (!DisableOpenMPOptDeglobalization)
4275d9659bf6SJohannes Doerfert       A.getOrCreateAAFor<AAHeapToStack>(IRPosition::function(*F));
427670b75f62SJohannes Doerfert 
427770b75f62SJohannes Doerfert     for (auto &I : instructions(*F)) {
427870b75f62SJohannes Doerfert       if (auto *LI = dyn_cast<LoadInst>(&I)) {
427970b75f62SJohannes Doerfert         bool UsedAssumedInformation = false;
428070b75f62SJohannes Doerfert         A.getAssumedSimplified(IRPosition::value(*LI), /* AA */ nullptr,
428170b75f62SJohannes Doerfert                                UsedAssumedInformation);
428270b75f62SJohannes Doerfert       }
428370b75f62SJohannes Doerfert     }
4284d9659bf6SJohannes Doerfert   }
4285d9659bf6SJohannes Doerfert }
4286d9659bf6SJohannes Doerfert 
4287b8235d2bSsstefan1 const char AAICVTracker::ID = 0;
4288d9659bf6SJohannes Doerfert const char AAKernelInfo::ID = 0;
428918283125SJoseph Huber const char AAExecutionDomain::ID = 0;
42906fc51c9fSJoseph Huber const char AAHeapToShared::ID = 0;
4291ca662297SShilei Tian const char AAFoldRuntimeCall::ID = 0;
4292b8235d2bSsstefan1 
4293b8235d2bSsstefan1 AAICVTracker &AAICVTracker::createForPosition(const IRPosition &IRP,
4294b8235d2bSsstefan1                                               Attributor &A) {
4295b8235d2bSsstefan1   AAICVTracker *AA = nullptr;
4296b8235d2bSsstefan1   switch (IRP.getPositionKind()) {
4297b8235d2bSsstefan1   case IRPosition::IRP_INVALID:
4298b8235d2bSsstefan1   case IRPosition::IRP_FLOAT:
4299b8235d2bSsstefan1   case IRPosition::IRP_ARGUMENT:
4300b8235d2bSsstefan1   case IRPosition::IRP_CALL_SITE_ARGUMENT:
43011de70a72SJohannes Doerfert     llvm_unreachable("ICVTracker can only be created for function position!");
43025dfd7cc4Ssstefan1   case IRPosition::IRP_RETURNED:
43035dfd7cc4Ssstefan1     AA = new (A.Allocator) AAICVTrackerFunctionReturned(IRP, A);
43045dfd7cc4Ssstefan1     break;
43055dfd7cc4Ssstefan1   case IRPosition::IRP_CALL_SITE_RETURNED:
43065dfd7cc4Ssstefan1     AA = new (A.Allocator) AAICVTrackerCallSiteReturned(IRP, A);
43075dfd7cc4Ssstefan1     break;
43085dfd7cc4Ssstefan1   case IRPosition::IRP_CALL_SITE:
43095dfd7cc4Ssstefan1     AA = new (A.Allocator) AAICVTrackerCallSite(IRP, A);
43105dfd7cc4Ssstefan1     break;
4311b8235d2bSsstefan1   case IRPosition::IRP_FUNCTION:
4312b8235d2bSsstefan1     AA = new (A.Allocator) AAICVTrackerFunction(IRP, A);
4313b8235d2bSsstefan1     break;
4314b8235d2bSsstefan1   }
4315b8235d2bSsstefan1 
4316b8235d2bSsstefan1   return *AA;
4317b8235d2bSsstefan1 }
4318b8235d2bSsstefan1 
431918283125SJoseph Huber AAExecutionDomain &AAExecutionDomain::createForPosition(const IRPosition &IRP,
432018283125SJoseph Huber                                                         Attributor &A) {
432118283125SJoseph Huber   AAExecutionDomainFunction *AA = nullptr;
432218283125SJoseph Huber   switch (IRP.getPositionKind()) {
432318283125SJoseph Huber   case IRPosition::IRP_INVALID:
432418283125SJoseph Huber   case IRPosition::IRP_FLOAT:
432518283125SJoseph Huber   case IRPosition::IRP_ARGUMENT:
432618283125SJoseph Huber   case IRPosition::IRP_CALL_SITE_ARGUMENT:
432718283125SJoseph Huber   case IRPosition::IRP_RETURNED:
432818283125SJoseph Huber   case IRPosition::IRP_CALL_SITE_RETURNED:
432918283125SJoseph Huber   case IRPosition::IRP_CALL_SITE:
433018283125SJoseph Huber     llvm_unreachable(
433118283125SJoseph Huber         "AAExecutionDomain can only be created for function position!");
433218283125SJoseph Huber   case IRPosition::IRP_FUNCTION:
433318283125SJoseph Huber     AA = new (A.Allocator) AAExecutionDomainFunction(IRP, A);
433418283125SJoseph Huber     break;
433518283125SJoseph Huber   }
433618283125SJoseph Huber 
433718283125SJoseph Huber   return *AA;
433818283125SJoseph Huber }
433918283125SJoseph Huber 
43406fc51c9fSJoseph Huber AAHeapToShared &AAHeapToShared::createForPosition(const IRPosition &IRP,
43416fc51c9fSJoseph Huber                                                   Attributor &A) {
43426fc51c9fSJoseph Huber   AAHeapToSharedFunction *AA = nullptr;
43436fc51c9fSJoseph Huber   switch (IRP.getPositionKind()) {
43446fc51c9fSJoseph Huber   case IRPosition::IRP_INVALID:
43456fc51c9fSJoseph Huber   case IRPosition::IRP_FLOAT:
43466fc51c9fSJoseph Huber   case IRPosition::IRP_ARGUMENT:
43476fc51c9fSJoseph Huber   case IRPosition::IRP_CALL_SITE_ARGUMENT:
43486fc51c9fSJoseph Huber   case IRPosition::IRP_RETURNED:
43496fc51c9fSJoseph Huber   case IRPosition::IRP_CALL_SITE_RETURNED:
43506fc51c9fSJoseph Huber   case IRPosition::IRP_CALL_SITE:
43516fc51c9fSJoseph Huber     llvm_unreachable(
43526fc51c9fSJoseph Huber         "AAHeapToShared can only be created for function position!");
43536fc51c9fSJoseph Huber   case IRPosition::IRP_FUNCTION:
43546fc51c9fSJoseph Huber     AA = new (A.Allocator) AAHeapToSharedFunction(IRP, A);
43556fc51c9fSJoseph Huber     break;
43566fc51c9fSJoseph Huber   }
43576fc51c9fSJoseph Huber 
43586fc51c9fSJoseph Huber   return *AA;
43596fc51c9fSJoseph Huber }
43606fc51c9fSJoseph Huber 
4361d9659bf6SJohannes Doerfert AAKernelInfo &AAKernelInfo::createForPosition(const IRPosition &IRP,
4362d9659bf6SJohannes Doerfert                                               Attributor &A) {
4363d9659bf6SJohannes Doerfert   AAKernelInfo *AA = nullptr;
4364d9659bf6SJohannes Doerfert   switch (IRP.getPositionKind()) {
4365d9659bf6SJohannes Doerfert   case IRPosition::IRP_INVALID:
4366d9659bf6SJohannes Doerfert   case IRPosition::IRP_FLOAT:
4367d9659bf6SJohannes Doerfert   case IRPosition::IRP_ARGUMENT:
4368d9659bf6SJohannes Doerfert   case IRPosition::IRP_RETURNED:
4369d9659bf6SJohannes Doerfert   case IRPosition::IRP_CALL_SITE_RETURNED:
4370d9659bf6SJohannes Doerfert   case IRPosition::IRP_CALL_SITE_ARGUMENT:
4371d9659bf6SJohannes Doerfert     llvm_unreachable("KernelInfo can only be created for function position!");
4372d9659bf6SJohannes Doerfert   case IRPosition::IRP_CALL_SITE:
4373d9659bf6SJohannes Doerfert     AA = new (A.Allocator) AAKernelInfoCallSite(IRP, A);
4374d9659bf6SJohannes Doerfert     break;
4375d9659bf6SJohannes Doerfert   case IRPosition::IRP_FUNCTION:
4376d9659bf6SJohannes Doerfert     AA = new (A.Allocator) AAKernelInfoFunction(IRP, A);
4377d9659bf6SJohannes Doerfert     break;
4378d9659bf6SJohannes Doerfert   }
4379d9659bf6SJohannes Doerfert 
4380d9659bf6SJohannes Doerfert   return *AA;
4381d9659bf6SJohannes Doerfert }
4382d9659bf6SJohannes Doerfert 
4383ca662297SShilei Tian AAFoldRuntimeCall &AAFoldRuntimeCall::createForPosition(const IRPosition &IRP,
4384ca662297SShilei Tian                                                         Attributor &A) {
4385ca662297SShilei Tian   AAFoldRuntimeCall *AA = nullptr;
4386ca662297SShilei Tian   switch (IRP.getPositionKind()) {
4387ca662297SShilei Tian   case IRPosition::IRP_INVALID:
4388ca662297SShilei Tian   case IRPosition::IRP_FLOAT:
4389ca662297SShilei Tian   case IRPosition::IRP_ARGUMENT:
4390ca662297SShilei Tian   case IRPosition::IRP_RETURNED:
4391ca662297SShilei Tian   case IRPosition::IRP_FUNCTION:
4392ca662297SShilei Tian   case IRPosition::IRP_CALL_SITE:
4393ca662297SShilei Tian   case IRPosition::IRP_CALL_SITE_ARGUMENT:
4394ca662297SShilei Tian     llvm_unreachable("KernelInfo can only be created for call site position!");
4395ca662297SShilei Tian   case IRPosition::IRP_CALL_SITE_RETURNED:
4396ca662297SShilei Tian     AA = new (A.Allocator) AAFoldRuntimeCallCallSiteReturned(IRP, A);
4397ca662297SShilei Tian     break;
4398ca662297SShilei Tian   }
4399ca662297SShilei Tian 
4400ca662297SShilei Tian   return *AA;
4401ca662297SShilei Tian }
4402ca662297SShilei Tian 
4403b2ad63d3SJoseph Huber PreservedAnalyses OpenMPOptPass::run(Module &M, ModuleAnalysisManager &AM) {
44045ccb7424SJoseph Huber   if (!containsOpenMP(M))
4405b2ad63d3SJoseph Huber     return PreservedAnalyses::all();
4406b2ad63d3SJoseph Huber   if (DisableOpenMPOptimizations)
4407b2ad63d3SJoseph Huber     return PreservedAnalyses::all();
4408b2ad63d3SJoseph Huber 
44090edb8777SJoseph Huber   FunctionAnalysisManager &FAM =
44100edb8777SJoseph Huber       AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
44115ccb7424SJoseph Huber   KernelSet Kernels = getDeviceKernels(M);
44125ccb7424SJoseph Huber 
441357ad2e10SJoseph Huber   auto IsCalled = [&](Function &F) {
441457ad2e10SJoseph Huber     if (Kernels.contains(&F))
441557ad2e10SJoseph Huber       return true;
441657ad2e10SJoseph Huber     for (const User *U : F.users())
441757ad2e10SJoseph Huber       if (!isa<BlockAddress>(U))
441857ad2e10SJoseph Huber         return true;
441957ad2e10SJoseph Huber     return false;
442057ad2e10SJoseph Huber   };
442157ad2e10SJoseph Huber 
44220edb8777SJoseph Huber   auto EmitRemark = [&](Function &F) {
44230edb8777SJoseph Huber     auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(F);
44240edb8777SJoseph Huber     ORE.emit([&]() {
44252c31d5ebSJoseph Huber       OptimizationRemarkAnalysis ORA(DEBUG_TYPE, "OMP140", &F);
4426ecabc668SJoseph Huber       return ORA << "Could not internalize function. "
4427adbaa39dSJoseph Huber                  << "Some optimizations may not be possible. [OMP140]";
44280edb8777SJoseph Huber     });
44290edb8777SJoseph Huber   };
44300edb8777SJoseph Huber 
443157ad2e10SJoseph Huber   // Create internal copies of each function if this is a kernel Module. This
443257ad2e10SJoseph Huber   // allows iterprocedural passes to see every call edge.
4433adbaa39dSJoseph Huber   DenseMap<Function *, Function *> InternalizedMap;
4434adbaa39dSJoseph Huber   if (isOpenMPDevice(M)) {
4435adbaa39dSJoseph Huber     SmallPtrSet<Function *, 16> InternalizeFns;
443603d7e61cSJoseph Huber     for (Function &F : M)
44374a668604SJoseph Huber       if (!F.isDeclaration() && !Kernels.contains(&F) && IsCalled(F) &&
44384a668604SJoseph Huber           !DisableInternalization) {
4439adbaa39dSJoseph Huber         if (Attributor::isInternalizable(F)) {
4440adbaa39dSJoseph Huber           InternalizeFns.insert(&F);
4441ecabc668SJoseph Huber         } else if (!F.hasLocalLinkage() && !F.hasFnAttribute(Attribute::Cold)) {
44420edb8777SJoseph Huber           EmitRemark(F);
44430edb8777SJoseph Huber         }
44440edb8777SJoseph Huber       }
444503d7e61cSJoseph Huber 
4446adbaa39dSJoseph Huber     Attributor::internalizeFunctions(InternalizeFns, InternalizedMap);
4447adbaa39dSJoseph Huber   }
4448adbaa39dSJoseph Huber 
444957ad2e10SJoseph Huber   // Look at every function in the Module unless it was internalized.
4450b2ad63d3SJoseph Huber   SmallVector<Function *, 16> SCC;
445103d7e61cSJoseph Huber   for (Function &F : M)
4452adbaa39dSJoseph Huber     if (!F.isDeclaration() && !InternalizedMap.lookup(&F))
445303d7e61cSJoseph Huber       SCC.push_back(&F);
4454b2ad63d3SJoseph Huber 
4455b2ad63d3SJoseph Huber   if (SCC.empty())
4456b2ad63d3SJoseph Huber     return PreservedAnalyses::all();
4457b2ad63d3SJoseph Huber 
4458b2ad63d3SJoseph Huber   AnalysisGetter AG(FAM);
4459b2ad63d3SJoseph Huber 
4460b2ad63d3SJoseph Huber   auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & {
4461b2ad63d3SJoseph Huber     return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F);
4462b2ad63d3SJoseph Huber   };
4463b2ad63d3SJoseph Huber 
4464b2ad63d3SJoseph Huber   BumpPtrAllocator Allocator;
4465b2ad63d3SJoseph Huber   CallGraphUpdater CGUpdater;
4466b2ad63d3SJoseph Huber 
4467b2ad63d3SJoseph Huber   SetVector<Function *> Functions(SCC.begin(), SCC.end());
44685ccb7424SJoseph Huber   OMPInformationCache InfoCache(M, AG, Allocator, /*CGSCC*/ Functions, Kernels);
4469b2ad63d3SJoseph Huber 
447013b2fba2SJoseph Huber   unsigned MaxFixpointIterations = (isOpenMPDevice(M)) ? 128 : 32;
44714a6bd8e3SJoseph Huber   Attributor A(Functions, InfoCache, CGUpdater, nullptr, true, false,
447213b2fba2SJoseph Huber                MaxFixpointIterations, OREGetter, DEBUG_TYPE);
4473b2ad63d3SJoseph Huber 
4474b2ad63d3SJoseph Huber   OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
4475b2ad63d3SJoseph Huber   bool Changed = OMPOpt.run(true);
4476b2ad63d3SJoseph Huber   if (Changed)
4477b2ad63d3SJoseph Huber     return PreservedAnalyses::none();
4478b2ad63d3SJoseph Huber 
4479b2ad63d3SJoseph Huber   return PreservedAnalyses::all();
4480b2ad63d3SJoseph Huber }
4481b2ad63d3SJoseph Huber 
4482b2ad63d3SJoseph Huber PreservedAnalyses OpenMPOptCGSCCPass::run(LazyCallGraph::SCC &C,
44839548b74aSJohannes Doerfert                                           CGSCCAnalysisManager &AM,
4484b2ad63d3SJoseph Huber                                           LazyCallGraph &CG,
4485b2ad63d3SJoseph Huber                                           CGSCCUpdateResult &UR) {
44865ccb7424SJoseph Huber   if (!containsOpenMP(*C.begin()->getFunction().getParent()))
44879548b74aSJohannes Doerfert     return PreservedAnalyses::all();
44889548b74aSJohannes Doerfert   if (DisableOpenMPOptimizations)
44899548b74aSJohannes Doerfert     return PreservedAnalyses::all();
44909548b74aSJohannes Doerfert 
4491ee17263aSJohannes Doerfert   SmallVector<Function *, 16> SCC;
4492351d234dSRoman Lebedev   // If there are kernels in the module, we have to run on all SCC's.
4493351d234dSRoman Lebedev   for (LazyCallGraph::Node &N : C) {
4494351d234dSRoman Lebedev     Function *Fn = &N.getFunction();
4495351d234dSRoman Lebedev     SCC.push_back(Fn);
4496351d234dSRoman Lebedev   }
4497351d234dSRoman Lebedev 
44985ccb7424SJoseph Huber   if (SCC.empty())
44999548b74aSJohannes Doerfert     return PreservedAnalyses::all();
45009548b74aSJohannes Doerfert 
45015ccb7424SJoseph Huber   Module &M = *C.begin()->getFunction().getParent();
45025ccb7424SJoseph Huber 
45035ccb7424SJoseph Huber   KernelSet Kernels = getDeviceKernels(M);
45045ccb7424SJoseph Huber 
45054d4ea9acSHuber, Joseph   FunctionAnalysisManager &FAM =
45064d4ea9acSHuber, Joseph       AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager();
45077cfd267cSsstefan1 
45087cfd267cSsstefan1   AnalysisGetter AG(FAM);
45097cfd267cSsstefan1 
45107cfd267cSsstefan1   auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & {
45114d4ea9acSHuber, Joseph     return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F);
45124d4ea9acSHuber, Joseph   };
45134d4ea9acSHuber, Joseph 
4514b2ad63d3SJoseph Huber   BumpPtrAllocator Allocator;
45159548b74aSJohannes Doerfert   CallGraphUpdater CGUpdater;
45169548b74aSJohannes Doerfert   CGUpdater.initialize(CG, C, AM, UR);
45177cfd267cSsstefan1 
45187cfd267cSsstefan1   SetVector<Function *> Functions(SCC.begin(), SCC.end());
45197cfd267cSsstefan1   OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG, Allocator,
45205ccb7424SJoseph Huber                                 /*CGSCC*/ Functions, Kernels);
45217cfd267cSsstefan1 
452213b2fba2SJoseph Huber   unsigned MaxFixpointIterations = (isOpenMPDevice(M)) ? 128 : 32;
45234a6bd8e3SJoseph Huber   Attributor A(Functions, InfoCache, CGUpdater, nullptr, false, true,
452413b2fba2SJoseph Huber                MaxFixpointIterations, OREGetter, DEBUG_TYPE);
4525b8235d2bSsstefan1 
4526b8235d2bSsstefan1   OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
4527b2ad63d3SJoseph Huber   bool Changed = OMPOpt.run(false);
4528694ded37SGiorgis Georgakoudis   if (Changed)
4529694ded37SGiorgis Georgakoudis     return PreservedAnalyses::none();
4530694ded37SGiorgis Georgakoudis 
45319548b74aSJohannes Doerfert   return PreservedAnalyses::all();
45329548b74aSJohannes Doerfert }
45338b57ed09SJoseph Huber 
45349548b74aSJohannes Doerfert namespace {
45359548b74aSJohannes Doerfert 
4536b2ad63d3SJoseph Huber struct OpenMPOptCGSCCLegacyPass : public CallGraphSCCPass {
45379548b74aSJohannes Doerfert   CallGraphUpdater CGUpdater;
45389548b74aSJohannes Doerfert   static char ID;
45399548b74aSJohannes Doerfert 
4540b2ad63d3SJoseph Huber   OpenMPOptCGSCCLegacyPass() : CallGraphSCCPass(ID) {
4541b2ad63d3SJoseph Huber     initializeOpenMPOptCGSCCLegacyPassPass(*PassRegistry::getPassRegistry());
45429548b74aSJohannes Doerfert   }
45439548b74aSJohannes Doerfert 
45449548b74aSJohannes Doerfert   void getAnalysisUsage(AnalysisUsage &AU) const override {
45459548b74aSJohannes Doerfert     CallGraphSCCPass::getAnalysisUsage(AU);
45469548b74aSJohannes Doerfert   }
45479548b74aSJohannes Doerfert 
45489548b74aSJohannes Doerfert   bool runOnSCC(CallGraphSCC &CGSCC) override {
45495ccb7424SJoseph Huber     if (!containsOpenMP(CGSCC.getCallGraph().getModule()))
45509548b74aSJohannes Doerfert       return false;
45519548b74aSJohannes Doerfert     if (DisableOpenMPOptimizations || skipSCC(CGSCC))
45529548b74aSJohannes Doerfert       return false;
45539548b74aSJohannes Doerfert 
4554ee17263aSJohannes Doerfert     SmallVector<Function *, 16> SCC;
4555351d234dSRoman Lebedev     // If there are kernels in the module, we have to run on all SCC's.
4556351d234dSRoman Lebedev     for (CallGraphNode *CGN : CGSCC) {
4557351d234dSRoman Lebedev       Function *Fn = CGN->getFunction();
4558351d234dSRoman Lebedev       if (!Fn || Fn->isDeclaration())
4559351d234dSRoman Lebedev         continue;
4560ee17263aSJohannes Doerfert       SCC.push_back(Fn);
4561351d234dSRoman Lebedev     }
4562351d234dSRoman Lebedev 
45635ccb7424SJoseph Huber     if (SCC.empty())
45649548b74aSJohannes Doerfert       return false;
45659548b74aSJohannes Doerfert 
45665ccb7424SJoseph Huber     Module &M = CGSCC.getCallGraph().getModule();
45675ccb7424SJoseph Huber     KernelSet Kernels = getDeviceKernels(M);
45685ccb7424SJoseph Huber 
45699548b74aSJohannes Doerfert     CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph();
45709548b74aSJohannes Doerfert     CGUpdater.initialize(CG, CGSCC);
45719548b74aSJohannes Doerfert 
45724d4ea9acSHuber, Joseph     // Maintain a map of functions to avoid rebuilding the ORE
45734d4ea9acSHuber, Joseph     DenseMap<Function *, std::unique_ptr<OptimizationRemarkEmitter>> OREMap;
45744d4ea9acSHuber, Joseph     auto OREGetter = [&OREMap](Function *F) -> OptimizationRemarkEmitter & {
45754d4ea9acSHuber, Joseph       std::unique_ptr<OptimizationRemarkEmitter> &ORE = OREMap[F];
45764d4ea9acSHuber, Joseph       if (!ORE)
45774d4ea9acSHuber, Joseph         ORE = std::make_unique<OptimizationRemarkEmitter>(F);
45784d4ea9acSHuber, Joseph       return *ORE;
45794d4ea9acSHuber, Joseph     };
45804d4ea9acSHuber, Joseph 
45817cfd267cSsstefan1     AnalysisGetter AG;
45827cfd267cSsstefan1     SetVector<Function *> Functions(SCC.begin(), SCC.end());
45837cfd267cSsstefan1     BumpPtrAllocator Allocator;
45845ccb7424SJoseph Huber     OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG,
45855ccb7424SJoseph Huber                                   Allocator,
45865ccb7424SJoseph Huber                                   /*CGSCC*/ Functions, Kernels);
45877cfd267cSsstefan1 
458813b2fba2SJoseph Huber     unsigned MaxFixpointIterations = (isOpenMPDevice(M)) ? 128 : 32;
458930e36c9bSJoseph Huber     Attributor A(Functions, InfoCache, CGUpdater, nullptr, false, true,
459013b2fba2SJoseph Huber                  MaxFixpointIterations, OREGetter, DEBUG_TYPE);
4591b8235d2bSsstefan1 
4592b8235d2bSsstefan1     OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
4593b2ad63d3SJoseph Huber     return OMPOpt.run(false);
45949548b74aSJohannes Doerfert   }
45959548b74aSJohannes Doerfert 
45969548b74aSJohannes Doerfert   bool doFinalization(CallGraph &CG) override { return CGUpdater.finalize(); }
45979548b74aSJohannes Doerfert };
45989548b74aSJohannes Doerfert 
45999548b74aSJohannes Doerfert } // end anonymous namespace
46009548b74aSJohannes Doerfert 
46015ccb7424SJoseph Huber KernelSet llvm::omp::getDeviceKernels(Module &M) {
46025ccb7424SJoseph Huber   // TODO: Create a more cross-platform way of determining device kernels.
4603e8039ad4SJohannes Doerfert   NamedMDNode *MD = M.getOrInsertNamedMetadata("nvvm.annotations");
46045ccb7424SJoseph Huber   KernelSet Kernels;
46055ccb7424SJoseph Huber 
4606e8039ad4SJohannes Doerfert   if (!MD)
46075ccb7424SJoseph Huber     return Kernels;
4608e8039ad4SJohannes Doerfert 
4609e8039ad4SJohannes Doerfert   for (auto *Op : MD->operands()) {
4610e8039ad4SJohannes Doerfert     if (Op->getNumOperands() < 2)
4611e8039ad4SJohannes Doerfert       continue;
4612e8039ad4SJohannes Doerfert     MDString *KindID = dyn_cast<MDString>(Op->getOperand(1));
4613e8039ad4SJohannes Doerfert     if (!KindID || KindID->getString() != "kernel")
4614e8039ad4SJohannes Doerfert       continue;
4615e8039ad4SJohannes Doerfert 
4616e8039ad4SJohannes Doerfert     Function *KernelFn =
4617e8039ad4SJohannes Doerfert         mdconst::dyn_extract_or_null<Function>(Op->getOperand(0));
4618e8039ad4SJohannes Doerfert     if (!KernelFn)
4619e8039ad4SJohannes Doerfert       continue;
4620e8039ad4SJohannes Doerfert 
4621e8039ad4SJohannes Doerfert     ++NumOpenMPTargetRegionKernels;
4622e8039ad4SJohannes Doerfert 
4623e8039ad4SJohannes Doerfert     Kernels.insert(KernelFn);
4624e8039ad4SJohannes Doerfert   }
46255ccb7424SJoseph Huber 
46265ccb7424SJoseph Huber   return Kernels;
4627e8039ad4SJohannes Doerfert }
4628e8039ad4SJohannes Doerfert 
46295ccb7424SJoseph Huber bool llvm::omp::containsOpenMP(Module &M) {
46305ccb7424SJoseph Huber   Metadata *MD = M.getModuleFlag("openmp");
46315ccb7424SJoseph Huber   if (!MD)
46325ccb7424SJoseph Huber     return false;
4633dce6bc18SJohannes Doerfert 
4634e8039ad4SJohannes Doerfert   return true;
4635e8039ad4SJohannes Doerfert }
4636e8039ad4SJohannes Doerfert 
46375ccb7424SJoseph Huber bool llvm::omp::isOpenMPDevice(Module &M) {
46385ccb7424SJoseph Huber   Metadata *MD = M.getModuleFlag("openmp-device");
46395ccb7424SJoseph Huber   if (!MD)
46405ccb7424SJoseph Huber     return false;
46415ccb7424SJoseph Huber 
46425ccb7424SJoseph Huber   return true;
46439548b74aSJohannes Doerfert }
46449548b74aSJohannes Doerfert 
4645b2ad63d3SJoseph Huber char OpenMPOptCGSCCLegacyPass::ID = 0;
46469548b74aSJohannes Doerfert 
4647b2ad63d3SJoseph Huber INITIALIZE_PASS_BEGIN(OpenMPOptCGSCCLegacyPass, "openmp-opt-cgscc",
46489548b74aSJohannes Doerfert                       "OpenMP specific optimizations", false, false)
46499548b74aSJohannes Doerfert INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass)
4650b2ad63d3SJoseph Huber INITIALIZE_PASS_END(OpenMPOptCGSCCLegacyPass, "openmp-opt-cgscc",
46519548b74aSJohannes Doerfert                     "OpenMP specific optimizations", false, false)
46529548b74aSJohannes Doerfert 
4653b2ad63d3SJoseph Huber Pass *llvm::createOpenMPOptCGSCCLegacyPass() {
4654b2ad63d3SJoseph Huber   return new OpenMPOptCGSCCLegacyPass();
4655b2ad63d3SJoseph Huber }
4656