19548b74aSJohannes Doerfert //===-- IPO/OpenMPOpt.cpp - Collection of OpenMP specific optimizations ---===//
29548b74aSJohannes Doerfert //
39548b74aSJohannes Doerfert // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
49548b74aSJohannes Doerfert // See https://llvm.org/LICENSE.txt for license information.
59548b74aSJohannes Doerfert // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
69548b74aSJohannes Doerfert //
79548b74aSJohannes Doerfert //===----------------------------------------------------------------------===//
89548b74aSJohannes Doerfert //
99548b74aSJohannes Doerfert // OpenMP specific optimizations:
109548b74aSJohannes Doerfert //
119548b74aSJohannes Doerfert // - Deduplication of runtime calls, e.g., omp_get_thread_num.
129548b74aSJohannes Doerfert //
139548b74aSJohannes Doerfert //===----------------------------------------------------------------------===//
149548b74aSJohannes Doerfert 
159548b74aSJohannes Doerfert #include "llvm/Transforms/IPO/OpenMPOpt.h"
169548b74aSJohannes Doerfert 
179548b74aSJohannes Doerfert #include "llvm/ADT/EnumeratedArray.h"
189548b74aSJohannes Doerfert #include "llvm/ADT/Statistic.h"
199548b74aSJohannes Doerfert #include "llvm/Analysis/CallGraph.h"
209548b74aSJohannes Doerfert #include "llvm/Analysis/CallGraphSCCPass.h"
214d4ea9acSHuber, Joseph #include "llvm/Analysis/OptimizationRemarkEmitter.h"
229548b74aSJohannes Doerfert #include "llvm/Frontend/OpenMP/OMPConstants.h"
23e28936f6SJohannes Doerfert #include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
249548b74aSJohannes Doerfert #include "llvm/InitializePasses.h"
259548b74aSJohannes Doerfert #include "llvm/Support/CommandLine.h"
269548b74aSJohannes Doerfert #include "llvm/Transforms/IPO.h"
277cfd267cSsstefan1 #include "llvm/Transforms/IPO/Attributor.h"
289548b74aSJohannes Doerfert #include "llvm/Transforms/Utils/CallGraphUpdater.h"
299548b74aSJohannes Doerfert 
309548b74aSJohannes Doerfert using namespace llvm;
319548b74aSJohannes Doerfert using namespace omp;
329548b74aSJohannes Doerfert 
339548b74aSJohannes Doerfert #define DEBUG_TYPE "openmp-opt"
349548b74aSJohannes Doerfert 
359548b74aSJohannes Doerfert static cl::opt<bool> DisableOpenMPOptimizations(
369548b74aSJohannes Doerfert     "openmp-opt-disable", cl::ZeroOrMore,
379548b74aSJohannes Doerfert     cl::desc("Disable OpenMP specific optimizations."), cl::Hidden,
389548b74aSJohannes Doerfert     cl::init(false));
399548b74aSJohannes Doerfert 
400f426935Ssstefan1 static cl::opt<bool> PrintICVValues("openmp-print-icv-values", cl::init(false),
410f426935Ssstefan1                                     cl::Hidden);
42e8039ad4SJohannes Doerfert static cl::opt<bool> PrintOpenMPKernels("openmp-print-gpu-kernels",
43e8039ad4SJohannes Doerfert                                         cl::init(false), cl::Hidden);
440f426935Ssstefan1 
45496f8e5bSHamilton Tobon Mosquera static cl::opt<bool> HideMemoryTransferLatency(
46496f8e5bSHamilton Tobon Mosquera     "openmp-hide-memory-transfer-latency",
47496f8e5bSHamilton Tobon Mosquera     cl::desc("[WIP] Tries to hide the latency of host to device memory"
48496f8e5bSHamilton Tobon Mosquera              " transfers"),
49496f8e5bSHamilton Tobon Mosquera     cl::Hidden, cl::init(false));
50496f8e5bSHamilton Tobon Mosquera 
51496f8e5bSHamilton Tobon Mosquera 
529548b74aSJohannes Doerfert STATISTIC(NumOpenMPRuntimeCallsDeduplicated,
539548b74aSJohannes Doerfert           "Number of OpenMP runtime calls deduplicated");
5455eb714aSRoman Lebedev STATISTIC(NumOpenMPParallelRegionsDeleted,
5555eb714aSRoman Lebedev           "Number of OpenMP parallel regions deleted");
569548b74aSJohannes Doerfert STATISTIC(NumOpenMPRuntimeFunctionsIdentified,
579548b74aSJohannes Doerfert           "Number of OpenMP runtime functions identified");
589548b74aSJohannes Doerfert STATISTIC(NumOpenMPRuntimeFunctionUsesIdentified,
599548b74aSJohannes Doerfert           "Number of OpenMP runtime function uses identified");
60e8039ad4SJohannes Doerfert STATISTIC(NumOpenMPTargetRegionKernels,
61e8039ad4SJohannes Doerfert           "Number of OpenMP target region entry points (=kernels) identified");
625b0581aeSJohannes Doerfert STATISTIC(
635b0581aeSJohannes Doerfert     NumOpenMPParallelRegionsReplacedInGPUStateMachine,
645b0581aeSJohannes Doerfert     "Number of OpenMP parallel regions replaced with ID in GPU state machines");
659548b74aSJohannes Doerfert 
66263c4a3cSrathod-sahaab #if !defined(NDEBUG)
679548b74aSJohannes Doerfert static constexpr auto TAG = "[" DEBUG_TYPE "]";
68a50c0b0dSMikael Holmen #endif
699548b74aSJohannes Doerfert 
70624d34afSJohannes Doerfert /// Apply \p CB to all uses of \p F. If \p LookThroughConstantExprUses is
71624d34afSJohannes Doerfert /// true, constant expression users are not given to \p CB but their uses are
72624d34afSJohannes Doerfert /// traversed transitively.
73624d34afSJohannes Doerfert template <typename CBTy>
74624d34afSJohannes Doerfert static void foreachUse(Function &F, CBTy CB,
75624d34afSJohannes Doerfert                        bool LookThroughConstantExprUses = true) {
76624d34afSJohannes Doerfert   SmallVector<Use *, 8> Worklist(make_pointer_range(F.uses()));
77624d34afSJohannes Doerfert 
78624d34afSJohannes Doerfert   for (unsigned idx = 0; idx < Worklist.size(); ++idx) {
79624d34afSJohannes Doerfert     Use &U = *Worklist[idx];
80624d34afSJohannes Doerfert 
81624d34afSJohannes Doerfert     // Allow use in constant bitcasts and simply look through them.
82624d34afSJohannes Doerfert     if (LookThroughConstantExprUses && isa<ConstantExpr>(U.getUser())) {
83624d34afSJohannes Doerfert       for (Use &CEU : cast<ConstantExpr>(U.getUser())->uses())
84624d34afSJohannes Doerfert         Worklist.push_back(&CEU);
85624d34afSJohannes Doerfert       continue;
86624d34afSJohannes Doerfert     }
87624d34afSJohannes Doerfert 
88624d34afSJohannes Doerfert     CB(U);
89624d34afSJohannes Doerfert   }
90624d34afSJohannes Doerfert }
91624d34afSJohannes Doerfert 
929548b74aSJohannes Doerfert namespace {
939548b74aSJohannes Doerfert 
94b8235d2bSsstefan1 struct AAICVTracker;
95b8235d2bSsstefan1 
967cfd267cSsstefan1 /// OpenMP specific information. For now, stores RFIs and ICVs also needed for
977cfd267cSsstefan1 /// Attributor runs.
987cfd267cSsstefan1 struct OMPInformationCache : public InformationCache {
997cfd267cSsstefan1   OMPInformationCache(Module &M, AnalysisGetter &AG,
100624d34afSJohannes Doerfert                       BumpPtrAllocator &Allocator, SetVector<Function *> &CGSCC,
101e8039ad4SJohannes Doerfert                       SmallPtrSetImpl<Kernel> &Kernels)
102624d34afSJohannes Doerfert       : InformationCache(M, AG, Allocator, &CGSCC), OMPBuilder(M),
103624d34afSJohannes Doerfert         Kernels(Kernels) {
104624d34afSJohannes Doerfert     initializeModuleSlice(CGSCC);
105624d34afSJohannes Doerfert 
10661238d26Ssstefan1     OMPBuilder.initialize();
1079548b74aSJohannes Doerfert     initializeRuntimeFunctions();
1080f426935Ssstefan1     initializeInternalControlVars();
1099548b74aSJohannes Doerfert   }
1109548b74aSJohannes Doerfert 
1110f426935Ssstefan1   /// Generic information that describes an internal control variable.
1120f426935Ssstefan1   struct InternalControlVarInfo {
1130f426935Ssstefan1     /// The kind, as described by InternalControlVar enum.
1140f426935Ssstefan1     InternalControlVar Kind;
1150f426935Ssstefan1 
1160f426935Ssstefan1     /// The name of the ICV.
1170f426935Ssstefan1     StringRef Name;
1180f426935Ssstefan1 
1190f426935Ssstefan1     /// Environment variable associated with this ICV.
1200f426935Ssstefan1     StringRef EnvVarName;
1210f426935Ssstefan1 
1220f426935Ssstefan1     /// Initial value kind.
1230f426935Ssstefan1     ICVInitValue InitKind;
1240f426935Ssstefan1 
1250f426935Ssstefan1     /// Initial value.
1260f426935Ssstefan1     ConstantInt *InitValue;
1270f426935Ssstefan1 
1280f426935Ssstefan1     /// Setter RTL function associated with this ICV.
1290f426935Ssstefan1     RuntimeFunction Setter;
1300f426935Ssstefan1 
1310f426935Ssstefan1     /// Getter RTL function associated with this ICV.
1320f426935Ssstefan1     RuntimeFunction Getter;
1330f426935Ssstefan1 
1340f426935Ssstefan1     /// RTL Function corresponding to the override clause of this ICV
1350f426935Ssstefan1     RuntimeFunction Clause;
1360f426935Ssstefan1   };
1370f426935Ssstefan1 
1389548b74aSJohannes Doerfert   /// Generic information that describes a runtime function
1399548b74aSJohannes Doerfert   struct RuntimeFunctionInfo {
1408855fec3SJohannes Doerfert 
1419548b74aSJohannes Doerfert     /// The kind, as described by the RuntimeFunction enum.
1429548b74aSJohannes Doerfert     RuntimeFunction Kind;
1439548b74aSJohannes Doerfert 
1449548b74aSJohannes Doerfert     /// The name of the function.
1459548b74aSJohannes Doerfert     StringRef Name;
1469548b74aSJohannes Doerfert 
1479548b74aSJohannes Doerfert     /// Flag to indicate a variadic function.
1489548b74aSJohannes Doerfert     bool IsVarArg;
1499548b74aSJohannes Doerfert 
1509548b74aSJohannes Doerfert     /// The return type of the function.
1519548b74aSJohannes Doerfert     Type *ReturnType;
1529548b74aSJohannes Doerfert 
1539548b74aSJohannes Doerfert     /// The argument types of the function.
1549548b74aSJohannes Doerfert     SmallVector<Type *, 8> ArgumentTypes;
1559548b74aSJohannes Doerfert 
1569548b74aSJohannes Doerfert     /// The declaration if available.
157f09f4b26SJohannes Doerfert     Function *Declaration = nullptr;
1589548b74aSJohannes Doerfert 
1599548b74aSJohannes Doerfert     /// Uses of this runtime function per function containing the use.
1608855fec3SJohannes Doerfert     using UseVector = SmallVector<Use *, 16>;
1618855fec3SJohannes Doerfert 
162b8235d2bSsstefan1     /// Clear UsesMap for runtime function.
163b8235d2bSsstefan1     void clearUsesMap() { UsesMap.clear(); }
164b8235d2bSsstefan1 
16554bd3751SJohannes Doerfert     /// Boolean conversion that is true if the runtime function was found.
16654bd3751SJohannes Doerfert     operator bool() const { return Declaration; }
16754bd3751SJohannes Doerfert 
1688855fec3SJohannes Doerfert     /// Return the vector of uses in function \p F.
1698855fec3SJohannes Doerfert     UseVector &getOrCreateUseVector(Function *F) {
170b8235d2bSsstefan1       std::shared_ptr<UseVector> &UV = UsesMap[F];
1718855fec3SJohannes Doerfert       if (!UV)
172b8235d2bSsstefan1         UV = std::make_shared<UseVector>();
1738855fec3SJohannes Doerfert       return *UV;
1748855fec3SJohannes Doerfert     }
1758855fec3SJohannes Doerfert 
1768855fec3SJohannes Doerfert     /// Return the vector of uses in function \p F or `nullptr` if there are
1778855fec3SJohannes Doerfert     /// none.
1788855fec3SJohannes Doerfert     const UseVector *getUseVector(Function &F) const {
17995e57072SDavid Blaikie       auto I = UsesMap.find(&F);
18095e57072SDavid Blaikie       if (I != UsesMap.end())
18195e57072SDavid Blaikie         return I->second.get();
18295e57072SDavid Blaikie       return nullptr;
1838855fec3SJohannes Doerfert     }
1848855fec3SJohannes Doerfert 
1858855fec3SJohannes Doerfert     /// Return how many functions contain uses of this runtime function.
1868855fec3SJohannes Doerfert     size_t getNumFunctionsWithUses() const { return UsesMap.size(); }
1879548b74aSJohannes Doerfert 
1889548b74aSJohannes Doerfert     /// Return the number of arguments (or the minimal number for variadic
1899548b74aSJohannes Doerfert     /// functions).
1909548b74aSJohannes Doerfert     size_t getNumArgs() const { return ArgumentTypes.size(); }
1919548b74aSJohannes Doerfert 
1929548b74aSJohannes Doerfert     /// Run the callback \p CB on each use and forget the use if the result is
1939548b74aSJohannes Doerfert     /// true. The callback will be fed the function in which the use was
1949548b74aSJohannes Doerfert     /// encountered as second argument.
195624d34afSJohannes Doerfert     void foreachUse(SmallVectorImpl<Function *> &SCC,
196624d34afSJohannes Doerfert                     function_ref<bool(Use &, Function &)> CB) {
197624d34afSJohannes Doerfert       for (Function *F : SCC)
198624d34afSJohannes Doerfert         foreachUse(CB, F);
199e099c7b6Ssstefan1     }
200e099c7b6Ssstefan1 
201e099c7b6Ssstefan1     /// Run the callback \p CB on each use within the function \p F and forget
202e099c7b6Ssstefan1     /// the use if the result is true.
203624d34afSJohannes Doerfert     void foreachUse(function_ref<bool(Use &, Function &)> CB, Function *F) {
2048855fec3SJohannes Doerfert       SmallVector<unsigned, 8> ToBeDeleted;
2059548b74aSJohannes Doerfert       ToBeDeleted.clear();
206e099c7b6Ssstefan1 
2078855fec3SJohannes Doerfert       unsigned Idx = 0;
208624d34afSJohannes Doerfert       UseVector &UV = getOrCreateUseVector(F);
209e099c7b6Ssstefan1 
2108855fec3SJohannes Doerfert       for (Use *U : UV) {
211e099c7b6Ssstefan1         if (CB(*U, *F))
2128855fec3SJohannes Doerfert           ToBeDeleted.push_back(Idx);
2138855fec3SJohannes Doerfert         ++Idx;
2148855fec3SJohannes Doerfert       }
2158855fec3SJohannes Doerfert 
2168855fec3SJohannes Doerfert       // Remove the to-be-deleted indices in reverse order as prior
217b726c557SJohannes Doerfert       // modifications will not modify the smaller indices.
2188855fec3SJohannes Doerfert       while (!ToBeDeleted.empty()) {
2198855fec3SJohannes Doerfert         unsigned Idx = ToBeDeleted.pop_back_val();
2208855fec3SJohannes Doerfert         UV[Idx] = UV.back();
2218855fec3SJohannes Doerfert         UV.pop_back();
2229548b74aSJohannes Doerfert       }
2239548b74aSJohannes Doerfert     }
2248855fec3SJohannes Doerfert 
2258855fec3SJohannes Doerfert   private:
2268855fec3SJohannes Doerfert     /// Map from functions to all uses of this runtime function contained in
2278855fec3SJohannes Doerfert     /// them.
228b8235d2bSsstefan1     DenseMap<Function *, std::shared_ptr<UseVector>> UsesMap;
2299548b74aSJohannes Doerfert   };
2309548b74aSJohannes Doerfert 
231624d34afSJohannes Doerfert   /// Initialize the ModuleSlice member based on \p SCC. ModuleSlices contains
232624d34afSJohannes Doerfert   /// (a subset of) all functions that we can look at during this SCC traversal.
233624d34afSJohannes Doerfert   /// This includes functions (transitively) called from the SCC and the
234624d34afSJohannes Doerfert   /// (transitive) callers of SCC functions. We also can look at a function if
235624d34afSJohannes Doerfert   /// there is a "reference edge", i.a., if the function somehow uses (!=calls)
236624d34afSJohannes Doerfert   /// a function in the SCC or a caller of a function in the SCC.
237624d34afSJohannes Doerfert   void initializeModuleSlice(SetVector<Function *> &SCC) {
238624d34afSJohannes Doerfert     ModuleSlice.insert(SCC.begin(), SCC.end());
239624d34afSJohannes Doerfert 
240624d34afSJohannes Doerfert     SmallPtrSet<Function *, 16> Seen;
241624d34afSJohannes Doerfert     SmallVector<Function *, 16> Worklist(SCC.begin(), SCC.end());
242624d34afSJohannes Doerfert     while (!Worklist.empty()) {
243624d34afSJohannes Doerfert       Function *F = Worklist.pop_back_val();
244624d34afSJohannes Doerfert       ModuleSlice.insert(F);
245624d34afSJohannes Doerfert 
246624d34afSJohannes Doerfert       for (Instruction &I : instructions(*F))
247624d34afSJohannes Doerfert         if (auto *CB = dyn_cast<CallBase>(&I))
248624d34afSJohannes Doerfert           if (Function *Callee = CB->getCalledFunction())
249624d34afSJohannes Doerfert             if (Seen.insert(Callee).second)
250624d34afSJohannes Doerfert               Worklist.push_back(Callee);
251624d34afSJohannes Doerfert     }
252624d34afSJohannes Doerfert 
253624d34afSJohannes Doerfert     Seen.clear();
254624d34afSJohannes Doerfert     Worklist.append(SCC.begin(), SCC.end());
255624d34afSJohannes Doerfert     while (!Worklist.empty()) {
256624d34afSJohannes Doerfert       Function *F = Worklist.pop_back_val();
257624d34afSJohannes Doerfert       ModuleSlice.insert(F);
258624d34afSJohannes Doerfert 
259624d34afSJohannes Doerfert       // Traverse all transitive uses.
260624d34afSJohannes Doerfert       foreachUse(*F, [&](Use &U) {
261624d34afSJohannes Doerfert         if (auto *UsrI = dyn_cast<Instruction>(U.getUser()))
262624d34afSJohannes Doerfert           if (Seen.insert(UsrI->getFunction()).second)
263624d34afSJohannes Doerfert             Worklist.push_back(UsrI->getFunction());
264624d34afSJohannes Doerfert       });
265624d34afSJohannes Doerfert     }
266624d34afSJohannes Doerfert   }
267624d34afSJohannes Doerfert 
2687cfd267cSsstefan1   /// The slice of the module we are allowed to look at.
269624d34afSJohannes Doerfert   SmallPtrSet<Function *, 8> ModuleSlice;
2707cfd267cSsstefan1 
2717cfd267cSsstefan1   /// An OpenMP-IR-Builder instance
2727cfd267cSsstefan1   OpenMPIRBuilder OMPBuilder;
2737cfd267cSsstefan1 
2747cfd267cSsstefan1   /// Map from runtime function kind to the runtime function description.
2757cfd267cSsstefan1   EnumeratedArray<RuntimeFunctionInfo, RuntimeFunction,
2767cfd267cSsstefan1                   RuntimeFunction::OMPRTL___last>
2777cfd267cSsstefan1       RFIs;
2787cfd267cSsstefan1 
2790f426935Ssstefan1   /// Map from ICV kind to the ICV description.
2800f426935Ssstefan1   EnumeratedArray<InternalControlVarInfo, InternalControlVar,
2810f426935Ssstefan1                   InternalControlVar::ICV___last>
2820f426935Ssstefan1       ICVs;
2830f426935Ssstefan1 
2840f426935Ssstefan1   /// Helper to initialize all internal control variable information for those
2850f426935Ssstefan1   /// defined in OMPKinds.def.
2860f426935Ssstefan1   void initializeInternalControlVars() {
2870f426935Ssstefan1 #define ICV_RT_SET(_Name, RTL)                                                 \
2880f426935Ssstefan1   {                                                                            \
2890f426935Ssstefan1     auto &ICV = ICVs[_Name];                                                   \
2900f426935Ssstefan1     ICV.Setter = RTL;                                                          \
2910f426935Ssstefan1   }
2920f426935Ssstefan1 #define ICV_RT_GET(Name, RTL)                                                  \
2930f426935Ssstefan1   {                                                                            \
2940f426935Ssstefan1     auto &ICV = ICVs[Name];                                                    \
2950f426935Ssstefan1     ICV.Getter = RTL;                                                          \
2960f426935Ssstefan1   }
2970f426935Ssstefan1 #define ICV_DATA_ENV(Enum, _Name, _EnvVarName, Init)                           \
2980f426935Ssstefan1   {                                                                            \
2990f426935Ssstefan1     auto &ICV = ICVs[Enum];                                                    \
3000f426935Ssstefan1     ICV.Name = _Name;                                                          \
3010f426935Ssstefan1     ICV.Kind = Enum;                                                           \
3020f426935Ssstefan1     ICV.InitKind = Init;                                                       \
3030f426935Ssstefan1     ICV.EnvVarName = _EnvVarName;                                              \
3040f426935Ssstefan1     switch (ICV.InitKind) {                                                    \
305951e43f3Ssstefan1     case ICV_IMPLEMENTATION_DEFINED:                                           \
3060f426935Ssstefan1       ICV.InitValue = nullptr;                                                 \
3070f426935Ssstefan1       break;                                                                   \
308951e43f3Ssstefan1     case ICV_ZERO:                                                             \
3096aab27baSsstefan1       ICV.InitValue = ConstantInt::get(                                        \
3106aab27baSsstefan1           Type::getInt32Ty(OMPBuilder.Int32->getContext()), 0);                \
3110f426935Ssstefan1       break;                                                                   \
312951e43f3Ssstefan1     case ICV_FALSE:                                                            \
3136aab27baSsstefan1       ICV.InitValue = ConstantInt::getFalse(OMPBuilder.Int1->getContext());    \
3140f426935Ssstefan1       break;                                                                   \
315951e43f3Ssstefan1     case ICV_LAST:                                                             \
3160f426935Ssstefan1       break;                                                                   \
3170f426935Ssstefan1     }                                                                          \
3180f426935Ssstefan1   }
3190f426935Ssstefan1 #include "llvm/Frontend/OpenMP/OMPKinds.def"
3200f426935Ssstefan1   }
3210f426935Ssstefan1 
3227cfd267cSsstefan1   /// Returns true if the function declaration \p F matches the runtime
3237cfd267cSsstefan1   /// function types, that is, return type \p RTFRetType, and argument types
3247cfd267cSsstefan1   /// \p RTFArgTypes.
3257cfd267cSsstefan1   static bool declMatchesRTFTypes(Function *F, Type *RTFRetType,
3267cfd267cSsstefan1                                   SmallVector<Type *, 8> &RTFArgTypes) {
3277cfd267cSsstefan1     // TODO: We should output information to the user (under debug output
3287cfd267cSsstefan1     //       and via remarks).
3297cfd267cSsstefan1 
3307cfd267cSsstefan1     if (!F)
3317cfd267cSsstefan1       return false;
3327cfd267cSsstefan1     if (F->getReturnType() != RTFRetType)
3337cfd267cSsstefan1       return false;
3347cfd267cSsstefan1     if (F->arg_size() != RTFArgTypes.size())
3357cfd267cSsstefan1       return false;
3367cfd267cSsstefan1 
3377cfd267cSsstefan1     auto RTFTyIt = RTFArgTypes.begin();
3387cfd267cSsstefan1     for (Argument &Arg : F->args()) {
3397cfd267cSsstefan1       if (Arg.getType() != *RTFTyIt)
3407cfd267cSsstefan1         return false;
3417cfd267cSsstefan1 
3427cfd267cSsstefan1       ++RTFTyIt;
3437cfd267cSsstefan1     }
3447cfd267cSsstefan1 
3457cfd267cSsstefan1     return true;
3467cfd267cSsstefan1   }
3477cfd267cSsstefan1 
348b726c557SJohannes Doerfert   // Helper to collect all uses of the declaration in the UsesMap.
349b8235d2bSsstefan1   unsigned collectUses(RuntimeFunctionInfo &RFI, bool CollectStats = true) {
3507cfd267cSsstefan1     unsigned NumUses = 0;
3517cfd267cSsstefan1     if (!RFI.Declaration)
3527cfd267cSsstefan1       return NumUses;
3537cfd267cSsstefan1     OMPBuilder.addAttributes(RFI.Kind, *RFI.Declaration);
3547cfd267cSsstefan1 
355b8235d2bSsstefan1     if (CollectStats) {
3567cfd267cSsstefan1       NumOpenMPRuntimeFunctionsIdentified += 1;
3577cfd267cSsstefan1       NumOpenMPRuntimeFunctionUsesIdentified += RFI.Declaration->getNumUses();
358b8235d2bSsstefan1     }
3597cfd267cSsstefan1 
3607cfd267cSsstefan1     // TODO: We directly convert uses into proper calls and unknown uses.
3617cfd267cSsstefan1     for (Use &U : RFI.Declaration->uses()) {
3627cfd267cSsstefan1       if (Instruction *UserI = dyn_cast<Instruction>(U.getUser())) {
3637cfd267cSsstefan1         if (ModuleSlice.count(UserI->getFunction())) {
3647cfd267cSsstefan1           RFI.getOrCreateUseVector(UserI->getFunction()).push_back(&U);
3657cfd267cSsstefan1           ++NumUses;
3667cfd267cSsstefan1         }
3677cfd267cSsstefan1       } else {
3687cfd267cSsstefan1         RFI.getOrCreateUseVector(nullptr).push_back(&U);
3697cfd267cSsstefan1         ++NumUses;
3707cfd267cSsstefan1       }
3717cfd267cSsstefan1     }
3727cfd267cSsstefan1     return NumUses;
373b8235d2bSsstefan1   }
3747cfd267cSsstefan1 
375b8235d2bSsstefan1   // Helper function to recollect uses of all runtime functions.
376b8235d2bSsstefan1   void recollectUses() {
377b8235d2bSsstefan1     for (int Idx = 0; Idx < RFIs.size(); ++Idx) {
378b8235d2bSsstefan1       auto &RFI = RFIs[static_cast<RuntimeFunction>(Idx)];
379b8235d2bSsstefan1       RFI.clearUsesMap();
380b8235d2bSsstefan1       collectUses(RFI, /*CollectStats*/ false);
381b8235d2bSsstefan1     }
382b8235d2bSsstefan1   }
383b8235d2bSsstefan1 
384b8235d2bSsstefan1   /// Helper to initialize all runtime function information for those defined
385b8235d2bSsstefan1   /// in OpenMPKinds.def.
386b8235d2bSsstefan1   void initializeRuntimeFunctions() {
3877cfd267cSsstefan1     Module &M = *((*ModuleSlice.begin())->getParent());
3887cfd267cSsstefan1 
3896aab27baSsstefan1     // Helper macros for handling __VA_ARGS__ in OMP_RTL
3906aab27baSsstefan1 #define OMP_TYPE(VarName, ...)                                                 \
3916aab27baSsstefan1   Type *VarName = OMPBuilder.VarName;                                          \
3926aab27baSsstefan1   (void)VarName;
3936aab27baSsstefan1 
3946aab27baSsstefan1 #define OMP_ARRAY_TYPE(VarName, ...)                                           \
3956aab27baSsstefan1   ArrayType *VarName##Ty = OMPBuilder.VarName##Ty;                             \
3966aab27baSsstefan1   (void)VarName##Ty;                                                           \
3976aab27baSsstefan1   PointerType *VarName##PtrTy = OMPBuilder.VarName##PtrTy;                     \
3986aab27baSsstefan1   (void)VarName##PtrTy;
3996aab27baSsstefan1 
4006aab27baSsstefan1 #define OMP_FUNCTION_TYPE(VarName, ...)                                        \
4016aab27baSsstefan1   FunctionType *VarName = OMPBuilder.VarName;                                  \
4026aab27baSsstefan1   (void)VarName;                                                               \
4036aab27baSsstefan1   PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr;                         \
4046aab27baSsstefan1   (void)VarName##Ptr;
4056aab27baSsstefan1 
4066aab27baSsstefan1 #define OMP_STRUCT_TYPE(VarName, ...)                                          \
4076aab27baSsstefan1   StructType *VarName = OMPBuilder.VarName;                                    \
4086aab27baSsstefan1   (void)VarName;                                                               \
4096aab27baSsstefan1   PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr;                         \
4106aab27baSsstefan1   (void)VarName##Ptr;
4116aab27baSsstefan1 
4127cfd267cSsstefan1 #define OMP_RTL(_Enum, _Name, _IsVarArg, _ReturnType, ...)                     \
4137cfd267cSsstefan1   {                                                                            \
4147cfd267cSsstefan1     SmallVector<Type *, 8> ArgsTypes({__VA_ARGS__});                           \
4157cfd267cSsstefan1     Function *F = M.getFunction(_Name);                                        \
4166aab27baSsstefan1     if (declMatchesRTFTypes(F, OMPBuilder._ReturnType, ArgsTypes)) {           \
4177cfd267cSsstefan1       auto &RFI = RFIs[_Enum];                                                 \
4187cfd267cSsstefan1       RFI.Kind = _Enum;                                                        \
4197cfd267cSsstefan1       RFI.Name = _Name;                                                        \
4207cfd267cSsstefan1       RFI.IsVarArg = _IsVarArg;                                                \
4216aab27baSsstefan1       RFI.ReturnType = OMPBuilder._ReturnType;                                 \
4227cfd267cSsstefan1       RFI.ArgumentTypes = std::move(ArgsTypes);                                \
4237cfd267cSsstefan1       RFI.Declaration = F;                                                     \
424b8235d2bSsstefan1       unsigned NumUses = collectUses(RFI);                                     \
4257cfd267cSsstefan1       (void)NumUses;                                                           \
4267cfd267cSsstefan1       LLVM_DEBUG({                                                             \
4277cfd267cSsstefan1         dbgs() << TAG << RFI.Name << (RFI.Declaration ? "" : " not")           \
4287cfd267cSsstefan1                << " found\n";                                                  \
4297cfd267cSsstefan1         if (RFI.Declaration)                                                   \
4307cfd267cSsstefan1           dbgs() << TAG << "-> got " << NumUses << " uses in "                 \
4317cfd267cSsstefan1                  << RFI.getNumFunctionsWithUses()                              \
4327cfd267cSsstefan1                  << " different functions.\n";                                 \
4337cfd267cSsstefan1       });                                                                      \
4347cfd267cSsstefan1     }                                                                          \
4357cfd267cSsstefan1   }
4367cfd267cSsstefan1 #include "llvm/Frontend/OpenMP/OMPKinds.def"
4377cfd267cSsstefan1 
4387cfd267cSsstefan1     // TODO: We should attach the attributes defined in OMPKinds.def.
4397cfd267cSsstefan1   }
440e8039ad4SJohannes Doerfert 
441e8039ad4SJohannes Doerfert   /// Collection of known kernels (\see Kernel) in the module.
442e8039ad4SJohannes Doerfert   SmallPtrSetImpl<Kernel> &Kernels;
4437cfd267cSsstefan1 };
4447cfd267cSsstefan1 
4457cfd267cSsstefan1 struct OpenMPOpt {
4467cfd267cSsstefan1 
4477cfd267cSsstefan1   using OptimizationRemarkGetter =
4487cfd267cSsstefan1       function_ref<OptimizationRemarkEmitter &(Function *)>;
4497cfd267cSsstefan1 
4507cfd267cSsstefan1   OpenMPOpt(SmallVectorImpl<Function *> &SCC, CallGraphUpdater &CGUpdater,
4517cfd267cSsstefan1             OptimizationRemarkGetter OREGetter,
452b8235d2bSsstefan1             OMPInformationCache &OMPInfoCache, Attributor &A)
45377b79d79SMehdi Amini       : M(*(*SCC.begin())->getParent()), SCC(SCC), CGUpdater(CGUpdater),
454b8235d2bSsstefan1         OREGetter(OREGetter), OMPInfoCache(OMPInfoCache), A(A) {}
4557cfd267cSsstefan1 
4569548b74aSJohannes Doerfert   /// Run all OpenMP optimizations on the underlying SCC/ModuleSlice.
4579548b74aSJohannes Doerfert   bool run() {
45854bd3751SJohannes Doerfert     if (SCC.empty())
45954bd3751SJohannes Doerfert       return false;
46054bd3751SJohannes Doerfert 
4619548b74aSJohannes Doerfert     bool Changed = false;
4629548b74aSJohannes Doerfert 
4639548b74aSJohannes Doerfert     LLVM_DEBUG(dbgs() << TAG << "Run on SCC with " << SCC.size()
46477b79d79SMehdi Amini                       << " functions in a slice with "
46577b79d79SMehdi Amini                       << OMPInfoCache.ModuleSlice.size() << " functions\n");
4669548b74aSJohannes Doerfert 
467e8039ad4SJohannes Doerfert     if (PrintICVValues)
468e8039ad4SJohannes Doerfert       printICVs();
469e8039ad4SJohannes Doerfert     if (PrintOpenMPKernels)
470e8039ad4SJohannes Doerfert       printKernels();
471e8039ad4SJohannes Doerfert 
4725b0581aeSJohannes Doerfert     Changed |= rewriteDeviceCodeStateMachine();
4735b0581aeSJohannes Doerfert 
474e8039ad4SJohannes Doerfert     Changed |= runAttributor();
475e8039ad4SJohannes Doerfert 
476e8039ad4SJohannes Doerfert     // Recollect uses, in case Attributor deleted any.
477e8039ad4SJohannes Doerfert     OMPInfoCache.recollectUses();
478e8039ad4SJohannes Doerfert 
479e8039ad4SJohannes Doerfert     Changed |= deduplicateRuntimeCalls();
480e8039ad4SJohannes Doerfert     Changed |= deleteParallelRegions();
481496f8e5bSHamilton Tobon Mosquera     if (HideMemoryTransferLatency)
482496f8e5bSHamilton Tobon Mosquera       Changed |= hideMemTransfersLatency();
483e8039ad4SJohannes Doerfert 
484e8039ad4SJohannes Doerfert     return Changed;
485e8039ad4SJohannes Doerfert   }
486e8039ad4SJohannes Doerfert 
4870f426935Ssstefan1   /// Print initial ICV values for testing.
4880f426935Ssstefan1   /// FIXME: This should be done from the Attributor once it is added.
489e8039ad4SJohannes Doerfert   void printICVs() const {
4900f426935Ssstefan1     InternalControlVar ICVs[] = {ICV_nthreads, ICV_active_levels, ICV_cancel};
4910f426935Ssstefan1 
4920f426935Ssstefan1     for (Function *F : OMPInfoCache.ModuleSlice) {
4930f426935Ssstefan1       for (auto ICV : ICVs) {
4940f426935Ssstefan1         auto ICVInfo = OMPInfoCache.ICVs[ICV];
4950f426935Ssstefan1         auto Remark = [&](OptimizationRemark OR) {
4960f426935Ssstefan1           return OR << "OpenMP ICV " << ore::NV("OpenMPICV", ICVInfo.Name)
4970f426935Ssstefan1                     << " Value: "
4980f426935Ssstefan1                     << (ICVInfo.InitValue
4990f426935Ssstefan1                             ? ICVInfo.InitValue->getValue().toString(10, true)
5000f426935Ssstefan1                             : "IMPLEMENTATION_DEFINED");
5010f426935Ssstefan1         };
5020f426935Ssstefan1 
5030f426935Ssstefan1         emitRemarkOnFunction(F, "OpenMPICVTracker", Remark);
5040f426935Ssstefan1       }
5050f426935Ssstefan1     }
5060f426935Ssstefan1   }
5070f426935Ssstefan1 
508e8039ad4SJohannes Doerfert   /// Print OpenMP GPU kernels for testing.
509e8039ad4SJohannes Doerfert   void printKernels() const {
510e8039ad4SJohannes Doerfert     for (Function *F : SCC) {
511e8039ad4SJohannes Doerfert       if (!OMPInfoCache.Kernels.count(F))
512e8039ad4SJohannes Doerfert         continue;
513b8235d2bSsstefan1 
514e8039ad4SJohannes Doerfert       auto Remark = [&](OptimizationRemark OR) {
515e8039ad4SJohannes Doerfert         return OR << "OpenMP GPU kernel "
516e8039ad4SJohannes Doerfert                   << ore::NV("OpenMPGPUKernel", F->getName()) << "\n";
517e8039ad4SJohannes Doerfert       };
518b8235d2bSsstefan1 
519e8039ad4SJohannes Doerfert       emitRemarkOnFunction(F, "OpenMPGPU", Remark);
520e8039ad4SJohannes Doerfert     }
5219548b74aSJohannes Doerfert   }
5229548b74aSJohannes Doerfert 
5237cfd267cSsstefan1   /// Return the call if \p U is a callee use in a regular call. If \p RFI is
5247cfd267cSsstefan1   /// given it has to be the callee or a nullptr is returned.
5257cfd267cSsstefan1   static CallInst *getCallIfRegularCall(
5267cfd267cSsstefan1       Use &U, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
5277cfd267cSsstefan1     CallInst *CI = dyn_cast<CallInst>(U.getUser());
5287cfd267cSsstefan1     if (CI && CI->isCallee(&U) && !CI->hasOperandBundles() &&
5297cfd267cSsstefan1         (!RFI || CI->getCalledFunction() == RFI->Declaration))
5307cfd267cSsstefan1       return CI;
5317cfd267cSsstefan1     return nullptr;
5327cfd267cSsstefan1   }
5337cfd267cSsstefan1 
5347cfd267cSsstefan1   /// Return the call if \p V is a regular call. If \p RFI is given it has to be
5357cfd267cSsstefan1   /// the callee or a nullptr is returned.
5367cfd267cSsstefan1   static CallInst *getCallIfRegularCall(
5377cfd267cSsstefan1       Value &V, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
5387cfd267cSsstefan1     CallInst *CI = dyn_cast<CallInst>(&V);
5397cfd267cSsstefan1     if (CI && !CI->hasOperandBundles() &&
5407cfd267cSsstefan1         (!RFI || CI->getCalledFunction() == RFI->Declaration))
5417cfd267cSsstefan1       return CI;
5427cfd267cSsstefan1     return nullptr;
5437cfd267cSsstefan1   }
5447cfd267cSsstefan1 
5459548b74aSJohannes Doerfert private:
5469d38f98dSJohannes Doerfert   /// Try to delete parallel regions if possible.
547e565db49SJohannes Doerfert   bool deleteParallelRegions() {
548e565db49SJohannes Doerfert     const unsigned CallbackCalleeOperand = 2;
549e565db49SJohannes Doerfert 
5507cfd267cSsstefan1     OMPInformationCache::RuntimeFunctionInfo &RFI =
5517cfd267cSsstefan1         OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
5527cfd267cSsstefan1 
553e565db49SJohannes Doerfert     if (!RFI.Declaration)
554e565db49SJohannes Doerfert       return false;
555e565db49SJohannes Doerfert 
556e565db49SJohannes Doerfert     bool Changed = false;
557e565db49SJohannes Doerfert     auto DeleteCallCB = [&](Use &U, Function &) {
558e565db49SJohannes Doerfert       CallInst *CI = getCallIfRegularCall(U);
559e565db49SJohannes Doerfert       if (!CI)
560e565db49SJohannes Doerfert         return false;
561e565db49SJohannes Doerfert       auto *Fn = dyn_cast<Function>(
562e565db49SJohannes Doerfert           CI->getArgOperand(CallbackCalleeOperand)->stripPointerCasts());
563e565db49SJohannes Doerfert       if (!Fn)
564e565db49SJohannes Doerfert         return false;
565e565db49SJohannes Doerfert       if (!Fn->onlyReadsMemory())
566e565db49SJohannes Doerfert         return false;
567e565db49SJohannes Doerfert       if (!Fn->hasFnAttribute(Attribute::WillReturn))
568e565db49SJohannes Doerfert         return false;
569e565db49SJohannes Doerfert 
570e565db49SJohannes Doerfert       LLVM_DEBUG(dbgs() << TAG << "Delete read-only parallel region in "
571e565db49SJohannes Doerfert                         << CI->getCaller()->getName() << "\n");
5724d4ea9acSHuber, Joseph 
5734d4ea9acSHuber, Joseph       auto Remark = [&](OptimizationRemark OR) {
5744d4ea9acSHuber, Joseph         return OR << "Parallel region in "
5754d4ea9acSHuber, Joseph                   << ore::NV("OpenMPParallelDelete", CI->getCaller()->getName())
5764d4ea9acSHuber, Joseph                   << " deleted";
5774d4ea9acSHuber, Joseph       };
5784d4ea9acSHuber, Joseph       emitRemark<OptimizationRemark>(CI, "OpenMPParallelRegionDeletion",
5794d4ea9acSHuber, Joseph                                      Remark);
5804d4ea9acSHuber, Joseph 
581e565db49SJohannes Doerfert       CGUpdater.removeCallSite(*CI);
582e565db49SJohannes Doerfert       CI->eraseFromParent();
583e565db49SJohannes Doerfert       Changed = true;
58455eb714aSRoman Lebedev       ++NumOpenMPParallelRegionsDeleted;
585e565db49SJohannes Doerfert       return true;
586e565db49SJohannes Doerfert     };
587e565db49SJohannes Doerfert 
588624d34afSJohannes Doerfert     RFI.foreachUse(SCC, DeleteCallCB);
589e565db49SJohannes Doerfert 
590e565db49SJohannes Doerfert     return Changed;
591e565db49SJohannes Doerfert   }
592e565db49SJohannes Doerfert 
593b726c557SJohannes Doerfert   /// Try to eliminate runtime calls by reusing existing ones.
5949548b74aSJohannes Doerfert   bool deduplicateRuntimeCalls() {
5959548b74aSJohannes Doerfert     bool Changed = false;
5969548b74aSJohannes Doerfert 
597e28936f6SJohannes Doerfert     RuntimeFunction DeduplicableRuntimeCallIDs[] = {
598e28936f6SJohannes Doerfert         OMPRTL_omp_get_num_threads,
599e28936f6SJohannes Doerfert         OMPRTL_omp_in_parallel,
600e28936f6SJohannes Doerfert         OMPRTL_omp_get_cancellation,
601e28936f6SJohannes Doerfert         OMPRTL_omp_get_thread_limit,
602e28936f6SJohannes Doerfert         OMPRTL_omp_get_supported_active_levels,
603e28936f6SJohannes Doerfert         OMPRTL_omp_get_level,
604e28936f6SJohannes Doerfert         OMPRTL_omp_get_ancestor_thread_num,
605e28936f6SJohannes Doerfert         OMPRTL_omp_get_team_size,
606e28936f6SJohannes Doerfert         OMPRTL_omp_get_active_level,
607e28936f6SJohannes Doerfert         OMPRTL_omp_in_final,
608e28936f6SJohannes Doerfert         OMPRTL_omp_get_proc_bind,
609e28936f6SJohannes Doerfert         OMPRTL_omp_get_num_places,
610e28936f6SJohannes Doerfert         OMPRTL_omp_get_num_procs,
611e28936f6SJohannes Doerfert         OMPRTL_omp_get_place_num,
612e28936f6SJohannes Doerfert         OMPRTL_omp_get_partition_num_places,
613e28936f6SJohannes Doerfert         OMPRTL_omp_get_partition_place_nums};
614e28936f6SJohannes Doerfert 
615bc93c2d7SMarek Kurdej     // Global-tid is handled separately.
6169548b74aSJohannes Doerfert     SmallSetVector<Value *, 16> GTIdArgs;
6179548b74aSJohannes Doerfert     collectGlobalThreadIdArguments(GTIdArgs);
6189548b74aSJohannes Doerfert     LLVM_DEBUG(dbgs() << TAG << "Found " << GTIdArgs.size()
6199548b74aSJohannes Doerfert                       << " global thread ID arguments\n");
6209548b74aSJohannes Doerfert 
6219548b74aSJohannes Doerfert     for (Function *F : SCC) {
622e28936f6SJohannes Doerfert       for (auto DeduplicableRuntimeCallID : DeduplicableRuntimeCallIDs)
6237cfd267cSsstefan1         deduplicateRuntimeCalls(*F,
6247cfd267cSsstefan1                                 OMPInfoCache.RFIs[DeduplicableRuntimeCallID]);
625e28936f6SJohannes Doerfert 
626e28936f6SJohannes Doerfert       // __kmpc_global_thread_num is special as we can replace it with an
627e28936f6SJohannes Doerfert       // argument in enough cases to make it worth trying.
6289548b74aSJohannes Doerfert       Value *GTIdArg = nullptr;
6299548b74aSJohannes Doerfert       for (Argument &Arg : F->args())
6309548b74aSJohannes Doerfert         if (GTIdArgs.count(&Arg)) {
6319548b74aSJohannes Doerfert           GTIdArg = &Arg;
6329548b74aSJohannes Doerfert           break;
6339548b74aSJohannes Doerfert         }
6349548b74aSJohannes Doerfert       Changed |= deduplicateRuntimeCalls(
6357cfd267cSsstefan1           *F, OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num], GTIdArg);
6369548b74aSJohannes Doerfert     }
6379548b74aSJohannes Doerfert 
6389548b74aSJohannes Doerfert     return Changed;
6399548b74aSJohannes Doerfert   }
6409548b74aSJohannes Doerfert 
641496f8e5bSHamilton Tobon Mosquera   /// Tries to hide the latency of runtime calls that involve host to
642496f8e5bSHamilton Tobon Mosquera   /// device memory transfers by splitting them into their "issue" and "wait"
643496f8e5bSHamilton Tobon Mosquera   /// versions. The "issue" is moved upwards as much as possible. The "wait" is
644496f8e5bSHamilton Tobon Mosquera   /// moved downards as much as possible. The "issue" issues the memory transfer
645496f8e5bSHamilton Tobon Mosquera   /// asynchronously, returning a handle. The "wait" waits in the returned
646496f8e5bSHamilton Tobon Mosquera   /// handle for the memory transfer to finish.
647496f8e5bSHamilton Tobon Mosquera   bool hideMemTransfersLatency() {
648496f8e5bSHamilton Tobon Mosquera     auto &RFI = OMPInfoCache.RFIs[OMPRTL___tgt_target_data_begin_mapper];
649496f8e5bSHamilton Tobon Mosquera     bool Changed = false;
650496f8e5bSHamilton Tobon Mosquera     auto SplitMemTransfers = [&](Use &U, Function &Decl) {
651496f8e5bSHamilton Tobon Mosquera       auto *RTCall = getCallIfRegularCall(U, &RFI);
652496f8e5bSHamilton Tobon Mosquera       if (!RTCall)
653496f8e5bSHamilton Tobon Mosquera         return false;
654496f8e5bSHamilton Tobon Mosquera 
655*bd2fa181SHamilton Tobon Mosquera       // TODO: Check if can be moved upwards.
656*bd2fa181SHamilton Tobon Mosquera       bool WasSplit = false;
657*bd2fa181SHamilton Tobon Mosquera       Instruction *WaitMovementPoint = canBeMovedDownwards(*RTCall);
658*bd2fa181SHamilton Tobon Mosquera       if (WaitMovementPoint)
659*bd2fa181SHamilton Tobon Mosquera         WasSplit = splitTargetDataBeginRTC(*RTCall, *WaitMovementPoint);
660*bd2fa181SHamilton Tobon Mosquera 
661496f8e5bSHamilton Tobon Mosquera       Changed |= WasSplit;
662496f8e5bSHamilton Tobon Mosquera       return WasSplit;
663496f8e5bSHamilton Tobon Mosquera     };
664496f8e5bSHamilton Tobon Mosquera     RFI.foreachUse(SCC, SplitMemTransfers);
665496f8e5bSHamilton Tobon Mosquera 
666496f8e5bSHamilton Tobon Mosquera     return Changed;
667496f8e5bSHamilton Tobon Mosquera   }
668496f8e5bSHamilton Tobon Mosquera 
669*bd2fa181SHamilton Tobon Mosquera   /// Returns the instruction where the "wait" counterpart \p RuntimeCall can be
670*bd2fa181SHamilton Tobon Mosquera   /// moved. Returns nullptr if the movement is not possible, or not worth it.
671*bd2fa181SHamilton Tobon Mosquera   Instruction *canBeMovedDownwards(CallInst &RuntimeCall) {
672*bd2fa181SHamilton Tobon Mosquera     // FIXME: This traverses only the BasicBlock where RuntimeCall is.
673*bd2fa181SHamilton Tobon Mosquera     //  Make it traverse the CFG.
674*bd2fa181SHamilton Tobon Mosquera 
675*bd2fa181SHamilton Tobon Mosquera     Instruction *CurrentI = &RuntimeCall;
676*bd2fa181SHamilton Tobon Mosquera     bool IsWorthIt = false;
677*bd2fa181SHamilton Tobon Mosquera     while ((CurrentI = CurrentI->getNextNode())) {
678*bd2fa181SHamilton Tobon Mosquera 
679*bd2fa181SHamilton Tobon Mosquera       // TODO: Once we detect the regions to be offloaded we should use the
680*bd2fa181SHamilton Tobon Mosquera       //  alias analysis manager to check if CurrentI may modify one of
681*bd2fa181SHamilton Tobon Mosquera       //  the offloaded regions.
682*bd2fa181SHamilton Tobon Mosquera       if (CurrentI->mayHaveSideEffects() || CurrentI->mayReadFromMemory()) {
683*bd2fa181SHamilton Tobon Mosquera         if (IsWorthIt)
684*bd2fa181SHamilton Tobon Mosquera           return CurrentI;
685*bd2fa181SHamilton Tobon Mosquera 
686*bd2fa181SHamilton Tobon Mosquera         return nullptr;
687*bd2fa181SHamilton Tobon Mosquera       }
688*bd2fa181SHamilton Tobon Mosquera 
689*bd2fa181SHamilton Tobon Mosquera       // FIXME: For now if we move it over anything without side effect
690*bd2fa181SHamilton Tobon Mosquera       //  is worth it.
691*bd2fa181SHamilton Tobon Mosquera       IsWorthIt = true;
692*bd2fa181SHamilton Tobon Mosquera     }
693*bd2fa181SHamilton Tobon Mosquera 
694*bd2fa181SHamilton Tobon Mosquera     // Return end of BasicBlock.
695*bd2fa181SHamilton Tobon Mosquera     return RuntimeCall.getParent()->getTerminator();
696*bd2fa181SHamilton Tobon Mosquera   }
697*bd2fa181SHamilton Tobon Mosquera 
698496f8e5bSHamilton Tobon Mosquera   /// Splits \p RuntimeCall into its "issue" and "wait" counterparts.
699*bd2fa181SHamilton Tobon Mosquera   bool splitTargetDataBeginRTC(CallInst &RuntimeCall,
700*bd2fa181SHamilton Tobon Mosquera                                Instruction &WaitMovementPoint) {
701496f8e5bSHamilton Tobon Mosquera     auto &IRBuilder = OMPInfoCache.OMPBuilder;
702496f8e5bSHamilton Tobon Mosquera     // Add "issue" runtime call declaration:
703496f8e5bSHamilton Tobon Mosquera     // declare %struct.tgt_async_info @__tgt_target_data_begin_issue(i64, i32,
704496f8e5bSHamilton Tobon Mosquera     //   i8**, i8**, i64*, i64*)
705496f8e5bSHamilton Tobon Mosquera     FunctionCallee IssueDecl = IRBuilder.getOrCreateRuntimeFunction(
706496f8e5bSHamilton Tobon Mosquera         M, OMPRTL___tgt_target_data_begin_mapper_issue);
707496f8e5bSHamilton Tobon Mosquera 
708496f8e5bSHamilton Tobon Mosquera     // Change RuntimeCall call site for its asynchronous version.
709496f8e5bSHamilton Tobon Mosquera     SmallVector<Value *, 8> Args;
710*bd2fa181SHamilton Tobon Mosquera     for (auto &Arg : RuntimeCall.args())
711496f8e5bSHamilton Tobon Mosquera       Args.push_back(Arg.get());
712496f8e5bSHamilton Tobon Mosquera 
713496f8e5bSHamilton Tobon Mosquera     CallInst *IssueCallsite =
714*bd2fa181SHamilton Tobon Mosquera         CallInst::Create(IssueDecl, Args, "handle", &RuntimeCall);
715*bd2fa181SHamilton Tobon Mosquera     RuntimeCall.eraseFromParent();
716496f8e5bSHamilton Tobon Mosquera 
717496f8e5bSHamilton Tobon Mosquera     // Add "wait" runtime call declaration:
718496f8e5bSHamilton Tobon Mosquera     // declare void @__tgt_target_data_begin_wait(i64, %struct.__tgt_async_info)
719496f8e5bSHamilton Tobon Mosquera     FunctionCallee WaitDecl = IRBuilder.getOrCreateRuntimeFunction(
720496f8e5bSHamilton Tobon Mosquera         M, OMPRTL___tgt_target_data_begin_mapper_wait);
721496f8e5bSHamilton Tobon Mosquera 
722496f8e5bSHamilton Tobon Mosquera     // Add call site to WaitDecl.
723496f8e5bSHamilton Tobon Mosquera     Value *WaitParams[2] = {
724496f8e5bSHamilton Tobon Mosquera         IssueCallsite->getArgOperand(0), // device_id.
725496f8e5bSHamilton Tobon Mosquera         IssueCallsite // returned handle.
726496f8e5bSHamilton Tobon Mosquera     };
727*bd2fa181SHamilton Tobon Mosquera     CallInst::Create(WaitDecl, WaitParams, /*NameStr=*/"", &WaitMovementPoint);
728496f8e5bSHamilton Tobon Mosquera 
729496f8e5bSHamilton Tobon Mosquera     return true;
730496f8e5bSHamilton Tobon Mosquera   }
731496f8e5bSHamilton Tobon Mosquera 
732dc3b5b00SJohannes Doerfert   static Value *combinedIdentStruct(Value *CurrentIdent, Value *NextIdent,
733dc3b5b00SJohannes Doerfert                                     bool GlobalOnly, bool &SingleChoice) {
734dc3b5b00SJohannes Doerfert     if (CurrentIdent == NextIdent)
735dc3b5b00SJohannes Doerfert       return CurrentIdent;
736dc3b5b00SJohannes Doerfert 
737396b7253SJohannes Doerfert     // TODO: Figure out how to actually combine multiple debug locations. For
738dc3b5b00SJohannes Doerfert     //       now we just keep an existing one if there is a single choice.
739dc3b5b00SJohannes Doerfert     if (!GlobalOnly || isa<GlobalValue>(NextIdent)) {
740dc3b5b00SJohannes Doerfert       SingleChoice = !CurrentIdent;
741dc3b5b00SJohannes Doerfert       return NextIdent;
742dc3b5b00SJohannes Doerfert     }
743396b7253SJohannes Doerfert     return nullptr;
744396b7253SJohannes Doerfert   }
745396b7253SJohannes Doerfert 
746396b7253SJohannes Doerfert   /// Return an `struct ident_t*` value that represents the ones used in the
747396b7253SJohannes Doerfert   /// calls of \p RFI inside of \p F. If \p GlobalOnly is true, we will not
748396b7253SJohannes Doerfert   /// return a local `struct ident_t*`. For now, if we cannot find a suitable
749396b7253SJohannes Doerfert   /// return value we create one from scratch. We also do not yet combine
750396b7253SJohannes Doerfert   /// information, e.g., the source locations, see combinedIdentStruct.
7517cfd267cSsstefan1   Value *
7527cfd267cSsstefan1   getCombinedIdentFromCallUsesIn(OMPInformationCache::RuntimeFunctionInfo &RFI,
7537cfd267cSsstefan1                                  Function &F, bool GlobalOnly) {
754dc3b5b00SJohannes Doerfert     bool SingleChoice = true;
755396b7253SJohannes Doerfert     Value *Ident = nullptr;
756396b7253SJohannes Doerfert     auto CombineIdentStruct = [&](Use &U, Function &Caller) {
757396b7253SJohannes Doerfert       CallInst *CI = getCallIfRegularCall(U, &RFI);
758396b7253SJohannes Doerfert       if (!CI || &F != &Caller)
759396b7253SJohannes Doerfert         return false;
760396b7253SJohannes Doerfert       Ident = combinedIdentStruct(Ident, CI->getArgOperand(0),
761dc3b5b00SJohannes Doerfert                                   /* GlobalOnly */ true, SingleChoice);
762396b7253SJohannes Doerfert       return false;
763396b7253SJohannes Doerfert     };
764624d34afSJohannes Doerfert     RFI.foreachUse(SCC, CombineIdentStruct);
765396b7253SJohannes Doerfert 
766dc3b5b00SJohannes Doerfert     if (!Ident || !SingleChoice) {
767396b7253SJohannes Doerfert       // The IRBuilder uses the insertion block to get to the module, this is
768396b7253SJohannes Doerfert       // unfortunate but we work around it for now.
7697cfd267cSsstefan1       if (!OMPInfoCache.OMPBuilder.getInsertionPoint().getBlock())
7707cfd267cSsstefan1         OMPInfoCache.OMPBuilder.updateToLocation(OpenMPIRBuilder::InsertPointTy(
771396b7253SJohannes Doerfert             &F.getEntryBlock(), F.getEntryBlock().begin()));
772396b7253SJohannes Doerfert       // Create a fallback location if non was found.
773396b7253SJohannes Doerfert       // TODO: Use the debug locations of the calls instead.
7747cfd267cSsstefan1       Constant *Loc = OMPInfoCache.OMPBuilder.getOrCreateDefaultSrcLocStr();
7757cfd267cSsstefan1       Ident = OMPInfoCache.OMPBuilder.getOrCreateIdent(Loc);
776396b7253SJohannes Doerfert     }
777396b7253SJohannes Doerfert     return Ident;
778396b7253SJohannes Doerfert   }
779396b7253SJohannes Doerfert 
780b726c557SJohannes Doerfert   /// Try to eliminate calls of \p RFI in \p F by reusing an existing one or
7819548b74aSJohannes Doerfert   /// \p ReplVal if given.
7827cfd267cSsstefan1   bool deduplicateRuntimeCalls(Function &F,
7837cfd267cSsstefan1                                OMPInformationCache::RuntimeFunctionInfo &RFI,
7849548b74aSJohannes Doerfert                                Value *ReplVal = nullptr) {
7858855fec3SJohannes Doerfert     auto *UV = RFI.getUseVector(F);
7868855fec3SJohannes Doerfert     if (!UV || UV->size() + (ReplVal != nullptr) < 2)
787b1fbf438SRoman Lebedev       return false;
788b1fbf438SRoman Lebedev 
7897cfd267cSsstefan1     LLVM_DEBUG(
7907cfd267cSsstefan1         dbgs() << TAG << "Deduplicate " << UV->size() << " uses of " << RFI.Name
7917cfd267cSsstefan1                << (ReplVal ? " with an existing value\n" : "\n") << "\n");
7927cfd267cSsstefan1 
793ab3da5ddSMichael Liao     assert((!ReplVal || (isa<Argument>(ReplVal) &&
794ab3da5ddSMichael Liao                          cast<Argument>(ReplVal)->getParent() == &F)) &&
7959548b74aSJohannes Doerfert            "Unexpected replacement value!");
796396b7253SJohannes Doerfert 
797396b7253SJohannes Doerfert     // TODO: Use dominance to find a good position instead.
7986aab27baSsstefan1     auto CanBeMoved = [this](CallBase &CB) {
799396b7253SJohannes Doerfert       unsigned NumArgs = CB.getNumArgOperands();
800396b7253SJohannes Doerfert       if (NumArgs == 0)
801396b7253SJohannes Doerfert         return true;
8026aab27baSsstefan1       if (CB.getArgOperand(0)->getType() != OMPInfoCache.OMPBuilder.IdentPtr)
803396b7253SJohannes Doerfert         return false;
804396b7253SJohannes Doerfert       for (unsigned u = 1; u < NumArgs; ++u)
805396b7253SJohannes Doerfert         if (isa<Instruction>(CB.getArgOperand(u)))
806396b7253SJohannes Doerfert           return false;
807396b7253SJohannes Doerfert       return true;
808396b7253SJohannes Doerfert     };
809396b7253SJohannes Doerfert 
8109548b74aSJohannes Doerfert     if (!ReplVal) {
8118855fec3SJohannes Doerfert       for (Use *U : *UV)
8129548b74aSJohannes Doerfert         if (CallInst *CI = getCallIfRegularCall(*U, &RFI)) {
813396b7253SJohannes Doerfert           if (!CanBeMoved(*CI))
814396b7253SJohannes Doerfert             continue;
8154d4ea9acSHuber, Joseph 
8164d4ea9acSHuber, Joseph           auto Remark = [&](OptimizationRemark OR) {
8174d4ea9acSHuber, Joseph             auto newLoc = &*F.getEntryBlock().getFirstInsertionPt();
8184d4ea9acSHuber, Joseph             return OR << "OpenMP runtime call "
8194d4ea9acSHuber, Joseph                       << ore::NV("OpenMPOptRuntime", RFI.Name) << " moved to "
8204d4ea9acSHuber, Joseph                       << ore::NV("OpenMPRuntimeMoves", newLoc->getDebugLoc());
8214d4ea9acSHuber, Joseph           };
8224d4ea9acSHuber, Joseph           emitRemark<OptimizationRemark>(CI, "OpenMPRuntimeCodeMotion", Remark);
8234d4ea9acSHuber, Joseph 
8249548b74aSJohannes Doerfert           CI->moveBefore(&*F.getEntryBlock().getFirstInsertionPt());
8259548b74aSJohannes Doerfert           ReplVal = CI;
8269548b74aSJohannes Doerfert           break;
8279548b74aSJohannes Doerfert         }
8289548b74aSJohannes Doerfert       if (!ReplVal)
8299548b74aSJohannes Doerfert         return false;
8309548b74aSJohannes Doerfert     }
8319548b74aSJohannes Doerfert 
832396b7253SJohannes Doerfert     // If we use a call as a replacement value we need to make sure the ident is
833396b7253SJohannes Doerfert     // valid at the new location. For now we just pick a global one, either
834396b7253SJohannes Doerfert     // existing and used by one of the calls, or created from scratch.
835396b7253SJohannes Doerfert     if (CallBase *CI = dyn_cast<CallBase>(ReplVal)) {
836396b7253SJohannes Doerfert       if (CI->getNumArgOperands() > 0 &&
8376aab27baSsstefan1           CI->getArgOperand(0)->getType() == OMPInfoCache.OMPBuilder.IdentPtr) {
838396b7253SJohannes Doerfert         Value *Ident = getCombinedIdentFromCallUsesIn(RFI, F,
839396b7253SJohannes Doerfert                                                       /* GlobalOnly */ true);
840396b7253SJohannes Doerfert         CI->setArgOperand(0, Ident);
841396b7253SJohannes Doerfert       }
842396b7253SJohannes Doerfert     }
843396b7253SJohannes Doerfert 
8449548b74aSJohannes Doerfert     bool Changed = false;
8459548b74aSJohannes Doerfert     auto ReplaceAndDeleteCB = [&](Use &U, Function &Caller) {
8469548b74aSJohannes Doerfert       CallInst *CI = getCallIfRegularCall(U, &RFI);
8479548b74aSJohannes Doerfert       if (!CI || CI == ReplVal || &F != &Caller)
8489548b74aSJohannes Doerfert         return false;
8499548b74aSJohannes Doerfert       assert(CI->getCaller() == &F && "Unexpected call!");
8504d4ea9acSHuber, Joseph 
8514d4ea9acSHuber, Joseph       auto Remark = [&](OptimizationRemark OR) {
8524d4ea9acSHuber, Joseph         return OR << "OpenMP runtime call "
8534d4ea9acSHuber, Joseph                   << ore::NV("OpenMPOptRuntime", RFI.Name) << " deduplicated";
8544d4ea9acSHuber, Joseph       };
8554d4ea9acSHuber, Joseph       emitRemark<OptimizationRemark>(CI, "OpenMPRuntimeDeduplicated", Remark);
8564d4ea9acSHuber, Joseph 
8579548b74aSJohannes Doerfert       CGUpdater.removeCallSite(*CI);
8589548b74aSJohannes Doerfert       CI->replaceAllUsesWith(ReplVal);
8599548b74aSJohannes Doerfert       CI->eraseFromParent();
8609548b74aSJohannes Doerfert       ++NumOpenMPRuntimeCallsDeduplicated;
8619548b74aSJohannes Doerfert       Changed = true;
8629548b74aSJohannes Doerfert       return true;
8639548b74aSJohannes Doerfert     };
864624d34afSJohannes Doerfert     RFI.foreachUse(SCC, ReplaceAndDeleteCB);
8659548b74aSJohannes Doerfert 
8669548b74aSJohannes Doerfert     return Changed;
8679548b74aSJohannes Doerfert   }
8689548b74aSJohannes Doerfert 
8699548b74aSJohannes Doerfert   /// Collect arguments that represent the global thread id in \p GTIdArgs.
8709548b74aSJohannes Doerfert   void collectGlobalThreadIdArguments(SmallSetVector<Value *, 16> &GTIdArgs) {
8719548b74aSJohannes Doerfert     // TODO: Below we basically perform a fixpoint iteration with a pessimistic
8729548b74aSJohannes Doerfert     //       initialization. We could define an AbstractAttribute instead and
8739548b74aSJohannes Doerfert     //       run the Attributor here once it can be run as an SCC pass.
8749548b74aSJohannes Doerfert 
8759548b74aSJohannes Doerfert     // Helper to check the argument \p ArgNo at all call sites of \p F for
8769548b74aSJohannes Doerfert     // a GTId.
8779548b74aSJohannes Doerfert     auto CallArgOpIsGTId = [&](Function &F, unsigned ArgNo, CallInst &RefCI) {
8789548b74aSJohannes Doerfert       if (!F.hasLocalLinkage())
8799548b74aSJohannes Doerfert         return false;
8809548b74aSJohannes Doerfert       for (Use &U : F.uses()) {
8819548b74aSJohannes Doerfert         if (CallInst *CI = getCallIfRegularCall(U)) {
8829548b74aSJohannes Doerfert           Value *ArgOp = CI->getArgOperand(ArgNo);
8839548b74aSJohannes Doerfert           if (CI == &RefCI || GTIdArgs.count(ArgOp) ||
8847cfd267cSsstefan1               getCallIfRegularCall(
8857cfd267cSsstefan1                   *ArgOp, &OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num]))
8869548b74aSJohannes Doerfert             continue;
8879548b74aSJohannes Doerfert         }
8889548b74aSJohannes Doerfert         return false;
8899548b74aSJohannes Doerfert       }
8909548b74aSJohannes Doerfert       return true;
8919548b74aSJohannes Doerfert     };
8929548b74aSJohannes Doerfert 
8939548b74aSJohannes Doerfert     // Helper to identify uses of a GTId as GTId arguments.
8949548b74aSJohannes Doerfert     auto AddUserArgs = [&](Value &GTId) {
8959548b74aSJohannes Doerfert       for (Use &U : GTId.uses())
8969548b74aSJohannes Doerfert         if (CallInst *CI = dyn_cast<CallInst>(U.getUser()))
8979548b74aSJohannes Doerfert           if (CI->isArgOperand(&U))
8989548b74aSJohannes Doerfert             if (Function *Callee = CI->getCalledFunction())
8999548b74aSJohannes Doerfert               if (CallArgOpIsGTId(*Callee, U.getOperandNo(), *CI))
9009548b74aSJohannes Doerfert                 GTIdArgs.insert(Callee->getArg(U.getOperandNo()));
9019548b74aSJohannes Doerfert     };
9029548b74aSJohannes Doerfert 
9039548b74aSJohannes Doerfert     // The argument users of __kmpc_global_thread_num calls are GTIds.
9047cfd267cSsstefan1     OMPInformationCache::RuntimeFunctionInfo &GlobThreadNumRFI =
9057cfd267cSsstefan1         OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num];
9067cfd267cSsstefan1 
907624d34afSJohannes Doerfert     GlobThreadNumRFI.foreachUse(SCC, [&](Use &U, Function &F) {
9088855fec3SJohannes Doerfert       if (CallInst *CI = getCallIfRegularCall(U, &GlobThreadNumRFI))
9099548b74aSJohannes Doerfert         AddUserArgs(*CI);
9108855fec3SJohannes Doerfert       return false;
9118855fec3SJohannes Doerfert     });
9129548b74aSJohannes Doerfert 
9139548b74aSJohannes Doerfert     // Transitively search for more arguments by looking at the users of the
9149548b74aSJohannes Doerfert     // ones we know already. During the search the GTIdArgs vector is extended
9159548b74aSJohannes Doerfert     // so we cannot cache the size nor can we use a range based for.
9169548b74aSJohannes Doerfert     for (unsigned u = 0; u < GTIdArgs.size(); ++u)
9179548b74aSJohannes Doerfert       AddUserArgs(*GTIdArgs[u]);
9189548b74aSJohannes Doerfert   }
9199548b74aSJohannes Doerfert 
9205b0581aeSJohannes Doerfert   /// Kernel (=GPU) optimizations and utility functions
9215b0581aeSJohannes Doerfert   ///
9225b0581aeSJohannes Doerfert   ///{{
9235b0581aeSJohannes Doerfert 
9245b0581aeSJohannes Doerfert   /// Check if \p F is a kernel, hence entry point for target offloading.
9255b0581aeSJohannes Doerfert   bool isKernel(Function &F) { return OMPInfoCache.Kernels.count(&F); }
9265b0581aeSJohannes Doerfert 
9275b0581aeSJohannes Doerfert   /// Cache to remember the unique kernel for a function.
9285b0581aeSJohannes Doerfert   DenseMap<Function *, Optional<Kernel>> UniqueKernelMap;
9295b0581aeSJohannes Doerfert 
9305b0581aeSJohannes Doerfert   /// Find the unique kernel that will execute \p F, if any.
9315b0581aeSJohannes Doerfert   Kernel getUniqueKernelFor(Function &F);
9325b0581aeSJohannes Doerfert 
9335b0581aeSJohannes Doerfert   /// Find the unique kernel that will execute \p I, if any.
9345b0581aeSJohannes Doerfert   Kernel getUniqueKernelFor(Instruction &I) {
9355b0581aeSJohannes Doerfert     return getUniqueKernelFor(*I.getFunction());
9365b0581aeSJohannes Doerfert   }
9375b0581aeSJohannes Doerfert 
9385b0581aeSJohannes Doerfert   /// Rewrite the device (=GPU) code state machine create in non-SPMD mode in
9395b0581aeSJohannes Doerfert   /// the cases we can avoid taking the address of a function.
9405b0581aeSJohannes Doerfert   bool rewriteDeviceCodeStateMachine();
9415b0581aeSJohannes Doerfert 
9425b0581aeSJohannes Doerfert   ///
9435b0581aeSJohannes Doerfert   ///}}
9445b0581aeSJohannes Doerfert 
9454d4ea9acSHuber, Joseph   /// Emit a remark generically
9464d4ea9acSHuber, Joseph   ///
9474d4ea9acSHuber, Joseph   /// This template function can be used to generically emit a remark. The
9484d4ea9acSHuber, Joseph   /// RemarkKind should be one of the following:
9494d4ea9acSHuber, Joseph   ///   - OptimizationRemark to indicate a successful optimization attempt
9504d4ea9acSHuber, Joseph   ///   - OptimizationRemarkMissed to report a failed optimization attempt
9514d4ea9acSHuber, Joseph   ///   - OptimizationRemarkAnalysis to provide additional information about an
9524d4ea9acSHuber, Joseph   ///     optimization attempt
9534d4ea9acSHuber, Joseph   ///
9544d4ea9acSHuber, Joseph   /// The remark is built using a callback function provided by the caller that
9554d4ea9acSHuber, Joseph   /// takes a RemarkKind as input and returns a RemarkKind.
9564d4ea9acSHuber, Joseph   template <typename RemarkKind,
9574d4ea9acSHuber, Joseph             typename RemarkCallBack = function_ref<RemarkKind(RemarkKind &&)>>
9584d4ea9acSHuber, Joseph   void emitRemark(Instruction *Inst, StringRef RemarkName,
959e8039ad4SJohannes Doerfert                   RemarkCallBack &&RemarkCB) const {
9604d4ea9acSHuber, Joseph     Function *F = Inst->getParent()->getParent();
9614d4ea9acSHuber, Joseph     auto &ORE = OREGetter(F);
9624d4ea9acSHuber, Joseph 
9637cfd267cSsstefan1     ORE.emit(
9647cfd267cSsstefan1         [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, Inst)); });
9654d4ea9acSHuber, Joseph   }
9664d4ea9acSHuber, Joseph 
9670f426935Ssstefan1   /// Emit a remark on a function. Since only OptimizationRemark is supporting
9680f426935Ssstefan1   /// this, it can't be made generic.
969e8039ad4SJohannes Doerfert   void
970e8039ad4SJohannes Doerfert   emitRemarkOnFunction(Function *F, StringRef RemarkName,
971e8039ad4SJohannes Doerfert                        function_ref<OptimizationRemark(OptimizationRemark &&)>
972e8039ad4SJohannes Doerfert                            &&RemarkCB) const {
9730f426935Ssstefan1     auto &ORE = OREGetter(F);
9740f426935Ssstefan1 
9750f426935Ssstefan1     ORE.emit([&]() {
9760f426935Ssstefan1       return RemarkCB(OptimizationRemark(DEBUG_TYPE, RemarkName, F));
9770f426935Ssstefan1     });
9780f426935Ssstefan1   }
9790f426935Ssstefan1 
980b726c557SJohannes Doerfert   /// The underlying module.
9819548b74aSJohannes Doerfert   Module &M;
9829548b74aSJohannes Doerfert 
9839548b74aSJohannes Doerfert   /// The SCC we are operating on.
984ee17263aSJohannes Doerfert   SmallVectorImpl<Function *> &SCC;
9859548b74aSJohannes Doerfert 
9869548b74aSJohannes Doerfert   /// Callback to update the call graph, the first argument is a removed call,
9879548b74aSJohannes Doerfert   /// the second an optional replacement call.
9889548b74aSJohannes Doerfert   CallGraphUpdater &CGUpdater;
9899548b74aSJohannes Doerfert 
9904d4ea9acSHuber, Joseph   /// Callback to get an OptimizationRemarkEmitter from a Function *
9914d4ea9acSHuber, Joseph   OptimizationRemarkGetter OREGetter;
9924d4ea9acSHuber, Joseph 
9937cfd267cSsstefan1   /// OpenMP-specific information cache. Also Used for Attributor runs.
9947cfd267cSsstefan1   OMPInformationCache &OMPInfoCache;
995b8235d2bSsstefan1 
996b8235d2bSsstefan1   /// Attributor instance.
997b8235d2bSsstefan1   Attributor &A;
998b8235d2bSsstefan1 
999b8235d2bSsstefan1   /// Helper function to run Attributor on SCC.
1000b8235d2bSsstefan1   bool runAttributor() {
1001b8235d2bSsstefan1     if (SCC.empty())
1002b8235d2bSsstefan1       return false;
1003b8235d2bSsstefan1 
1004b8235d2bSsstefan1     registerAAs();
1005b8235d2bSsstefan1 
1006b8235d2bSsstefan1     ChangeStatus Changed = A.run();
1007b8235d2bSsstefan1 
1008b8235d2bSsstefan1     LLVM_DEBUG(dbgs() << "[Attributor] Done with " << SCC.size()
1009b8235d2bSsstefan1                       << " functions, result: " << Changed << ".\n");
1010b8235d2bSsstefan1 
1011b8235d2bSsstefan1     return Changed == ChangeStatus::CHANGED;
1012b8235d2bSsstefan1   }
1013b8235d2bSsstefan1 
1014b8235d2bSsstefan1   /// Populate the Attributor with abstract attribute opportunities in the
1015b8235d2bSsstefan1   /// function.
1016b8235d2bSsstefan1   void registerAAs() {
1017b0b32e64Ssstefan1     if (SCC.empty())
1018b0b32e64Ssstefan1       return;
1019b8235d2bSsstefan1 
1020b0b32e64Ssstefan1     // Create CallSite AA for all Getters.
1021b0b32e64Ssstefan1     for (int Idx = 0; Idx < OMPInfoCache.ICVs.size() - 1; ++Idx) {
1022b0b32e64Ssstefan1       auto ICVInfo = OMPInfoCache.ICVs[static_cast<InternalControlVar>(Idx)];
1023b0b32e64Ssstefan1 
1024b0b32e64Ssstefan1       auto &GetterRFI = OMPInfoCache.RFIs[ICVInfo.Getter];
1025b0b32e64Ssstefan1 
1026b0b32e64Ssstefan1       auto CreateAA = [&](Use &U, Function &Caller) {
1027b0b32e64Ssstefan1         CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &GetterRFI);
1028b0b32e64Ssstefan1         if (!CI)
1029b0b32e64Ssstefan1           return false;
1030b0b32e64Ssstefan1 
1031b0b32e64Ssstefan1         auto &CB = cast<CallBase>(*CI);
1032b0b32e64Ssstefan1 
1033b0b32e64Ssstefan1         IRPosition CBPos = IRPosition::callsite_function(CB);
1034b0b32e64Ssstefan1         A.getOrCreateAAFor<AAICVTracker>(CBPos);
1035b0b32e64Ssstefan1         return false;
1036b0b32e64Ssstefan1       };
1037b0b32e64Ssstefan1 
1038b0b32e64Ssstefan1       GetterRFI.foreachUse(SCC, CreateAA);
1039b8235d2bSsstefan1     }
1040b8235d2bSsstefan1   }
1041b8235d2bSsstefan1 };
1042b8235d2bSsstefan1 
10435b0581aeSJohannes Doerfert Kernel OpenMPOpt::getUniqueKernelFor(Function &F) {
10445b0581aeSJohannes Doerfert   if (!OMPInfoCache.ModuleSlice.count(&F))
10455b0581aeSJohannes Doerfert     return nullptr;
10465b0581aeSJohannes Doerfert 
10475b0581aeSJohannes Doerfert   // Use a scope to keep the lifetime of the CachedKernel short.
10485b0581aeSJohannes Doerfert   {
10495b0581aeSJohannes Doerfert     Optional<Kernel> &CachedKernel = UniqueKernelMap[&F];
10505b0581aeSJohannes Doerfert     if (CachedKernel)
10515b0581aeSJohannes Doerfert       return *CachedKernel;
10525b0581aeSJohannes Doerfert 
10535b0581aeSJohannes Doerfert     // TODO: We should use an AA to create an (optimistic and callback
10545b0581aeSJohannes Doerfert     //       call-aware) call graph. For now we stick to simple patterns that
10555b0581aeSJohannes Doerfert     //       are less powerful, basically the worst fixpoint.
10565b0581aeSJohannes Doerfert     if (isKernel(F)) {
10575b0581aeSJohannes Doerfert       CachedKernel = Kernel(&F);
10585b0581aeSJohannes Doerfert       return *CachedKernel;
10595b0581aeSJohannes Doerfert     }
10605b0581aeSJohannes Doerfert 
10615b0581aeSJohannes Doerfert     CachedKernel = nullptr;
10625b0581aeSJohannes Doerfert     if (!F.hasLocalLinkage())
10635b0581aeSJohannes Doerfert       return nullptr;
10645b0581aeSJohannes Doerfert   }
10655b0581aeSJohannes Doerfert 
10665b0581aeSJohannes Doerfert   auto GetUniqueKernelForUse = [&](const Use &U) -> Kernel {
10675b0581aeSJohannes Doerfert     if (auto *Cmp = dyn_cast<ICmpInst>(U.getUser())) {
10685b0581aeSJohannes Doerfert       // Allow use in equality comparisons.
10695b0581aeSJohannes Doerfert       if (Cmp->isEquality())
10705b0581aeSJohannes Doerfert         return getUniqueKernelFor(*Cmp);
10715b0581aeSJohannes Doerfert       return nullptr;
10725b0581aeSJohannes Doerfert     }
10735b0581aeSJohannes Doerfert     if (auto *CB = dyn_cast<CallBase>(U.getUser())) {
10745b0581aeSJohannes Doerfert       // Allow direct calls.
10755b0581aeSJohannes Doerfert       if (CB->isCallee(&U))
10765b0581aeSJohannes Doerfert         return getUniqueKernelFor(*CB);
10775b0581aeSJohannes Doerfert       // Allow the use in __kmpc_kernel_prepare_parallel calls.
10785b0581aeSJohannes Doerfert       if (Function *Callee = CB->getCalledFunction())
10795b0581aeSJohannes Doerfert         if (Callee->getName() == "__kmpc_kernel_prepare_parallel")
10805b0581aeSJohannes Doerfert           return getUniqueKernelFor(*CB);
10815b0581aeSJohannes Doerfert       return nullptr;
10825b0581aeSJohannes Doerfert     }
10835b0581aeSJohannes Doerfert     // Disallow every other use.
10845b0581aeSJohannes Doerfert     return nullptr;
10855b0581aeSJohannes Doerfert   };
10865b0581aeSJohannes Doerfert 
10875b0581aeSJohannes Doerfert   // TODO: In the future we want to track more than just a unique kernel.
10885b0581aeSJohannes Doerfert   SmallPtrSet<Kernel, 2> PotentialKernels;
10895b0581aeSJohannes Doerfert   foreachUse(F, [&](const Use &U) {
10905b0581aeSJohannes Doerfert     PotentialKernels.insert(GetUniqueKernelForUse(U));
10915b0581aeSJohannes Doerfert   });
10925b0581aeSJohannes Doerfert 
10935b0581aeSJohannes Doerfert   Kernel K = nullptr;
10945b0581aeSJohannes Doerfert   if (PotentialKernels.size() == 1)
10955b0581aeSJohannes Doerfert     K = *PotentialKernels.begin();
10965b0581aeSJohannes Doerfert 
10975b0581aeSJohannes Doerfert   // Cache the result.
10985b0581aeSJohannes Doerfert   UniqueKernelMap[&F] = K;
10995b0581aeSJohannes Doerfert 
11005b0581aeSJohannes Doerfert   return K;
11015b0581aeSJohannes Doerfert }
11025b0581aeSJohannes Doerfert 
11035b0581aeSJohannes Doerfert bool OpenMPOpt::rewriteDeviceCodeStateMachine() {
11045b0581aeSJohannes Doerfert   OMPInformationCache::RuntimeFunctionInfo &KernelPrepareParallelRFI =
11055b0581aeSJohannes Doerfert       OMPInfoCache.RFIs[OMPRTL___kmpc_kernel_prepare_parallel];
11065b0581aeSJohannes Doerfert 
11075b0581aeSJohannes Doerfert   bool Changed = false;
11085b0581aeSJohannes Doerfert   if (!KernelPrepareParallelRFI)
11095b0581aeSJohannes Doerfert     return Changed;
11105b0581aeSJohannes Doerfert 
11115b0581aeSJohannes Doerfert   for (Function *F : SCC) {
11125b0581aeSJohannes Doerfert 
11135b0581aeSJohannes Doerfert     // Check if the function is uses in a __kmpc_kernel_prepare_parallel call at
11145b0581aeSJohannes Doerfert     // all.
11155b0581aeSJohannes Doerfert     bool UnknownUse = false;
1116fec1f210SJohannes Doerfert     bool KernelPrepareUse = false;
11175b0581aeSJohannes Doerfert     unsigned NumDirectCalls = 0;
11185b0581aeSJohannes Doerfert 
11195b0581aeSJohannes Doerfert     SmallVector<Use *, 2> ToBeReplacedStateMachineUses;
11205b0581aeSJohannes Doerfert     foreachUse(*F, [&](Use &U) {
11215b0581aeSJohannes Doerfert       if (auto *CB = dyn_cast<CallBase>(U.getUser()))
11225b0581aeSJohannes Doerfert         if (CB->isCallee(&U)) {
11235b0581aeSJohannes Doerfert           ++NumDirectCalls;
11245b0581aeSJohannes Doerfert           return;
11255b0581aeSJohannes Doerfert         }
11265b0581aeSJohannes Doerfert 
112781db6144SMichael Liao       if (isa<ICmpInst>(U.getUser())) {
11285b0581aeSJohannes Doerfert         ToBeReplacedStateMachineUses.push_back(&U);
11295b0581aeSJohannes Doerfert         return;
11305b0581aeSJohannes Doerfert       }
1131fec1f210SJohannes Doerfert       if (!KernelPrepareUse && OpenMPOpt::getCallIfRegularCall(
1132fec1f210SJohannes Doerfert                                    *U.getUser(), &KernelPrepareParallelRFI)) {
1133fec1f210SJohannes Doerfert         KernelPrepareUse = true;
11345b0581aeSJohannes Doerfert         ToBeReplacedStateMachineUses.push_back(&U);
11355b0581aeSJohannes Doerfert         return;
11365b0581aeSJohannes Doerfert       }
11375b0581aeSJohannes Doerfert       UnknownUse = true;
11385b0581aeSJohannes Doerfert     });
11395b0581aeSJohannes Doerfert 
1140fec1f210SJohannes Doerfert     // Do not emit a remark if we haven't seen a __kmpc_kernel_prepare_parallel
1141fec1f210SJohannes Doerfert     // use.
1142fec1f210SJohannes Doerfert     if (!KernelPrepareUse)
11435b0581aeSJohannes Doerfert       continue;
11445b0581aeSJohannes Doerfert 
1145fec1f210SJohannes Doerfert     {
1146fec1f210SJohannes Doerfert       auto Remark = [&](OptimizationRemark OR) {
1147fec1f210SJohannes Doerfert         return OR << "Found a parallel region that is called in a target "
1148fec1f210SJohannes Doerfert                      "region but not part of a combined target construct nor "
1149fec1f210SJohannes Doerfert                      "nesed inside a target construct without intermediate "
1150fec1f210SJohannes Doerfert                      "code. This can lead to excessive register usage for "
1151fec1f210SJohannes Doerfert                      "unrelated target regions in the same translation unit "
1152fec1f210SJohannes Doerfert                      "due to spurious call edges assumed by ptxas.";
1153fec1f210SJohannes Doerfert       };
1154fec1f210SJohannes Doerfert       emitRemarkOnFunction(F, "OpenMPParallelRegionInNonSPMD", Remark);
1155fec1f210SJohannes Doerfert     }
1156fec1f210SJohannes Doerfert 
1157fec1f210SJohannes Doerfert     // If this ever hits, we should investigate.
1158fec1f210SJohannes Doerfert     // TODO: Checking the number of uses is not a necessary restriction and
1159fec1f210SJohannes Doerfert     // should be lifted.
1160fec1f210SJohannes Doerfert     if (UnknownUse || NumDirectCalls != 1 ||
1161fec1f210SJohannes Doerfert         ToBeReplacedStateMachineUses.size() != 2) {
1162fec1f210SJohannes Doerfert       {
1163fec1f210SJohannes Doerfert         auto Remark = [&](OptimizationRemark OR) {
1164fec1f210SJohannes Doerfert           return OR << "Parallel region is used in "
1165fec1f210SJohannes Doerfert                     << (UnknownUse ? "unknown" : "unexpected")
1166fec1f210SJohannes Doerfert                     << " ways; will not attempt to rewrite the state machine.";
1167fec1f210SJohannes Doerfert         };
1168fec1f210SJohannes Doerfert         emitRemarkOnFunction(F, "OpenMPParallelRegionInNonSPMD", Remark);
1169fec1f210SJohannes Doerfert       }
11705b0581aeSJohannes Doerfert       continue;
1171fec1f210SJohannes Doerfert     }
11725b0581aeSJohannes Doerfert 
11735b0581aeSJohannes Doerfert     // Even if we have __kmpc_kernel_prepare_parallel calls, we (for now) give
11745b0581aeSJohannes Doerfert     // up if the function is not called from a unique kernel.
11755b0581aeSJohannes Doerfert     Kernel K = getUniqueKernelFor(*F);
1176fec1f210SJohannes Doerfert     if (!K) {
1177fec1f210SJohannes Doerfert       {
1178fec1f210SJohannes Doerfert         auto Remark = [&](OptimizationRemark OR) {
1179fec1f210SJohannes Doerfert           return OR << "Parallel region is not known to be called from a "
1180fec1f210SJohannes Doerfert                        "unique single target region, maybe the surrounding "
1181fec1f210SJohannes Doerfert                        "function has external linkage?; will not attempt to "
1182fec1f210SJohannes Doerfert                        "rewrite the state machine use.";
1183fec1f210SJohannes Doerfert         };
1184fec1f210SJohannes Doerfert         emitRemarkOnFunction(F, "OpenMPParallelRegionInMultipleKernesl",
1185fec1f210SJohannes Doerfert                              Remark);
1186fec1f210SJohannes Doerfert       }
11875b0581aeSJohannes Doerfert       continue;
1188fec1f210SJohannes Doerfert     }
11895b0581aeSJohannes Doerfert 
11905b0581aeSJohannes Doerfert     // We now know F is a parallel body function called only from the kernel K.
11915b0581aeSJohannes Doerfert     // We also identified the state machine uses in which we replace the
11925b0581aeSJohannes Doerfert     // function pointer by a new global symbol for identification purposes. This
11935b0581aeSJohannes Doerfert     // ensures only direct calls to the function are left.
11945b0581aeSJohannes Doerfert 
1195fec1f210SJohannes Doerfert     {
1196fec1f210SJohannes Doerfert       auto RemarkParalleRegion = [&](OptimizationRemark OR) {
1197fec1f210SJohannes Doerfert         return OR << "Specialize parallel region that is only reached from a "
1198fec1f210SJohannes Doerfert                      "single target region to avoid spurious call edges and "
1199fec1f210SJohannes Doerfert                      "excessive register usage in other target regions. "
1200fec1f210SJohannes Doerfert                      "(parallel region ID: "
1201fec1f210SJohannes Doerfert                   << ore::NV("OpenMPParallelRegion", F->getName())
1202fec1f210SJohannes Doerfert                   << ", kernel ID: "
1203fec1f210SJohannes Doerfert                   << ore::NV("OpenMPTargetRegion", K->getName()) << ")";
1204fec1f210SJohannes Doerfert       };
1205fec1f210SJohannes Doerfert       emitRemarkOnFunction(F, "OpenMPParallelRegionInNonSPMD",
1206fec1f210SJohannes Doerfert                            RemarkParalleRegion);
1207fec1f210SJohannes Doerfert       auto RemarkKernel = [&](OptimizationRemark OR) {
1208fec1f210SJohannes Doerfert         return OR << "Target region containing the parallel region that is "
1209fec1f210SJohannes Doerfert                      "specialized. (parallel region ID: "
1210fec1f210SJohannes Doerfert                   << ore::NV("OpenMPParallelRegion", F->getName())
1211fec1f210SJohannes Doerfert                   << ", kernel ID: "
1212fec1f210SJohannes Doerfert                   << ore::NV("OpenMPTargetRegion", K->getName()) << ")";
1213fec1f210SJohannes Doerfert       };
1214fec1f210SJohannes Doerfert       emitRemarkOnFunction(K, "OpenMPParallelRegionInNonSPMD", RemarkKernel);
1215fec1f210SJohannes Doerfert     }
1216fec1f210SJohannes Doerfert 
12175b0581aeSJohannes Doerfert     Module &M = *F->getParent();
12185b0581aeSJohannes Doerfert     Type *Int8Ty = Type::getInt8Ty(M.getContext());
12195b0581aeSJohannes Doerfert 
12205b0581aeSJohannes Doerfert     auto *ID = new GlobalVariable(
12215b0581aeSJohannes Doerfert         M, Int8Ty, /* isConstant */ true, GlobalValue::PrivateLinkage,
12225b0581aeSJohannes Doerfert         UndefValue::get(Int8Ty), F->getName() + ".ID");
12235b0581aeSJohannes Doerfert 
12245b0581aeSJohannes Doerfert     for (Use *U : ToBeReplacedStateMachineUses)
12255b0581aeSJohannes Doerfert       U->set(ConstantExpr::getBitCast(ID, U->get()->getType()));
12265b0581aeSJohannes Doerfert 
12275b0581aeSJohannes Doerfert     ++NumOpenMPParallelRegionsReplacedInGPUStateMachine;
12285b0581aeSJohannes Doerfert 
12295b0581aeSJohannes Doerfert     Changed = true;
12305b0581aeSJohannes Doerfert   }
12315b0581aeSJohannes Doerfert 
12325b0581aeSJohannes Doerfert   return Changed;
12335b0581aeSJohannes Doerfert }
12345b0581aeSJohannes Doerfert 
1235b8235d2bSsstefan1 /// Abstract Attribute for tracking ICV values.
1236b8235d2bSsstefan1 struct AAICVTracker : public StateWrapper<BooleanState, AbstractAttribute> {
1237b8235d2bSsstefan1   using Base = StateWrapper<BooleanState, AbstractAttribute>;
1238b8235d2bSsstefan1   AAICVTracker(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
1239b8235d2bSsstefan1 
1240b0b32e64Ssstefan1   void initialize(Attributor &A) override {
1241b0b32e64Ssstefan1     Function *F = getAnchorScope();
1242b0b32e64Ssstefan1     if (!F || !A.isFunctionIPOAmendable(*F))
1243b0b32e64Ssstefan1       indicatePessimisticFixpoint();
1244b0b32e64Ssstefan1   }
1245b0b32e64Ssstefan1 
1246b8235d2bSsstefan1   /// Returns true if value is assumed to be tracked.
1247b8235d2bSsstefan1   bool isAssumedTracked() const { return getAssumed(); }
1248b8235d2bSsstefan1 
1249b8235d2bSsstefan1   /// Returns true if value is known to be tracked.
1250b8235d2bSsstefan1   bool isKnownTracked() const { return getAssumed(); }
1251b8235d2bSsstefan1 
1252b8235d2bSsstefan1   /// Create an abstract attribute biew for the position \p IRP.
1253b8235d2bSsstefan1   static AAICVTracker &createForPosition(const IRPosition &IRP, Attributor &A);
1254b8235d2bSsstefan1 
1255b8235d2bSsstefan1   /// Return the value with which \p I can be replaced for specific \p ICV.
1256b0b32e64Ssstefan1   virtual Optional<Value *> getReplacementValue(InternalControlVar ICV,
1257b0b32e64Ssstefan1                                                 const Instruction *I,
1258b0b32e64Ssstefan1                                                 Attributor &A) const {
1259b0b32e64Ssstefan1     return None;
1260b0b32e64Ssstefan1   }
1261b0b32e64Ssstefan1 
1262b0b32e64Ssstefan1   /// Return an assumed unique ICV value if a single candidate is found. If
1263b0b32e64Ssstefan1   /// there cannot be one, return a nullptr. If it is not clear yet, return the
1264b0b32e64Ssstefan1   /// Optional::NoneType.
1265b0b32e64Ssstefan1   virtual Optional<Value *>
1266b0b32e64Ssstefan1   getUniqueReplacementValue(InternalControlVar ICV) const = 0;
1267b0b32e64Ssstefan1 
1268b0b32e64Ssstefan1   // Currently only nthreads is being tracked.
1269b0b32e64Ssstefan1   // this array will only grow with time.
1270b0b32e64Ssstefan1   InternalControlVar TrackableICVs[1] = {ICV_nthreads};
1271b8235d2bSsstefan1 
1272b8235d2bSsstefan1   /// See AbstractAttribute::getName()
1273b8235d2bSsstefan1   const std::string getName() const override { return "AAICVTracker"; }
1274b8235d2bSsstefan1 
1275233af895SLuofan Chen   /// See AbstractAttribute::getIdAddr()
1276233af895SLuofan Chen   const char *getIdAddr() const override { return &ID; }
1277233af895SLuofan Chen 
1278233af895SLuofan Chen   /// This function should return true if the type of the \p AA is AAICVTracker
1279233af895SLuofan Chen   static bool classof(const AbstractAttribute *AA) {
1280233af895SLuofan Chen     return (AA->getIdAddr() == &ID);
1281233af895SLuofan Chen   }
1282233af895SLuofan Chen 
1283b8235d2bSsstefan1   static const char ID;
1284b8235d2bSsstefan1 };
1285b8235d2bSsstefan1 
1286b8235d2bSsstefan1 struct AAICVTrackerFunction : public AAICVTracker {
1287b8235d2bSsstefan1   AAICVTrackerFunction(const IRPosition &IRP, Attributor &A)
1288b8235d2bSsstefan1       : AAICVTracker(IRP, A) {}
1289b8235d2bSsstefan1 
1290b8235d2bSsstefan1   // FIXME: come up with better string.
1291b0b32e64Ssstefan1   const std::string getAsStr() const override { return "ICVTrackerFunction"; }
1292b8235d2bSsstefan1 
1293b8235d2bSsstefan1   // FIXME: come up with some stats.
1294b8235d2bSsstefan1   void trackStatistics() const override {}
1295b8235d2bSsstefan1 
1296b0b32e64Ssstefan1   /// We don't manifest anything for this AA.
1297b8235d2bSsstefan1   ChangeStatus manifest(Attributor &A) override {
1298b0b32e64Ssstefan1     return ChangeStatus::UNCHANGED;
1299b8235d2bSsstefan1   }
1300b8235d2bSsstefan1 
1301b8235d2bSsstefan1   // Map of ICV to their values at specific program point.
1302b0b32e64Ssstefan1   EnumeratedArray<DenseMap<Instruction *, Value *>, InternalControlVar,
1303b8235d2bSsstefan1                   InternalControlVar::ICV___last>
1304b0b32e64Ssstefan1       ICVReplacementValuesMap;
1305b8235d2bSsstefan1 
1306b8235d2bSsstefan1   ChangeStatus updateImpl(Attributor &A) override {
1307b8235d2bSsstefan1     ChangeStatus HasChanged = ChangeStatus::UNCHANGED;
1308b8235d2bSsstefan1 
1309b8235d2bSsstefan1     Function *F = getAnchorScope();
1310b8235d2bSsstefan1 
1311b8235d2bSsstefan1     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
1312b8235d2bSsstefan1 
1313b8235d2bSsstefan1     for (InternalControlVar ICV : TrackableICVs) {
1314b8235d2bSsstefan1       auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
1315b8235d2bSsstefan1 
1316b0b32e64Ssstefan1       auto &ValuesMap = ICVReplacementValuesMap[ICV];
1317b8235d2bSsstefan1       auto TrackValues = [&](Use &U, Function &) {
1318b8235d2bSsstefan1         CallInst *CI = OpenMPOpt::getCallIfRegularCall(U);
1319b8235d2bSsstefan1         if (!CI)
1320b8235d2bSsstefan1           return false;
1321b8235d2bSsstefan1 
1322b8235d2bSsstefan1         // FIXME: handle setters with more that 1 arguments.
1323b8235d2bSsstefan1         /// Track new value.
1324b0b32e64Ssstefan1         if (ValuesMap.insert(std::make_pair(CI, CI->getArgOperand(0))).second)
1325b8235d2bSsstefan1           HasChanged = ChangeStatus::CHANGED;
1326b8235d2bSsstefan1 
1327b8235d2bSsstefan1         return false;
1328b8235d2bSsstefan1       };
1329b8235d2bSsstefan1 
1330b0b32e64Ssstefan1       auto CallCheck = [&](Instruction &I) {
1331b0b32e64Ssstefan1         Optional<Value *> ReplVal = getValueForCall(A, &I, ICV);
1332b0b32e64Ssstefan1         if (ReplVal.hasValue() &&
1333b0b32e64Ssstefan1             ValuesMap.insert(std::make_pair(&I, *ReplVal)).second)
1334b0b32e64Ssstefan1           HasChanged = ChangeStatus::CHANGED;
1335b0b32e64Ssstefan1 
1336b0b32e64Ssstefan1         return true;
1337b0b32e64Ssstefan1       };
1338b0b32e64Ssstefan1 
1339b0b32e64Ssstefan1       // Track all changes of an ICV.
1340b8235d2bSsstefan1       SetterRFI.foreachUse(TrackValues, F);
1341b0b32e64Ssstefan1 
1342b0b32e64Ssstefan1       A.checkForAllInstructions(CallCheck, *this, {Instruction::Call},
1343b0b32e64Ssstefan1                                 /* CheckBBLivenessOnly */ true);
1344b0b32e64Ssstefan1 
1345b0b32e64Ssstefan1       /// TODO: Figure out a way to avoid adding entry in
1346b0b32e64Ssstefan1       /// ICVReplacementValuesMap
1347b0b32e64Ssstefan1       Instruction *Entry = &F->getEntryBlock().front();
1348b0b32e64Ssstefan1       if (HasChanged == ChangeStatus::CHANGED && !ValuesMap.count(Entry))
1349b0b32e64Ssstefan1         ValuesMap.insert(std::make_pair(Entry, nullptr));
1350b8235d2bSsstefan1     }
1351b8235d2bSsstefan1 
1352b8235d2bSsstefan1     return HasChanged;
1353b8235d2bSsstefan1   }
1354b8235d2bSsstefan1 
1355b0b32e64Ssstefan1   /// Hepler to check if \p I is a call and get the value for it if it is
1356b0b32e64Ssstefan1   /// unique.
1357b0b32e64Ssstefan1   Optional<Value *> getValueForCall(Attributor &A, const Instruction *I,
1358b0b32e64Ssstefan1                                     InternalControlVar &ICV) const {
1359b8235d2bSsstefan1 
1360b0b32e64Ssstefan1     const auto *CB = dyn_cast<CallBase>(I);
1361b0b32e64Ssstefan1     if (!CB)
1362b0b32e64Ssstefan1       return None;
1363b0b32e64Ssstefan1 
1364b8235d2bSsstefan1     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
1365b8235d2bSsstefan1     auto &GetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Getter];
1366b0b32e64Ssstefan1     auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
1367b0b32e64Ssstefan1     Function *CalledFunction = CB->getCalledFunction();
1368b8235d2bSsstefan1 
1369b0b32e64Ssstefan1     if (CalledFunction == GetterRFI.Declaration)
1370b0b32e64Ssstefan1       return None;
1371b0b32e64Ssstefan1     if (CalledFunction == SetterRFI.Declaration) {
1372b0b32e64Ssstefan1       if (ICVReplacementValuesMap[ICV].count(I))
1373b0b32e64Ssstefan1         return ICVReplacementValuesMap[ICV].lookup(I);
1374b0b32e64Ssstefan1 
1375b0b32e64Ssstefan1       return nullptr;
1376b0b32e64Ssstefan1     }
1377b0b32e64Ssstefan1 
1378b0b32e64Ssstefan1     // Since we don't know, assume it changes the ICV.
1379b0b32e64Ssstefan1     if (CalledFunction->isDeclaration())
1380b0b32e64Ssstefan1       return nullptr;
1381b0b32e64Ssstefan1 
1382b0b32e64Ssstefan1     const auto &ICVTrackingAA =
1383b0b32e64Ssstefan1         A.getAAFor<AAICVTracker>(*this, IRPosition::callsite_returned(*CB));
1384b0b32e64Ssstefan1 
1385b0b32e64Ssstefan1     if (ICVTrackingAA.isAssumedTracked())
1386b0b32e64Ssstefan1       return ICVTrackingAA.getUniqueReplacementValue(ICV);
1387b0b32e64Ssstefan1 
1388b0b32e64Ssstefan1     // If we don't know, assume it changes.
1389b0b32e64Ssstefan1     return nullptr;
1390b0b32e64Ssstefan1   }
1391b0b32e64Ssstefan1 
1392b0b32e64Ssstefan1   // We don't check unique value for a function, so return None.
1393b0b32e64Ssstefan1   Optional<Value *>
1394b0b32e64Ssstefan1   getUniqueReplacementValue(InternalControlVar ICV) const override {
1395b0b32e64Ssstefan1     return None;
1396b0b32e64Ssstefan1   }
1397b0b32e64Ssstefan1 
1398b0b32e64Ssstefan1   /// Return the value with which \p I can be replaced for specific \p ICV.
1399b0b32e64Ssstefan1   Optional<Value *> getReplacementValue(InternalControlVar ICV,
1400b0b32e64Ssstefan1                                         const Instruction *I,
1401b0b32e64Ssstefan1                                         Attributor &A) const override {
1402b0b32e64Ssstefan1     const auto &ValuesMap = ICVReplacementValuesMap[ICV];
1403b0b32e64Ssstefan1     if (ValuesMap.count(I))
1404b0b32e64Ssstefan1       return ValuesMap.lookup(I);
1405b0b32e64Ssstefan1 
1406b0b32e64Ssstefan1     SmallVector<const Instruction *, 16> Worklist;
1407b0b32e64Ssstefan1     SmallPtrSet<const Instruction *, 16> Visited;
1408b0b32e64Ssstefan1     Worklist.push_back(I);
1409b0b32e64Ssstefan1 
1410b0b32e64Ssstefan1     Optional<Value *> ReplVal;
1411b0b32e64Ssstefan1 
1412b0b32e64Ssstefan1     while (!Worklist.empty()) {
1413b0b32e64Ssstefan1       const Instruction *CurrInst = Worklist.pop_back_val();
1414b0b32e64Ssstefan1       if (!Visited.insert(CurrInst).second)
1415b8235d2bSsstefan1         continue;
1416b8235d2bSsstefan1 
1417b0b32e64Ssstefan1       const BasicBlock *CurrBB = CurrInst->getParent();
1418b0b32e64Ssstefan1 
1419b0b32e64Ssstefan1       // Go up and look for all potential setters/calls that might change the
1420b0b32e64Ssstefan1       // ICV.
1421b0b32e64Ssstefan1       while ((CurrInst = CurrInst->getPrevNode())) {
1422b0b32e64Ssstefan1         if (ValuesMap.count(CurrInst)) {
1423b0b32e64Ssstefan1           Optional<Value *> NewReplVal = ValuesMap.lookup(CurrInst);
1424b0b32e64Ssstefan1           // Unknown value, track new.
1425b0b32e64Ssstefan1           if (!ReplVal.hasValue()) {
1426b0b32e64Ssstefan1             ReplVal = NewReplVal;
1427b0b32e64Ssstefan1             break;
1428b0b32e64Ssstefan1           }
1429b0b32e64Ssstefan1 
1430b0b32e64Ssstefan1           // If we found a new value, we can't know the icv value anymore.
1431b0b32e64Ssstefan1           if (NewReplVal.hasValue())
1432b0b32e64Ssstefan1             if (ReplVal != NewReplVal)
1433b8235d2bSsstefan1               return nullptr;
1434b8235d2bSsstefan1 
1435b0b32e64Ssstefan1           break;
1436b8235d2bSsstefan1         }
1437b8235d2bSsstefan1 
1438b0b32e64Ssstefan1         Optional<Value *> NewReplVal = getValueForCall(A, CurrInst, ICV);
1439b0b32e64Ssstefan1         if (!NewReplVal.hasValue())
1440b0b32e64Ssstefan1           continue;
1441b0b32e64Ssstefan1 
1442b0b32e64Ssstefan1         // Unknown value, track new.
1443b0b32e64Ssstefan1         if (!ReplVal.hasValue()) {
1444b0b32e64Ssstefan1           ReplVal = NewReplVal;
1445b0b32e64Ssstefan1           break;
1446b8235d2bSsstefan1         }
1447b8235d2bSsstefan1 
1448b0b32e64Ssstefan1         // if (NewReplVal.hasValue())
1449b0b32e64Ssstefan1         // We found a new value, we can't know the icv value anymore.
1450b0b32e64Ssstefan1         if (ReplVal != NewReplVal)
1451b8235d2bSsstefan1           return nullptr;
1452b8235d2bSsstefan1       }
1453b0b32e64Ssstefan1 
1454b0b32e64Ssstefan1       // If we are in the same BB and we have a value, we are done.
1455b0b32e64Ssstefan1       if (CurrBB == I->getParent() && ReplVal.hasValue())
1456b0b32e64Ssstefan1         return ReplVal;
1457b0b32e64Ssstefan1 
1458b0b32e64Ssstefan1       // Go through all predecessors and add terminators for analysis.
1459b0b32e64Ssstefan1       for (const BasicBlock *Pred : predecessors(CurrBB))
1460b0b32e64Ssstefan1         if (const Instruction *Terminator = Pred->getTerminator())
1461b0b32e64Ssstefan1           Worklist.push_back(Terminator);
1462b0b32e64Ssstefan1     }
1463b0b32e64Ssstefan1 
1464b0b32e64Ssstefan1     return ReplVal;
1465b0b32e64Ssstefan1   }
1466b0b32e64Ssstefan1 };
1467b0b32e64Ssstefan1 
1468b0b32e64Ssstefan1 struct AAICVTrackerFunctionReturned : AAICVTracker {
1469b0b32e64Ssstefan1   AAICVTrackerFunctionReturned(const IRPosition &IRP, Attributor &A)
1470b0b32e64Ssstefan1       : AAICVTracker(IRP, A) {}
1471b0b32e64Ssstefan1 
1472b0b32e64Ssstefan1   // FIXME: come up with better string.
1473b0b32e64Ssstefan1   const std::string getAsStr() const override {
1474b0b32e64Ssstefan1     return "ICVTrackerFunctionReturned";
1475b0b32e64Ssstefan1   }
1476b0b32e64Ssstefan1 
1477b0b32e64Ssstefan1   // FIXME: come up with some stats.
1478b0b32e64Ssstefan1   void trackStatistics() const override {}
1479b0b32e64Ssstefan1 
1480b0b32e64Ssstefan1   /// We don't manifest anything for this AA.
1481b0b32e64Ssstefan1   ChangeStatus manifest(Attributor &A) override {
1482b0b32e64Ssstefan1     return ChangeStatus::UNCHANGED;
1483b0b32e64Ssstefan1   }
1484b0b32e64Ssstefan1 
1485b0b32e64Ssstefan1   // Map of ICV to their values at specific program point.
1486b0b32e64Ssstefan1   EnumeratedArray<Optional<Value *>, InternalControlVar,
1487b0b32e64Ssstefan1                   InternalControlVar::ICV___last>
1488b0b32e64Ssstefan1       ICVReplacementValuesMap;
1489b0b32e64Ssstefan1 
1490b0b32e64Ssstefan1   /// Return the value with which \p I can be replaced for specific \p ICV.
1491b0b32e64Ssstefan1   Optional<Value *>
1492b0b32e64Ssstefan1   getUniqueReplacementValue(InternalControlVar ICV) const override {
1493b0b32e64Ssstefan1     return ICVReplacementValuesMap[ICV];
1494b0b32e64Ssstefan1   }
1495b0b32e64Ssstefan1 
1496b0b32e64Ssstefan1   ChangeStatus updateImpl(Attributor &A) override {
1497b0b32e64Ssstefan1     ChangeStatus Changed = ChangeStatus::UNCHANGED;
1498b0b32e64Ssstefan1     const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
1499b0b32e64Ssstefan1         *this, IRPosition::function(*getAnchorScope()));
1500b0b32e64Ssstefan1 
1501b0b32e64Ssstefan1     if (!ICVTrackingAA.isAssumedTracked())
1502b0b32e64Ssstefan1       return indicatePessimisticFixpoint();
1503b0b32e64Ssstefan1 
1504b0b32e64Ssstefan1     for (InternalControlVar ICV : TrackableICVs) {
1505b0b32e64Ssstefan1       Optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
1506b0b32e64Ssstefan1       Optional<Value *> UniqueICVValue;
1507b0b32e64Ssstefan1 
1508b0b32e64Ssstefan1       auto CheckReturnInst = [&](Instruction &I) {
1509b0b32e64Ssstefan1         Optional<Value *> NewReplVal =
1510b0b32e64Ssstefan1             ICVTrackingAA.getReplacementValue(ICV, &I, A);
1511b0b32e64Ssstefan1 
1512b0b32e64Ssstefan1         // If we found a second ICV value there is no unique returned value.
1513b0b32e64Ssstefan1         if (UniqueICVValue.hasValue() && UniqueICVValue != NewReplVal)
1514b0b32e64Ssstefan1           return false;
1515b0b32e64Ssstefan1 
1516b0b32e64Ssstefan1         UniqueICVValue = NewReplVal;
1517b0b32e64Ssstefan1 
1518b0b32e64Ssstefan1         return true;
1519b0b32e64Ssstefan1       };
1520b0b32e64Ssstefan1 
1521b0b32e64Ssstefan1       if (!A.checkForAllInstructions(CheckReturnInst, *this, {Instruction::Ret},
1522b0b32e64Ssstefan1                                      /* CheckBBLivenessOnly */ true))
1523b0b32e64Ssstefan1         UniqueICVValue = nullptr;
1524b0b32e64Ssstefan1 
1525b0b32e64Ssstefan1       if (UniqueICVValue == ReplVal)
1526b0b32e64Ssstefan1         continue;
1527b0b32e64Ssstefan1 
1528b0b32e64Ssstefan1       ReplVal = UniqueICVValue;
1529b0b32e64Ssstefan1       Changed = ChangeStatus::CHANGED;
1530b0b32e64Ssstefan1     }
1531b0b32e64Ssstefan1 
1532b0b32e64Ssstefan1     return Changed;
1533b0b32e64Ssstefan1   }
1534b0b32e64Ssstefan1 };
1535b0b32e64Ssstefan1 
1536b0b32e64Ssstefan1 struct AAICVTrackerCallSite : AAICVTracker {
1537b0b32e64Ssstefan1   AAICVTrackerCallSite(const IRPosition &IRP, Attributor &A)
1538b0b32e64Ssstefan1       : AAICVTracker(IRP, A) {}
1539b0b32e64Ssstefan1 
1540b0b32e64Ssstefan1   void initialize(Attributor &A) override {
1541b0b32e64Ssstefan1     Function *F = getAnchorScope();
1542b0b32e64Ssstefan1     if (!F || !A.isFunctionIPOAmendable(*F))
1543b0b32e64Ssstefan1       indicatePessimisticFixpoint();
1544b0b32e64Ssstefan1 
1545b0b32e64Ssstefan1     // We only initialize this AA for getters, so we need to know which ICV it
1546b0b32e64Ssstefan1     // gets.
1547b0b32e64Ssstefan1     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
1548b0b32e64Ssstefan1     for (InternalControlVar ICV : TrackableICVs) {
1549b0b32e64Ssstefan1       auto ICVInfo = OMPInfoCache.ICVs[ICV];
1550b0b32e64Ssstefan1       auto &Getter = OMPInfoCache.RFIs[ICVInfo.Getter];
1551b0b32e64Ssstefan1       if (Getter.Declaration == getAssociatedFunction()) {
1552b0b32e64Ssstefan1         AssociatedICV = ICVInfo.Kind;
1553b0b32e64Ssstefan1         return;
1554b0b32e64Ssstefan1       }
1555b0b32e64Ssstefan1     }
1556b0b32e64Ssstefan1 
1557b0b32e64Ssstefan1     /// Unknown ICV.
1558b0b32e64Ssstefan1     indicatePessimisticFixpoint();
1559b0b32e64Ssstefan1   }
1560b0b32e64Ssstefan1 
1561b0b32e64Ssstefan1   ChangeStatus manifest(Attributor &A) override {
1562b0b32e64Ssstefan1     if (!ReplVal.hasValue() || !ReplVal.getValue())
1563b0b32e64Ssstefan1       return ChangeStatus::UNCHANGED;
1564b0b32e64Ssstefan1 
1565b0b32e64Ssstefan1     A.changeValueAfterManifest(*getCtxI(), **ReplVal);
1566b0b32e64Ssstefan1     A.deleteAfterManifest(*getCtxI());
1567b0b32e64Ssstefan1 
1568b0b32e64Ssstefan1     return ChangeStatus::CHANGED;
1569b0b32e64Ssstefan1   }
1570b0b32e64Ssstefan1 
1571b0b32e64Ssstefan1   // FIXME: come up with better string.
1572b0b32e64Ssstefan1   const std::string getAsStr() const override { return "ICVTrackerCallSite"; }
1573b0b32e64Ssstefan1 
1574b0b32e64Ssstefan1   // FIXME: come up with some stats.
1575b0b32e64Ssstefan1   void trackStatistics() const override {}
1576b0b32e64Ssstefan1 
1577b0b32e64Ssstefan1   InternalControlVar AssociatedICV;
1578b0b32e64Ssstefan1   Optional<Value *> ReplVal;
1579b0b32e64Ssstefan1 
1580b0b32e64Ssstefan1   ChangeStatus updateImpl(Attributor &A) override {
1581b0b32e64Ssstefan1     const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
1582b0b32e64Ssstefan1         *this, IRPosition::function(*getAnchorScope()));
1583b0b32e64Ssstefan1 
1584b0b32e64Ssstefan1     // We don't have any information, so we assume it changes the ICV.
1585b0b32e64Ssstefan1     if (!ICVTrackingAA.isAssumedTracked())
1586b0b32e64Ssstefan1       return indicatePessimisticFixpoint();
1587b0b32e64Ssstefan1 
1588b0b32e64Ssstefan1     Optional<Value *> NewReplVal =
1589b0b32e64Ssstefan1         ICVTrackingAA.getReplacementValue(AssociatedICV, getCtxI(), A);
1590b0b32e64Ssstefan1 
1591b0b32e64Ssstefan1     if (ReplVal == NewReplVal)
1592b0b32e64Ssstefan1       return ChangeStatus::UNCHANGED;
1593b0b32e64Ssstefan1 
1594b0b32e64Ssstefan1     ReplVal = NewReplVal;
1595b0b32e64Ssstefan1     return ChangeStatus::CHANGED;
1596b0b32e64Ssstefan1   }
1597b0b32e64Ssstefan1 
1598b0b32e64Ssstefan1   // Return the value with which associated value can be replaced for specific
1599b0b32e64Ssstefan1   // \p ICV.
1600b0b32e64Ssstefan1   Optional<Value *>
1601b0b32e64Ssstefan1   getUniqueReplacementValue(InternalControlVar ICV) const override {
1602b0b32e64Ssstefan1     return ReplVal;
1603b0b32e64Ssstefan1   }
1604b0b32e64Ssstefan1 };
1605b0b32e64Ssstefan1 
1606b0b32e64Ssstefan1 struct AAICVTrackerCallSiteReturned : AAICVTracker {
1607b0b32e64Ssstefan1   AAICVTrackerCallSiteReturned(const IRPosition &IRP, Attributor &A)
1608b0b32e64Ssstefan1       : AAICVTracker(IRP, A) {}
1609b0b32e64Ssstefan1 
1610b0b32e64Ssstefan1   // FIXME: come up with better string.
1611b0b32e64Ssstefan1   const std::string getAsStr() const override {
1612b0b32e64Ssstefan1     return "ICVTrackerCallSiteReturned";
1613b0b32e64Ssstefan1   }
1614b0b32e64Ssstefan1 
1615b0b32e64Ssstefan1   // FIXME: come up with some stats.
1616b0b32e64Ssstefan1   void trackStatistics() const override {}
1617b0b32e64Ssstefan1 
1618b0b32e64Ssstefan1   /// We don't manifest anything for this AA.
1619b0b32e64Ssstefan1   ChangeStatus manifest(Attributor &A) override {
1620b0b32e64Ssstefan1     return ChangeStatus::UNCHANGED;
1621b0b32e64Ssstefan1   }
1622b0b32e64Ssstefan1 
1623b0b32e64Ssstefan1   // Map of ICV to their values at specific program point.
1624b0b32e64Ssstefan1   EnumeratedArray<Optional<Value *>, InternalControlVar,
1625b0b32e64Ssstefan1                   InternalControlVar::ICV___last>
1626b0b32e64Ssstefan1       ICVReplacementValuesMap;
1627b0b32e64Ssstefan1 
1628b0b32e64Ssstefan1   /// Return the value with which associated value can be replaced for specific
1629b0b32e64Ssstefan1   /// \p ICV.
1630b0b32e64Ssstefan1   Optional<Value *>
1631b0b32e64Ssstefan1   getUniqueReplacementValue(InternalControlVar ICV) const override {
1632b0b32e64Ssstefan1     return ICVReplacementValuesMap[ICV];
1633b0b32e64Ssstefan1   }
1634b0b32e64Ssstefan1 
1635b0b32e64Ssstefan1   ChangeStatus updateImpl(Attributor &A) override {
1636b0b32e64Ssstefan1     ChangeStatus Changed = ChangeStatus::UNCHANGED;
1637b0b32e64Ssstefan1     const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
1638b0b32e64Ssstefan1         *this, IRPosition::returned(*getAssociatedFunction()));
1639b0b32e64Ssstefan1 
1640b0b32e64Ssstefan1     // We don't have any information, so we assume it changes the ICV.
1641b0b32e64Ssstefan1     if (!ICVTrackingAA.isAssumedTracked())
1642b0b32e64Ssstefan1       return indicatePessimisticFixpoint();
1643b0b32e64Ssstefan1 
1644b0b32e64Ssstefan1     for (InternalControlVar ICV : TrackableICVs) {
1645b0b32e64Ssstefan1       Optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
1646b0b32e64Ssstefan1       Optional<Value *> NewReplVal =
1647b0b32e64Ssstefan1           ICVTrackingAA.getUniqueReplacementValue(ICV);
1648b0b32e64Ssstefan1 
1649b0b32e64Ssstefan1       if (ReplVal == NewReplVal)
1650b0b32e64Ssstefan1         continue;
1651b0b32e64Ssstefan1 
1652b0b32e64Ssstefan1       ReplVal = NewReplVal;
1653b0b32e64Ssstefan1       Changed = ChangeStatus::CHANGED;
1654b0b32e64Ssstefan1     }
1655b0b32e64Ssstefan1     return Changed;
1656b0b32e64Ssstefan1   }
16579548b74aSJohannes Doerfert };
16589548b74aSJohannes Doerfert } // namespace
16599548b74aSJohannes Doerfert 
1660b8235d2bSsstefan1 const char AAICVTracker::ID = 0;
1661b8235d2bSsstefan1 
1662b8235d2bSsstefan1 AAICVTracker &AAICVTracker::createForPosition(const IRPosition &IRP,
1663b8235d2bSsstefan1                                               Attributor &A) {
1664b8235d2bSsstefan1   AAICVTracker *AA = nullptr;
1665b8235d2bSsstefan1   switch (IRP.getPositionKind()) {
1666b8235d2bSsstefan1   case IRPosition::IRP_INVALID:
1667b8235d2bSsstefan1   case IRPosition::IRP_FLOAT:
1668b8235d2bSsstefan1   case IRPosition::IRP_ARGUMENT:
1669b8235d2bSsstefan1   case IRPosition::IRP_CALL_SITE_ARGUMENT:
1670b0b32e64Ssstefan1     llvm_unreachable("ICVTracker: invalid IRPosition!");
1671b8235d2bSsstefan1   case IRPosition::IRP_FUNCTION:
1672b8235d2bSsstefan1     AA = new (A.Allocator) AAICVTrackerFunction(IRP, A);
1673b8235d2bSsstefan1     break;
1674b0b32e64Ssstefan1   case IRPosition::IRP_RETURNED:
1675b0b32e64Ssstefan1     AA = new (A.Allocator) AAICVTrackerFunctionReturned(IRP, A);
1676b0b32e64Ssstefan1     break;
1677b0b32e64Ssstefan1   case IRPosition::IRP_CALL_SITE_RETURNED:
1678b0b32e64Ssstefan1     AA = new (A.Allocator) AAICVTrackerCallSiteReturned(IRP, A);
1679b0b32e64Ssstefan1     break;
1680b0b32e64Ssstefan1   case IRPosition::IRP_CALL_SITE:
1681b0b32e64Ssstefan1     AA = new (A.Allocator) AAICVTrackerCallSite(IRP, A);
1682b0b32e64Ssstefan1     break;
1683b8235d2bSsstefan1   }
1684b8235d2bSsstefan1 
1685b8235d2bSsstefan1   return *AA;
1686b8235d2bSsstefan1 }
1687b8235d2bSsstefan1 
16889548b74aSJohannes Doerfert PreservedAnalyses OpenMPOptPass::run(LazyCallGraph::SCC &C,
16899548b74aSJohannes Doerfert                                      CGSCCAnalysisManager &AM,
16909548b74aSJohannes Doerfert                                      LazyCallGraph &CG, CGSCCUpdateResult &UR) {
16919548b74aSJohannes Doerfert   if (!containsOpenMP(*C.begin()->getFunction().getParent(), OMPInModule))
16929548b74aSJohannes Doerfert     return PreservedAnalyses::all();
16939548b74aSJohannes Doerfert 
16949548b74aSJohannes Doerfert   if (DisableOpenMPOptimizations)
16959548b74aSJohannes Doerfert     return PreservedAnalyses::all();
16969548b74aSJohannes Doerfert 
1697ee17263aSJohannes Doerfert   SmallVector<Function *, 16> SCC;
1698351d234dSRoman Lebedev   // If there are kernels in the module, we have to run on all SCC's.
1699351d234dSRoman Lebedev   bool SCCIsInteresting = !OMPInModule.getKernels().empty();
1700351d234dSRoman Lebedev   for (LazyCallGraph::Node &N : C) {
1701351d234dSRoman Lebedev     Function *Fn = &N.getFunction();
1702351d234dSRoman Lebedev     SCC.push_back(Fn);
17039548b74aSJohannes Doerfert 
1704351d234dSRoman Lebedev     // Do we already know that the SCC contains kernels,
1705351d234dSRoman Lebedev     // or that OpenMP functions are called from this SCC?
1706351d234dSRoman Lebedev     if (SCCIsInteresting)
1707351d234dSRoman Lebedev       continue;
1708351d234dSRoman Lebedev     // If not, let's check that.
1709351d234dSRoman Lebedev     SCCIsInteresting |= OMPInModule.containsOMPRuntimeCalls(Fn);
1710351d234dSRoman Lebedev   }
1711351d234dSRoman Lebedev 
1712351d234dSRoman Lebedev   if (!SCCIsInteresting || SCC.empty())
17139548b74aSJohannes Doerfert     return PreservedAnalyses::all();
17149548b74aSJohannes Doerfert 
17154d4ea9acSHuber, Joseph   FunctionAnalysisManager &FAM =
17164d4ea9acSHuber, Joseph       AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager();
17177cfd267cSsstefan1 
17187cfd267cSsstefan1   AnalysisGetter AG(FAM);
17197cfd267cSsstefan1 
17207cfd267cSsstefan1   auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & {
17214d4ea9acSHuber, Joseph     return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F);
17224d4ea9acSHuber, Joseph   };
17234d4ea9acSHuber, Joseph 
17249548b74aSJohannes Doerfert   CallGraphUpdater CGUpdater;
17259548b74aSJohannes Doerfert   CGUpdater.initialize(CG, C, AM, UR);
17267cfd267cSsstefan1 
17277cfd267cSsstefan1   SetVector<Function *> Functions(SCC.begin(), SCC.end());
17287cfd267cSsstefan1   BumpPtrAllocator Allocator;
17297cfd267cSsstefan1   OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG, Allocator,
1730624d34afSJohannes Doerfert                                 /*CGSCC*/ Functions, OMPInModule.getKernels());
17317cfd267cSsstefan1 
1732b0b32e64Ssstefan1   SetVector<Function *> ModuleSlice(InfoCache.ModuleSlice.begin(),
1733b0b32e64Ssstefan1                                     InfoCache.ModuleSlice.end());
1734b0b32e64Ssstefan1   Attributor A(ModuleSlice, InfoCache, CGUpdater);
1735b8235d2bSsstefan1 
17369548b74aSJohannes Doerfert   // TODO: Compute the module slice we are allowed to look at.
1737b8235d2bSsstefan1   OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
17389548b74aSJohannes Doerfert   bool Changed = OMPOpt.run();
1739694ded37SGiorgis Georgakoudis   if (Changed)
1740694ded37SGiorgis Georgakoudis     return PreservedAnalyses::none();
1741694ded37SGiorgis Georgakoudis 
17429548b74aSJohannes Doerfert   return PreservedAnalyses::all();
17439548b74aSJohannes Doerfert }
17449548b74aSJohannes Doerfert 
17459548b74aSJohannes Doerfert namespace {
17469548b74aSJohannes Doerfert 
17479548b74aSJohannes Doerfert struct OpenMPOptLegacyPass : public CallGraphSCCPass {
17489548b74aSJohannes Doerfert   CallGraphUpdater CGUpdater;
17499548b74aSJohannes Doerfert   OpenMPInModule OMPInModule;
17509548b74aSJohannes Doerfert   static char ID;
17519548b74aSJohannes Doerfert 
17529548b74aSJohannes Doerfert   OpenMPOptLegacyPass() : CallGraphSCCPass(ID) {
17539548b74aSJohannes Doerfert     initializeOpenMPOptLegacyPassPass(*PassRegistry::getPassRegistry());
17549548b74aSJohannes Doerfert   }
17559548b74aSJohannes Doerfert 
17569548b74aSJohannes Doerfert   void getAnalysisUsage(AnalysisUsage &AU) const override {
17579548b74aSJohannes Doerfert     CallGraphSCCPass::getAnalysisUsage(AU);
17589548b74aSJohannes Doerfert   }
17599548b74aSJohannes Doerfert 
17609548b74aSJohannes Doerfert   bool doInitialization(CallGraph &CG) override {
17619548b74aSJohannes Doerfert     // Disable the pass if there is no OpenMP (runtime call) in the module.
17629548b74aSJohannes Doerfert     containsOpenMP(CG.getModule(), OMPInModule);
17639548b74aSJohannes Doerfert     return false;
17649548b74aSJohannes Doerfert   }
17659548b74aSJohannes Doerfert 
17669548b74aSJohannes Doerfert   bool runOnSCC(CallGraphSCC &CGSCC) override {
17679548b74aSJohannes Doerfert     if (!containsOpenMP(CGSCC.getCallGraph().getModule(), OMPInModule))
17689548b74aSJohannes Doerfert       return false;
17699548b74aSJohannes Doerfert     if (DisableOpenMPOptimizations || skipSCC(CGSCC))
17709548b74aSJohannes Doerfert       return false;
17719548b74aSJohannes Doerfert 
1772ee17263aSJohannes Doerfert     SmallVector<Function *, 16> SCC;
1773351d234dSRoman Lebedev     // If there are kernels in the module, we have to run on all SCC's.
1774351d234dSRoman Lebedev     bool SCCIsInteresting = !OMPInModule.getKernels().empty();
1775351d234dSRoman Lebedev     for (CallGraphNode *CGN : CGSCC) {
1776351d234dSRoman Lebedev       Function *Fn = CGN->getFunction();
1777351d234dSRoman Lebedev       if (!Fn || Fn->isDeclaration())
1778351d234dSRoman Lebedev         continue;
1779ee17263aSJohannes Doerfert       SCC.push_back(Fn);
17809548b74aSJohannes Doerfert 
1781351d234dSRoman Lebedev       // Do we already know that the SCC contains kernels,
1782351d234dSRoman Lebedev       // or that OpenMP functions are called from this SCC?
1783351d234dSRoman Lebedev       if (SCCIsInteresting)
1784351d234dSRoman Lebedev         continue;
1785351d234dSRoman Lebedev       // If not, let's check that.
1786351d234dSRoman Lebedev       SCCIsInteresting |= OMPInModule.containsOMPRuntimeCalls(Fn);
1787351d234dSRoman Lebedev     }
1788351d234dSRoman Lebedev 
1789351d234dSRoman Lebedev     if (!SCCIsInteresting || SCC.empty())
17909548b74aSJohannes Doerfert       return false;
17919548b74aSJohannes Doerfert 
17929548b74aSJohannes Doerfert     CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph();
17939548b74aSJohannes Doerfert     CGUpdater.initialize(CG, CGSCC);
17949548b74aSJohannes Doerfert 
17954d4ea9acSHuber, Joseph     // Maintain a map of functions to avoid rebuilding the ORE
17964d4ea9acSHuber, Joseph     DenseMap<Function *, std::unique_ptr<OptimizationRemarkEmitter>> OREMap;
17974d4ea9acSHuber, Joseph     auto OREGetter = [&OREMap](Function *F) -> OptimizationRemarkEmitter & {
17984d4ea9acSHuber, Joseph       std::unique_ptr<OptimizationRemarkEmitter> &ORE = OREMap[F];
17994d4ea9acSHuber, Joseph       if (!ORE)
18004d4ea9acSHuber, Joseph         ORE = std::make_unique<OptimizationRemarkEmitter>(F);
18014d4ea9acSHuber, Joseph       return *ORE;
18024d4ea9acSHuber, Joseph     };
18034d4ea9acSHuber, Joseph 
18047cfd267cSsstefan1     AnalysisGetter AG;
18057cfd267cSsstefan1     SetVector<Function *> Functions(SCC.begin(), SCC.end());
18067cfd267cSsstefan1     BumpPtrAllocator Allocator;
1807e8039ad4SJohannes Doerfert     OMPInformationCache InfoCache(
1808e8039ad4SJohannes Doerfert         *(Functions.back()->getParent()), AG, Allocator,
1809624d34afSJohannes Doerfert         /*CGSCC*/ Functions, OMPInModule.getKernels());
18107cfd267cSsstefan1 
1811b0b32e64Ssstefan1     SetVector<Function *> ModuleSlice(InfoCache.ModuleSlice.begin(),
1812b0b32e64Ssstefan1                                       InfoCache.ModuleSlice.end());
1813b0b32e64Ssstefan1     Attributor A(ModuleSlice, InfoCache, CGUpdater);
1814b8235d2bSsstefan1 
18159548b74aSJohannes Doerfert     // TODO: Compute the module slice we are allowed to look at.
1816b8235d2bSsstefan1     OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
18179548b74aSJohannes Doerfert     return OMPOpt.run();
18189548b74aSJohannes Doerfert   }
18199548b74aSJohannes Doerfert 
18209548b74aSJohannes Doerfert   bool doFinalization(CallGraph &CG) override { return CGUpdater.finalize(); }
18219548b74aSJohannes Doerfert };
18229548b74aSJohannes Doerfert 
18239548b74aSJohannes Doerfert } // end anonymous namespace
18249548b74aSJohannes Doerfert 
1825e8039ad4SJohannes Doerfert void OpenMPInModule::identifyKernels(Module &M) {
1826e8039ad4SJohannes Doerfert 
1827e8039ad4SJohannes Doerfert   NamedMDNode *MD = M.getOrInsertNamedMetadata("nvvm.annotations");
1828e8039ad4SJohannes Doerfert   if (!MD)
1829e8039ad4SJohannes Doerfert     return;
1830e8039ad4SJohannes Doerfert 
1831e8039ad4SJohannes Doerfert   for (auto *Op : MD->operands()) {
1832e8039ad4SJohannes Doerfert     if (Op->getNumOperands() < 2)
1833e8039ad4SJohannes Doerfert       continue;
1834e8039ad4SJohannes Doerfert     MDString *KindID = dyn_cast<MDString>(Op->getOperand(1));
1835e8039ad4SJohannes Doerfert     if (!KindID || KindID->getString() != "kernel")
1836e8039ad4SJohannes Doerfert       continue;
1837e8039ad4SJohannes Doerfert 
1838e8039ad4SJohannes Doerfert     Function *KernelFn =
1839e8039ad4SJohannes Doerfert         mdconst::dyn_extract_or_null<Function>(Op->getOperand(0));
1840e8039ad4SJohannes Doerfert     if (!KernelFn)
1841e8039ad4SJohannes Doerfert       continue;
1842e8039ad4SJohannes Doerfert 
1843e8039ad4SJohannes Doerfert     ++NumOpenMPTargetRegionKernels;
1844e8039ad4SJohannes Doerfert 
1845e8039ad4SJohannes Doerfert     Kernels.insert(KernelFn);
1846e8039ad4SJohannes Doerfert   }
1847e8039ad4SJohannes Doerfert }
1848e8039ad4SJohannes Doerfert 
18499548b74aSJohannes Doerfert bool llvm::omp::containsOpenMP(Module &M, OpenMPInModule &OMPInModule) {
18509548b74aSJohannes Doerfert   if (OMPInModule.isKnown())
18519548b74aSJohannes Doerfert     return OMPInModule;
1852dce6bc18SJohannes Doerfert 
1853351d234dSRoman Lebedev   auto RecordFunctionsContainingUsesOf = [&](Function *F) {
1854351d234dSRoman Lebedev     for (User *U : F->users())
1855351d234dSRoman Lebedev       if (auto *I = dyn_cast<Instruction>(U))
1856351d234dSRoman Lebedev         OMPInModule.FuncsWithOMPRuntimeCalls.insert(I->getFunction());
1857351d234dSRoman Lebedev   };
1858351d234dSRoman Lebedev 
1859dce6bc18SJohannes Doerfert   // MSVC doesn't like long if-else chains for some reason and instead just
1860dce6bc18SJohannes Doerfert   // issues an error. Work around it..
1861dce6bc18SJohannes Doerfert   do {
18629548b74aSJohannes Doerfert #define OMP_RTL(_Enum, _Name, ...)                                             \
1863351d234dSRoman Lebedev   if (Function *F = M.getFunction(_Name)) {                                    \
1864351d234dSRoman Lebedev     RecordFunctionsContainingUsesOf(F);                                        \
1865dce6bc18SJohannes Doerfert     OMPInModule = true;                                                        \
1866dce6bc18SJohannes Doerfert   }
18679548b74aSJohannes Doerfert #include "llvm/Frontend/OpenMP/OMPKinds.def"
1868dce6bc18SJohannes Doerfert   } while (false);
1869e8039ad4SJohannes Doerfert 
1870e8039ad4SJohannes Doerfert   // Identify kernels once. TODO: We should split the OMPInformationCache into a
1871e8039ad4SJohannes Doerfert   // module and an SCC part. The kernel information, among other things, could
1872e8039ad4SJohannes Doerfert   // go into the module part.
1873e8039ad4SJohannes Doerfert   if (OMPInModule.isKnown() && OMPInModule) {
1874e8039ad4SJohannes Doerfert     OMPInModule.identifyKernels(M);
1875e8039ad4SJohannes Doerfert     return true;
1876e8039ad4SJohannes Doerfert   }
1877e8039ad4SJohannes Doerfert 
18789548b74aSJohannes Doerfert   return OMPInModule = false;
18799548b74aSJohannes Doerfert }
18809548b74aSJohannes Doerfert 
18819548b74aSJohannes Doerfert char OpenMPOptLegacyPass::ID = 0;
18829548b74aSJohannes Doerfert 
18839548b74aSJohannes Doerfert INITIALIZE_PASS_BEGIN(OpenMPOptLegacyPass, "openmpopt",
18849548b74aSJohannes Doerfert                       "OpenMP specific optimizations", false, false)
18859548b74aSJohannes Doerfert INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass)
18869548b74aSJohannes Doerfert INITIALIZE_PASS_END(OpenMPOptLegacyPass, "openmpopt",
18879548b74aSJohannes Doerfert                     "OpenMP specific optimizations", false, false)
18889548b74aSJohannes Doerfert 
18899548b74aSJohannes Doerfert Pass *llvm::createOpenMPOptLegacyPass() { return new OpenMPOptLegacyPass(); }
1890