19548b74aSJohannes Doerfert //===-- IPO/OpenMPOpt.cpp - Collection of OpenMP specific optimizations ---===//
29548b74aSJohannes Doerfert //
39548b74aSJohannes Doerfert // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
49548b74aSJohannes Doerfert // See https://llvm.org/LICENSE.txt for license information.
59548b74aSJohannes Doerfert // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
69548b74aSJohannes Doerfert //
79548b74aSJohannes Doerfert //===----------------------------------------------------------------------===//
89548b74aSJohannes Doerfert //
99548b74aSJohannes Doerfert // OpenMP specific optimizations:
109548b74aSJohannes Doerfert //
119548b74aSJohannes Doerfert // - Deduplication of runtime calls, e.g., omp_get_thread_num.
129548b74aSJohannes Doerfert //
139548b74aSJohannes Doerfert //===----------------------------------------------------------------------===//
149548b74aSJohannes Doerfert 
159548b74aSJohannes Doerfert #include "llvm/Transforms/IPO/OpenMPOpt.h"
169548b74aSJohannes Doerfert 
179548b74aSJohannes Doerfert #include "llvm/ADT/EnumeratedArray.h"
1818283125SJoseph Huber #include "llvm/ADT/PostOrderIterator.h"
199548b74aSJohannes Doerfert #include "llvm/ADT/Statistic.h"
209548b74aSJohannes Doerfert #include "llvm/Analysis/CallGraph.h"
219548b74aSJohannes Doerfert #include "llvm/Analysis/CallGraphSCCPass.h"
224d4ea9acSHuber, Joseph #include "llvm/Analysis/OptimizationRemarkEmitter.h"
233a6bfcf2SGiorgis Georgakoudis #include "llvm/Analysis/ValueTracking.h"
249548b74aSJohannes Doerfert #include "llvm/Frontend/OpenMP/OMPConstants.h"
25e28936f6SJohannes Doerfert #include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
2668abc3d2SJoseph Huber #include "llvm/IR/IntrinsicInst.h"
2768abc3d2SJoseph Huber #include "llvm/IR/IntrinsicsAMDGPU.h"
2868abc3d2SJoseph Huber #include "llvm/IR/IntrinsicsNVPTX.h"
299548b74aSJohannes Doerfert #include "llvm/InitializePasses.h"
309548b74aSJohannes Doerfert #include "llvm/Support/CommandLine.h"
319548b74aSJohannes Doerfert #include "llvm/Transforms/IPO.h"
327cfd267cSsstefan1 #include "llvm/Transforms/IPO/Attributor.h"
333a6bfcf2SGiorgis Georgakoudis #include "llvm/Transforms/Utils/BasicBlockUtils.h"
349548b74aSJohannes Doerfert #include "llvm/Transforms/Utils/CallGraphUpdater.h"
3597517055SGiorgis Georgakoudis #include "llvm/Transforms/Utils/CodeExtractor.h"
369548b74aSJohannes Doerfert 
379548b74aSJohannes Doerfert using namespace llvm;
389548b74aSJohannes Doerfert using namespace omp;
399548b74aSJohannes Doerfert 
409548b74aSJohannes Doerfert #define DEBUG_TYPE "openmp-opt"
419548b74aSJohannes Doerfert 
429548b74aSJohannes Doerfert static cl::opt<bool> DisableOpenMPOptimizations(
439548b74aSJohannes Doerfert     "openmp-opt-disable", cl::ZeroOrMore,
449548b74aSJohannes Doerfert     cl::desc("Disable OpenMP specific optimizations."), cl::Hidden,
459548b74aSJohannes Doerfert     cl::init(false));
469548b74aSJohannes Doerfert 
473a6bfcf2SGiorgis Georgakoudis static cl::opt<bool> EnableParallelRegionMerging(
483a6bfcf2SGiorgis Georgakoudis     "openmp-opt-enable-merging", cl::ZeroOrMore,
493a6bfcf2SGiorgis Georgakoudis     cl::desc("Enable the OpenMP region merging optimization."), cl::Hidden,
503a6bfcf2SGiorgis Georgakoudis     cl::init(false));
513a6bfcf2SGiorgis Georgakoudis 
520f426935Ssstefan1 static cl::opt<bool> PrintICVValues("openmp-print-icv-values", cl::init(false),
530f426935Ssstefan1                                     cl::Hidden);
54e8039ad4SJohannes Doerfert static cl::opt<bool> PrintOpenMPKernels("openmp-print-gpu-kernels",
55e8039ad4SJohannes Doerfert                                         cl::init(false), cl::Hidden);
560f426935Ssstefan1 
57496f8e5bSHamilton Tobon Mosquera static cl::opt<bool> HideMemoryTransferLatency(
58496f8e5bSHamilton Tobon Mosquera     "openmp-hide-memory-transfer-latency",
59496f8e5bSHamilton Tobon Mosquera     cl::desc("[WIP] Tries to hide the latency of host to device memory"
60496f8e5bSHamilton Tobon Mosquera              " transfers"),
61496f8e5bSHamilton Tobon Mosquera     cl::Hidden, cl::init(false));
62496f8e5bSHamilton Tobon Mosquera 
639548b74aSJohannes Doerfert STATISTIC(NumOpenMPRuntimeCallsDeduplicated,
649548b74aSJohannes Doerfert           "Number of OpenMP runtime calls deduplicated");
6555eb714aSRoman Lebedev STATISTIC(NumOpenMPParallelRegionsDeleted,
6655eb714aSRoman Lebedev           "Number of OpenMP parallel regions deleted");
679548b74aSJohannes Doerfert STATISTIC(NumOpenMPRuntimeFunctionsIdentified,
689548b74aSJohannes Doerfert           "Number of OpenMP runtime functions identified");
699548b74aSJohannes Doerfert STATISTIC(NumOpenMPRuntimeFunctionUsesIdentified,
709548b74aSJohannes Doerfert           "Number of OpenMP runtime function uses identified");
71e8039ad4SJohannes Doerfert STATISTIC(NumOpenMPTargetRegionKernels,
72e8039ad4SJohannes Doerfert           "Number of OpenMP target region entry points (=kernels) identified");
735b0581aeSJohannes Doerfert STATISTIC(
745b0581aeSJohannes Doerfert     NumOpenMPParallelRegionsReplacedInGPUStateMachine,
755b0581aeSJohannes Doerfert     "Number of OpenMP parallel regions replaced with ID in GPU state machines");
763a6bfcf2SGiorgis Georgakoudis STATISTIC(NumOpenMPParallelRegionsMerged,
773a6bfcf2SGiorgis Georgakoudis           "Number of OpenMP parallel regions merged");
789548b74aSJohannes Doerfert 
79263c4a3cSrathod-sahaab #if !defined(NDEBUG)
809548b74aSJohannes Doerfert static constexpr auto TAG = "[" DEBUG_TYPE "]";
81a50c0b0dSMikael Holmen #endif
829548b74aSJohannes Doerfert 
839548b74aSJohannes Doerfert namespace {
849548b74aSJohannes Doerfert 
8518283125SJoseph Huber struct AAExecutionDomain
8618283125SJoseph Huber     : public StateWrapper<BooleanState, AbstractAttribute> {
8718283125SJoseph Huber   using Base = StateWrapper<BooleanState, AbstractAttribute>;
8818283125SJoseph Huber   AAExecutionDomain(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
8918283125SJoseph Huber 
9018283125SJoseph Huber   /// Create an abstract attribute view for the position \p IRP.
9118283125SJoseph Huber   static AAExecutionDomain &createForPosition(const IRPosition &IRP,
9218283125SJoseph Huber                                               Attributor &A);
9318283125SJoseph Huber 
9418283125SJoseph Huber   /// See AbstractAttribute::getName().
9518283125SJoseph Huber   const std::string getName() const override { return "AAExecutionDomain"; }
9618283125SJoseph Huber 
9718283125SJoseph Huber   /// See AbstractAttribute::getIdAddr().
9818283125SJoseph Huber   const char *getIdAddr() const override { return &ID; }
9918283125SJoseph Huber 
10018283125SJoseph Huber   /// Check if an instruction is executed by a single thread.
10118283125SJoseph Huber   virtual bool isSingleThreadExecution(const Instruction &) const = 0;
10218283125SJoseph Huber 
10318283125SJoseph Huber   virtual bool isSingleThreadExecution(const BasicBlock &) const = 0;
10418283125SJoseph Huber 
10518283125SJoseph Huber   /// This function should return true if the type of the \p AA is
10618283125SJoseph Huber   /// AAExecutionDomain.
10718283125SJoseph Huber   static bool classof(const AbstractAttribute *AA) {
10818283125SJoseph Huber     return (AA->getIdAddr() == &ID);
10918283125SJoseph Huber   }
11018283125SJoseph Huber 
11118283125SJoseph Huber   /// Unique ID (due to the unique address)
11218283125SJoseph Huber   static const char ID;
11318283125SJoseph Huber };
11418283125SJoseph Huber 
115b8235d2bSsstefan1 struct AAICVTracker;
116b8235d2bSsstefan1 
1177cfd267cSsstefan1 /// OpenMP specific information. For now, stores RFIs and ICVs also needed for
1187cfd267cSsstefan1 /// Attributor runs.
1197cfd267cSsstefan1 struct OMPInformationCache : public InformationCache {
1207cfd267cSsstefan1   OMPInformationCache(Module &M, AnalysisGetter &AG,
121624d34afSJohannes Doerfert                       BumpPtrAllocator &Allocator, SetVector<Function *> &CGSCC,
122e8039ad4SJohannes Doerfert                       SmallPtrSetImpl<Kernel> &Kernels)
123624d34afSJohannes Doerfert       : InformationCache(M, AG, Allocator, &CGSCC), OMPBuilder(M),
124624d34afSJohannes Doerfert         Kernels(Kernels) {
125624d34afSJohannes Doerfert 
12661238d26Ssstefan1     OMPBuilder.initialize();
1279548b74aSJohannes Doerfert     initializeRuntimeFunctions();
1280f426935Ssstefan1     initializeInternalControlVars();
1299548b74aSJohannes Doerfert   }
1309548b74aSJohannes Doerfert 
1310f426935Ssstefan1   /// Generic information that describes an internal control variable.
1320f426935Ssstefan1   struct InternalControlVarInfo {
1330f426935Ssstefan1     /// The kind, as described by InternalControlVar enum.
1340f426935Ssstefan1     InternalControlVar Kind;
1350f426935Ssstefan1 
1360f426935Ssstefan1     /// The name of the ICV.
1370f426935Ssstefan1     StringRef Name;
1380f426935Ssstefan1 
1390f426935Ssstefan1     /// Environment variable associated with this ICV.
1400f426935Ssstefan1     StringRef EnvVarName;
1410f426935Ssstefan1 
1420f426935Ssstefan1     /// Initial value kind.
1430f426935Ssstefan1     ICVInitValue InitKind;
1440f426935Ssstefan1 
1450f426935Ssstefan1     /// Initial value.
1460f426935Ssstefan1     ConstantInt *InitValue;
1470f426935Ssstefan1 
1480f426935Ssstefan1     /// Setter RTL function associated with this ICV.
1490f426935Ssstefan1     RuntimeFunction Setter;
1500f426935Ssstefan1 
1510f426935Ssstefan1     /// Getter RTL function associated with this ICV.
1520f426935Ssstefan1     RuntimeFunction Getter;
1530f426935Ssstefan1 
1540f426935Ssstefan1     /// RTL Function corresponding to the override clause of this ICV
1550f426935Ssstefan1     RuntimeFunction Clause;
1560f426935Ssstefan1   };
1570f426935Ssstefan1 
1589548b74aSJohannes Doerfert   /// Generic information that describes a runtime function
1599548b74aSJohannes Doerfert   struct RuntimeFunctionInfo {
1608855fec3SJohannes Doerfert 
1619548b74aSJohannes Doerfert     /// The kind, as described by the RuntimeFunction enum.
1629548b74aSJohannes Doerfert     RuntimeFunction Kind;
1639548b74aSJohannes Doerfert 
1649548b74aSJohannes Doerfert     /// The name of the function.
1659548b74aSJohannes Doerfert     StringRef Name;
1669548b74aSJohannes Doerfert 
1679548b74aSJohannes Doerfert     /// Flag to indicate a variadic function.
1689548b74aSJohannes Doerfert     bool IsVarArg;
1699548b74aSJohannes Doerfert 
1709548b74aSJohannes Doerfert     /// The return type of the function.
1719548b74aSJohannes Doerfert     Type *ReturnType;
1729548b74aSJohannes Doerfert 
1739548b74aSJohannes Doerfert     /// The argument types of the function.
1749548b74aSJohannes Doerfert     SmallVector<Type *, 8> ArgumentTypes;
1759548b74aSJohannes Doerfert 
1769548b74aSJohannes Doerfert     /// The declaration if available.
177f09f4b26SJohannes Doerfert     Function *Declaration = nullptr;
1789548b74aSJohannes Doerfert 
1799548b74aSJohannes Doerfert     /// Uses of this runtime function per function containing the use.
1808855fec3SJohannes Doerfert     using UseVector = SmallVector<Use *, 16>;
1818855fec3SJohannes Doerfert 
182b8235d2bSsstefan1     /// Clear UsesMap for runtime function.
183b8235d2bSsstefan1     void clearUsesMap() { UsesMap.clear(); }
184b8235d2bSsstefan1 
18554bd3751SJohannes Doerfert     /// Boolean conversion that is true if the runtime function was found.
18654bd3751SJohannes Doerfert     operator bool() const { return Declaration; }
18754bd3751SJohannes Doerfert 
1888855fec3SJohannes Doerfert     /// Return the vector of uses in function \p F.
1898855fec3SJohannes Doerfert     UseVector &getOrCreateUseVector(Function *F) {
190b8235d2bSsstefan1       std::shared_ptr<UseVector> &UV = UsesMap[F];
1918855fec3SJohannes Doerfert       if (!UV)
192b8235d2bSsstefan1         UV = std::make_shared<UseVector>();
1938855fec3SJohannes Doerfert       return *UV;
1948855fec3SJohannes Doerfert     }
1958855fec3SJohannes Doerfert 
1968855fec3SJohannes Doerfert     /// Return the vector of uses in function \p F or `nullptr` if there are
1978855fec3SJohannes Doerfert     /// none.
1988855fec3SJohannes Doerfert     const UseVector *getUseVector(Function &F) const {
19995e57072SDavid Blaikie       auto I = UsesMap.find(&F);
20095e57072SDavid Blaikie       if (I != UsesMap.end())
20195e57072SDavid Blaikie         return I->second.get();
20295e57072SDavid Blaikie       return nullptr;
2038855fec3SJohannes Doerfert     }
2048855fec3SJohannes Doerfert 
2058855fec3SJohannes Doerfert     /// Return how many functions contain uses of this runtime function.
2068855fec3SJohannes Doerfert     size_t getNumFunctionsWithUses() const { return UsesMap.size(); }
2079548b74aSJohannes Doerfert 
2089548b74aSJohannes Doerfert     /// Return the number of arguments (or the minimal number for variadic
2099548b74aSJohannes Doerfert     /// functions).
2109548b74aSJohannes Doerfert     size_t getNumArgs() const { return ArgumentTypes.size(); }
2119548b74aSJohannes Doerfert 
2129548b74aSJohannes Doerfert     /// Run the callback \p CB on each use and forget the use if the result is
2139548b74aSJohannes Doerfert     /// true. The callback will be fed the function in which the use was
2149548b74aSJohannes Doerfert     /// encountered as second argument.
215624d34afSJohannes Doerfert     void foreachUse(SmallVectorImpl<Function *> &SCC,
216624d34afSJohannes Doerfert                     function_ref<bool(Use &, Function &)> CB) {
217624d34afSJohannes Doerfert       for (Function *F : SCC)
218624d34afSJohannes Doerfert         foreachUse(CB, F);
219e099c7b6Ssstefan1     }
220e099c7b6Ssstefan1 
221e099c7b6Ssstefan1     /// Run the callback \p CB on each use within the function \p F and forget
222e099c7b6Ssstefan1     /// the use if the result is true.
223624d34afSJohannes Doerfert     void foreachUse(function_ref<bool(Use &, Function &)> CB, Function *F) {
2248855fec3SJohannes Doerfert       SmallVector<unsigned, 8> ToBeDeleted;
2259548b74aSJohannes Doerfert       ToBeDeleted.clear();
226e099c7b6Ssstefan1 
2278855fec3SJohannes Doerfert       unsigned Idx = 0;
228624d34afSJohannes Doerfert       UseVector &UV = getOrCreateUseVector(F);
229e099c7b6Ssstefan1 
2308855fec3SJohannes Doerfert       for (Use *U : UV) {
231e099c7b6Ssstefan1         if (CB(*U, *F))
2328855fec3SJohannes Doerfert           ToBeDeleted.push_back(Idx);
2338855fec3SJohannes Doerfert         ++Idx;
2348855fec3SJohannes Doerfert       }
2358855fec3SJohannes Doerfert 
2368855fec3SJohannes Doerfert       // Remove the to-be-deleted indices in reverse order as prior
237b726c557SJohannes Doerfert       // modifications will not modify the smaller indices.
2388855fec3SJohannes Doerfert       while (!ToBeDeleted.empty()) {
2398855fec3SJohannes Doerfert         unsigned Idx = ToBeDeleted.pop_back_val();
2408855fec3SJohannes Doerfert         UV[Idx] = UV.back();
2418855fec3SJohannes Doerfert         UV.pop_back();
2429548b74aSJohannes Doerfert       }
2439548b74aSJohannes Doerfert     }
2448855fec3SJohannes Doerfert 
2458855fec3SJohannes Doerfert   private:
2468855fec3SJohannes Doerfert     /// Map from functions to all uses of this runtime function contained in
2478855fec3SJohannes Doerfert     /// them.
248b8235d2bSsstefan1     DenseMap<Function *, std::shared_ptr<UseVector>> UsesMap;
2499548b74aSJohannes Doerfert   };
2509548b74aSJohannes Doerfert 
2517cfd267cSsstefan1   /// An OpenMP-IR-Builder instance
2527cfd267cSsstefan1   OpenMPIRBuilder OMPBuilder;
2537cfd267cSsstefan1 
2547cfd267cSsstefan1   /// Map from runtime function kind to the runtime function description.
2557cfd267cSsstefan1   EnumeratedArray<RuntimeFunctionInfo, RuntimeFunction,
2567cfd267cSsstefan1                   RuntimeFunction::OMPRTL___last>
2577cfd267cSsstefan1       RFIs;
2587cfd267cSsstefan1 
2590f426935Ssstefan1   /// Map from ICV kind to the ICV description.
2600f426935Ssstefan1   EnumeratedArray<InternalControlVarInfo, InternalControlVar,
2610f426935Ssstefan1                   InternalControlVar::ICV___last>
2620f426935Ssstefan1       ICVs;
2630f426935Ssstefan1 
2640f426935Ssstefan1   /// Helper to initialize all internal control variable information for those
2650f426935Ssstefan1   /// defined in OMPKinds.def.
2660f426935Ssstefan1   void initializeInternalControlVars() {
2670f426935Ssstefan1 #define ICV_RT_SET(_Name, RTL)                                                 \
2680f426935Ssstefan1   {                                                                            \
2690f426935Ssstefan1     auto &ICV = ICVs[_Name];                                                   \
2700f426935Ssstefan1     ICV.Setter = RTL;                                                          \
2710f426935Ssstefan1   }
2720f426935Ssstefan1 #define ICV_RT_GET(Name, RTL)                                                  \
2730f426935Ssstefan1   {                                                                            \
2740f426935Ssstefan1     auto &ICV = ICVs[Name];                                                    \
2750f426935Ssstefan1     ICV.Getter = RTL;                                                          \
2760f426935Ssstefan1   }
2770f426935Ssstefan1 #define ICV_DATA_ENV(Enum, _Name, _EnvVarName, Init)                           \
2780f426935Ssstefan1   {                                                                            \
2790f426935Ssstefan1     auto &ICV = ICVs[Enum];                                                    \
2800f426935Ssstefan1     ICV.Name = _Name;                                                          \
2810f426935Ssstefan1     ICV.Kind = Enum;                                                           \
2820f426935Ssstefan1     ICV.InitKind = Init;                                                       \
2830f426935Ssstefan1     ICV.EnvVarName = _EnvVarName;                                              \
2840f426935Ssstefan1     switch (ICV.InitKind) {                                                    \
285951e43f3Ssstefan1     case ICV_IMPLEMENTATION_DEFINED:                                           \
2860f426935Ssstefan1       ICV.InitValue = nullptr;                                                 \
2870f426935Ssstefan1       break;                                                                   \
288951e43f3Ssstefan1     case ICV_ZERO:                                                             \
2896aab27baSsstefan1       ICV.InitValue = ConstantInt::get(                                        \
2906aab27baSsstefan1           Type::getInt32Ty(OMPBuilder.Int32->getContext()), 0);                \
2910f426935Ssstefan1       break;                                                                   \
292951e43f3Ssstefan1     case ICV_FALSE:                                                            \
2936aab27baSsstefan1       ICV.InitValue = ConstantInt::getFalse(OMPBuilder.Int1->getContext());    \
2940f426935Ssstefan1       break;                                                                   \
295951e43f3Ssstefan1     case ICV_LAST:                                                             \
2960f426935Ssstefan1       break;                                                                   \
2970f426935Ssstefan1     }                                                                          \
2980f426935Ssstefan1   }
2990f426935Ssstefan1 #include "llvm/Frontend/OpenMP/OMPKinds.def"
3000f426935Ssstefan1   }
3010f426935Ssstefan1 
3027cfd267cSsstefan1   /// Returns true if the function declaration \p F matches the runtime
3037cfd267cSsstefan1   /// function types, that is, return type \p RTFRetType, and argument types
3047cfd267cSsstefan1   /// \p RTFArgTypes.
3057cfd267cSsstefan1   static bool declMatchesRTFTypes(Function *F, Type *RTFRetType,
3067cfd267cSsstefan1                                   SmallVector<Type *, 8> &RTFArgTypes) {
3077cfd267cSsstefan1     // TODO: We should output information to the user (under debug output
3087cfd267cSsstefan1     //       and via remarks).
3097cfd267cSsstefan1 
3107cfd267cSsstefan1     if (!F)
3117cfd267cSsstefan1       return false;
3127cfd267cSsstefan1     if (F->getReturnType() != RTFRetType)
3137cfd267cSsstefan1       return false;
3147cfd267cSsstefan1     if (F->arg_size() != RTFArgTypes.size())
3157cfd267cSsstefan1       return false;
3167cfd267cSsstefan1 
3177cfd267cSsstefan1     auto RTFTyIt = RTFArgTypes.begin();
3187cfd267cSsstefan1     for (Argument &Arg : F->args()) {
3197cfd267cSsstefan1       if (Arg.getType() != *RTFTyIt)
3207cfd267cSsstefan1         return false;
3217cfd267cSsstefan1 
3227cfd267cSsstefan1       ++RTFTyIt;
3237cfd267cSsstefan1     }
3247cfd267cSsstefan1 
3257cfd267cSsstefan1     return true;
3267cfd267cSsstefan1   }
3277cfd267cSsstefan1 
328b726c557SJohannes Doerfert   // Helper to collect all uses of the declaration in the UsesMap.
329b8235d2bSsstefan1   unsigned collectUses(RuntimeFunctionInfo &RFI, bool CollectStats = true) {
3307cfd267cSsstefan1     unsigned NumUses = 0;
3317cfd267cSsstefan1     if (!RFI.Declaration)
3327cfd267cSsstefan1       return NumUses;
3337cfd267cSsstefan1     OMPBuilder.addAttributes(RFI.Kind, *RFI.Declaration);
3347cfd267cSsstefan1 
335b8235d2bSsstefan1     if (CollectStats) {
3367cfd267cSsstefan1       NumOpenMPRuntimeFunctionsIdentified += 1;
3377cfd267cSsstefan1       NumOpenMPRuntimeFunctionUsesIdentified += RFI.Declaration->getNumUses();
338b8235d2bSsstefan1     }
3397cfd267cSsstefan1 
3407cfd267cSsstefan1     // TODO: We directly convert uses into proper calls and unknown uses.
3417cfd267cSsstefan1     for (Use &U : RFI.Declaration->uses()) {
3427cfd267cSsstefan1       if (Instruction *UserI = dyn_cast<Instruction>(U.getUser())) {
3437cfd267cSsstefan1         if (ModuleSlice.count(UserI->getFunction())) {
3447cfd267cSsstefan1           RFI.getOrCreateUseVector(UserI->getFunction()).push_back(&U);
3457cfd267cSsstefan1           ++NumUses;
3467cfd267cSsstefan1         }
3477cfd267cSsstefan1       } else {
3487cfd267cSsstefan1         RFI.getOrCreateUseVector(nullptr).push_back(&U);
3497cfd267cSsstefan1         ++NumUses;
3507cfd267cSsstefan1       }
3517cfd267cSsstefan1     }
3527cfd267cSsstefan1     return NumUses;
353b8235d2bSsstefan1   }
3547cfd267cSsstefan1 
35597517055SGiorgis Georgakoudis   // Helper function to recollect uses of a runtime function.
35697517055SGiorgis Georgakoudis   void recollectUsesForFunction(RuntimeFunction RTF) {
35797517055SGiorgis Georgakoudis     auto &RFI = RFIs[RTF];
358b8235d2bSsstefan1     RFI.clearUsesMap();
359b8235d2bSsstefan1     collectUses(RFI, /*CollectStats*/ false);
360b8235d2bSsstefan1   }
36197517055SGiorgis Georgakoudis 
36297517055SGiorgis Georgakoudis   // Helper function to recollect uses of all runtime functions.
36397517055SGiorgis Georgakoudis   void recollectUses() {
36497517055SGiorgis Georgakoudis     for (int Idx = 0; Idx < RFIs.size(); ++Idx)
36597517055SGiorgis Georgakoudis       recollectUsesForFunction(static_cast<RuntimeFunction>(Idx));
366b8235d2bSsstefan1   }
367b8235d2bSsstefan1 
368b8235d2bSsstefan1   /// Helper to initialize all runtime function information for those defined
369b8235d2bSsstefan1   /// in OpenMPKinds.def.
370b8235d2bSsstefan1   void initializeRuntimeFunctions() {
3717cfd267cSsstefan1     Module &M = *((*ModuleSlice.begin())->getParent());
3727cfd267cSsstefan1 
3736aab27baSsstefan1     // Helper macros for handling __VA_ARGS__ in OMP_RTL
3746aab27baSsstefan1 #define OMP_TYPE(VarName, ...)                                                 \
3756aab27baSsstefan1   Type *VarName = OMPBuilder.VarName;                                          \
3766aab27baSsstefan1   (void)VarName;
3776aab27baSsstefan1 
3786aab27baSsstefan1 #define OMP_ARRAY_TYPE(VarName, ...)                                           \
3796aab27baSsstefan1   ArrayType *VarName##Ty = OMPBuilder.VarName##Ty;                             \
3806aab27baSsstefan1   (void)VarName##Ty;                                                           \
3816aab27baSsstefan1   PointerType *VarName##PtrTy = OMPBuilder.VarName##PtrTy;                     \
3826aab27baSsstefan1   (void)VarName##PtrTy;
3836aab27baSsstefan1 
3846aab27baSsstefan1 #define OMP_FUNCTION_TYPE(VarName, ...)                                        \
3856aab27baSsstefan1   FunctionType *VarName = OMPBuilder.VarName;                                  \
3866aab27baSsstefan1   (void)VarName;                                                               \
3876aab27baSsstefan1   PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr;                         \
3886aab27baSsstefan1   (void)VarName##Ptr;
3896aab27baSsstefan1 
3906aab27baSsstefan1 #define OMP_STRUCT_TYPE(VarName, ...)                                          \
3916aab27baSsstefan1   StructType *VarName = OMPBuilder.VarName;                                    \
3926aab27baSsstefan1   (void)VarName;                                                               \
3936aab27baSsstefan1   PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr;                         \
3946aab27baSsstefan1   (void)VarName##Ptr;
3956aab27baSsstefan1 
3967cfd267cSsstefan1 #define OMP_RTL(_Enum, _Name, _IsVarArg, _ReturnType, ...)                     \
3977cfd267cSsstefan1   {                                                                            \
3987cfd267cSsstefan1     SmallVector<Type *, 8> ArgsTypes({__VA_ARGS__});                           \
3997cfd267cSsstefan1     Function *F = M.getFunction(_Name);                                        \
4006aab27baSsstefan1     if (declMatchesRTFTypes(F, OMPBuilder._ReturnType, ArgsTypes)) {           \
4017cfd267cSsstefan1       auto &RFI = RFIs[_Enum];                                                 \
4027cfd267cSsstefan1       RFI.Kind = _Enum;                                                        \
4037cfd267cSsstefan1       RFI.Name = _Name;                                                        \
4047cfd267cSsstefan1       RFI.IsVarArg = _IsVarArg;                                                \
4056aab27baSsstefan1       RFI.ReturnType = OMPBuilder._ReturnType;                                 \
4067cfd267cSsstefan1       RFI.ArgumentTypes = std::move(ArgsTypes);                                \
4077cfd267cSsstefan1       RFI.Declaration = F;                                                     \
408b8235d2bSsstefan1       unsigned NumUses = collectUses(RFI);                                     \
4097cfd267cSsstefan1       (void)NumUses;                                                           \
4107cfd267cSsstefan1       LLVM_DEBUG({                                                             \
4117cfd267cSsstefan1         dbgs() << TAG << RFI.Name << (RFI.Declaration ? "" : " not")           \
4127cfd267cSsstefan1                << " found\n";                                                  \
4137cfd267cSsstefan1         if (RFI.Declaration)                                                   \
4147cfd267cSsstefan1           dbgs() << TAG << "-> got " << NumUses << " uses in "                 \
4157cfd267cSsstefan1                  << RFI.getNumFunctionsWithUses()                              \
4167cfd267cSsstefan1                  << " different functions.\n";                                 \
4177cfd267cSsstefan1       });                                                                      \
4187cfd267cSsstefan1     }                                                                          \
4197cfd267cSsstefan1   }
4207cfd267cSsstefan1 #include "llvm/Frontend/OpenMP/OMPKinds.def"
4217cfd267cSsstefan1 
4227cfd267cSsstefan1     // TODO: We should attach the attributes defined in OMPKinds.def.
4237cfd267cSsstefan1   }
424e8039ad4SJohannes Doerfert 
425e8039ad4SJohannes Doerfert   /// Collection of known kernels (\see Kernel) in the module.
426e8039ad4SJohannes Doerfert   SmallPtrSetImpl<Kernel> &Kernels;
4277cfd267cSsstefan1 };
4287cfd267cSsstefan1 
4298931add6SHamilton Tobon Mosquera /// Used to map the values physically (in the IR) stored in an offload
4308931add6SHamilton Tobon Mosquera /// array, to a vector in memory.
4318931add6SHamilton Tobon Mosquera struct OffloadArray {
4328931add6SHamilton Tobon Mosquera   /// Physical array (in the IR).
4338931add6SHamilton Tobon Mosquera   AllocaInst *Array = nullptr;
4348931add6SHamilton Tobon Mosquera   /// Mapped values.
4358931add6SHamilton Tobon Mosquera   SmallVector<Value *, 8> StoredValues;
4368931add6SHamilton Tobon Mosquera   /// Last stores made in the offload array.
4378931add6SHamilton Tobon Mosquera   SmallVector<StoreInst *, 8> LastAccesses;
4388931add6SHamilton Tobon Mosquera 
4398931add6SHamilton Tobon Mosquera   OffloadArray() = default;
4408931add6SHamilton Tobon Mosquera 
4418931add6SHamilton Tobon Mosquera   /// Initializes the OffloadArray with the values stored in \p Array before
4428931add6SHamilton Tobon Mosquera   /// instruction \p Before is reached. Returns false if the initialization
4438931add6SHamilton Tobon Mosquera   /// fails.
4448931add6SHamilton Tobon Mosquera   /// This MUST be used immediately after the construction of the object.
4458931add6SHamilton Tobon Mosquera   bool initialize(AllocaInst &Array, Instruction &Before) {
4468931add6SHamilton Tobon Mosquera     if (!Array.getAllocatedType()->isArrayTy())
4478931add6SHamilton Tobon Mosquera       return false;
4488931add6SHamilton Tobon Mosquera 
4498931add6SHamilton Tobon Mosquera     if (!getValues(Array, Before))
4508931add6SHamilton Tobon Mosquera       return false;
4518931add6SHamilton Tobon Mosquera 
4528931add6SHamilton Tobon Mosquera     this->Array = &Array;
4538931add6SHamilton Tobon Mosquera     return true;
4548931add6SHamilton Tobon Mosquera   }
4558931add6SHamilton Tobon Mosquera 
456da8bec47SJoseph Huber   static const unsigned DeviceIDArgNum = 1;
457da8bec47SJoseph Huber   static const unsigned BasePtrsArgNum = 3;
458da8bec47SJoseph Huber   static const unsigned PtrsArgNum = 4;
459da8bec47SJoseph Huber   static const unsigned SizesArgNum = 5;
4601d3d9b9cSHamilton Tobon Mosquera 
4618931add6SHamilton Tobon Mosquera private:
4628931add6SHamilton Tobon Mosquera   /// Traverses the BasicBlock where \p Array is, collecting the stores made to
4638931add6SHamilton Tobon Mosquera   /// \p Array, leaving StoredValues with the values stored before the
4648931add6SHamilton Tobon Mosquera   /// instruction \p Before is reached.
4658931add6SHamilton Tobon Mosquera   bool getValues(AllocaInst &Array, Instruction &Before) {
4668931add6SHamilton Tobon Mosquera     // Initialize container.
467d08d490aSJohannes Doerfert     const uint64_t NumValues = Array.getAllocatedType()->getArrayNumElements();
4688931add6SHamilton Tobon Mosquera     StoredValues.assign(NumValues, nullptr);
4698931add6SHamilton Tobon Mosquera     LastAccesses.assign(NumValues, nullptr);
4708931add6SHamilton Tobon Mosquera 
4718931add6SHamilton Tobon Mosquera     // TODO: This assumes the instruction \p Before is in the same
4728931add6SHamilton Tobon Mosquera     //  BasicBlock as Array. Make it general, for any control flow graph.
4738931add6SHamilton Tobon Mosquera     BasicBlock *BB = Array.getParent();
4748931add6SHamilton Tobon Mosquera     if (BB != Before.getParent())
4758931add6SHamilton Tobon Mosquera       return false;
4768931add6SHamilton Tobon Mosquera 
4778931add6SHamilton Tobon Mosquera     const DataLayout &DL = Array.getModule()->getDataLayout();
4788931add6SHamilton Tobon Mosquera     const unsigned int PointerSize = DL.getPointerSize();
4798931add6SHamilton Tobon Mosquera 
4808931add6SHamilton Tobon Mosquera     for (Instruction &I : *BB) {
4818931add6SHamilton Tobon Mosquera       if (&I == &Before)
4828931add6SHamilton Tobon Mosquera         break;
4838931add6SHamilton Tobon Mosquera 
4848931add6SHamilton Tobon Mosquera       if (!isa<StoreInst>(&I))
4858931add6SHamilton Tobon Mosquera         continue;
4868931add6SHamilton Tobon Mosquera 
4878931add6SHamilton Tobon Mosquera       auto *S = cast<StoreInst>(&I);
4888931add6SHamilton Tobon Mosquera       int64_t Offset = -1;
489d08d490aSJohannes Doerfert       auto *Dst =
490d08d490aSJohannes Doerfert           GetPointerBaseWithConstantOffset(S->getPointerOperand(), Offset, DL);
4918931add6SHamilton Tobon Mosquera       if (Dst == &Array) {
4928931add6SHamilton Tobon Mosquera         int64_t Idx = Offset / PointerSize;
4938931add6SHamilton Tobon Mosquera         StoredValues[Idx] = getUnderlyingObject(S->getValueOperand());
4948931add6SHamilton Tobon Mosquera         LastAccesses[Idx] = S;
4958931add6SHamilton Tobon Mosquera       }
4968931add6SHamilton Tobon Mosquera     }
4978931add6SHamilton Tobon Mosquera 
4988931add6SHamilton Tobon Mosquera     return isFilled();
4998931add6SHamilton Tobon Mosquera   }
5008931add6SHamilton Tobon Mosquera 
5018931add6SHamilton Tobon Mosquera   /// Returns true if all values in StoredValues and
5028931add6SHamilton Tobon Mosquera   /// LastAccesses are not nullptrs.
5038931add6SHamilton Tobon Mosquera   bool isFilled() {
5048931add6SHamilton Tobon Mosquera     const unsigned NumValues = StoredValues.size();
5058931add6SHamilton Tobon Mosquera     for (unsigned I = 0; I < NumValues; ++I) {
5068931add6SHamilton Tobon Mosquera       if (!StoredValues[I] || !LastAccesses[I])
5078931add6SHamilton Tobon Mosquera         return false;
5088931add6SHamilton Tobon Mosquera     }
5098931add6SHamilton Tobon Mosquera 
5108931add6SHamilton Tobon Mosquera     return true;
5118931add6SHamilton Tobon Mosquera   }
5128931add6SHamilton Tobon Mosquera };
5138931add6SHamilton Tobon Mosquera 
5147cfd267cSsstefan1 struct OpenMPOpt {
5157cfd267cSsstefan1 
5167cfd267cSsstefan1   using OptimizationRemarkGetter =
5177cfd267cSsstefan1       function_ref<OptimizationRemarkEmitter &(Function *)>;
5187cfd267cSsstefan1 
5197cfd267cSsstefan1   OpenMPOpt(SmallVectorImpl<Function *> &SCC, CallGraphUpdater &CGUpdater,
5207cfd267cSsstefan1             OptimizationRemarkGetter OREGetter,
521b8235d2bSsstefan1             OMPInformationCache &OMPInfoCache, Attributor &A)
52277b79d79SMehdi Amini       : M(*(*SCC.begin())->getParent()), SCC(SCC), CGUpdater(CGUpdater),
523b8235d2bSsstefan1         OREGetter(OREGetter), OMPInfoCache(OMPInfoCache), A(A) {}
5247cfd267cSsstefan1 
525a2281419SJoseph Huber   /// Check if any remarks are enabled for openmp-opt
526a2281419SJoseph Huber   bool remarksEnabled() {
527a2281419SJoseph Huber     auto &Ctx = M.getContext();
528a2281419SJoseph Huber     return Ctx.getDiagHandlerPtr()->isAnyRemarkEnabled(DEBUG_TYPE);
529a2281419SJoseph Huber   }
530a2281419SJoseph Huber 
5319548b74aSJohannes Doerfert   /// Run all OpenMP optimizations on the underlying SCC/ModuleSlice.
532b2ad63d3SJoseph Huber   bool run(bool IsModulePass) {
53354bd3751SJohannes Doerfert     if (SCC.empty())
53454bd3751SJohannes Doerfert       return false;
53554bd3751SJohannes Doerfert 
5369548b74aSJohannes Doerfert     bool Changed = false;
5379548b74aSJohannes Doerfert 
5389548b74aSJohannes Doerfert     LLVM_DEBUG(dbgs() << TAG << "Run on SCC with " << SCC.size()
53977b79d79SMehdi Amini                       << " functions in a slice with "
54077b79d79SMehdi Amini                       << OMPInfoCache.ModuleSlice.size() << " functions\n");
5419548b74aSJohannes Doerfert 
542b2ad63d3SJoseph Huber     if (IsModulePass) {
54318283125SJoseph Huber       Changed |= runAttributor();
54418283125SJoseph Huber 
545b2ad63d3SJoseph Huber       if (remarksEnabled())
546b2ad63d3SJoseph Huber         analysisGlobalization();
547b2ad63d3SJoseph Huber     } else {
548e8039ad4SJohannes Doerfert       if (PrintICVValues)
549e8039ad4SJohannes Doerfert         printICVs();
550e8039ad4SJohannes Doerfert       if (PrintOpenMPKernels)
551e8039ad4SJohannes Doerfert         printKernels();
552e8039ad4SJohannes Doerfert 
5535b0581aeSJohannes Doerfert       Changed |= rewriteDeviceCodeStateMachine();
5545b0581aeSJohannes Doerfert 
555e8039ad4SJohannes Doerfert       Changed |= runAttributor();
556e8039ad4SJohannes Doerfert 
557e8039ad4SJohannes Doerfert       // Recollect uses, in case Attributor deleted any.
558e8039ad4SJohannes Doerfert       OMPInfoCache.recollectUses();
559e8039ad4SJohannes Doerfert 
560e8039ad4SJohannes Doerfert       Changed |= deleteParallelRegions();
561496f8e5bSHamilton Tobon Mosquera       if (HideMemoryTransferLatency)
562496f8e5bSHamilton Tobon Mosquera         Changed |= hideMemTransfersLatency();
5633a6bfcf2SGiorgis Georgakoudis       Changed |= deduplicateRuntimeCalls();
5643a6bfcf2SGiorgis Georgakoudis       if (EnableParallelRegionMerging) {
5653a6bfcf2SGiorgis Georgakoudis         if (mergeParallelRegions()) {
5663a6bfcf2SGiorgis Georgakoudis           deduplicateRuntimeCalls();
5673a6bfcf2SGiorgis Georgakoudis           Changed = true;
5683a6bfcf2SGiorgis Georgakoudis         }
5693a6bfcf2SGiorgis Georgakoudis       }
570b2ad63d3SJoseph Huber     }
571e8039ad4SJohannes Doerfert 
572e8039ad4SJohannes Doerfert     return Changed;
573e8039ad4SJohannes Doerfert   }
574e8039ad4SJohannes Doerfert 
5750f426935Ssstefan1   /// Print initial ICV values for testing.
5760f426935Ssstefan1   /// FIXME: This should be done from the Attributor once it is added.
577e8039ad4SJohannes Doerfert   void printICVs() const {
578cb9cfa0dSsstefan1     InternalControlVar ICVs[] = {ICV_nthreads, ICV_active_levels, ICV_cancel,
579cb9cfa0dSsstefan1                                  ICV_proc_bind};
5800f426935Ssstefan1 
5810f426935Ssstefan1     for (Function *F : OMPInfoCache.ModuleSlice) {
5820f426935Ssstefan1       for (auto ICV : ICVs) {
5830f426935Ssstefan1         auto ICVInfo = OMPInfoCache.ICVs[ICV];
584*2db182ffSJoseph Huber         auto Remark = [&](OptimizationRemarkAnalysis ORA) {
585*2db182ffSJoseph Huber           return ORA << "OpenMP ICV " << ore::NV("OpenMPICV", ICVInfo.Name)
5860f426935Ssstefan1                      << " Value: "
5870f426935Ssstefan1                      << (ICVInfo.InitValue
5880f426935Ssstefan1                              ? ICVInfo.InitValue->getValue().toString(10, true)
5890f426935Ssstefan1                              : "IMPLEMENTATION_DEFINED");
5900f426935Ssstefan1         };
5910f426935Ssstefan1 
592*2db182ffSJoseph Huber         emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPICVTracker", Remark);
5930f426935Ssstefan1       }
5940f426935Ssstefan1     }
5950f426935Ssstefan1   }
5960f426935Ssstefan1 
597e8039ad4SJohannes Doerfert   /// Print OpenMP GPU kernels for testing.
598e8039ad4SJohannes Doerfert   void printKernels() const {
599e8039ad4SJohannes Doerfert     for (Function *F : SCC) {
600e8039ad4SJohannes Doerfert       if (!OMPInfoCache.Kernels.count(F))
601e8039ad4SJohannes Doerfert         continue;
602b8235d2bSsstefan1 
603*2db182ffSJoseph Huber       auto Remark = [&](OptimizationRemarkAnalysis ORA) {
604*2db182ffSJoseph Huber         return ORA << "OpenMP GPU kernel "
605e8039ad4SJohannes Doerfert                    << ore::NV("OpenMPGPUKernel", F->getName()) << "\n";
606e8039ad4SJohannes Doerfert       };
607b8235d2bSsstefan1 
608*2db182ffSJoseph Huber       emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPGPU", Remark);
609e8039ad4SJohannes Doerfert     }
6109548b74aSJohannes Doerfert   }
6119548b74aSJohannes Doerfert 
6127cfd267cSsstefan1   /// Return the call if \p U is a callee use in a regular call. If \p RFI is
6137cfd267cSsstefan1   /// given it has to be the callee or a nullptr is returned.
6147cfd267cSsstefan1   static CallInst *getCallIfRegularCall(
6157cfd267cSsstefan1       Use &U, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
6167cfd267cSsstefan1     CallInst *CI = dyn_cast<CallInst>(U.getUser());
6177cfd267cSsstefan1     if (CI && CI->isCallee(&U) && !CI->hasOperandBundles() &&
6187cfd267cSsstefan1         (!RFI || CI->getCalledFunction() == RFI->Declaration))
6197cfd267cSsstefan1       return CI;
6207cfd267cSsstefan1     return nullptr;
6217cfd267cSsstefan1   }
6227cfd267cSsstefan1 
6237cfd267cSsstefan1   /// Return the call if \p V is a regular call. If \p RFI is given it has to be
6247cfd267cSsstefan1   /// the callee or a nullptr is returned.
6257cfd267cSsstefan1   static CallInst *getCallIfRegularCall(
6267cfd267cSsstefan1       Value &V, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
6277cfd267cSsstefan1     CallInst *CI = dyn_cast<CallInst>(&V);
6287cfd267cSsstefan1     if (CI && !CI->hasOperandBundles() &&
6297cfd267cSsstefan1         (!RFI || CI->getCalledFunction() == RFI->Declaration))
6307cfd267cSsstefan1       return CI;
6317cfd267cSsstefan1     return nullptr;
6327cfd267cSsstefan1   }
6337cfd267cSsstefan1 
6349548b74aSJohannes Doerfert private:
6353a6bfcf2SGiorgis Georgakoudis   /// Merge parallel regions when it is safe.
6363a6bfcf2SGiorgis Georgakoudis   bool mergeParallelRegions() {
6373a6bfcf2SGiorgis Georgakoudis     const unsigned CallbackCalleeOperand = 2;
6383a6bfcf2SGiorgis Georgakoudis     const unsigned CallbackFirstArgOperand = 3;
6393a6bfcf2SGiorgis Georgakoudis     using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
6403a6bfcf2SGiorgis Georgakoudis 
6413a6bfcf2SGiorgis Georgakoudis     // Check if there are any __kmpc_fork_call calls to merge.
6423a6bfcf2SGiorgis Georgakoudis     OMPInformationCache::RuntimeFunctionInfo &RFI =
6433a6bfcf2SGiorgis Georgakoudis         OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
6443a6bfcf2SGiorgis Georgakoudis 
6453a6bfcf2SGiorgis Georgakoudis     if (!RFI.Declaration)
6463a6bfcf2SGiorgis Georgakoudis       return false;
6473a6bfcf2SGiorgis Georgakoudis 
64897517055SGiorgis Georgakoudis     // Unmergable calls that prevent merging a parallel region.
64997517055SGiorgis Georgakoudis     OMPInformationCache::RuntimeFunctionInfo UnmergableCallsInfo[] = {
65097517055SGiorgis Georgakoudis         OMPInfoCache.RFIs[OMPRTL___kmpc_push_proc_bind],
65197517055SGiorgis Georgakoudis         OMPInfoCache.RFIs[OMPRTL___kmpc_push_num_threads],
65297517055SGiorgis Georgakoudis     };
6533a6bfcf2SGiorgis Georgakoudis 
6543a6bfcf2SGiorgis Georgakoudis     bool Changed = false;
6553a6bfcf2SGiorgis Georgakoudis     LoopInfo *LI = nullptr;
6563a6bfcf2SGiorgis Georgakoudis     DominatorTree *DT = nullptr;
6573a6bfcf2SGiorgis Georgakoudis 
6583a6bfcf2SGiorgis Georgakoudis     SmallDenseMap<BasicBlock *, SmallPtrSet<Instruction *, 4>> BB2PRMap;
6593a6bfcf2SGiorgis Georgakoudis 
6603a6bfcf2SGiorgis Georgakoudis     BasicBlock *StartBB = nullptr, *EndBB = nullptr;
6613a6bfcf2SGiorgis Georgakoudis     auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
6623a6bfcf2SGiorgis Georgakoudis                          BasicBlock &ContinuationIP) {
6633a6bfcf2SGiorgis Georgakoudis       BasicBlock *CGStartBB = CodeGenIP.getBlock();
6643a6bfcf2SGiorgis Georgakoudis       BasicBlock *CGEndBB =
6653a6bfcf2SGiorgis Georgakoudis           SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
6663a6bfcf2SGiorgis Georgakoudis       assert(StartBB != nullptr && "StartBB should not be null");
6673a6bfcf2SGiorgis Georgakoudis       CGStartBB->getTerminator()->setSuccessor(0, StartBB);
6683a6bfcf2SGiorgis Georgakoudis       assert(EndBB != nullptr && "EndBB should not be null");
6693a6bfcf2SGiorgis Georgakoudis       EndBB->getTerminator()->setSuccessor(0, CGEndBB);
6703a6bfcf2SGiorgis Georgakoudis     };
6713a6bfcf2SGiorgis Georgakoudis 
672240dd924SAlex Zinenko     auto PrivCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, Value &,
673240dd924SAlex Zinenko                       Value &Inner, Value *&ReplacementValue) -> InsertPointTy {
674240dd924SAlex Zinenko       ReplacementValue = &Inner;
6753a6bfcf2SGiorgis Georgakoudis       return CodeGenIP;
6763a6bfcf2SGiorgis Georgakoudis     };
6773a6bfcf2SGiorgis Georgakoudis 
6783a6bfcf2SGiorgis Georgakoudis     auto FiniCB = [&](InsertPointTy CodeGenIP) {};
6793a6bfcf2SGiorgis Georgakoudis 
68097517055SGiorgis Georgakoudis     /// Create a sequential execution region within a merged parallel region,
68197517055SGiorgis Georgakoudis     /// encapsulated in a master construct with a barrier for synchronization.
68297517055SGiorgis Georgakoudis     auto CreateSequentialRegion = [&](Function *OuterFn,
68397517055SGiorgis Georgakoudis                                       BasicBlock *OuterPredBB,
68497517055SGiorgis Georgakoudis                                       Instruction *SeqStartI,
68597517055SGiorgis Georgakoudis                                       Instruction *SeqEndI) {
68697517055SGiorgis Georgakoudis       // Isolate the instructions of the sequential region to a separate
68797517055SGiorgis Georgakoudis       // block.
68897517055SGiorgis Georgakoudis       BasicBlock *ParentBB = SeqStartI->getParent();
68997517055SGiorgis Georgakoudis       BasicBlock *SeqEndBB =
69097517055SGiorgis Georgakoudis           SplitBlock(ParentBB, SeqEndI->getNextNode(), DT, LI);
69197517055SGiorgis Georgakoudis       BasicBlock *SeqAfterBB =
69297517055SGiorgis Georgakoudis           SplitBlock(SeqEndBB, &*SeqEndBB->getFirstInsertionPt(), DT, LI);
69397517055SGiorgis Georgakoudis       BasicBlock *SeqStartBB =
69497517055SGiorgis Georgakoudis           SplitBlock(ParentBB, SeqStartI, DT, LI, nullptr, "seq.par.merged");
69597517055SGiorgis Georgakoudis 
69697517055SGiorgis Georgakoudis       assert(ParentBB->getUniqueSuccessor() == SeqStartBB &&
69797517055SGiorgis Georgakoudis              "Expected a different CFG");
69897517055SGiorgis Georgakoudis       const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc();
69997517055SGiorgis Georgakoudis       ParentBB->getTerminator()->eraseFromParent();
70097517055SGiorgis Georgakoudis 
70197517055SGiorgis Georgakoudis       auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
70297517055SGiorgis Georgakoudis                            BasicBlock &ContinuationIP) {
70397517055SGiorgis Georgakoudis         BasicBlock *CGStartBB = CodeGenIP.getBlock();
70497517055SGiorgis Georgakoudis         BasicBlock *CGEndBB =
70597517055SGiorgis Georgakoudis             SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
70697517055SGiorgis Georgakoudis         assert(SeqStartBB != nullptr && "SeqStartBB should not be null");
70797517055SGiorgis Georgakoudis         CGStartBB->getTerminator()->setSuccessor(0, SeqStartBB);
70897517055SGiorgis Georgakoudis         assert(SeqEndBB != nullptr && "SeqEndBB should not be null");
70997517055SGiorgis Georgakoudis         SeqEndBB->getTerminator()->setSuccessor(0, CGEndBB);
71097517055SGiorgis Georgakoudis       };
71197517055SGiorgis Georgakoudis       auto FiniCB = [&](InsertPointTy CodeGenIP) {};
71297517055SGiorgis Georgakoudis 
71397517055SGiorgis Georgakoudis       // Find outputs from the sequential region to outside users and
71497517055SGiorgis Georgakoudis       // broadcast their values to them.
71597517055SGiorgis Georgakoudis       for (Instruction &I : *SeqStartBB) {
71697517055SGiorgis Georgakoudis         SmallPtrSet<Instruction *, 4> OutsideUsers;
71797517055SGiorgis Georgakoudis         for (User *Usr : I.users()) {
71897517055SGiorgis Georgakoudis           Instruction &UsrI = *cast<Instruction>(Usr);
71997517055SGiorgis Georgakoudis           // Ignore outputs to LT intrinsics, code extraction for the merged
72097517055SGiorgis Georgakoudis           // parallel region will fix them.
72197517055SGiorgis Georgakoudis           if (UsrI.isLifetimeStartOrEnd())
72297517055SGiorgis Georgakoudis             continue;
72397517055SGiorgis Georgakoudis 
72497517055SGiorgis Georgakoudis           if (UsrI.getParent() != SeqStartBB)
72597517055SGiorgis Georgakoudis             OutsideUsers.insert(&UsrI);
72697517055SGiorgis Georgakoudis         }
72797517055SGiorgis Georgakoudis 
72897517055SGiorgis Georgakoudis         if (OutsideUsers.empty())
72997517055SGiorgis Georgakoudis           continue;
73097517055SGiorgis Georgakoudis 
73197517055SGiorgis Georgakoudis         // Emit an alloca in the outer region to store the broadcasted
73297517055SGiorgis Georgakoudis         // value.
73397517055SGiorgis Georgakoudis         const DataLayout &DL = M.getDataLayout();
73497517055SGiorgis Georgakoudis         AllocaInst *AllocaI = new AllocaInst(
73597517055SGiorgis Georgakoudis             I.getType(), DL.getAllocaAddrSpace(), nullptr,
73697517055SGiorgis Georgakoudis             I.getName() + ".seq.output.alloc", &OuterFn->front().front());
73797517055SGiorgis Georgakoudis 
73897517055SGiorgis Georgakoudis         // Emit a store instruction in the sequential BB to update the
73997517055SGiorgis Georgakoudis         // value.
74097517055SGiorgis Georgakoudis         new StoreInst(&I, AllocaI, SeqStartBB->getTerminator());
74197517055SGiorgis Georgakoudis 
74297517055SGiorgis Georgakoudis         // Emit a load instruction and replace the use of the output value
74397517055SGiorgis Georgakoudis         // with it.
74497517055SGiorgis Georgakoudis         for (Instruction *UsrI : OutsideUsers) {
7455b70c12fSJohannes Doerfert           LoadInst *LoadI = new LoadInst(
7465b70c12fSJohannes Doerfert               I.getType(), AllocaI, I.getName() + ".seq.output.load", UsrI);
74797517055SGiorgis Georgakoudis           UsrI->replaceUsesOfWith(&I, LoadI);
74897517055SGiorgis Georgakoudis         }
74997517055SGiorgis Georgakoudis       }
75097517055SGiorgis Georgakoudis 
75197517055SGiorgis Georgakoudis       OpenMPIRBuilder::LocationDescription Loc(
75297517055SGiorgis Georgakoudis           InsertPointTy(ParentBB, ParentBB->end()), DL);
75397517055SGiorgis Georgakoudis       InsertPointTy SeqAfterIP =
75497517055SGiorgis Georgakoudis           OMPInfoCache.OMPBuilder.createMaster(Loc, BodyGenCB, FiniCB);
75597517055SGiorgis Georgakoudis 
75697517055SGiorgis Georgakoudis       OMPInfoCache.OMPBuilder.createBarrier(SeqAfterIP, OMPD_parallel);
75797517055SGiorgis Georgakoudis 
75897517055SGiorgis Georgakoudis       BranchInst::Create(SeqAfterBB, SeqAfterIP.getBlock());
75997517055SGiorgis Georgakoudis 
76097517055SGiorgis Georgakoudis       LLVM_DEBUG(dbgs() << TAG << "After sequential inlining " << *OuterFn
76197517055SGiorgis Georgakoudis                         << "\n");
76297517055SGiorgis Georgakoudis     };
76397517055SGiorgis Georgakoudis 
7643a6bfcf2SGiorgis Georgakoudis     // Helper to merge the __kmpc_fork_call calls in MergableCIs. They are all
7653a6bfcf2SGiorgis Georgakoudis     // contained in BB and only separated by instructions that can be
7663a6bfcf2SGiorgis Georgakoudis     // redundantly executed in parallel. The block BB is split before the first
7673a6bfcf2SGiorgis Georgakoudis     // call (in MergableCIs) and after the last so the entire region we merge
7683a6bfcf2SGiorgis Georgakoudis     // into a single parallel region is contained in a single basic block
7693a6bfcf2SGiorgis Georgakoudis     // without any other instructions. We use the OpenMPIRBuilder to outline
7703a6bfcf2SGiorgis Georgakoudis     // that block and call the resulting function via __kmpc_fork_call.
7713a6bfcf2SGiorgis Georgakoudis     auto Merge = [&](SmallVectorImpl<CallInst *> &MergableCIs, BasicBlock *BB) {
7723a6bfcf2SGiorgis Georgakoudis       // TODO: Change the interface to allow single CIs expanded, e.g, to
7733a6bfcf2SGiorgis Georgakoudis       // include an outer loop.
7743a6bfcf2SGiorgis Georgakoudis       assert(MergableCIs.size() > 1 && "Assumed multiple mergable CIs");
7753a6bfcf2SGiorgis Georgakoudis 
7763a6bfcf2SGiorgis Georgakoudis       auto Remark = [&](OptimizationRemark OR) {
7773a6bfcf2SGiorgis Georgakoudis         OR << "Parallel region at "
7783a6bfcf2SGiorgis Georgakoudis            << ore::NV("OpenMPParallelMergeFront",
7793a6bfcf2SGiorgis Georgakoudis                       MergableCIs.front()->getDebugLoc())
7803a6bfcf2SGiorgis Georgakoudis            << " merged with parallel regions at ";
78123b0ab2aSKazu Hirata         for (auto *CI : llvm::drop_begin(MergableCIs)) {
7823a6bfcf2SGiorgis Georgakoudis           OR << ore::NV("OpenMPParallelMerge", CI->getDebugLoc());
7833a6bfcf2SGiorgis Georgakoudis           if (CI != MergableCIs.back())
7843a6bfcf2SGiorgis Georgakoudis             OR << ", ";
7853a6bfcf2SGiorgis Georgakoudis         }
7863a6bfcf2SGiorgis Georgakoudis         return OR;
7873a6bfcf2SGiorgis Georgakoudis       };
7883a6bfcf2SGiorgis Georgakoudis 
7893a6bfcf2SGiorgis Georgakoudis       emitRemark<OptimizationRemark>(MergableCIs.front(),
7903a6bfcf2SGiorgis Georgakoudis                                      "OpenMPParallelRegionMerging", Remark);
7913a6bfcf2SGiorgis Georgakoudis 
7923a6bfcf2SGiorgis Georgakoudis       Function *OriginalFn = BB->getParent();
7933a6bfcf2SGiorgis Georgakoudis       LLVM_DEBUG(dbgs() << TAG << "Merge " << MergableCIs.size()
7943a6bfcf2SGiorgis Georgakoudis                         << " parallel regions in " << OriginalFn->getName()
7953a6bfcf2SGiorgis Georgakoudis                         << "\n");
7963a6bfcf2SGiorgis Georgakoudis 
7973a6bfcf2SGiorgis Georgakoudis       // Isolate the calls to merge in a separate block.
7983a6bfcf2SGiorgis Georgakoudis       EndBB = SplitBlock(BB, MergableCIs.back()->getNextNode(), DT, LI);
7993a6bfcf2SGiorgis Georgakoudis       BasicBlock *AfterBB =
8003a6bfcf2SGiorgis Georgakoudis           SplitBlock(EndBB, &*EndBB->getFirstInsertionPt(), DT, LI);
8013a6bfcf2SGiorgis Georgakoudis       StartBB = SplitBlock(BB, MergableCIs.front(), DT, LI, nullptr,
8023a6bfcf2SGiorgis Georgakoudis                            "omp.par.merged");
8033a6bfcf2SGiorgis Georgakoudis 
8043a6bfcf2SGiorgis Georgakoudis       assert(BB->getUniqueSuccessor() == StartBB && "Expected a different CFG");
8053a6bfcf2SGiorgis Georgakoudis       const DebugLoc DL = BB->getTerminator()->getDebugLoc();
8063a6bfcf2SGiorgis Georgakoudis       BB->getTerminator()->eraseFromParent();
8073a6bfcf2SGiorgis Georgakoudis 
80897517055SGiorgis Georgakoudis       // Create sequential regions for sequential instructions that are
80997517055SGiorgis Georgakoudis       // in-between mergable parallel regions.
81097517055SGiorgis Georgakoudis       for (auto *It = MergableCIs.begin(), *End = MergableCIs.end() - 1;
81197517055SGiorgis Georgakoudis            It != End; ++It) {
81297517055SGiorgis Georgakoudis         Instruction *ForkCI = *It;
81397517055SGiorgis Georgakoudis         Instruction *NextForkCI = *(It + 1);
81497517055SGiorgis Georgakoudis 
81597517055SGiorgis Georgakoudis         // Continue if there are not in-between instructions.
81697517055SGiorgis Georgakoudis         if (ForkCI->getNextNode() == NextForkCI)
81797517055SGiorgis Georgakoudis           continue;
81897517055SGiorgis Georgakoudis 
81997517055SGiorgis Georgakoudis         CreateSequentialRegion(OriginalFn, BB, ForkCI->getNextNode(),
82097517055SGiorgis Georgakoudis                                NextForkCI->getPrevNode());
82197517055SGiorgis Georgakoudis       }
82297517055SGiorgis Georgakoudis 
8233a6bfcf2SGiorgis Georgakoudis       OpenMPIRBuilder::LocationDescription Loc(InsertPointTy(BB, BB->end()),
8243a6bfcf2SGiorgis Georgakoudis                                                DL);
8253a6bfcf2SGiorgis Georgakoudis       IRBuilder<>::InsertPoint AllocaIP(
8263a6bfcf2SGiorgis Georgakoudis           &OriginalFn->getEntryBlock(),
8273a6bfcf2SGiorgis Georgakoudis           OriginalFn->getEntryBlock().getFirstInsertionPt());
8283a6bfcf2SGiorgis Georgakoudis       // Create the merged parallel region with default proc binding, to
8293a6bfcf2SGiorgis Georgakoudis       // avoid overriding binding settings, and without explicit cancellation.
830e5dba2d7SMichael Kruse       InsertPointTy AfterIP = OMPInfoCache.OMPBuilder.createParallel(
8313a6bfcf2SGiorgis Georgakoudis           Loc, AllocaIP, BodyGenCB, PrivCB, FiniCB, nullptr, nullptr,
8323a6bfcf2SGiorgis Georgakoudis           OMP_PROC_BIND_default, /* IsCancellable */ false);
8333a6bfcf2SGiorgis Georgakoudis       BranchInst::Create(AfterBB, AfterIP.getBlock());
8343a6bfcf2SGiorgis Georgakoudis 
8353a6bfcf2SGiorgis Georgakoudis       // Perform the actual outlining.
836b1191206SMichael Kruse       OMPInfoCache.OMPBuilder.finalize(OriginalFn,
837b1191206SMichael Kruse                                        /* AllowExtractorSinking */ true);
8383a6bfcf2SGiorgis Georgakoudis 
8393a6bfcf2SGiorgis Georgakoudis       Function *OutlinedFn = MergableCIs.front()->getCaller();
8403a6bfcf2SGiorgis Georgakoudis 
8413a6bfcf2SGiorgis Georgakoudis       // Replace the __kmpc_fork_call calls with direct calls to the outlined
8423a6bfcf2SGiorgis Georgakoudis       // callbacks.
8433a6bfcf2SGiorgis Georgakoudis       SmallVector<Value *, 8> Args;
8443a6bfcf2SGiorgis Georgakoudis       for (auto *CI : MergableCIs) {
8453a6bfcf2SGiorgis Georgakoudis         Value *Callee =
8463a6bfcf2SGiorgis Georgakoudis             CI->getArgOperand(CallbackCalleeOperand)->stripPointerCasts();
8473a6bfcf2SGiorgis Georgakoudis         FunctionType *FT =
8483a6bfcf2SGiorgis Georgakoudis             cast<FunctionType>(Callee->getType()->getPointerElementType());
8493a6bfcf2SGiorgis Georgakoudis         Args.clear();
8503a6bfcf2SGiorgis Georgakoudis         Args.push_back(OutlinedFn->getArg(0));
8513a6bfcf2SGiorgis Georgakoudis         Args.push_back(OutlinedFn->getArg(1));
8523a6bfcf2SGiorgis Georgakoudis         for (unsigned U = CallbackFirstArgOperand, E = CI->getNumArgOperands();
8533a6bfcf2SGiorgis Georgakoudis              U < E; ++U)
8543a6bfcf2SGiorgis Georgakoudis           Args.push_back(CI->getArgOperand(U));
8553a6bfcf2SGiorgis Georgakoudis 
8563a6bfcf2SGiorgis Georgakoudis         CallInst *NewCI = CallInst::Create(FT, Callee, Args, "", CI);
8573a6bfcf2SGiorgis Georgakoudis         if (CI->getDebugLoc())
8583a6bfcf2SGiorgis Georgakoudis           NewCI->setDebugLoc(CI->getDebugLoc());
8593a6bfcf2SGiorgis Georgakoudis 
8603a6bfcf2SGiorgis Georgakoudis         // Forward parameter attributes from the callback to the callee.
8613a6bfcf2SGiorgis Georgakoudis         for (unsigned U = CallbackFirstArgOperand, E = CI->getNumArgOperands();
8623a6bfcf2SGiorgis Georgakoudis              U < E; ++U)
8633a6bfcf2SGiorgis Georgakoudis           for (const Attribute &A : CI->getAttributes().getParamAttributes(U))
8643a6bfcf2SGiorgis Georgakoudis             NewCI->addParamAttr(
8653a6bfcf2SGiorgis Georgakoudis                 U - (CallbackFirstArgOperand - CallbackCalleeOperand), A);
8663a6bfcf2SGiorgis Georgakoudis 
8673a6bfcf2SGiorgis Georgakoudis         // Emit an explicit barrier to replace the implicit fork-join barrier.
8683a6bfcf2SGiorgis Georgakoudis         if (CI != MergableCIs.back()) {
8693a6bfcf2SGiorgis Georgakoudis           // TODO: Remove barrier if the merged parallel region includes the
8703a6bfcf2SGiorgis Georgakoudis           // 'nowait' clause.
871e5dba2d7SMichael Kruse           OMPInfoCache.OMPBuilder.createBarrier(
8723a6bfcf2SGiorgis Georgakoudis               InsertPointTy(NewCI->getParent(),
8733a6bfcf2SGiorgis Georgakoudis                             NewCI->getNextNode()->getIterator()),
8743a6bfcf2SGiorgis Georgakoudis               OMPD_parallel);
8753a6bfcf2SGiorgis Georgakoudis         }
8763a6bfcf2SGiorgis Georgakoudis 
8773a6bfcf2SGiorgis Georgakoudis         auto Remark = [&](OptimizationRemark OR) {
8783a6bfcf2SGiorgis Georgakoudis           return OR << "Parallel region at "
8793a6bfcf2SGiorgis Georgakoudis                     << ore::NV("OpenMPParallelMerge", CI->getDebugLoc())
8803a6bfcf2SGiorgis Georgakoudis                     << " merged with "
8813a6bfcf2SGiorgis Georgakoudis                     << ore::NV("OpenMPParallelMergeFront",
8823a6bfcf2SGiorgis Georgakoudis                                MergableCIs.front()->getDebugLoc());
8833a6bfcf2SGiorgis Georgakoudis         };
8843a6bfcf2SGiorgis Georgakoudis         if (CI != MergableCIs.front())
8853a6bfcf2SGiorgis Georgakoudis           emitRemark<OptimizationRemark>(CI, "OpenMPParallelRegionMerging",
8863a6bfcf2SGiorgis Georgakoudis                                          Remark);
8873a6bfcf2SGiorgis Georgakoudis 
8883a6bfcf2SGiorgis Georgakoudis         CI->eraseFromParent();
8893a6bfcf2SGiorgis Georgakoudis       }
8903a6bfcf2SGiorgis Georgakoudis 
8913a6bfcf2SGiorgis Georgakoudis       assert(OutlinedFn != OriginalFn && "Outlining failed");
8927fea561eSArthur Eubanks       CGUpdater.registerOutlinedFunction(*OriginalFn, *OutlinedFn);
8933a6bfcf2SGiorgis Georgakoudis       CGUpdater.reanalyzeFunction(*OriginalFn);
8943a6bfcf2SGiorgis Georgakoudis 
8953a6bfcf2SGiorgis Georgakoudis       NumOpenMPParallelRegionsMerged += MergableCIs.size();
8963a6bfcf2SGiorgis Georgakoudis 
8973a6bfcf2SGiorgis Georgakoudis       return true;
8983a6bfcf2SGiorgis Georgakoudis     };
8993a6bfcf2SGiorgis Georgakoudis 
9003a6bfcf2SGiorgis Georgakoudis     // Helper function that identifes sequences of
9013a6bfcf2SGiorgis Georgakoudis     // __kmpc_fork_call uses in a basic block.
9023a6bfcf2SGiorgis Georgakoudis     auto DetectPRsCB = [&](Use &U, Function &F) {
9033a6bfcf2SGiorgis Georgakoudis       CallInst *CI = getCallIfRegularCall(U, &RFI);
9043a6bfcf2SGiorgis Georgakoudis       BB2PRMap[CI->getParent()].insert(CI);
9053a6bfcf2SGiorgis Georgakoudis 
9063a6bfcf2SGiorgis Georgakoudis       return false;
9073a6bfcf2SGiorgis Georgakoudis     };
9083a6bfcf2SGiorgis Georgakoudis 
9093a6bfcf2SGiorgis Georgakoudis     BB2PRMap.clear();
9103a6bfcf2SGiorgis Georgakoudis     RFI.foreachUse(SCC, DetectPRsCB);
9113a6bfcf2SGiorgis Georgakoudis     SmallVector<SmallVector<CallInst *, 4>, 4> MergableCIsVector;
9123a6bfcf2SGiorgis Georgakoudis     // Find mergable parallel regions within a basic block that are
9133a6bfcf2SGiorgis Georgakoudis     // safe to merge, that is any in-between instructions can safely
9143a6bfcf2SGiorgis Georgakoudis     // execute in parallel after merging.
9153a6bfcf2SGiorgis Georgakoudis     // TODO: support merging across basic-blocks.
9163a6bfcf2SGiorgis Georgakoudis     for (auto &It : BB2PRMap) {
9173a6bfcf2SGiorgis Georgakoudis       auto &CIs = It.getSecond();
9183a6bfcf2SGiorgis Georgakoudis       if (CIs.size() < 2)
9193a6bfcf2SGiorgis Georgakoudis         continue;
9203a6bfcf2SGiorgis Georgakoudis 
9213a6bfcf2SGiorgis Georgakoudis       BasicBlock *BB = It.getFirst();
9223a6bfcf2SGiorgis Georgakoudis       SmallVector<CallInst *, 4> MergableCIs;
9233a6bfcf2SGiorgis Georgakoudis 
92497517055SGiorgis Georgakoudis       /// Returns true if the instruction is mergable, false otherwise.
92597517055SGiorgis Georgakoudis       /// A terminator instruction is unmergable by definition since merging
92697517055SGiorgis Georgakoudis       /// works within a BB. Instructions before the mergable region are
92797517055SGiorgis Georgakoudis       /// mergable if they are not calls to OpenMP runtime functions that may
92897517055SGiorgis Georgakoudis       /// set different execution parameters for subsequent parallel regions.
92997517055SGiorgis Georgakoudis       /// Instructions in-between parallel regions are mergable if they are not
93097517055SGiorgis Georgakoudis       /// calls to any non-intrinsic function since that may call a non-mergable
93197517055SGiorgis Georgakoudis       /// OpenMP runtime function.
93297517055SGiorgis Georgakoudis       auto IsMergable = [&](Instruction &I, bool IsBeforeMergableRegion) {
93397517055SGiorgis Georgakoudis         // We do not merge across BBs, hence return false (unmergable) if the
93497517055SGiorgis Georgakoudis         // instruction is a terminator.
93597517055SGiorgis Georgakoudis         if (I.isTerminator())
93697517055SGiorgis Georgakoudis           return false;
93797517055SGiorgis Georgakoudis 
93897517055SGiorgis Georgakoudis         if (!isa<CallInst>(&I))
93997517055SGiorgis Georgakoudis           return true;
94097517055SGiorgis Georgakoudis 
94197517055SGiorgis Georgakoudis         CallInst *CI = cast<CallInst>(&I);
94297517055SGiorgis Georgakoudis         if (IsBeforeMergableRegion) {
94397517055SGiorgis Georgakoudis           Function *CalledFunction = CI->getCalledFunction();
94497517055SGiorgis Georgakoudis           if (!CalledFunction)
94597517055SGiorgis Georgakoudis             return false;
94697517055SGiorgis Georgakoudis           // Return false (unmergable) if the call before the parallel
94797517055SGiorgis Georgakoudis           // region calls an explicit affinity (proc_bind) or number of
94897517055SGiorgis Georgakoudis           // threads (num_threads) compiler-generated function. Those settings
94997517055SGiorgis Georgakoudis           // may be incompatible with following parallel regions.
95097517055SGiorgis Georgakoudis           // TODO: ICV tracking to detect compatibility.
95197517055SGiorgis Georgakoudis           for (const auto &RFI : UnmergableCallsInfo) {
95297517055SGiorgis Georgakoudis             if (CalledFunction == RFI.Declaration)
95397517055SGiorgis Georgakoudis               return false;
95497517055SGiorgis Georgakoudis           }
95597517055SGiorgis Georgakoudis         } else {
95697517055SGiorgis Georgakoudis           // Return false (unmergable) if there is a call instruction
95797517055SGiorgis Georgakoudis           // in-between parallel regions when it is not an intrinsic. It
95897517055SGiorgis Georgakoudis           // may call an unmergable OpenMP runtime function in its callpath.
95997517055SGiorgis Georgakoudis           // TODO: Keep track of possible OpenMP calls in the callpath.
96097517055SGiorgis Georgakoudis           if (!isa<IntrinsicInst>(CI))
96197517055SGiorgis Georgakoudis             return false;
96297517055SGiorgis Georgakoudis         }
96397517055SGiorgis Georgakoudis 
96497517055SGiorgis Georgakoudis         return true;
96597517055SGiorgis Georgakoudis       };
9663a6bfcf2SGiorgis Georgakoudis       // Find maximal number of parallel region CIs that are safe to merge.
96797517055SGiorgis Georgakoudis       for (auto It = BB->begin(), End = BB->end(); It != End;) {
96897517055SGiorgis Georgakoudis         Instruction &I = *It;
96997517055SGiorgis Georgakoudis         ++It;
97097517055SGiorgis Georgakoudis 
9713a6bfcf2SGiorgis Georgakoudis         if (CIs.count(&I)) {
9723a6bfcf2SGiorgis Georgakoudis           MergableCIs.push_back(cast<CallInst>(&I));
9733a6bfcf2SGiorgis Georgakoudis           continue;
9743a6bfcf2SGiorgis Georgakoudis         }
9753a6bfcf2SGiorgis Georgakoudis 
97697517055SGiorgis Georgakoudis         // Continue expanding if the instruction is mergable.
97797517055SGiorgis Georgakoudis         if (IsMergable(I, MergableCIs.empty()))
9783a6bfcf2SGiorgis Georgakoudis           continue;
9793a6bfcf2SGiorgis Georgakoudis 
98097517055SGiorgis Georgakoudis         // Forward the instruction iterator to skip the next parallel region
98197517055SGiorgis Georgakoudis         // since there is an unmergable instruction which can affect it.
98297517055SGiorgis Georgakoudis         for (; It != End; ++It) {
98397517055SGiorgis Georgakoudis           Instruction &SkipI = *It;
98497517055SGiorgis Georgakoudis           if (CIs.count(&SkipI)) {
98597517055SGiorgis Georgakoudis             LLVM_DEBUG(dbgs() << TAG << "Skip parallel region " << SkipI
98697517055SGiorgis Georgakoudis                               << " due to " << I << "\n");
98797517055SGiorgis Georgakoudis             ++It;
98897517055SGiorgis Georgakoudis             break;
98997517055SGiorgis Georgakoudis           }
99097517055SGiorgis Georgakoudis         }
99197517055SGiorgis Georgakoudis 
99297517055SGiorgis Georgakoudis         // Store mergable regions found.
9933a6bfcf2SGiorgis Georgakoudis         if (MergableCIs.size() > 1) {
9943a6bfcf2SGiorgis Georgakoudis           MergableCIsVector.push_back(MergableCIs);
9953a6bfcf2SGiorgis Georgakoudis           LLVM_DEBUG(dbgs() << TAG << "Found " << MergableCIs.size()
9963a6bfcf2SGiorgis Georgakoudis                             << " parallel regions in block " << BB->getName()
9973a6bfcf2SGiorgis Georgakoudis                             << " of function " << BB->getParent()->getName()
9983a6bfcf2SGiorgis Georgakoudis                             << "\n";);
9993a6bfcf2SGiorgis Georgakoudis         }
10003a6bfcf2SGiorgis Georgakoudis 
10013a6bfcf2SGiorgis Georgakoudis         MergableCIs.clear();
10023a6bfcf2SGiorgis Georgakoudis       }
10033a6bfcf2SGiorgis Georgakoudis 
10043a6bfcf2SGiorgis Georgakoudis       if (!MergableCIsVector.empty()) {
10053a6bfcf2SGiorgis Georgakoudis         Changed = true;
10063a6bfcf2SGiorgis Georgakoudis 
10073a6bfcf2SGiorgis Georgakoudis         for (auto &MergableCIs : MergableCIsVector)
10083a6bfcf2SGiorgis Georgakoudis           Merge(MergableCIs, BB);
1009b2ad63d3SJoseph Huber         MergableCIsVector.clear();
10103a6bfcf2SGiorgis Georgakoudis       }
10113a6bfcf2SGiorgis Georgakoudis     }
10123a6bfcf2SGiorgis Georgakoudis 
10133a6bfcf2SGiorgis Georgakoudis     if (Changed) {
101497517055SGiorgis Georgakoudis       /// Re-collect use for fork calls, emitted barrier calls, and
101597517055SGiorgis Georgakoudis       /// any emitted master/end_master calls.
101697517055SGiorgis Georgakoudis       OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_fork_call);
101797517055SGiorgis Georgakoudis       OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_barrier);
101897517055SGiorgis Georgakoudis       OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_master);
101997517055SGiorgis Georgakoudis       OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_end_master);
10203a6bfcf2SGiorgis Georgakoudis     }
10213a6bfcf2SGiorgis Georgakoudis 
10223a6bfcf2SGiorgis Georgakoudis     return Changed;
10233a6bfcf2SGiorgis Georgakoudis   }
10243a6bfcf2SGiorgis Georgakoudis 
10259d38f98dSJohannes Doerfert   /// Try to delete parallel regions if possible.
1026e565db49SJohannes Doerfert   bool deleteParallelRegions() {
1027e565db49SJohannes Doerfert     const unsigned CallbackCalleeOperand = 2;
1028e565db49SJohannes Doerfert 
10297cfd267cSsstefan1     OMPInformationCache::RuntimeFunctionInfo &RFI =
10307cfd267cSsstefan1         OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
10317cfd267cSsstefan1 
1032e565db49SJohannes Doerfert     if (!RFI.Declaration)
1033e565db49SJohannes Doerfert       return false;
1034e565db49SJohannes Doerfert 
1035e565db49SJohannes Doerfert     bool Changed = false;
1036e565db49SJohannes Doerfert     auto DeleteCallCB = [&](Use &U, Function &) {
1037e565db49SJohannes Doerfert       CallInst *CI = getCallIfRegularCall(U);
1038e565db49SJohannes Doerfert       if (!CI)
1039e565db49SJohannes Doerfert         return false;
1040e565db49SJohannes Doerfert       auto *Fn = dyn_cast<Function>(
1041e565db49SJohannes Doerfert           CI->getArgOperand(CallbackCalleeOperand)->stripPointerCasts());
1042e565db49SJohannes Doerfert       if (!Fn)
1043e565db49SJohannes Doerfert         return false;
1044e565db49SJohannes Doerfert       if (!Fn->onlyReadsMemory())
1045e565db49SJohannes Doerfert         return false;
1046e565db49SJohannes Doerfert       if (!Fn->hasFnAttribute(Attribute::WillReturn))
1047e565db49SJohannes Doerfert         return false;
1048e565db49SJohannes Doerfert 
1049e565db49SJohannes Doerfert       LLVM_DEBUG(dbgs() << TAG << "Delete read-only parallel region in "
1050e565db49SJohannes Doerfert                         << CI->getCaller()->getName() << "\n");
10514d4ea9acSHuber, Joseph 
10524d4ea9acSHuber, Joseph       auto Remark = [&](OptimizationRemark OR) {
10534d4ea9acSHuber, Joseph         return OR << "Parallel region in "
10544d4ea9acSHuber, Joseph                   << ore::NV("OpenMPParallelDelete", CI->getCaller()->getName())
10554d4ea9acSHuber, Joseph                   << " deleted";
10564d4ea9acSHuber, Joseph       };
10574d4ea9acSHuber, Joseph       emitRemark<OptimizationRemark>(CI, "OpenMPParallelRegionDeletion",
10584d4ea9acSHuber, Joseph                                      Remark);
10594d4ea9acSHuber, Joseph 
1060e565db49SJohannes Doerfert       CGUpdater.removeCallSite(*CI);
1061e565db49SJohannes Doerfert       CI->eraseFromParent();
1062e565db49SJohannes Doerfert       Changed = true;
106355eb714aSRoman Lebedev       ++NumOpenMPParallelRegionsDeleted;
1064e565db49SJohannes Doerfert       return true;
1065e565db49SJohannes Doerfert     };
1066e565db49SJohannes Doerfert 
1067624d34afSJohannes Doerfert     RFI.foreachUse(SCC, DeleteCallCB);
1068e565db49SJohannes Doerfert 
1069e565db49SJohannes Doerfert     return Changed;
1070e565db49SJohannes Doerfert   }
1071e565db49SJohannes Doerfert 
1072b726c557SJohannes Doerfert   /// Try to eliminate runtime calls by reusing existing ones.
10739548b74aSJohannes Doerfert   bool deduplicateRuntimeCalls() {
10749548b74aSJohannes Doerfert     bool Changed = false;
10759548b74aSJohannes Doerfert 
1076e28936f6SJohannes Doerfert     RuntimeFunction DeduplicableRuntimeCallIDs[] = {
1077e28936f6SJohannes Doerfert         OMPRTL_omp_get_num_threads,
1078e28936f6SJohannes Doerfert         OMPRTL_omp_in_parallel,
1079e28936f6SJohannes Doerfert         OMPRTL_omp_get_cancellation,
1080e28936f6SJohannes Doerfert         OMPRTL_omp_get_thread_limit,
1081e28936f6SJohannes Doerfert         OMPRTL_omp_get_supported_active_levels,
1082e28936f6SJohannes Doerfert         OMPRTL_omp_get_level,
1083e28936f6SJohannes Doerfert         OMPRTL_omp_get_ancestor_thread_num,
1084e28936f6SJohannes Doerfert         OMPRTL_omp_get_team_size,
1085e28936f6SJohannes Doerfert         OMPRTL_omp_get_active_level,
1086e28936f6SJohannes Doerfert         OMPRTL_omp_in_final,
1087e28936f6SJohannes Doerfert         OMPRTL_omp_get_proc_bind,
1088e28936f6SJohannes Doerfert         OMPRTL_omp_get_num_places,
1089e28936f6SJohannes Doerfert         OMPRTL_omp_get_num_procs,
1090e28936f6SJohannes Doerfert         OMPRTL_omp_get_place_num,
1091e28936f6SJohannes Doerfert         OMPRTL_omp_get_partition_num_places,
1092e28936f6SJohannes Doerfert         OMPRTL_omp_get_partition_place_nums};
1093e28936f6SJohannes Doerfert 
1094bc93c2d7SMarek Kurdej     // Global-tid is handled separately.
10959548b74aSJohannes Doerfert     SmallSetVector<Value *, 16> GTIdArgs;
10969548b74aSJohannes Doerfert     collectGlobalThreadIdArguments(GTIdArgs);
10979548b74aSJohannes Doerfert     LLVM_DEBUG(dbgs() << TAG << "Found " << GTIdArgs.size()
10989548b74aSJohannes Doerfert                       << " global thread ID arguments\n");
10999548b74aSJohannes Doerfert 
11009548b74aSJohannes Doerfert     for (Function *F : SCC) {
1101e28936f6SJohannes Doerfert       for (auto DeduplicableRuntimeCallID : DeduplicableRuntimeCallIDs)
11024e29d256Sserge-sans-paille         Changed |= deduplicateRuntimeCalls(
11034e29d256Sserge-sans-paille             *F, OMPInfoCache.RFIs[DeduplicableRuntimeCallID]);
1104e28936f6SJohannes Doerfert 
1105e28936f6SJohannes Doerfert       // __kmpc_global_thread_num is special as we can replace it with an
1106e28936f6SJohannes Doerfert       // argument in enough cases to make it worth trying.
11079548b74aSJohannes Doerfert       Value *GTIdArg = nullptr;
11089548b74aSJohannes Doerfert       for (Argument &Arg : F->args())
11099548b74aSJohannes Doerfert         if (GTIdArgs.count(&Arg)) {
11109548b74aSJohannes Doerfert           GTIdArg = &Arg;
11119548b74aSJohannes Doerfert           break;
11129548b74aSJohannes Doerfert         }
11139548b74aSJohannes Doerfert       Changed |= deduplicateRuntimeCalls(
11147cfd267cSsstefan1           *F, OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num], GTIdArg);
11159548b74aSJohannes Doerfert     }
11169548b74aSJohannes Doerfert 
11179548b74aSJohannes Doerfert     return Changed;
11189548b74aSJohannes Doerfert   }
11199548b74aSJohannes Doerfert 
1120496f8e5bSHamilton Tobon Mosquera   /// Tries to hide the latency of runtime calls that involve host to
1121496f8e5bSHamilton Tobon Mosquera   /// device memory transfers by splitting them into their "issue" and "wait"
1122496f8e5bSHamilton Tobon Mosquera   /// versions. The "issue" is moved upwards as much as possible. The "wait" is
1123496f8e5bSHamilton Tobon Mosquera   /// moved downards as much as possible. The "issue" issues the memory transfer
1124496f8e5bSHamilton Tobon Mosquera   /// asynchronously, returning a handle. The "wait" waits in the returned
1125496f8e5bSHamilton Tobon Mosquera   /// handle for the memory transfer to finish.
1126496f8e5bSHamilton Tobon Mosquera   bool hideMemTransfersLatency() {
1127496f8e5bSHamilton Tobon Mosquera     auto &RFI = OMPInfoCache.RFIs[OMPRTL___tgt_target_data_begin_mapper];
1128496f8e5bSHamilton Tobon Mosquera     bool Changed = false;
1129496f8e5bSHamilton Tobon Mosquera     auto SplitMemTransfers = [&](Use &U, Function &Decl) {
1130496f8e5bSHamilton Tobon Mosquera       auto *RTCall = getCallIfRegularCall(U, &RFI);
1131496f8e5bSHamilton Tobon Mosquera       if (!RTCall)
1132496f8e5bSHamilton Tobon Mosquera         return false;
1133496f8e5bSHamilton Tobon Mosquera 
11348931add6SHamilton Tobon Mosquera       OffloadArray OffloadArrays[3];
11358931add6SHamilton Tobon Mosquera       if (!getValuesInOffloadArrays(*RTCall, OffloadArrays))
11368931add6SHamilton Tobon Mosquera         return false;
11378931add6SHamilton Tobon Mosquera 
11388931add6SHamilton Tobon Mosquera       LLVM_DEBUG(dumpValuesInOffloadArrays(OffloadArrays));
11398931add6SHamilton Tobon Mosquera 
1140bd2fa181SHamilton Tobon Mosquera       // TODO: Check if can be moved upwards.
1141bd2fa181SHamilton Tobon Mosquera       bool WasSplit = false;
1142bd2fa181SHamilton Tobon Mosquera       Instruction *WaitMovementPoint = canBeMovedDownwards(*RTCall);
1143bd2fa181SHamilton Tobon Mosquera       if (WaitMovementPoint)
1144bd2fa181SHamilton Tobon Mosquera         WasSplit = splitTargetDataBeginRTC(*RTCall, *WaitMovementPoint);
1145bd2fa181SHamilton Tobon Mosquera 
1146496f8e5bSHamilton Tobon Mosquera       Changed |= WasSplit;
1147496f8e5bSHamilton Tobon Mosquera       return WasSplit;
1148496f8e5bSHamilton Tobon Mosquera     };
1149496f8e5bSHamilton Tobon Mosquera     RFI.foreachUse(SCC, SplitMemTransfers);
1150496f8e5bSHamilton Tobon Mosquera 
1151496f8e5bSHamilton Tobon Mosquera     return Changed;
1152496f8e5bSHamilton Tobon Mosquera   }
1153496f8e5bSHamilton Tobon Mosquera 
1154a2281419SJoseph Huber   void analysisGlobalization() {
115582453e75SJoseph Huber     RuntimeFunction GlobalizationRuntimeIDs[] = {
115682453e75SJoseph Huber         OMPRTL___kmpc_data_sharing_coalesced_push_stack,
115782453e75SJoseph Huber         OMPRTL___kmpc_data_sharing_push_stack};
1158a2281419SJoseph Huber 
115982453e75SJoseph Huber     for (const auto GlobalizationCallID : GlobalizationRuntimeIDs) {
116082453e75SJoseph Huber       auto &RFI = OMPInfoCache.RFIs[GlobalizationCallID];
116182453e75SJoseph Huber 
116282453e75SJoseph Huber       auto CheckGlobalization = [&](Use &U, Function &Decl) {
1163a2281419SJoseph Huber         if (CallInst *CI = getCallIfRegularCall(U, &RFI)) {
1164a2281419SJoseph Huber           auto Remark = [&](OptimizationRemarkAnalysis ORA) {
1165a2281419SJoseph Huber             return ORA
1166a2281419SJoseph Huber                    << "Found thread data sharing on the GPU. "
1167a2281419SJoseph Huber                    << "Expect degraded performance due to data globalization.";
1168a2281419SJoseph Huber           };
1169a2281419SJoseph Huber           emitRemark<OptimizationRemarkAnalysis>(CI, "OpenMPGlobalization",
1170a2281419SJoseph Huber                                                  Remark);
1171a2281419SJoseph Huber         }
1172a2281419SJoseph Huber 
1173a2281419SJoseph Huber         return false;
1174a2281419SJoseph Huber       };
1175a2281419SJoseph Huber 
117682453e75SJoseph Huber       RFI.foreachUse(SCC, CheckGlobalization);
117782453e75SJoseph Huber     }
1178a2281419SJoseph Huber   }
1179a2281419SJoseph Huber 
11808931add6SHamilton Tobon Mosquera   /// Maps the values stored in the offload arrays passed as arguments to
11818931add6SHamilton Tobon Mosquera   /// \p RuntimeCall into the offload arrays in \p OAs.
11828931add6SHamilton Tobon Mosquera   bool getValuesInOffloadArrays(CallInst &RuntimeCall,
11838931add6SHamilton Tobon Mosquera                                 MutableArrayRef<OffloadArray> OAs) {
11848931add6SHamilton Tobon Mosquera     assert(OAs.size() == 3 && "Need space for three offload arrays!");
11858931add6SHamilton Tobon Mosquera 
11868931add6SHamilton Tobon Mosquera     // A runtime call that involves memory offloading looks something like:
11878931add6SHamilton Tobon Mosquera     // call void @__tgt_target_data_begin_mapper(arg0, arg1,
11888931add6SHamilton Tobon Mosquera     //   i8** %offload_baseptrs, i8** %offload_ptrs, i64* %offload_sizes,
11898931add6SHamilton Tobon Mosquera     // ...)
11908931add6SHamilton Tobon Mosquera     // So, the idea is to access the allocas that allocate space for these
11918931add6SHamilton Tobon Mosquera     // offload arrays, offload_baseptrs, offload_ptrs, offload_sizes.
11928931add6SHamilton Tobon Mosquera     // Therefore:
11938931add6SHamilton Tobon Mosquera     // i8** %offload_baseptrs.
11941d3d9b9cSHamilton Tobon Mosquera     Value *BasePtrsArg =
11951d3d9b9cSHamilton Tobon Mosquera         RuntimeCall.getArgOperand(OffloadArray::BasePtrsArgNum);
11968931add6SHamilton Tobon Mosquera     // i8** %offload_ptrs.
11971d3d9b9cSHamilton Tobon Mosquera     Value *PtrsArg = RuntimeCall.getArgOperand(OffloadArray::PtrsArgNum);
11988931add6SHamilton Tobon Mosquera     // i8** %offload_sizes.
11991d3d9b9cSHamilton Tobon Mosquera     Value *SizesArg = RuntimeCall.getArgOperand(OffloadArray::SizesArgNum);
12008931add6SHamilton Tobon Mosquera 
12018931add6SHamilton Tobon Mosquera     // Get values stored in **offload_baseptrs.
12028931add6SHamilton Tobon Mosquera     auto *V = getUnderlyingObject(BasePtrsArg);
12038931add6SHamilton Tobon Mosquera     if (!isa<AllocaInst>(V))
12048931add6SHamilton Tobon Mosquera       return false;
12058931add6SHamilton Tobon Mosquera     auto *BasePtrsArray = cast<AllocaInst>(V);
12068931add6SHamilton Tobon Mosquera     if (!OAs[0].initialize(*BasePtrsArray, RuntimeCall))
12078931add6SHamilton Tobon Mosquera       return false;
12088931add6SHamilton Tobon Mosquera 
12098931add6SHamilton Tobon Mosquera     // Get values stored in **offload_baseptrs.
12108931add6SHamilton Tobon Mosquera     V = getUnderlyingObject(PtrsArg);
12118931add6SHamilton Tobon Mosquera     if (!isa<AllocaInst>(V))
12128931add6SHamilton Tobon Mosquera       return false;
12138931add6SHamilton Tobon Mosquera     auto *PtrsArray = cast<AllocaInst>(V);
12148931add6SHamilton Tobon Mosquera     if (!OAs[1].initialize(*PtrsArray, RuntimeCall))
12158931add6SHamilton Tobon Mosquera       return false;
12168931add6SHamilton Tobon Mosquera 
12178931add6SHamilton Tobon Mosquera     // Get values stored in **offload_sizes.
12188931add6SHamilton Tobon Mosquera     V = getUnderlyingObject(SizesArg);
12198931add6SHamilton Tobon Mosquera     // If it's a [constant] global array don't analyze it.
12208931add6SHamilton Tobon Mosquera     if (isa<GlobalValue>(V))
12218931add6SHamilton Tobon Mosquera       return isa<Constant>(V);
12228931add6SHamilton Tobon Mosquera     if (!isa<AllocaInst>(V))
12238931add6SHamilton Tobon Mosquera       return false;
12248931add6SHamilton Tobon Mosquera 
12258931add6SHamilton Tobon Mosquera     auto *SizesArray = cast<AllocaInst>(V);
12268931add6SHamilton Tobon Mosquera     if (!OAs[2].initialize(*SizesArray, RuntimeCall))
12278931add6SHamilton Tobon Mosquera       return false;
12288931add6SHamilton Tobon Mosquera 
12298931add6SHamilton Tobon Mosquera     return true;
12308931add6SHamilton Tobon Mosquera   }
12318931add6SHamilton Tobon Mosquera 
12328931add6SHamilton Tobon Mosquera   /// Prints the values in the OffloadArrays \p OAs using LLVM_DEBUG.
12338931add6SHamilton Tobon Mosquera   /// For now this is a way to test that the function getValuesInOffloadArrays
12348931add6SHamilton Tobon Mosquera   /// is working properly.
12358931add6SHamilton Tobon Mosquera   /// TODO: Move this to a unittest when unittests are available for OpenMPOpt.
12368931add6SHamilton Tobon Mosquera   void dumpValuesInOffloadArrays(ArrayRef<OffloadArray> OAs) {
12378931add6SHamilton Tobon Mosquera     assert(OAs.size() == 3 && "There are three offload arrays to debug!");
12388931add6SHamilton Tobon Mosquera 
12398931add6SHamilton Tobon Mosquera     LLVM_DEBUG(dbgs() << TAG << " Successfully got offload values:\n");
12408931add6SHamilton Tobon Mosquera     std::string ValuesStr;
12418931add6SHamilton Tobon Mosquera     raw_string_ostream Printer(ValuesStr);
12428931add6SHamilton Tobon Mosquera     std::string Separator = " --- ";
12438931add6SHamilton Tobon Mosquera 
12448931add6SHamilton Tobon Mosquera     for (auto *BP : OAs[0].StoredValues) {
12458931add6SHamilton Tobon Mosquera       BP->print(Printer);
12468931add6SHamilton Tobon Mosquera       Printer << Separator;
12478931add6SHamilton Tobon Mosquera     }
12488931add6SHamilton Tobon Mosquera     LLVM_DEBUG(dbgs() << "\t\toffload_baseptrs: " << Printer.str() << "\n");
12498931add6SHamilton Tobon Mosquera     ValuesStr.clear();
12508931add6SHamilton Tobon Mosquera 
12518931add6SHamilton Tobon Mosquera     for (auto *P : OAs[1].StoredValues) {
12528931add6SHamilton Tobon Mosquera       P->print(Printer);
12538931add6SHamilton Tobon Mosquera       Printer << Separator;
12548931add6SHamilton Tobon Mosquera     }
12558931add6SHamilton Tobon Mosquera     LLVM_DEBUG(dbgs() << "\t\toffload_ptrs: " << Printer.str() << "\n");
12568931add6SHamilton Tobon Mosquera     ValuesStr.clear();
12578931add6SHamilton Tobon Mosquera 
12588931add6SHamilton Tobon Mosquera     for (auto *S : OAs[2].StoredValues) {
12598931add6SHamilton Tobon Mosquera       S->print(Printer);
12608931add6SHamilton Tobon Mosquera       Printer << Separator;
12618931add6SHamilton Tobon Mosquera     }
12628931add6SHamilton Tobon Mosquera     LLVM_DEBUG(dbgs() << "\t\toffload_sizes: " << Printer.str() << "\n");
12638931add6SHamilton Tobon Mosquera   }
12648931add6SHamilton Tobon Mosquera 
1265bd2fa181SHamilton Tobon Mosquera   /// Returns the instruction where the "wait" counterpart \p RuntimeCall can be
1266bd2fa181SHamilton Tobon Mosquera   /// moved. Returns nullptr if the movement is not possible, or not worth it.
1267bd2fa181SHamilton Tobon Mosquera   Instruction *canBeMovedDownwards(CallInst &RuntimeCall) {
1268bd2fa181SHamilton Tobon Mosquera     // FIXME: This traverses only the BasicBlock where RuntimeCall is.
1269bd2fa181SHamilton Tobon Mosquera     //  Make it traverse the CFG.
1270bd2fa181SHamilton Tobon Mosquera 
1271bd2fa181SHamilton Tobon Mosquera     Instruction *CurrentI = &RuntimeCall;
1272bd2fa181SHamilton Tobon Mosquera     bool IsWorthIt = false;
1273bd2fa181SHamilton Tobon Mosquera     while ((CurrentI = CurrentI->getNextNode())) {
1274bd2fa181SHamilton Tobon Mosquera 
1275bd2fa181SHamilton Tobon Mosquera       // TODO: Once we detect the regions to be offloaded we should use the
1276bd2fa181SHamilton Tobon Mosquera       //  alias analysis manager to check if CurrentI may modify one of
1277bd2fa181SHamilton Tobon Mosquera       //  the offloaded regions.
1278bd2fa181SHamilton Tobon Mosquera       if (CurrentI->mayHaveSideEffects() || CurrentI->mayReadFromMemory()) {
1279bd2fa181SHamilton Tobon Mosquera         if (IsWorthIt)
1280bd2fa181SHamilton Tobon Mosquera           return CurrentI;
1281bd2fa181SHamilton Tobon Mosquera 
1282bd2fa181SHamilton Tobon Mosquera         return nullptr;
1283bd2fa181SHamilton Tobon Mosquera       }
1284bd2fa181SHamilton Tobon Mosquera 
1285bd2fa181SHamilton Tobon Mosquera       // FIXME: For now if we move it over anything without side effect
1286bd2fa181SHamilton Tobon Mosquera       //  is worth it.
1287bd2fa181SHamilton Tobon Mosquera       IsWorthIt = true;
1288bd2fa181SHamilton Tobon Mosquera     }
1289bd2fa181SHamilton Tobon Mosquera 
1290bd2fa181SHamilton Tobon Mosquera     // Return end of BasicBlock.
1291bd2fa181SHamilton Tobon Mosquera     return RuntimeCall.getParent()->getTerminator();
1292bd2fa181SHamilton Tobon Mosquera   }
1293bd2fa181SHamilton Tobon Mosquera 
1294496f8e5bSHamilton Tobon Mosquera   /// Splits \p RuntimeCall into its "issue" and "wait" counterparts.
1295bd2fa181SHamilton Tobon Mosquera   bool splitTargetDataBeginRTC(CallInst &RuntimeCall,
1296bd2fa181SHamilton Tobon Mosquera                                Instruction &WaitMovementPoint) {
1297bd31abc1SHamilton Tobon Mosquera     // Create stack allocated handle (__tgt_async_info) at the beginning of the
1298bd31abc1SHamilton Tobon Mosquera     // function. Used for storing information of the async transfer, allowing to
1299bd31abc1SHamilton Tobon Mosquera     // wait on it later.
1300496f8e5bSHamilton Tobon Mosquera     auto &IRBuilder = OMPInfoCache.OMPBuilder;
1301bd31abc1SHamilton Tobon Mosquera     auto *F = RuntimeCall.getCaller();
1302bd31abc1SHamilton Tobon Mosquera     Instruction *FirstInst = &(F->getEntryBlock().front());
1303bd31abc1SHamilton Tobon Mosquera     AllocaInst *Handle = new AllocaInst(
1304bd31abc1SHamilton Tobon Mosquera         IRBuilder.AsyncInfo, F->getAddressSpace(), "handle", FirstInst);
1305bd31abc1SHamilton Tobon Mosquera 
1306496f8e5bSHamilton Tobon Mosquera     // Add "issue" runtime call declaration:
1307496f8e5bSHamilton Tobon Mosquera     // declare %struct.tgt_async_info @__tgt_target_data_begin_issue(i64, i32,
1308496f8e5bSHamilton Tobon Mosquera     //   i8**, i8**, i64*, i64*)
1309496f8e5bSHamilton Tobon Mosquera     FunctionCallee IssueDecl = IRBuilder.getOrCreateRuntimeFunction(
1310496f8e5bSHamilton Tobon Mosquera         M, OMPRTL___tgt_target_data_begin_mapper_issue);
1311496f8e5bSHamilton Tobon Mosquera 
1312496f8e5bSHamilton Tobon Mosquera     // Change RuntimeCall call site for its asynchronous version.
131397e55cfeSJoseph Huber     SmallVector<Value *, 16> Args;
1314bd2fa181SHamilton Tobon Mosquera     for (auto &Arg : RuntimeCall.args())
1315496f8e5bSHamilton Tobon Mosquera       Args.push_back(Arg.get());
1316bd31abc1SHamilton Tobon Mosquera     Args.push_back(Handle);
1317496f8e5bSHamilton Tobon Mosquera 
1318496f8e5bSHamilton Tobon Mosquera     CallInst *IssueCallsite =
1319bd31abc1SHamilton Tobon Mosquera         CallInst::Create(IssueDecl, Args, /*NameStr=*/"", &RuntimeCall);
1320bd2fa181SHamilton Tobon Mosquera     RuntimeCall.eraseFromParent();
1321496f8e5bSHamilton Tobon Mosquera 
1322496f8e5bSHamilton Tobon Mosquera     // Add "wait" runtime call declaration:
1323496f8e5bSHamilton Tobon Mosquera     // declare void @__tgt_target_data_begin_wait(i64, %struct.__tgt_async_info)
1324496f8e5bSHamilton Tobon Mosquera     FunctionCallee WaitDecl = IRBuilder.getOrCreateRuntimeFunction(
1325496f8e5bSHamilton Tobon Mosquera         M, OMPRTL___tgt_target_data_begin_mapper_wait);
1326496f8e5bSHamilton Tobon Mosquera 
1327496f8e5bSHamilton Tobon Mosquera     Value *WaitParams[2] = {
1328da8bec47SJoseph Huber         IssueCallsite->getArgOperand(
1329da8bec47SJoseph Huber             OffloadArray::DeviceIDArgNum), // device_id.
1330bd31abc1SHamilton Tobon Mosquera         Handle                             // handle to wait on.
1331496f8e5bSHamilton Tobon Mosquera     };
1332bd2fa181SHamilton Tobon Mosquera     CallInst::Create(WaitDecl, WaitParams, /*NameStr=*/"", &WaitMovementPoint);
1333496f8e5bSHamilton Tobon Mosquera 
1334496f8e5bSHamilton Tobon Mosquera     return true;
1335496f8e5bSHamilton Tobon Mosquera   }
1336496f8e5bSHamilton Tobon Mosquera 
1337dc3b5b00SJohannes Doerfert   static Value *combinedIdentStruct(Value *CurrentIdent, Value *NextIdent,
1338dc3b5b00SJohannes Doerfert                                     bool GlobalOnly, bool &SingleChoice) {
1339dc3b5b00SJohannes Doerfert     if (CurrentIdent == NextIdent)
1340dc3b5b00SJohannes Doerfert       return CurrentIdent;
1341dc3b5b00SJohannes Doerfert 
1342396b7253SJohannes Doerfert     // TODO: Figure out how to actually combine multiple debug locations. For
1343dc3b5b00SJohannes Doerfert     //       now we just keep an existing one if there is a single choice.
1344dc3b5b00SJohannes Doerfert     if (!GlobalOnly || isa<GlobalValue>(NextIdent)) {
1345dc3b5b00SJohannes Doerfert       SingleChoice = !CurrentIdent;
1346dc3b5b00SJohannes Doerfert       return NextIdent;
1347dc3b5b00SJohannes Doerfert     }
1348396b7253SJohannes Doerfert     return nullptr;
1349396b7253SJohannes Doerfert   }
1350396b7253SJohannes Doerfert 
1351396b7253SJohannes Doerfert   /// Return an `struct ident_t*` value that represents the ones used in the
1352396b7253SJohannes Doerfert   /// calls of \p RFI inside of \p F. If \p GlobalOnly is true, we will not
1353396b7253SJohannes Doerfert   /// return a local `struct ident_t*`. For now, if we cannot find a suitable
1354396b7253SJohannes Doerfert   /// return value we create one from scratch. We also do not yet combine
1355396b7253SJohannes Doerfert   /// information, e.g., the source locations, see combinedIdentStruct.
13567cfd267cSsstefan1   Value *
13577cfd267cSsstefan1   getCombinedIdentFromCallUsesIn(OMPInformationCache::RuntimeFunctionInfo &RFI,
13587cfd267cSsstefan1                                  Function &F, bool GlobalOnly) {
1359dc3b5b00SJohannes Doerfert     bool SingleChoice = true;
1360396b7253SJohannes Doerfert     Value *Ident = nullptr;
1361396b7253SJohannes Doerfert     auto CombineIdentStruct = [&](Use &U, Function &Caller) {
1362396b7253SJohannes Doerfert       CallInst *CI = getCallIfRegularCall(U, &RFI);
1363396b7253SJohannes Doerfert       if (!CI || &F != &Caller)
1364396b7253SJohannes Doerfert         return false;
1365396b7253SJohannes Doerfert       Ident = combinedIdentStruct(Ident, CI->getArgOperand(0),
1366dc3b5b00SJohannes Doerfert                                   /* GlobalOnly */ true, SingleChoice);
1367396b7253SJohannes Doerfert       return false;
1368396b7253SJohannes Doerfert     };
1369624d34afSJohannes Doerfert     RFI.foreachUse(SCC, CombineIdentStruct);
1370396b7253SJohannes Doerfert 
1371dc3b5b00SJohannes Doerfert     if (!Ident || !SingleChoice) {
1372396b7253SJohannes Doerfert       // The IRBuilder uses the insertion block to get to the module, this is
1373396b7253SJohannes Doerfert       // unfortunate but we work around it for now.
13747cfd267cSsstefan1       if (!OMPInfoCache.OMPBuilder.getInsertionPoint().getBlock())
13757cfd267cSsstefan1         OMPInfoCache.OMPBuilder.updateToLocation(OpenMPIRBuilder::InsertPointTy(
1376396b7253SJohannes Doerfert             &F.getEntryBlock(), F.getEntryBlock().begin()));
1377396b7253SJohannes Doerfert       // Create a fallback location if non was found.
1378396b7253SJohannes Doerfert       // TODO: Use the debug locations of the calls instead.
13797cfd267cSsstefan1       Constant *Loc = OMPInfoCache.OMPBuilder.getOrCreateDefaultSrcLocStr();
13807cfd267cSsstefan1       Ident = OMPInfoCache.OMPBuilder.getOrCreateIdent(Loc);
1381396b7253SJohannes Doerfert     }
1382396b7253SJohannes Doerfert     return Ident;
1383396b7253SJohannes Doerfert   }
1384396b7253SJohannes Doerfert 
1385b726c557SJohannes Doerfert   /// Try to eliminate calls of \p RFI in \p F by reusing an existing one or
13869548b74aSJohannes Doerfert   /// \p ReplVal if given.
13877cfd267cSsstefan1   bool deduplicateRuntimeCalls(Function &F,
13887cfd267cSsstefan1                                OMPInformationCache::RuntimeFunctionInfo &RFI,
13899548b74aSJohannes Doerfert                                Value *ReplVal = nullptr) {
13908855fec3SJohannes Doerfert     auto *UV = RFI.getUseVector(F);
13918855fec3SJohannes Doerfert     if (!UV || UV->size() + (ReplVal != nullptr) < 2)
1392b1fbf438SRoman Lebedev       return false;
1393b1fbf438SRoman Lebedev 
13947cfd267cSsstefan1     LLVM_DEBUG(
13957cfd267cSsstefan1         dbgs() << TAG << "Deduplicate " << UV->size() << " uses of " << RFI.Name
13967cfd267cSsstefan1                << (ReplVal ? " with an existing value\n" : "\n") << "\n");
13977cfd267cSsstefan1 
1398ab3da5ddSMichael Liao     assert((!ReplVal || (isa<Argument>(ReplVal) &&
1399ab3da5ddSMichael Liao                          cast<Argument>(ReplVal)->getParent() == &F)) &&
14009548b74aSJohannes Doerfert            "Unexpected replacement value!");
1401396b7253SJohannes Doerfert 
1402396b7253SJohannes Doerfert     // TODO: Use dominance to find a good position instead.
14036aab27baSsstefan1     auto CanBeMoved = [this](CallBase &CB) {
1404396b7253SJohannes Doerfert       unsigned NumArgs = CB.getNumArgOperands();
1405396b7253SJohannes Doerfert       if (NumArgs == 0)
1406396b7253SJohannes Doerfert         return true;
14076aab27baSsstefan1       if (CB.getArgOperand(0)->getType() != OMPInfoCache.OMPBuilder.IdentPtr)
1408396b7253SJohannes Doerfert         return false;
1409396b7253SJohannes Doerfert       for (unsigned u = 1; u < NumArgs; ++u)
1410396b7253SJohannes Doerfert         if (isa<Instruction>(CB.getArgOperand(u)))
1411396b7253SJohannes Doerfert           return false;
1412396b7253SJohannes Doerfert       return true;
1413396b7253SJohannes Doerfert     };
1414396b7253SJohannes Doerfert 
14159548b74aSJohannes Doerfert     if (!ReplVal) {
14168855fec3SJohannes Doerfert       for (Use *U : *UV)
14179548b74aSJohannes Doerfert         if (CallInst *CI = getCallIfRegularCall(*U, &RFI)) {
1418396b7253SJohannes Doerfert           if (!CanBeMoved(*CI))
1419396b7253SJohannes Doerfert             continue;
14204d4ea9acSHuber, Joseph 
14214d4ea9acSHuber, Joseph           auto Remark = [&](OptimizationRemark OR) {
14224d4ea9acSHuber, Joseph             return OR << "OpenMP runtime call "
1423*2db182ffSJoseph Huber                       << ore::NV("OpenMPOptRuntime", RFI.Name)
1424*2db182ffSJoseph Huber                       << " moved to beginning of OpenMP region";
14254d4ea9acSHuber, Joseph           };
1426*2db182ffSJoseph Huber           emitRemark<OptimizationRemark>(&F, "OpenMPRuntimeCodeMotion", Remark);
14274d4ea9acSHuber, Joseph 
14289548b74aSJohannes Doerfert           CI->moveBefore(&*F.getEntryBlock().getFirstInsertionPt());
14299548b74aSJohannes Doerfert           ReplVal = CI;
14309548b74aSJohannes Doerfert           break;
14319548b74aSJohannes Doerfert         }
14329548b74aSJohannes Doerfert       if (!ReplVal)
14339548b74aSJohannes Doerfert         return false;
14349548b74aSJohannes Doerfert     }
14359548b74aSJohannes Doerfert 
1436396b7253SJohannes Doerfert     // If we use a call as a replacement value we need to make sure the ident is
1437396b7253SJohannes Doerfert     // valid at the new location. For now we just pick a global one, either
1438396b7253SJohannes Doerfert     // existing and used by one of the calls, or created from scratch.
1439396b7253SJohannes Doerfert     if (CallBase *CI = dyn_cast<CallBase>(ReplVal)) {
1440396b7253SJohannes Doerfert       if (CI->getNumArgOperands() > 0 &&
14416aab27baSsstefan1           CI->getArgOperand(0)->getType() == OMPInfoCache.OMPBuilder.IdentPtr) {
1442396b7253SJohannes Doerfert         Value *Ident = getCombinedIdentFromCallUsesIn(RFI, F,
1443396b7253SJohannes Doerfert                                                       /* GlobalOnly */ true);
1444396b7253SJohannes Doerfert         CI->setArgOperand(0, Ident);
1445396b7253SJohannes Doerfert       }
1446396b7253SJohannes Doerfert     }
1447396b7253SJohannes Doerfert 
14489548b74aSJohannes Doerfert     bool Changed = false;
14499548b74aSJohannes Doerfert     auto ReplaceAndDeleteCB = [&](Use &U, Function &Caller) {
14509548b74aSJohannes Doerfert       CallInst *CI = getCallIfRegularCall(U, &RFI);
14519548b74aSJohannes Doerfert       if (!CI || CI == ReplVal || &F != &Caller)
14529548b74aSJohannes Doerfert         return false;
14539548b74aSJohannes Doerfert       assert(CI->getCaller() == &F && "Unexpected call!");
14544d4ea9acSHuber, Joseph 
14554d4ea9acSHuber, Joseph       auto Remark = [&](OptimizationRemark OR) {
14564d4ea9acSHuber, Joseph         return OR << "OpenMP runtime call "
14574d4ea9acSHuber, Joseph                   << ore::NV("OpenMPOptRuntime", RFI.Name) << " deduplicated";
14584d4ea9acSHuber, Joseph       };
1459*2db182ffSJoseph Huber       emitRemark<OptimizationRemark>(&F, "OpenMPRuntimeDeduplicated", Remark);
14604d4ea9acSHuber, Joseph 
14619548b74aSJohannes Doerfert       CGUpdater.removeCallSite(*CI);
14629548b74aSJohannes Doerfert       CI->replaceAllUsesWith(ReplVal);
14639548b74aSJohannes Doerfert       CI->eraseFromParent();
14649548b74aSJohannes Doerfert       ++NumOpenMPRuntimeCallsDeduplicated;
14659548b74aSJohannes Doerfert       Changed = true;
14669548b74aSJohannes Doerfert       return true;
14679548b74aSJohannes Doerfert     };
1468624d34afSJohannes Doerfert     RFI.foreachUse(SCC, ReplaceAndDeleteCB);
14699548b74aSJohannes Doerfert 
14709548b74aSJohannes Doerfert     return Changed;
14719548b74aSJohannes Doerfert   }
14729548b74aSJohannes Doerfert 
14739548b74aSJohannes Doerfert   /// Collect arguments that represent the global thread id in \p GTIdArgs.
14749548b74aSJohannes Doerfert   void collectGlobalThreadIdArguments(SmallSetVector<Value *, 16> &GTIdArgs) {
14759548b74aSJohannes Doerfert     // TODO: Below we basically perform a fixpoint iteration with a pessimistic
14769548b74aSJohannes Doerfert     //       initialization. We could define an AbstractAttribute instead and
14779548b74aSJohannes Doerfert     //       run the Attributor here once it can be run as an SCC pass.
14789548b74aSJohannes Doerfert 
14799548b74aSJohannes Doerfert     // Helper to check the argument \p ArgNo at all call sites of \p F for
14809548b74aSJohannes Doerfert     // a GTId.
14819548b74aSJohannes Doerfert     auto CallArgOpIsGTId = [&](Function &F, unsigned ArgNo, CallInst &RefCI) {
14829548b74aSJohannes Doerfert       if (!F.hasLocalLinkage())
14839548b74aSJohannes Doerfert         return false;
14849548b74aSJohannes Doerfert       for (Use &U : F.uses()) {
14859548b74aSJohannes Doerfert         if (CallInst *CI = getCallIfRegularCall(U)) {
14869548b74aSJohannes Doerfert           Value *ArgOp = CI->getArgOperand(ArgNo);
14879548b74aSJohannes Doerfert           if (CI == &RefCI || GTIdArgs.count(ArgOp) ||
14887cfd267cSsstefan1               getCallIfRegularCall(
14897cfd267cSsstefan1                   *ArgOp, &OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num]))
14909548b74aSJohannes Doerfert             continue;
14919548b74aSJohannes Doerfert         }
14929548b74aSJohannes Doerfert         return false;
14939548b74aSJohannes Doerfert       }
14949548b74aSJohannes Doerfert       return true;
14959548b74aSJohannes Doerfert     };
14969548b74aSJohannes Doerfert 
14979548b74aSJohannes Doerfert     // Helper to identify uses of a GTId as GTId arguments.
14989548b74aSJohannes Doerfert     auto AddUserArgs = [&](Value &GTId) {
14999548b74aSJohannes Doerfert       for (Use &U : GTId.uses())
15009548b74aSJohannes Doerfert         if (CallInst *CI = dyn_cast<CallInst>(U.getUser()))
15019548b74aSJohannes Doerfert           if (CI->isArgOperand(&U))
15029548b74aSJohannes Doerfert             if (Function *Callee = CI->getCalledFunction())
15039548b74aSJohannes Doerfert               if (CallArgOpIsGTId(*Callee, U.getOperandNo(), *CI))
15049548b74aSJohannes Doerfert                 GTIdArgs.insert(Callee->getArg(U.getOperandNo()));
15059548b74aSJohannes Doerfert     };
15069548b74aSJohannes Doerfert 
15079548b74aSJohannes Doerfert     // The argument users of __kmpc_global_thread_num calls are GTIds.
15087cfd267cSsstefan1     OMPInformationCache::RuntimeFunctionInfo &GlobThreadNumRFI =
15097cfd267cSsstefan1         OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num];
15107cfd267cSsstefan1 
1511624d34afSJohannes Doerfert     GlobThreadNumRFI.foreachUse(SCC, [&](Use &U, Function &F) {
15128855fec3SJohannes Doerfert       if (CallInst *CI = getCallIfRegularCall(U, &GlobThreadNumRFI))
15139548b74aSJohannes Doerfert         AddUserArgs(*CI);
15148855fec3SJohannes Doerfert       return false;
15158855fec3SJohannes Doerfert     });
15169548b74aSJohannes Doerfert 
15179548b74aSJohannes Doerfert     // Transitively search for more arguments by looking at the users of the
15189548b74aSJohannes Doerfert     // ones we know already. During the search the GTIdArgs vector is extended
15199548b74aSJohannes Doerfert     // so we cannot cache the size nor can we use a range based for.
15209548b74aSJohannes Doerfert     for (unsigned u = 0; u < GTIdArgs.size(); ++u)
15219548b74aSJohannes Doerfert       AddUserArgs(*GTIdArgs[u]);
15229548b74aSJohannes Doerfert   }
15239548b74aSJohannes Doerfert 
15245b0581aeSJohannes Doerfert   /// Kernel (=GPU) optimizations and utility functions
15255b0581aeSJohannes Doerfert   ///
15265b0581aeSJohannes Doerfert   ///{{
15275b0581aeSJohannes Doerfert 
15285b0581aeSJohannes Doerfert   /// Check if \p F is a kernel, hence entry point for target offloading.
15295b0581aeSJohannes Doerfert   bool isKernel(Function &F) { return OMPInfoCache.Kernels.count(&F); }
15305b0581aeSJohannes Doerfert 
15315b0581aeSJohannes Doerfert   /// Cache to remember the unique kernel for a function.
15325b0581aeSJohannes Doerfert   DenseMap<Function *, Optional<Kernel>> UniqueKernelMap;
15335b0581aeSJohannes Doerfert 
15345b0581aeSJohannes Doerfert   /// Find the unique kernel that will execute \p F, if any.
15355b0581aeSJohannes Doerfert   Kernel getUniqueKernelFor(Function &F);
15365b0581aeSJohannes Doerfert 
15375b0581aeSJohannes Doerfert   /// Find the unique kernel that will execute \p I, if any.
15385b0581aeSJohannes Doerfert   Kernel getUniqueKernelFor(Instruction &I) {
15395b0581aeSJohannes Doerfert     return getUniqueKernelFor(*I.getFunction());
15405b0581aeSJohannes Doerfert   }
15415b0581aeSJohannes Doerfert 
15425b0581aeSJohannes Doerfert   /// Rewrite the device (=GPU) code state machine create in non-SPMD mode in
15435b0581aeSJohannes Doerfert   /// the cases we can avoid taking the address of a function.
15445b0581aeSJohannes Doerfert   bool rewriteDeviceCodeStateMachine();
15455b0581aeSJohannes Doerfert 
15465b0581aeSJohannes Doerfert   ///
15475b0581aeSJohannes Doerfert   ///}}
15485b0581aeSJohannes Doerfert 
15494d4ea9acSHuber, Joseph   /// Emit a remark generically
15504d4ea9acSHuber, Joseph   ///
15514d4ea9acSHuber, Joseph   /// This template function can be used to generically emit a remark. The
15524d4ea9acSHuber, Joseph   /// RemarkKind should be one of the following:
15534d4ea9acSHuber, Joseph   ///   - OptimizationRemark to indicate a successful optimization attempt
15544d4ea9acSHuber, Joseph   ///   - OptimizationRemarkMissed to report a failed optimization attempt
15554d4ea9acSHuber, Joseph   ///   - OptimizationRemarkAnalysis to provide additional information about an
15564d4ea9acSHuber, Joseph   ///     optimization attempt
15574d4ea9acSHuber, Joseph   ///
15584d4ea9acSHuber, Joseph   /// The remark is built using a callback function provided by the caller that
15594d4ea9acSHuber, Joseph   /// takes a RemarkKind as input and returns a RemarkKind.
1560*2db182ffSJoseph Huber   template <typename RemarkKind, typename RemarkCallBack>
1561*2db182ffSJoseph Huber   void emitRemark(Instruction *I, StringRef RemarkName,
1562e8039ad4SJohannes Doerfert                   RemarkCallBack &&RemarkCB) const {
1563*2db182ffSJoseph Huber     Function *F = I->getParent()->getParent();
15644d4ea9acSHuber, Joseph     auto &ORE = OREGetter(F);
15654d4ea9acSHuber, Joseph 
1566*2db182ffSJoseph Huber     ORE.emit([&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I)); });
15674d4ea9acSHuber, Joseph   }
15684d4ea9acSHuber, Joseph 
1569*2db182ffSJoseph Huber   /// Emit a remark on a function.
1570*2db182ffSJoseph Huber   template <typename RemarkKind, typename RemarkCallBack>
1571*2db182ffSJoseph Huber   void emitRemark(Function *F, StringRef RemarkName,
1572*2db182ffSJoseph Huber                   RemarkCallBack &&RemarkCB) const {
15730f426935Ssstefan1     auto &ORE = OREGetter(F);
15740f426935Ssstefan1 
1575*2db182ffSJoseph Huber     ORE.emit([&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F)); });
15760f426935Ssstefan1   }
15770f426935Ssstefan1 
1578b726c557SJohannes Doerfert   /// The underlying module.
15799548b74aSJohannes Doerfert   Module &M;
15809548b74aSJohannes Doerfert 
15819548b74aSJohannes Doerfert   /// The SCC we are operating on.
1582ee17263aSJohannes Doerfert   SmallVectorImpl<Function *> &SCC;
15839548b74aSJohannes Doerfert 
15849548b74aSJohannes Doerfert   /// Callback to update the call graph, the first argument is a removed call,
15859548b74aSJohannes Doerfert   /// the second an optional replacement call.
15869548b74aSJohannes Doerfert   CallGraphUpdater &CGUpdater;
15879548b74aSJohannes Doerfert 
15884d4ea9acSHuber, Joseph   /// Callback to get an OptimizationRemarkEmitter from a Function *
15894d4ea9acSHuber, Joseph   OptimizationRemarkGetter OREGetter;
15904d4ea9acSHuber, Joseph 
15917cfd267cSsstefan1   /// OpenMP-specific information cache. Also Used for Attributor runs.
15927cfd267cSsstefan1   OMPInformationCache &OMPInfoCache;
1593b8235d2bSsstefan1 
1594b8235d2bSsstefan1   /// Attributor instance.
1595b8235d2bSsstefan1   Attributor &A;
1596b8235d2bSsstefan1 
1597b8235d2bSsstefan1   /// Helper function to run Attributor on SCC.
1598b8235d2bSsstefan1   bool runAttributor() {
1599b8235d2bSsstefan1     if (SCC.empty())
1600b8235d2bSsstefan1       return false;
1601b8235d2bSsstefan1 
1602b8235d2bSsstefan1     registerAAs();
1603b8235d2bSsstefan1 
1604b8235d2bSsstefan1     ChangeStatus Changed = A.run();
1605b8235d2bSsstefan1 
1606b8235d2bSsstefan1     LLVM_DEBUG(dbgs() << "[Attributor] Done with " << SCC.size()
1607b8235d2bSsstefan1                       << " functions, result: " << Changed << ".\n");
1608b8235d2bSsstefan1 
1609b8235d2bSsstefan1     return Changed == ChangeStatus::CHANGED;
1610b8235d2bSsstefan1   }
1611b8235d2bSsstefan1 
1612b8235d2bSsstefan1   /// Populate the Attributor with abstract attribute opportunities in the
1613b8235d2bSsstefan1   /// function.
1614b8235d2bSsstefan1   void registerAAs() {
16155dfd7cc4Ssstefan1     if (SCC.empty())
16165dfd7cc4Ssstefan1       return;
1617b8235d2bSsstefan1 
16185dfd7cc4Ssstefan1     // Create CallSite AA for all Getters.
16195dfd7cc4Ssstefan1     for (int Idx = 0; Idx < OMPInfoCache.ICVs.size() - 1; ++Idx) {
16205dfd7cc4Ssstefan1       auto ICVInfo = OMPInfoCache.ICVs[static_cast<InternalControlVar>(Idx)];
16215dfd7cc4Ssstefan1 
16225dfd7cc4Ssstefan1       auto &GetterRFI = OMPInfoCache.RFIs[ICVInfo.Getter];
16235dfd7cc4Ssstefan1 
16245dfd7cc4Ssstefan1       auto CreateAA = [&](Use &U, Function &Caller) {
16255dfd7cc4Ssstefan1         CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &GetterRFI);
16265dfd7cc4Ssstefan1         if (!CI)
16275dfd7cc4Ssstefan1           return false;
16285dfd7cc4Ssstefan1 
16295dfd7cc4Ssstefan1         auto &CB = cast<CallBase>(*CI);
16305dfd7cc4Ssstefan1 
16315dfd7cc4Ssstefan1         IRPosition CBPos = IRPosition::callsite_function(CB);
16325dfd7cc4Ssstefan1         A.getOrCreateAAFor<AAICVTracker>(CBPos);
16335dfd7cc4Ssstefan1         return false;
16345dfd7cc4Ssstefan1       };
16355dfd7cc4Ssstefan1 
16365dfd7cc4Ssstefan1       GetterRFI.foreachUse(SCC, CreateAA);
1637b8235d2bSsstefan1     }
163818283125SJoseph Huber 
163918283125SJoseph Huber     for (auto &F : M) {
164018283125SJoseph Huber       if (!F.isDeclaration())
164118283125SJoseph Huber         A.getOrCreateAAFor<AAExecutionDomain>(IRPosition::function(F));
164218283125SJoseph Huber     }
1643b8235d2bSsstefan1   }
1644b8235d2bSsstefan1 };
1645b8235d2bSsstefan1 
16465b0581aeSJohannes Doerfert Kernel OpenMPOpt::getUniqueKernelFor(Function &F) {
16475b0581aeSJohannes Doerfert   if (!OMPInfoCache.ModuleSlice.count(&F))
16485b0581aeSJohannes Doerfert     return nullptr;
16495b0581aeSJohannes Doerfert 
16505b0581aeSJohannes Doerfert   // Use a scope to keep the lifetime of the CachedKernel short.
16515b0581aeSJohannes Doerfert   {
16525b0581aeSJohannes Doerfert     Optional<Kernel> &CachedKernel = UniqueKernelMap[&F];
16535b0581aeSJohannes Doerfert     if (CachedKernel)
16545b0581aeSJohannes Doerfert       return *CachedKernel;
16555b0581aeSJohannes Doerfert 
16565b0581aeSJohannes Doerfert     // TODO: We should use an AA to create an (optimistic and callback
16575b0581aeSJohannes Doerfert     //       call-aware) call graph. For now we stick to simple patterns that
16585b0581aeSJohannes Doerfert     //       are less powerful, basically the worst fixpoint.
16595b0581aeSJohannes Doerfert     if (isKernel(F)) {
16605b0581aeSJohannes Doerfert       CachedKernel = Kernel(&F);
16615b0581aeSJohannes Doerfert       return *CachedKernel;
16625b0581aeSJohannes Doerfert     }
16635b0581aeSJohannes Doerfert 
16645b0581aeSJohannes Doerfert     CachedKernel = nullptr;
1665994bb6ebSJohannes Doerfert     if (!F.hasLocalLinkage()) {
1666994bb6ebSJohannes Doerfert 
1667994bb6ebSJohannes Doerfert       // See https://openmp.llvm.org/remarks/OptimizationRemarks.html
1668*2db182ffSJoseph Huber       auto Remark = [&](OptimizationRemarkAnalysis ORA) {
1669*2db182ffSJoseph Huber         return ORA
1670*2db182ffSJoseph Huber                << "[OMP100] Potentially unknown OpenMP target region caller";
1671994bb6ebSJohannes Doerfert       };
1672*2db182ffSJoseph Huber       emitRemark<OptimizationRemarkAnalysis>(&F, "OMP100", Remark);
1673994bb6ebSJohannes Doerfert 
16745b0581aeSJohannes Doerfert       return nullptr;
16755b0581aeSJohannes Doerfert     }
1676994bb6ebSJohannes Doerfert   }
16775b0581aeSJohannes Doerfert 
16785b0581aeSJohannes Doerfert   auto GetUniqueKernelForUse = [&](const Use &U) -> Kernel {
16795b0581aeSJohannes Doerfert     if (auto *Cmp = dyn_cast<ICmpInst>(U.getUser())) {
16805b0581aeSJohannes Doerfert       // Allow use in equality comparisons.
16815b0581aeSJohannes Doerfert       if (Cmp->isEquality())
16825b0581aeSJohannes Doerfert         return getUniqueKernelFor(*Cmp);
16835b0581aeSJohannes Doerfert       return nullptr;
16845b0581aeSJohannes Doerfert     }
16855b0581aeSJohannes Doerfert     if (auto *CB = dyn_cast<CallBase>(U.getUser())) {
16865b0581aeSJohannes Doerfert       // Allow direct calls.
16875b0581aeSJohannes Doerfert       if (CB->isCallee(&U))
16885b0581aeSJohannes Doerfert         return getUniqueKernelFor(*CB);
1689a2dbfb6bSGiorgis Georgakoudis 
1690a2dbfb6bSGiorgis Georgakoudis       OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
1691a2dbfb6bSGiorgis Georgakoudis           OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
1692a2dbfb6bSGiorgis Georgakoudis       // Allow the use in __kmpc_parallel_51 calls.
1693a2dbfb6bSGiorgis Georgakoudis       if (OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI))
16945b0581aeSJohannes Doerfert         return getUniqueKernelFor(*CB);
16955b0581aeSJohannes Doerfert       return nullptr;
16965b0581aeSJohannes Doerfert     }
16975b0581aeSJohannes Doerfert     // Disallow every other use.
16985b0581aeSJohannes Doerfert     return nullptr;
16995b0581aeSJohannes Doerfert   };
17005b0581aeSJohannes Doerfert 
17015b0581aeSJohannes Doerfert   // TODO: In the future we want to track more than just a unique kernel.
17025b0581aeSJohannes Doerfert   SmallPtrSet<Kernel, 2> PotentialKernels;
17038d8ce85bSsstefan1   OMPInformationCache::foreachUse(F, [&](const Use &U) {
17045b0581aeSJohannes Doerfert     PotentialKernels.insert(GetUniqueKernelForUse(U));
17055b0581aeSJohannes Doerfert   });
17065b0581aeSJohannes Doerfert 
17075b0581aeSJohannes Doerfert   Kernel K = nullptr;
17085b0581aeSJohannes Doerfert   if (PotentialKernels.size() == 1)
17095b0581aeSJohannes Doerfert     K = *PotentialKernels.begin();
17105b0581aeSJohannes Doerfert 
17115b0581aeSJohannes Doerfert   // Cache the result.
17125b0581aeSJohannes Doerfert   UniqueKernelMap[&F] = K;
17135b0581aeSJohannes Doerfert 
17145b0581aeSJohannes Doerfert   return K;
17155b0581aeSJohannes Doerfert }
17165b0581aeSJohannes Doerfert 
17175b0581aeSJohannes Doerfert bool OpenMPOpt::rewriteDeviceCodeStateMachine() {
1718a2dbfb6bSGiorgis Georgakoudis   OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
1719a2dbfb6bSGiorgis Georgakoudis       OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
17205b0581aeSJohannes Doerfert 
17215b0581aeSJohannes Doerfert   bool Changed = false;
1722a2dbfb6bSGiorgis Georgakoudis   if (!KernelParallelRFI)
17235b0581aeSJohannes Doerfert     return Changed;
17245b0581aeSJohannes Doerfert 
17255b0581aeSJohannes Doerfert   for (Function *F : SCC) {
17265b0581aeSJohannes Doerfert 
1727a2dbfb6bSGiorgis Georgakoudis     // Check if the function is a use in a __kmpc_parallel_51 call at
17285b0581aeSJohannes Doerfert     // all.
17295b0581aeSJohannes Doerfert     bool UnknownUse = false;
1730a2dbfb6bSGiorgis Georgakoudis     bool KernelParallelUse = false;
17315b0581aeSJohannes Doerfert     unsigned NumDirectCalls = 0;
17325b0581aeSJohannes Doerfert 
17335b0581aeSJohannes Doerfert     SmallVector<Use *, 2> ToBeReplacedStateMachineUses;
17348d8ce85bSsstefan1     OMPInformationCache::foreachUse(*F, [&](Use &U) {
17355b0581aeSJohannes Doerfert       if (auto *CB = dyn_cast<CallBase>(U.getUser()))
17365b0581aeSJohannes Doerfert         if (CB->isCallee(&U)) {
17375b0581aeSJohannes Doerfert           ++NumDirectCalls;
17385b0581aeSJohannes Doerfert           return;
17395b0581aeSJohannes Doerfert         }
17405b0581aeSJohannes Doerfert 
174181db6144SMichael Liao       if (isa<ICmpInst>(U.getUser())) {
17425b0581aeSJohannes Doerfert         ToBeReplacedStateMachineUses.push_back(&U);
17435b0581aeSJohannes Doerfert         return;
17445b0581aeSJohannes Doerfert       }
1745a2dbfb6bSGiorgis Georgakoudis 
1746a2dbfb6bSGiorgis Georgakoudis       // Find wrapper functions that represent parallel kernels.
1747a2dbfb6bSGiorgis Georgakoudis       CallInst *CI =
1748a2dbfb6bSGiorgis Georgakoudis           OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI);
1749a2dbfb6bSGiorgis Georgakoudis       const unsigned int WrapperFunctionArgNo = 6;
1750a2dbfb6bSGiorgis Georgakoudis       if (!KernelParallelUse && CI &&
1751a2dbfb6bSGiorgis Georgakoudis           CI->getArgOperandNo(&U) == WrapperFunctionArgNo) {
1752a2dbfb6bSGiorgis Georgakoudis         KernelParallelUse = true;
17535b0581aeSJohannes Doerfert         ToBeReplacedStateMachineUses.push_back(&U);
17545b0581aeSJohannes Doerfert         return;
17555b0581aeSJohannes Doerfert       }
17565b0581aeSJohannes Doerfert       UnknownUse = true;
17575b0581aeSJohannes Doerfert     });
17585b0581aeSJohannes Doerfert 
1759a2dbfb6bSGiorgis Georgakoudis     // Do not emit a remark if we haven't seen a __kmpc_parallel_51
1760fec1f210SJohannes Doerfert     // use.
1761a2dbfb6bSGiorgis Georgakoudis     if (!KernelParallelUse)
17625b0581aeSJohannes Doerfert       continue;
17635b0581aeSJohannes Doerfert 
1764fec1f210SJohannes Doerfert     {
1765*2db182ffSJoseph Huber       auto Remark = [&](OptimizationRemarkAnalysis ORA) {
1766*2db182ffSJoseph Huber         return ORA << "Found a parallel region that is called in a target "
1767fec1f210SJohannes Doerfert                       "region but not part of a combined target construct nor "
1768a2dbfb6bSGiorgis Georgakoudis                       "nested inside a target construct without intermediate "
1769fec1f210SJohannes Doerfert                       "code. This can lead to excessive register usage for "
1770fec1f210SJohannes Doerfert                       "unrelated target regions in the same translation unit "
1771fec1f210SJohannes Doerfert                       "due to spurious call edges assumed by ptxas.";
1772fec1f210SJohannes Doerfert       };
1773*2db182ffSJoseph Huber       emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPParallelRegionInNonSPMD",
1774*2db182ffSJoseph Huber                                              Remark);
1775fec1f210SJohannes Doerfert     }
1776fec1f210SJohannes Doerfert 
1777fec1f210SJohannes Doerfert     // If this ever hits, we should investigate.
1778fec1f210SJohannes Doerfert     // TODO: Checking the number of uses is not a necessary restriction and
1779fec1f210SJohannes Doerfert     // should be lifted.
1780fec1f210SJohannes Doerfert     if (UnknownUse || NumDirectCalls != 1 ||
1781fec1f210SJohannes Doerfert         ToBeReplacedStateMachineUses.size() != 2) {
1782fec1f210SJohannes Doerfert       {
1783*2db182ffSJoseph Huber         auto Remark = [&](OptimizationRemarkAnalysis ORA) {
1784*2db182ffSJoseph Huber           return ORA << "Parallel region is used in "
1785fec1f210SJohannes Doerfert                      << (UnknownUse ? "unknown" : "unexpected")
1786fec1f210SJohannes Doerfert                      << " ways; will not attempt to rewrite the state machine.";
1787fec1f210SJohannes Doerfert         };
1788*2db182ffSJoseph Huber         emitRemark<OptimizationRemarkAnalysis>(
1789*2db182ffSJoseph Huber             F, "OpenMPParallelRegionInNonSPMD", Remark);
1790fec1f210SJohannes Doerfert       }
17915b0581aeSJohannes Doerfert       continue;
1792fec1f210SJohannes Doerfert     }
17935b0581aeSJohannes Doerfert 
1794a2dbfb6bSGiorgis Georgakoudis     // Even if we have __kmpc_parallel_51 calls, we (for now) give
17955b0581aeSJohannes Doerfert     // up if the function is not called from a unique kernel.
17965b0581aeSJohannes Doerfert     Kernel K = getUniqueKernelFor(*F);
1797fec1f210SJohannes Doerfert     if (!K) {
1798fec1f210SJohannes Doerfert       {
1799*2db182ffSJoseph Huber         auto Remark = [&](OptimizationRemarkAnalysis ORA) {
1800*2db182ffSJoseph Huber           return ORA << "Parallel region is not known to be called from a "
1801fec1f210SJohannes Doerfert                         "unique single target region, maybe the surrounding "
1802fec1f210SJohannes Doerfert                         "function has external linkage?; will not attempt to "
1803fec1f210SJohannes Doerfert                         "rewrite the state machine use.";
1804fec1f210SJohannes Doerfert         };
1805*2db182ffSJoseph Huber         emitRemark<OptimizationRemarkAnalysis>(
1806*2db182ffSJoseph Huber             F, "OpenMPParallelRegionInMultipleKernesl", Remark);
1807fec1f210SJohannes Doerfert       }
18085b0581aeSJohannes Doerfert       continue;
1809fec1f210SJohannes Doerfert     }
18105b0581aeSJohannes Doerfert 
18115b0581aeSJohannes Doerfert     // We now know F is a parallel body function called only from the kernel K.
18125b0581aeSJohannes Doerfert     // We also identified the state machine uses in which we replace the
18135b0581aeSJohannes Doerfert     // function pointer by a new global symbol for identification purposes. This
18145b0581aeSJohannes Doerfert     // ensures only direct calls to the function are left.
18155b0581aeSJohannes Doerfert 
1816fec1f210SJohannes Doerfert     {
1817*2db182ffSJoseph Huber       auto RemarkParalleRegion = [&](OptimizationRemarkAnalysis ORA) {
1818*2db182ffSJoseph Huber         return ORA << "Specialize parallel region that is only reached from a "
1819fec1f210SJohannes Doerfert                       "single target region to avoid spurious call edges and "
1820fec1f210SJohannes Doerfert                       "excessive register usage in other target regions. "
1821fec1f210SJohannes Doerfert                       "(parallel region ID: "
1822fec1f210SJohannes Doerfert                    << ore::NV("OpenMPParallelRegion", F->getName())
1823fec1f210SJohannes Doerfert                    << ", kernel ID: "
1824fec1f210SJohannes Doerfert                    << ore::NV("OpenMPTargetRegion", K->getName()) << ")";
1825fec1f210SJohannes Doerfert       };
1826*2db182ffSJoseph Huber       emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPParallelRegionInNonSPMD",
1827fec1f210SJohannes Doerfert                                              RemarkParalleRegion);
1828*2db182ffSJoseph Huber       auto RemarkKernel = [&](OptimizationRemarkAnalysis ORA) {
1829*2db182ffSJoseph Huber         return ORA << "Target region containing the parallel region that is "
1830fec1f210SJohannes Doerfert                       "specialized. (parallel region ID: "
1831fec1f210SJohannes Doerfert                    << ore::NV("OpenMPParallelRegion", F->getName())
1832fec1f210SJohannes Doerfert                    << ", kernel ID: "
1833fec1f210SJohannes Doerfert                    << ore::NV("OpenMPTargetRegion", K->getName()) << ")";
1834fec1f210SJohannes Doerfert       };
1835*2db182ffSJoseph Huber       emitRemark<OptimizationRemarkAnalysis>(K, "OpenMPParallelRegionInNonSPMD",
1836*2db182ffSJoseph Huber                                              RemarkKernel);
1837fec1f210SJohannes Doerfert     }
1838fec1f210SJohannes Doerfert 
18395b0581aeSJohannes Doerfert     Module &M = *F->getParent();
18405b0581aeSJohannes Doerfert     Type *Int8Ty = Type::getInt8Ty(M.getContext());
18415b0581aeSJohannes Doerfert 
18425b0581aeSJohannes Doerfert     auto *ID = new GlobalVariable(
18435b0581aeSJohannes Doerfert         M, Int8Ty, /* isConstant */ true, GlobalValue::PrivateLinkage,
18445b0581aeSJohannes Doerfert         UndefValue::get(Int8Ty), F->getName() + ".ID");
18455b0581aeSJohannes Doerfert 
18465b0581aeSJohannes Doerfert     for (Use *U : ToBeReplacedStateMachineUses)
18475b0581aeSJohannes Doerfert       U->set(ConstantExpr::getBitCast(ID, U->get()->getType()));
18485b0581aeSJohannes Doerfert 
18495b0581aeSJohannes Doerfert     ++NumOpenMPParallelRegionsReplacedInGPUStateMachine;
18505b0581aeSJohannes Doerfert 
18515b0581aeSJohannes Doerfert     Changed = true;
18525b0581aeSJohannes Doerfert   }
18535b0581aeSJohannes Doerfert 
18545b0581aeSJohannes Doerfert   return Changed;
18555b0581aeSJohannes Doerfert }
18565b0581aeSJohannes Doerfert 
1857b8235d2bSsstefan1 /// Abstract Attribute for tracking ICV values.
1858b8235d2bSsstefan1 struct AAICVTracker : public StateWrapper<BooleanState, AbstractAttribute> {
1859b8235d2bSsstefan1   using Base = StateWrapper<BooleanState, AbstractAttribute>;
1860b8235d2bSsstefan1   AAICVTracker(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
1861b8235d2bSsstefan1 
18625dfd7cc4Ssstefan1   void initialize(Attributor &A) override {
18635dfd7cc4Ssstefan1     Function *F = getAnchorScope();
18645dfd7cc4Ssstefan1     if (!F || !A.isFunctionIPOAmendable(*F))
18655dfd7cc4Ssstefan1       indicatePessimisticFixpoint();
18665dfd7cc4Ssstefan1   }
18675dfd7cc4Ssstefan1 
1868b8235d2bSsstefan1   /// Returns true if value is assumed to be tracked.
1869b8235d2bSsstefan1   bool isAssumedTracked() const { return getAssumed(); }
1870b8235d2bSsstefan1 
1871b8235d2bSsstefan1   /// Returns true if value is known to be tracked.
1872b8235d2bSsstefan1   bool isKnownTracked() const { return getAssumed(); }
1873b8235d2bSsstefan1 
1874b8235d2bSsstefan1   /// Create an abstract attribute biew for the position \p IRP.
1875b8235d2bSsstefan1   static AAICVTracker &createForPosition(const IRPosition &IRP, Attributor &A);
1876b8235d2bSsstefan1 
1877b8235d2bSsstefan1   /// Return the value with which \p I can be replaced for specific \p ICV.
18785dfd7cc4Ssstefan1   virtual Optional<Value *> getReplacementValue(InternalControlVar ICV,
18795dfd7cc4Ssstefan1                                                 const Instruction *I,
18805dfd7cc4Ssstefan1                                                 Attributor &A) const {
18815dfd7cc4Ssstefan1     return None;
18825dfd7cc4Ssstefan1   }
18835dfd7cc4Ssstefan1 
18845dfd7cc4Ssstefan1   /// Return an assumed unique ICV value if a single candidate is found. If
18855dfd7cc4Ssstefan1   /// there cannot be one, return a nullptr. If it is not clear yet, return the
18865dfd7cc4Ssstefan1   /// Optional::NoneType.
18875dfd7cc4Ssstefan1   virtual Optional<Value *>
18885dfd7cc4Ssstefan1   getUniqueReplacementValue(InternalControlVar ICV) const = 0;
18895dfd7cc4Ssstefan1 
18905dfd7cc4Ssstefan1   // Currently only nthreads is being tracked.
18915dfd7cc4Ssstefan1   // this array will only grow with time.
18925dfd7cc4Ssstefan1   InternalControlVar TrackableICVs[1] = {ICV_nthreads};
1893b8235d2bSsstefan1 
1894b8235d2bSsstefan1   /// See AbstractAttribute::getName()
1895b8235d2bSsstefan1   const std::string getName() const override { return "AAICVTracker"; }
1896b8235d2bSsstefan1 
1897233af895SLuofan Chen   /// See AbstractAttribute::getIdAddr()
1898233af895SLuofan Chen   const char *getIdAddr() const override { return &ID; }
1899233af895SLuofan Chen 
1900233af895SLuofan Chen   /// This function should return true if the type of the \p AA is AAICVTracker
1901233af895SLuofan Chen   static bool classof(const AbstractAttribute *AA) {
1902233af895SLuofan Chen     return (AA->getIdAddr() == &ID);
1903233af895SLuofan Chen   }
1904233af895SLuofan Chen 
1905b8235d2bSsstefan1   static const char ID;
1906b8235d2bSsstefan1 };
1907b8235d2bSsstefan1 
1908b8235d2bSsstefan1 struct AAICVTrackerFunction : public AAICVTracker {
1909b8235d2bSsstefan1   AAICVTrackerFunction(const IRPosition &IRP, Attributor &A)
1910b8235d2bSsstefan1       : AAICVTracker(IRP, A) {}
1911b8235d2bSsstefan1 
1912b8235d2bSsstefan1   // FIXME: come up with better string.
19135dfd7cc4Ssstefan1   const std::string getAsStr() const override { return "ICVTrackerFunction"; }
1914b8235d2bSsstefan1 
1915b8235d2bSsstefan1   // FIXME: come up with some stats.
1916b8235d2bSsstefan1   void trackStatistics() const override {}
1917b8235d2bSsstefan1 
19185dfd7cc4Ssstefan1   /// We don't manifest anything for this AA.
1919b8235d2bSsstefan1   ChangeStatus manifest(Attributor &A) override {
19205dfd7cc4Ssstefan1     return ChangeStatus::UNCHANGED;
1921b8235d2bSsstefan1   }
1922b8235d2bSsstefan1 
1923b8235d2bSsstefan1   // Map of ICV to their values at specific program point.
19245dfd7cc4Ssstefan1   EnumeratedArray<DenseMap<Instruction *, Value *>, InternalControlVar,
1925b8235d2bSsstefan1                   InternalControlVar::ICV___last>
19265dfd7cc4Ssstefan1       ICVReplacementValuesMap;
1927b8235d2bSsstefan1 
1928b8235d2bSsstefan1   ChangeStatus updateImpl(Attributor &A) override {
1929b8235d2bSsstefan1     ChangeStatus HasChanged = ChangeStatus::UNCHANGED;
1930b8235d2bSsstefan1 
1931b8235d2bSsstefan1     Function *F = getAnchorScope();
1932b8235d2bSsstefan1 
1933b8235d2bSsstefan1     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
1934b8235d2bSsstefan1 
1935b8235d2bSsstefan1     for (InternalControlVar ICV : TrackableICVs) {
1936b8235d2bSsstefan1       auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
1937b8235d2bSsstefan1 
19385dfd7cc4Ssstefan1       auto &ValuesMap = ICVReplacementValuesMap[ICV];
1939b8235d2bSsstefan1       auto TrackValues = [&](Use &U, Function &) {
1940b8235d2bSsstefan1         CallInst *CI = OpenMPOpt::getCallIfRegularCall(U);
1941b8235d2bSsstefan1         if (!CI)
1942b8235d2bSsstefan1           return false;
1943b8235d2bSsstefan1 
1944b8235d2bSsstefan1         // FIXME: handle setters with more that 1 arguments.
1945b8235d2bSsstefan1         /// Track new value.
19465dfd7cc4Ssstefan1         if (ValuesMap.insert(std::make_pair(CI, CI->getArgOperand(0))).second)
1947b8235d2bSsstefan1           HasChanged = ChangeStatus::CHANGED;
1948b8235d2bSsstefan1 
1949b8235d2bSsstefan1         return false;
1950b8235d2bSsstefan1       };
1951b8235d2bSsstefan1 
19525dfd7cc4Ssstefan1       auto CallCheck = [&](Instruction &I) {
19535dfd7cc4Ssstefan1         Optional<Value *> ReplVal = getValueForCall(A, &I, ICV);
19545dfd7cc4Ssstefan1         if (ReplVal.hasValue() &&
19555dfd7cc4Ssstefan1             ValuesMap.insert(std::make_pair(&I, *ReplVal)).second)
19565dfd7cc4Ssstefan1           HasChanged = ChangeStatus::CHANGED;
19575dfd7cc4Ssstefan1 
19585dfd7cc4Ssstefan1         return true;
19595dfd7cc4Ssstefan1       };
19605dfd7cc4Ssstefan1 
19615dfd7cc4Ssstefan1       // Track all changes of an ICV.
1962b8235d2bSsstefan1       SetterRFI.foreachUse(TrackValues, F);
19635dfd7cc4Ssstefan1 
19645dfd7cc4Ssstefan1       A.checkForAllInstructions(CallCheck, *this, {Instruction::Call},
19655dfd7cc4Ssstefan1                                 /* CheckBBLivenessOnly */ true);
19665dfd7cc4Ssstefan1 
19675dfd7cc4Ssstefan1       /// TODO: Figure out a way to avoid adding entry in
19685dfd7cc4Ssstefan1       /// ICVReplacementValuesMap
19695dfd7cc4Ssstefan1       Instruction *Entry = &F->getEntryBlock().front();
19705dfd7cc4Ssstefan1       if (HasChanged == ChangeStatus::CHANGED && !ValuesMap.count(Entry))
19715dfd7cc4Ssstefan1         ValuesMap.insert(std::make_pair(Entry, nullptr));
1972b8235d2bSsstefan1     }
1973b8235d2bSsstefan1 
1974b8235d2bSsstefan1     return HasChanged;
1975b8235d2bSsstefan1   }
1976b8235d2bSsstefan1 
19775dfd7cc4Ssstefan1   /// Hepler to check if \p I is a call and get the value for it if it is
19785dfd7cc4Ssstefan1   /// unique.
19795dfd7cc4Ssstefan1   Optional<Value *> getValueForCall(Attributor &A, const Instruction *I,
19805dfd7cc4Ssstefan1                                     InternalControlVar &ICV) const {
1981b8235d2bSsstefan1 
19825dfd7cc4Ssstefan1     const auto *CB = dyn_cast<CallBase>(I);
1983dcaec812SJohannes Doerfert     if (!CB || CB->hasFnAttr("no_openmp") ||
1984dcaec812SJohannes Doerfert         CB->hasFnAttr("no_openmp_routines"))
19855dfd7cc4Ssstefan1       return None;
19865dfd7cc4Ssstefan1 
1987b8235d2bSsstefan1     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
1988b8235d2bSsstefan1     auto &GetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Getter];
19895dfd7cc4Ssstefan1     auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
19905dfd7cc4Ssstefan1     Function *CalledFunction = CB->getCalledFunction();
1991b8235d2bSsstefan1 
19924eef14f9SWei Wang     // Indirect call, assume ICV changes.
19934eef14f9SWei Wang     if (CalledFunction == nullptr)
19944eef14f9SWei Wang       return nullptr;
19955dfd7cc4Ssstefan1     if (CalledFunction == GetterRFI.Declaration)
19965dfd7cc4Ssstefan1       return None;
19975dfd7cc4Ssstefan1     if (CalledFunction == SetterRFI.Declaration) {
19985dfd7cc4Ssstefan1       if (ICVReplacementValuesMap[ICV].count(I))
19995dfd7cc4Ssstefan1         return ICVReplacementValuesMap[ICV].lookup(I);
20005dfd7cc4Ssstefan1 
20015dfd7cc4Ssstefan1       return nullptr;
20025dfd7cc4Ssstefan1     }
20035dfd7cc4Ssstefan1 
20045dfd7cc4Ssstefan1     // Since we don't know, assume it changes the ICV.
20055dfd7cc4Ssstefan1     if (CalledFunction->isDeclaration())
20065dfd7cc4Ssstefan1       return nullptr;
20075dfd7cc4Ssstefan1 
20085b70c12fSJohannes Doerfert     const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
20095b70c12fSJohannes Doerfert         *this, IRPosition::callsite_returned(*CB), DepClassTy::REQUIRED);
20105dfd7cc4Ssstefan1 
20115dfd7cc4Ssstefan1     if (ICVTrackingAA.isAssumedTracked())
20125dfd7cc4Ssstefan1       return ICVTrackingAA.getUniqueReplacementValue(ICV);
20135dfd7cc4Ssstefan1 
20145dfd7cc4Ssstefan1     // If we don't know, assume it changes.
20155dfd7cc4Ssstefan1     return nullptr;
20165dfd7cc4Ssstefan1   }
20175dfd7cc4Ssstefan1 
20185dfd7cc4Ssstefan1   // We don't check unique value for a function, so return None.
20195dfd7cc4Ssstefan1   Optional<Value *>
20205dfd7cc4Ssstefan1   getUniqueReplacementValue(InternalControlVar ICV) const override {
20215dfd7cc4Ssstefan1     return None;
20225dfd7cc4Ssstefan1   }
20235dfd7cc4Ssstefan1 
20245dfd7cc4Ssstefan1   /// Return the value with which \p I can be replaced for specific \p ICV.
20255dfd7cc4Ssstefan1   Optional<Value *> getReplacementValue(InternalControlVar ICV,
20265dfd7cc4Ssstefan1                                         const Instruction *I,
20275dfd7cc4Ssstefan1                                         Attributor &A) const override {
20285dfd7cc4Ssstefan1     const auto &ValuesMap = ICVReplacementValuesMap[ICV];
20295dfd7cc4Ssstefan1     if (ValuesMap.count(I))
20305dfd7cc4Ssstefan1       return ValuesMap.lookup(I);
20315dfd7cc4Ssstefan1 
20325dfd7cc4Ssstefan1     SmallVector<const Instruction *, 16> Worklist;
20335dfd7cc4Ssstefan1     SmallPtrSet<const Instruction *, 16> Visited;
20345dfd7cc4Ssstefan1     Worklist.push_back(I);
20355dfd7cc4Ssstefan1 
20365dfd7cc4Ssstefan1     Optional<Value *> ReplVal;
20375dfd7cc4Ssstefan1 
20385dfd7cc4Ssstefan1     while (!Worklist.empty()) {
20395dfd7cc4Ssstefan1       const Instruction *CurrInst = Worklist.pop_back_val();
20405dfd7cc4Ssstefan1       if (!Visited.insert(CurrInst).second)
2041b8235d2bSsstefan1         continue;
2042b8235d2bSsstefan1 
20435dfd7cc4Ssstefan1       const BasicBlock *CurrBB = CurrInst->getParent();
20445dfd7cc4Ssstefan1 
20455dfd7cc4Ssstefan1       // Go up and look for all potential setters/calls that might change the
20465dfd7cc4Ssstefan1       // ICV.
20475dfd7cc4Ssstefan1       while ((CurrInst = CurrInst->getPrevNode())) {
20485dfd7cc4Ssstefan1         if (ValuesMap.count(CurrInst)) {
20495dfd7cc4Ssstefan1           Optional<Value *> NewReplVal = ValuesMap.lookup(CurrInst);
20505dfd7cc4Ssstefan1           // Unknown value, track new.
20515dfd7cc4Ssstefan1           if (!ReplVal.hasValue()) {
20525dfd7cc4Ssstefan1             ReplVal = NewReplVal;
20535dfd7cc4Ssstefan1             break;
20545dfd7cc4Ssstefan1           }
20555dfd7cc4Ssstefan1 
20565dfd7cc4Ssstefan1           // If we found a new value, we can't know the icv value anymore.
20575dfd7cc4Ssstefan1           if (NewReplVal.hasValue())
20585dfd7cc4Ssstefan1             if (ReplVal != NewReplVal)
2059b8235d2bSsstefan1               return nullptr;
2060b8235d2bSsstefan1 
20615dfd7cc4Ssstefan1           break;
2062b8235d2bSsstefan1         }
2063b8235d2bSsstefan1 
20645dfd7cc4Ssstefan1         Optional<Value *> NewReplVal = getValueForCall(A, CurrInst, ICV);
20655dfd7cc4Ssstefan1         if (!NewReplVal.hasValue())
20665dfd7cc4Ssstefan1           continue;
20675dfd7cc4Ssstefan1 
20685dfd7cc4Ssstefan1         // Unknown value, track new.
20695dfd7cc4Ssstefan1         if (!ReplVal.hasValue()) {
20705dfd7cc4Ssstefan1           ReplVal = NewReplVal;
20715dfd7cc4Ssstefan1           break;
2072b8235d2bSsstefan1         }
2073b8235d2bSsstefan1 
20745dfd7cc4Ssstefan1         // if (NewReplVal.hasValue())
20755dfd7cc4Ssstefan1         // We found a new value, we can't know the icv value anymore.
20765dfd7cc4Ssstefan1         if (ReplVal != NewReplVal)
2077b8235d2bSsstefan1           return nullptr;
2078b8235d2bSsstefan1       }
20795dfd7cc4Ssstefan1 
20805dfd7cc4Ssstefan1       // If we are in the same BB and we have a value, we are done.
20815dfd7cc4Ssstefan1       if (CurrBB == I->getParent() && ReplVal.hasValue())
20825dfd7cc4Ssstefan1         return ReplVal;
20835dfd7cc4Ssstefan1 
20845dfd7cc4Ssstefan1       // Go through all predecessors and add terminators for analysis.
20855dfd7cc4Ssstefan1       for (const BasicBlock *Pred : predecessors(CurrBB))
20865dfd7cc4Ssstefan1         if (const Instruction *Terminator = Pred->getTerminator())
20875dfd7cc4Ssstefan1           Worklist.push_back(Terminator);
20885dfd7cc4Ssstefan1     }
20895dfd7cc4Ssstefan1 
20905dfd7cc4Ssstefan1     return ReplVal;
20915dfd7cc4Ssstefan1   }
20925dfd7cc4Ssstefan1 };
20935dfd7cc4Ssstefan1 
20945dfd7cc4Ssstefan1 struct AAICVTrackerFunctionReturned : AAICVTracker {
20955dfd7cc4Ssstefan1   AAICVTrackerFunctionReturned(const IRPosition &IRP, Attributor &A)
20965dfd7cc4Ssstefan1       : AAICVTracker(IRP, A) {}
20975dfd7cc4Ssstefan1 
20985dfd7cc4Ssstefan1   // FIXME: come up with better string.
20995dfd7cc4Ssstefan1   const std::string getAsStr() const override {
21005dfd7cc4Ssstefan1     return "ICVTrackerFunctionReturned";
21015dfd7cc4Ssstefan1   }
21025dfd7cc4Ssstefan1 
21035dfd7cc4Ssstefan1   // FIXME: come up with some stats.
21045dfd7cc4Ssstefan1   void trackStatistics() const override {}
21055dfd7cc4Ssstefan1 
21065dfd7cc4Ssstefan1   /// We don't manifest anything for this AA.
21075dfd7cc4Ssstefan1   ChangeStatus manifest(Attributor &A) override {
21085dfd7cc4Ssstefan1     return ChangeStatus::UNCHANGED;
21095dfd7cc4Ssstefan1   }
21105dfd7cc4Ssstefan1 
21115dfd7cc4Ssstefan1   // Map of ICV to their values at specific program point.
21125dfd7cc4Ssstefan1   EnumeratedArray<Optional<Value *>, InternalControlVar,
21135dfd7cc4Ssstefan1                   InternalControlVar::ICV___last>
21145dfd7cc4Ssstefan1       ICVReplacementValuesMap;
21155dfd7cc4Ssstefan1 
21165dfd7cc4Ssstefan1   /// Return the value with which \p I can be replaced for specific \p ICV.
21175dfd7cc4Ssstefan1   Optional<Value *>
21185dfd7cc4Ssstefan1   getUniqueReplacementValue(InternalControlVar ICV) const override {
21195dfd7cc4Ssstefan1     return ICVReplacementValuesMap[ICV];
21205dfd7cc4Ssstefan1   }
21215dfd7cc4Ssstefan1 
21225dfd7cc4Ssstefan1   ChangeStatus updateImpl(Attributor &A) override {
21235dfd7cc4Ssstefan1     ChangeStatus Changed = ChangeStatus::UNCHANGED;
21245dfd7cc4Ssstefan1     const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
21255b70c12fSJohannes Doerfert         *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
21265dfd7cc4Ssstefan1 
21275dfd7cc4Ssstefan1     if (!ICVTrackingAA.isAssumedTracked())
21285dfd7cc4Ssstefan1       return indicatePessimisticFixpoint();
21295dfd7cc4Ssstefan1 
21305dfd7cc4Ssstefan1     for (InternalControlVar ICV : TrackableICVs) {
21315dfd7cc4Ssstefan1       Optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
21325dfd7cc4Ssstefan1       Optional<Value *> UniqueICVValue;
21335dfd7cc4Ssstefan1 
21345dfd7cc4Ssstefan1       auto CheckReturnInst = [&](Instruction &I) {
21355dfd7cc4Ssstefan1         Optional<Value *> NewReplVal =
21365dfd7cc4Ssstefan1             ICVTrackingAA.getReplacementValue(ICV, &I, A);
21375dfd7cc4Ssstefan1 
21385dfd7cc4Ssstefan1         // If we found a second ICV value there is no unique returned value.
21395dfd7cc4Ssstefan1         if (UniqueICVValue.hasValue() && UniqueICVValue != NewReplVal)
21405dfd7cc4Ssstefan1           return false;
21415dfd7cc4Ssstefan1 
21425dfd7cc4Ssstefan1         UniqueICVValue = NewReplVal;
21435dfd7cc4Ssstefan1 
21445dfd7cc4Ssstefan1         return true;
21455dfd7cc4Ssstefan1       };
21465dfd7cc4Ssstefan1 
21475dfd7cc4Ssstefan1       if (!A.checkForAllInstructions(CheckReturnInst, *this, {Instruction::Ret},
21485dfd7cc4Ssstefan1                                      /* CheckBBLivenessOnly */ true))
21495dfd7cc4Ssstefan1         UniqueICVValue = nullptr;
21505dfd7cc4Ssstefan1 
21515dfd7cc4Ssstefan1       if (UniqueICVValue == ReplVal)
21525dfd7cc4Ssstefan1         continue;
21535dfd7cc4Ssstefan1 
21545dfd7cc4Ssstefan1       ReplVal = UniqueICVValue;
21555dfd7cc4Ssstefan1       Changed = ChangeStatus::CHANGED;
21565dfd7cc4Ssstefan1     }
21575dfd7cc4Ssstefan1 
21585dfd7cc4Ssstefan1     return Changed;
21595dfd7cc4Ssstefan1   }
21605dfd7cc4Ssstefan1 };
21615dfd7cc4Ssstefan1 
21625dfd7cc4Ssstefan1 struct AAICVTrackerCallSite : AAICVTracker {
21635dfd7cc4Ssstefan1   AAICVTrackerCallSite(const IRPosition &IRP, Attributor &A)
21645dfd7cc4Ssstefan1       : AAICVTracker(IRP, A) {}
21655dfd7cc4Ssstefan1 
21665dfd7cc4Ssstefan1   void initialize(Attributor &A) override {
21675dfd7cc4Ssstefan1     Function *F = getAnchorScope();
21685dfd7cc4Ssstefan1     if (!F || !A.isFunctionIPOAmendable(*F))
21695dfd7cc4Ssstefan1       indicatePessimisticFixpoint();
21705dfd7cc4Ssstefan1 
21715dfd7cc4Ssstefan1     // We only initialize this AA for getters, so we need to know which ICV it
21725dfd7cc4Ssstefan1     // gets.
21735dfd7cc4Ssstefan1     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
21745dfd7cc4Ssstefan1     for (InternalControlVar ICV : TrackableICVs) {
21755dfd7cc4Ssstefan1       auto ICVInfo = OMPInfoCache.ICVs[ICV];
21765dfd7cc4Ssstefan1       auto &Getter = OMPInfoCache.RFIs[ICVInfo.Getter];
21775dfd7cc4Ssstefan1       if (Getter.Declaration == getAssociatedFunction()) {
21785dfd7cc4Ssstefan1         AssociatedICV = ICVInfo.Kind;
21795dfd7cc4Ssstefan1         return;
21805dfd7cc4Ssstefan1       }
21815dfd7cc4Ssstefan1     }
21825dfd7cc4Ssstefan1 
21835dfd7cc4Ssstefan1     /// Unknown ICV.
21845dfd7cc4Ssstefan1     indicatePessimisticFixpoint();
21855dfd7cc4Ssstefan1   }
21865dfd7cc4Ssstefan1 
21875dfd7cc4Ssstefan1   ChangeStatus manifest(Attributor &A) override {
21885dfd7cc4Ssstefan1     if (!ReplVal.hasValue() || !ReplVal.getValue())
21895dfd7cc4Ssstefan1       return ChangeStatus::UNCHANGED;
21905dfd7cc4Ssstefan1 
21915dfd7cc4Ssstefan1     A.changeValueAfterManifest(*getCtxI(), **ReplVal);
21925dfd7cc4Ssstefan1     A.deleteAfterManifest(*getCtxI());
21935dfd7cc4Ssstefan1 
21945dfd7cc4Ssstefan1     return ChangeStatus::CHANGED;
21955dfd7cc4Ssstefan1   }
21965dfd7cc4Ssstefan1 
21975dfd7cc4Ssstefan1   // FIXME: come up with better string.
21985dfd7cc4Ssstefan1   const std::string getAsStr() const override { return "ICVTrackerCallSite"; }
21995dfd7cc4Ssstefan1 
22005dfd7cc4Ssstefan1   // FIXME: come up with some stats.
22015dfd7cc4Ssstefan1   void trackStatistics() const override {}
22025dfd7cc4Ssstefan1 
22035dfd7cc4Ssstefan1   InternalControlVar AssociatedICV;
22045dfd7cc4Ssstefan1   Optional<Value *> ReplVal;
22055dfd7cc4Ssstefan1 
22065dfd7cc4Ssstefan1   ChangeStatus updateImpl(Attributor &A) override {
22075dfd7cc4Ssstefan1     const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
22085b70c12fSJohannes Doerfert         *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
22095dfd7cc4Ssstefan1 
22105dfd7cc4Ssstefan1     // We don't have any information, so we assume it changes the ICV.
22115dfd7cc4Ssstefan1     if (!ICVTrackingAA.isAssumedTracked())
22125dfd7cc4Ssstefan1       return indicatePessimisticFixpoint();
22135dfd7cc4Ssstefan1 
22145dfd7cc4Ssstefan1     Optional<Value *> NewReplVal =
22155dfd7cc4Ssstefan1         ICVTrackingAA.getReplacementValue(AssociatedICV, getCtxI(), A);
22165dfd7cc4Ssstefan1 
22175dfd7cc4Ssstefan1     if (ReplVal == NewReplVal)
22185dfd7cc4Ssstefan1       return ChangeStatus::UNCHANGED;
22195dfd7cc4Ssstefan1 
22205dfd7cc4Ssstefan1     ReplVal = NewReplVal;
22215dfd7cc4Ssstefan1     return ChangeStatus::CHANGED;
22225dfd7cc4Ssstefan1   }
22235dfd7cc4Ssstefan1 
22245dfd7cc4Ssstefan1   // Return the value with which associated value can be replaced for specific
22255dfd7cc4Ssstefan1   // \p ICV.
22265dfd7cc4Ssstefan1   Optional<Value *>
22275dfd7cc4Ssstefan1   getUniqueReplacementValue(InternalControlVar ICV) const override {
22285dfd7cc4Ssstefan1     return ReplVal;
22295dfd7cc4Ssstefan1   }
22305dfd7cc4Ssstefan1 };
22315dfd7cc4Ssstefan1 
22325dfd7cc4Ssstefan1 struct AAICVTrackerCallSiteReturned : AAICVTracker {
22335dfd7cc4Ssstefan1   AAICVTrackerCallSiteReturned(const IRPosition &IRP, Attributor &A)
22345dfd7cc4Ssstefan1       : AAICVTracker(IRP, A) {}
22355dfd7cc4Ssstefan1 
22365dfd7cc4Ssstefan1   // FIXME: come up with better string.
22375dfd7cc4Ssstefan1   const std::string getAsStr() const override {
22385dfd7cc4Ssstefan1     return "ICVTrackerCallSiteReturned";
22395dfd7cc4Ssstefan1   }
22405dfd7cc4Ssstefan1 
22415dfd7cc4Ssstefan1   // FIXME: come up with some stats.
22425dfd7cc4Ssstefan1   void trackStatistics() const override {}
22435dfd7cc4Ssstefan1 
22445dfd7cc4Ssstefan1   /// We don't manifest anything for this AA.
22455dfd7cc4Ssstefan1   ChangeStatus manifest(Attributor &A) override {
22465dfd7cc4Ssstefan1     return ChangeStatus::UNCHANGED;
22475dfd7cc4Ssstefan1   }
22485dfd7cc4Ssstefan1 
22495dfd7cc4Ssstefan1   // Map of ICV to their values at specific program point.
22505dfd7cc4Ssstefan1   EnumeratedArray<Optional<Value *>, InternalControlVar,
22515dfd7cc4Ssstefan1                   InternalControlVar::ICV___last>
22525dfd7cc4Ssstefan1       ICVReplacementValuesMap;
22535dfd7cc4Ssstefan1 
22545dfd7cc4Ssstefan1   /// Return the value with which associated value can be replaced for specific
22555dfd7cc4Ssstefan1   /// \p ICV.
22565dfd7cc4Ssstefan1   Optional<Value *>
22575dfd7cc4Ssstefan1   getUniqueReplacementValue(InternalControlVar ICV) const override {
22585dfd7cc4Ssstefan1     return ICVReplacementValuesMap[ICV];
22595dfd7cc4Ssstefan1   }
22605dfd7cc4Ssstefan1 
22615dfd7cc4Ssstefan1   ChangeStatus updateImpl(Attributor &A) override {
22625dfd7cc4Ssstefan1     ChangeStatus Changed = ChangeStatus::UNCHANGED;
22635dfd7cc4Ssstefan1     const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
22645b70c12fSJohannes Doerfert         *this, IRPosition::returned(*getAssociatedFunction()),
22655b70c12fSJohannes Doerfert         DepClassTy::REQUIRED);
22665dfd7cc4Ssstefan1 
22675dfd7cc4Ssstefan1     // We don't have any information, so we assume it changes the ICV.
22685dfd7cc4Ssstefan1     if (!ICVTrackingAA.isAssumedTracked())
22695dfd7cc4Ssstefan1       return indicatePessimisticFixpoint();
22705dfd7cc4Ssstefan1 
22715dfd7cc4Ssstefan1     for (InternalControlVar ICV : TrackableICVs) {
22725dfd7cc4Ssstefan1       Optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
22735dfd7cc4Ssstefan1       Optional<Value *> NewReplVal =
22745dfd7cc4Ssstefan1           ICVTrackingAA.getUniqueReplacementValue(ICV);
22755dfd7cc4Ssstefan1 
22765dfd7cc4Ssstefan1       if (ReplVal == NewReplVal)
22775dfd7cc4Ssstefan1         continue;
22785dfd7cc4Ssstefan1 
22795dfd7cc4Ssstefan1       ReplVal = NewReplVal;
22805dfd7cc4Ssstefan1       Changed = ChangeStatus::CHANGED;
22815dfd7cc4Ssstefan1     }
22825dfd7cc4Ssstefan1     return Changed;
22835dfd7cc4Ssstefan1   }
22849548b74aSJohannes Doerfert };
228518283125SJoseph Huber 
228618283125SJoseph Huber struct AAExecutionDomainFunction : public AAExecutionDomain {
228718283125SJoseph Huber   AAExecutionDomainFunction(const IRPosition &IRP, Attributor &A)
228818283125SJoseph Huber       : AAExecutionDomain(IRP, A) {}
228918283125SJoseph Huber 
229018283125SJoseph Huber   const std::string getAsStr() const override {
229118283125SJoseph Huber     return "[AAExecutionDomain] " + std::to_string(SingleThreadedBBs.size()) +
229218283125SJoseph Huber            "/" + std::to_string(NumBBs) + " BBs thread 0 only.";
229318283125SJoseph Huber   }
229418283125SJoseph Huber 
229518283125SJoseph Huber   /// See AbstractAttribute::trackStatistics().
229618283125SJoseph Huber   void trackStatistics() const override {}
229718283125SJoseph Huber 
229818283125SJoseph Huber   void initialize(Attributor &A) override {
229918283125SJoseph Huber     Function *F = getAnchorScope();
230018283125SJoseph Huber     for (const auto &BB : *F)
230118283125SJoseph Huber       SingleThreadedBBs.insert(&BB);
230218283125SJoseph Huber     NumBBs = SingleThreadedBBs.size();
230318283125SJoseph Huber   }
230418283125SJoseph Huber 
230518283125SJoseph Huber   ChangeStatus manifest(Attributor &A) override {
230618283125SJoseph Huber     LLVM_DEBUG({
230718283125SJoseph Huber       for (const BasicBlock *BB : SingleThreadedBBs)
230818283125SJoseph Huber         dbgs() << TAG << " Basic block @" << getAnchorScope()->getName() << " "
230918283125SJoseph Huber                << BB->getName() << " is executed by a single thread.\n";
231018283125SJoseph Huber     });
231118283125SJoseph Huber     return ChangeStatus::UNCHANGED;
231218283125SJoseph Huber   }
231318283125SJoseph Huber 
231418283125SJoseph Huber   ChangeStatus updateImpl(Attributor &A) override;
231518283125SJoseph Huber 
231618283125SJoseph Huber   /// Check if an instruction is executed by a single thread.
231718283125SJoseph Huber   bool isSingleThreadExecution(const Instruction &I) const override {
231818283125SJoseph Huber     return isSingleThreadExecution(*I.getParent());
231918283125SJoseph Huber   }
232018283125SJoseph Huber 
232118283125SJoseph Huber   bool isSingleThreadExecution(const BasicBlock &BB) const override {
232218283125SJoseph Huber     return SingleThreadedBBs.contains(&BB);
232318283125SJoseph Huber   }
232418283125SJoseph Huber 
232518283125SJoseph Huber   /// Set of basic blocks that are executed by a single thread.
232618283125SJoseph Huber   DenseSet<const BasicBlock *> SingleThreadedBBs;
232718283125SJoseph Huber 
232818283125SJoseph Huber   /// Total number of basic blocks in this function.
232918283125SJoseph Huber   long unsigned NumBBs;
233018283125SJoseph Huber };
233118283125SJoseph Huber 
233218283125SJoseph Huber ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) {
233318283125SJoseph Huber   Function *F = getAnchorScope();
233418283125SJoseph Huber   ReversePostOrderTraversal<Function *> RPOT(F);
233518283125SJoseph Huber   auto NumSingleThreadedBBs = SingleThreadedBBs.size();
233618283125SJoseph Huber 
233718283125SJoseph Huber   bool AllCallSitesKnown;
233818283125SJoseph Huber   auto PredForCallSite = [&](AbstractCallSite ACS) {
233918283125SJoseph Huber     const auto &ExecutionDomainAA = A.getAAFor<AAExecutionDomain>(
234018283125SJoseph Huber         *this, IRPosition::function(*ACS.getInstruction()->getFunction()),
234118283125SJoseph Huber         DepClassTy::REQUIRED);
234218283125SJoseph Huber     return ExecutionDomainAA.isSingleThreadExecution(*ACS.getInstruction());
234318283125SJoseph Huber   };
234418283125SJoseph Huber 
234518283125SJoseph Huber   if (!A.checkForAllCallSites(PredForCallSite, *this,
234618283125SJoseph Huber                               /* RequiresAllCallSites */ true,
234718283125SJoseph Huber                               AllCallSitesKnown))
234818283125SJoseph Huber     SingleThreadedBBs.erase(&F->getEntryBlock());
234918283125SJoseph Huber 
235018283125SJoseph Huber   // Check if the edge into the successor block compares a thread-id function to
235118283125SJoseph Huber   // a constant zero.
235218283125SJoseph Huber   // TODO: Use AAValueSimplify to simplify and propogate constants.
235318283125SJoseph Huber   // TODO: Check more than a single use for thread ID's.
235418283125SJoseph Huber   auto IsSingleThreadOnly = [&](BranchInst *Edge, BasicBlock *SuccessorBB) {
235518283125SJoseph Huber     if (!Edge || !Edge->isConditional())
235618283125SJoseph Huber       return false;
235718283125SJoseph Huber     if (Edge->getSuccessor(0) != SuccessorBB)
235818283125SJoseph Huber       return false;
235918283125SJoseph Huber 
236018283125SJoseph Huber     auto *Cmp = dyn_cast<CmpInst>(Edge->getCondition());
236118283125SJoseph Huber     if (!Cmp || !Cmp->isTrueWhenEqual() || !Cmp->isEquality())
236218283125SJoseph Huber       return false;
236318283125SJoseph Huber 
236418283125SJoseph Huber     ConstantInt *C = dyn_cast<ConstantInt>(Cmp->getOperand(1));
236518283125SJoseph Huber     if (!C || !C->isZero())
236618283125SJoseph Huber       return false;
236718283125SJoseph Huber 
236868abc3d2SJoseph Huber     if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0)))
236968abc3d2SJoseph Huber       if (II->getIntrinsicID() == Intrinsic::nvvm_read_ptx_sreg_tid_x)
237018283125SJoseph Huber         return true;
237168abc3d2SJoseph Huber     if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0)))
237268abc3d2SJoseph Huber       if (II->getIntrinsicID() == Intrinsic::amdgcn_workitem_id_x)
237368abc3d2SJoseph Huber         return true;
237418283125SJoseph Huber 
237518283125SJoseph Huber     return false;
237618283125SJoseph Huber   };
237718283125SJoseph Huber 
237818283125SJoseph Huber   // Merge all the predecessor states into the current basic block. A basic
237918283125SJoseph Huber   // block is executed by a single thread if all of its predecessors are.
238018283125SJoseph Huber   auto MergePredecessorStates = [&](BasicBlock *BB) {
238118283125SJoseph Huber     if (pred_begin(BB) == pred_end(BB))
238218283125SJoseph Huber       return SingleThreadedBBs.contains(BB);
238318283125SJoseph Huber 
238418283125SJoseph Huber     bool IsSingleThreaded = true;
238518283125SJoseph Huber     for (auto PredBB = pred_begin(BB), PredEndBB = pred_end(BB);
238618283125SJoseph Huber          PredBB != PredEndBB; ++PredBB) {
238718283125SJoseph Huber       if (!IsSingleThreadOnly(dyn_cast<BranchInst>((*PredBB)->getTerminator()),
238818283125SJoseph Huber                               BB))
238918283125SJoseph Huber         IsSingleThreaded &= SingleThreadedBBs.contains(*PredBB);
239018283125SJoseph Huber     }
239118283125SJoseph Huber 
239218283125SJoseph Huber     return IsSingleThreaded;
239318283125SJoseph Huber   };
239418283125SJoseph Huber 
239518283125SJoseph Huber   for (auto *BB : RPOT) {
239618283125SJoseph Huber     if (!MergePredecessorStates(BB))
239718283125SJoseph Huber       SingleThreadedBBs.erase(BB);
239818283125SJoseph Huber   }
239918283125SJoseph Huber 
240018283125SJoseph Huber   return (NumSingleThreadedBBs == SingleThreadedBBs.size())
240118283125SJoseph Huber              ? ChangeStatus::UNCHANGED
240218283125SJoseph Huber              : ChangeStatus::CHANGED;
240318283125SJoseph Huber }
240418283125SJoseph Huber 
24059548b74aSJohannes Doerfert } // namespace
24069548b74aSJohannes Doerfert 
2407b8235d2bSsstefan1 const char AAICVTracker::ID = 0;
240818283125SJoseph Huber const char AAExecutionDomain::ID = 0;
2409b8235d2bSsstefan1 
2410b8235d2bSsstefan1 AAICVTracker &AAICVTracker::createForPosition(const IRPosition &IRP,
2411b8235d2bSsstefan1                                               Attributor &A) {
2412b8235d2bSsstefan1   AAICVTracker *AA = nullptr;
2413b8235d2bSsstefan1   switch (IRP.getPositionKind()) {
2414b8235d2bSsstefan1   case IRPosition::IRP_INVALID:
2415b8235d2bSsstefan1   case IRPosition::IRP_FLOAT:
2416b8235d2bSsstefan1   case IRPosition::IRP_ARGUMENT:
2417b8235d2bSsstefan1   case IRPosition::IRP_CALL_SITE_ARGUMENT:
24181de70a72SJohannes Doerfert     llvm_unreachable("ICVTracker can only be created for function position!");
24195dfd7cc4Ssstefan1   case IRPosition::IRP_RETURNED:
24205dfd7cc4Ssstefan1     AA = new (A.Allocator) AAICVTrackerFunctionReturned(IRP, A);
24215dfd7cc4Ssstefan1     break;
24225dfd7cc4Ssstefan1   case IRPosition::IRP_CALL_SITE_RETURNED:
24235dfd7cc4Ssstefan1     AA = new (A.Allocator) AAICVTrackerCallSiteReturned(IRP, A);
24245dfd7cc4Ssstefan1     break;
24255dfd7cc4Ssstefan1   case IRPosition::IRP_CALL_SITE:
24265dfd7cc4Ssstefan1     AA = new (A.Allocator) AAICVTrackerCallSite(IRP, A);
24275dfd7cc4Ssstefan1     break;
2428b8235d2bSsstefan1   case IRPosition::IRP_FUNCTION:
2429b8235d2bSsstefan1     AA = new (A.Allocator) AAICVTrackerFunction(IRP, A);
2430b8235d2bSsstefan1     break;
2431b8235d2bSsstefan1   }
2432b8235d2bSsstefan1 
2433b8235d2bSsstefan1   return *AA;
2434b8235d2bSsstefan1 }
2435b8235d2bSsstefan1 
243618283125SJoseph Huber AAExecutionDomain &AAExecutionDomain::createForPosition(const IRPosition &IRP,
243718283125SJoseph Huber                                                         Attributor &A) {
243818283125SJoseph Huber   AAExecutionDomainFunction *AA = nullptr;
243918283125SJoseph Huber   switch (IRP.getPositionKind()) {
244018283125SJoseph Huber   case IRPosition::IRP_INVALID:
244118283125SJoseph Huber   case IRPosition::IRP_FLOAT:
244218283125SJoseph Huber   case IRPosition::IRP_ARGUMENT:
244318283125SJoseph Huber   case IRPosition::IRP_CALL_SITE_ARGUMENT:
244418283125SJoseph Huber   case IRPosition::IRP_RETURNED:
244518283125SJoseph Huber   case IRPosition::IRP_CALL_SITE_RETURNED:
244618283125SJoseph Huber   case IRPosition::IRP_CALL_SITE:
244718283125SJoseph Huber     llvm_unreachable(
244818283125SJoseph Huber         "AAExecutionDomain can only be created for function position!");
244918283125SJoseph Huber   case IRPosition::IRP_FUNCTION:
245018283125SJoseph Huber     AA = new (A.Allocator) AAExecutionDomainFunction(IRP, A);
245118283125SJoseph Huber     break;
245218283125SJoseph Huber   }
245318283125SJoseph Huber 
245418283125SJoseph Huber   return *AA;
245518283125SJoseph Huber }
245618283125SJoseph Huber 
2457b2ad63d3SJoseph Huber PreservedAnalyses OpenMPOptPass::run(Module &M, ModuleAnalysisManager &AM) {
2458b2ad63d3SJoseph Huber   if (!containsOpenMP(M, OMPInModule))
2459b2ad63d3SJoseph Huber     return PreservedAnalyses::all();
2460b2ad63d3SJoseph Huber 
2461b2ad63d3SJoseph Huber   if (DisableOpenMPOptimizations)
2462b2ad63d3SJoseph Huber     return PreservedAnalyses::all();
2463b2ad63d3SJoseph Huber 
2464b2ad63d3SJoseph Huber   // Look at every function definition in the Module.
2465b2ad63d3SJoseph Huber   SmallVector<Function *, 16> SCC;
2466b2ad63d3SJoseph Huber   for (Function &Fn : M)
2467b2ad63d3SJoseph Huber     if (!Fn.isDeclaration())
2468b2ad63d3SJoseph Huber       SCC.push_back(&Fn);
2469b2ad63d3SJoseph Huber 
2470b2ad63d3SJoseph Huber   if (SCC.empty())
2471b2ad63d3SJoseph Huber     return PreservedAnalyses::all();
2472b2ad63d3SJoseph Huber 
2473b2ad63d3SJoseph Huber   FunctionAnalysisManager &FAM =
2474b2ad63d3SJoseph Huber       AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
2475b2ad63d3SJoseph Huber 
2476b2ad63d3SJoseph Huber   AnalysisGetter AG(FAM);
2477b2ad63d3SJoseph Huber 
2478b2ad63d3SJoseph Huber   auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & {
2479b2ad63d3SJoseph Huber     return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F);
2480b2ad63d3SJoseph Huber   };
2481b2ad63d3SJoseph Huber 
2482b2ad63d3SJoseph Huber   BumpPtrAllocator Allocator;
2483b2ad63d3SJoseph Huber   CallGraphUpdater CGUpdater;
2484b2ad63d3SJoseph Huber 
2485b2ad63d3SJoseph Huber   SetVector<Function *> Functions(SCC.begin(), SCC.end());
2486b2ad63d3SJoseph Huber   OMPInformationCache InfoCache(M, AG, Allocator, /*CGSCC*/ Functions,
2487b2ad63d3SJoseph Huber                                 OMPInModule.getKernels());
2488b2ad63d3SJoseph Huber 
2489b2ad63d3SJoseph Huber   Attributor A(Functions, InfoCache, CGUpdater);
2490b2ad63d3SJoseph Huber 
2491b2ad63d3SJoseph Huber   OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
2492b2ad63d3SJoseph Huber   bool Changed = OMPOpt.run(true);
2493b2ad63d3SJoseph Huber   if (Changed)
2494b2ad63d3SJoseph Huber     return PreservedAnalyses::none();
2495b2ad63d3SJoseph Huber 
2496b2ad63d3SJoseph Huber   return PreservedAnalyses::all();
2497b2ad63d3SJoseph Huber }
2498b2ad63d3SJoseph Huber 
2499b2ad63d3SJoseph Huber PreservedAnalyses OpenMPOptCGSCCPass::run(LazyCallGraph::SCC &C,
25009548b74aSJohannes Doerfert                                           CGSCCAnalysisManager &AM,
2501b2ad63d3SJoseph Huber                                           LazyCallGraph &CG,
2502b2ad63d3SJoseph Huber                                           CGSCCUpdateResult &UR) {
25039548b74aSJohannes Doerfert   if (!containsOpenMP(*C.begin()->getFunction().getParent(), OMPInModule))
25049548b74aSJohannes Doerfert     return PreservedAnalyses::all();
25059548b74aSJohannes Doerfert 
25069548b74aSJohannes Doerfert   if (DisableOpenMPOptimizations)
25079548b74aSJohannes Doerfert     return PreservedAnalyses::all();
25089548b74aSJohannes Doerfert 
2509ee17263aSJohannes Doerfert   SmallVector<Function *, 16> SCC;
2510351d234dSRoman Lebedev   // If there are kernels in the module, we have to run on all SCC's.
2511351d234dSRoman Lebedev   bool SCCIsInteresting = !OMPInModule.getKernels().empty();
2512351d234dSRoman Lebedev   for (LazyCallGraph::Node &N : C) {
2513351d234dSRoman Lebedev     Function *Fn = &N.getFunction();
2514351d234dSRoman Lebedev     SCC.push_back(Fn);
25159548b74aSJohannes Doerfert 
2516351d234dSRoman Lebedev     // Do we already know that the SCC contains kernels,
2517351d234dSRoman Lebedev     // or that OpenMP functions are called from this SCC?
2518351d234dSRoman Lebedev     if (SCCIsInteresting)
2519351d234dSRoman Lebedev       continue;
2520351d234dSRoman Lebedev     // If not, let's check that.
2521351d234dSRoman Lebedev     SCCIsInteresting |= OMPInModule.containsOMPRuntimeCalls(Fn);
2522351d234dSRoman Lebedev   }
2523351d234dSRoman Lebedev 
2524351d234dSRoman Lebedev   if (!SCCIsInteresting || SCC.empty())
25259548b74aSJohannes Doerfert     return PreservedAnalyses::all();
25269548b74aSJohannes Doerfert 
25274d4ea9acSHuber, Joseph   FunctionAnalysisManager &FAM =
25284d4ea9acSHuber, Joseph       AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager();
25297cfd267cSsstefan1 
25307cfd267cSsstefan1   AnalysisGetter AG(FAM);
25317cfd267cSsstefan1 
25327cfd267cSsstefan1   auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & {
25334d4ea9acSHuber, Joseph     return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F);
25344d4ea9acSHuber, Joseph   };
25354d4ea9acSHuber, Joseph 
2536b2ad63d3SJoseph Huber   BumpPtrAllocator Allocator;
25379548b74aSJohannes Doerfert   CallGraphUpdater CGUpdater;
25389548b74aSJohannes Doerfert   CGUpdater.initialize(CG, C, AM, UR);
25397cfd267cSsstefan1 
25407cfd267cSsstefan1   SetVector<Function *> Functions(SCC.begin(), SCC.end());
25417cfd267cSsstefan1   OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG, Allocator,
2542624d34afSJohannes Doerfert                                 /*CGSCC*/ Functions, OMPInModule.getKernels());
25437cfd267cSsstefan1 
25448b57ed09SJoseph Huber   Attributor A(Functions, InfoCache, CGUpdater, nullptr, false);
2545b8235d2bSsstefan1 
2546b8235d2bSsstefan1   OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
2547b2ad63d3SJoseph Huber   bool Changed = OMPOpt.run(false);
2548694ded37SGiorgis Georgakoudis   if (Changed)
2549694ded37SGiorgis Georgakoudis     return PreservedAnalyses::none();
2550694ded37SGiorgis Georgakoudis 
25519548b74aSJohannes Doerfert   return PreservedAnalyses::all();
25529548b74aSJohannes Doerfert }
25538b57ed09SJoseph Huber 
25549548b74aSJohannes Doerfert namespace {
25559548b74aSJohannes Doerfert 
2556b2ad63d3SJoseph Huber struct OpenMPOptCGSCCLegacyPass : public CallGraphSCCPass {
25579548b74aSJohannes Doerfert   CallGraphUpdater CGUpdater;
25589548b74aSJohannes Doerfert   OpenMPInModule OMPInModule;
25599548b74aSJohannes Doerfert   static char ID;
25609548b74aSJohannes Doerfert 
2561b2ad63d3SJoseph Huber   OpenMPOptCGSCCLegacyPass() : CallGraphSCCPass(ID) {
2562b2ad63d3SJoseph Huber     initializeOpenMPOptCGSCCLegacyPassPass(*PassRegistry::getPassRegistry());
25639548b74aSJohannes Doerfert   }
25649548b74aSJohannes Doerfert 
25659548b74aSJohannes Doerfert   void getAnalysisUsage(AnalysisUsage &AU) const override {
25669548b74aSJohannes Doerfert     CallGraphSCCPass::getAnalysisUsage(AU);
25679548b74aSJohannes Doerfert   }
25689548b74aSJohannes Doerfert 
25699548b74aSJohannes Doerfert   bool doInitialization(CallGraph &CG) override {
25709548b74aSJohannes Doerfert     // Disable the pass if there is no OpenMP (runtime call) in the module.
25719548b74aSJohannes Doerfert     containsOpenMP(CG.getModule(), OMPInModule);
25729548b74aSJohannes Doerfert     return false;
25739548b74aSJohannes Doerfert   }
25749548b74aSJohannes Doerfert 
25759548b74aSJohannes Doerfert   bool runOnSCC(CallGraphSCC &CGSCC) override {
25769548b74aSJohannes Doerfert     if (!containsOpenMP(CGSCC.getCallGraph().getModule(), OMPInModule))
25779548b74aSJohannes Doerfert       return false;
25789548b74aSJohannes Doerfert     if (DisableOpenMPOptimizations || skipSCC(CGSCC))
25799548b74aSJohannes Doerfert       return false;
25809548b74aSJohannes Doerfert 
2581ee17263aSJohannes Doerfert     SmallVector<Function *, 16> SCC;
2582351d234dSRoman Lebedev     // If there are kernels in the module, we have to run on all SCC's.
2583351d234dSRoman Lebedev     bool SCCIsInteresting = !OMPInModule.getKernels().empty();
2584351d234dSRoman Lebedev     for (CallGraphNode *CGN : CGSCC) {
2585351d234dSRoman Lebedev       Function *Fn = CGN->getFunction();
2586351d234dSRoman Lebedev       if (!Fn || Fn->isDeclaration())
2587351d234dSRoman Lebedev         continue;
2588ee17263aSJohannes Doerfert       SCC.push_back(Fn);
25899548b74aSJohannes Doerfert 
2590351d234dSRoman Lebedev       // Do we already know that the SCC contains kernels,
2591351d234dSRoman Lebedev       // or that OpenMP functions are called from this SCC?
2592351d234dSRoman Lebedev       if (SCCIsInteresting)
2593351d234dSRoman Lebedev         continue;
2594351d234dSRoman Lebedev       // If not, let's check that.
2595351d234dSRoman Lebedev       SCCIsInteresting |= OMPInModule.containsOMPRuntimeCalls(Fn);
2596351d234dSRoman Lebedev     }
2597351d234dSRoman Lebedev 
2598351d234dSRoman Lebedev     if (!SCCIsInteresting || SCC.empty())
25999548b74aSJohannes Doerfert       return false;
26009548b74aSJohannes Doerfert 
26019548b74aSJohannes Doerfert     CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph();
26029548b74aSJohannes Doerfert     CGUpdater.initialize(CG, CGSCC);
26039548b74aSJohannes Doerfert 
26044d4ea9acSHuber, Joseph     // Maintain a map of functions to avoid rebuilding the ORE
26054d4ea9acSHuber, Joseph     DenseMap<Function *, std::unique_ptr<OptimizationRemarkEmitter>> OREMap;
26064d4ea9acSHuber, Joseph     auto OREGetter = [&OREMap](Function *F) -> OptimizationRemarkEmitter & {
26074d4ea9acSHuber, Joseph       std::unique_ptr<OptimizationRemarkEmitter> &ORE = OREMap[F];
26084d4ea9acSHuber, Joseph       if (!ORE)
26094d4ea9acSHuber, Joseph         ORE = std::make_unique<OptimizationRemarkEmitter>(F);
26104d4ea9acSHuber, Joseph       return *ORE;
26114d4ea9acSHuber, Joseph     };
26124d4ea9acSHuber, Joseph 
26137cfd267cSsstefan1     AnalysisGetter AG;
26147cfd267cSsstefan1     SetVector<Function *> Functions(SCC.begin(), SCC.end());
26157cfd267cSsstefan1     BumpPtrAllocator Allocator;
2616e8039ad4SJohannes Doerfert     OMPInformationCache InfoCache(
2617e8039ad4SJohannes Doerfert         *(Functions.back()->getParent()), AG, Allocator,
2618624d34afSJohannes Doerfert         /*CGSCC*/ Functions, OMPInModule.getKernels());
26197cfd267cSsstefan1 
26208b57ed09SJoseph Huber     Attributor A(Functions, InfoCache, CGUpdater, nullptr, false);
2621b8235d2bSsstefan1 
2622b8235d2bSsstefan1     OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
2623b2ad63d3SJoseph Huber     return OMPOpt.run(false);
26249548b74aSJohannes Doerfert   }
26259548b74aSJohannes Doerfert 
26269548b74aSJohannes Doerfert   bool doFinalization(CallGraph &CG) override { return CGUpdater.finalize(); }
26279548b74aSJohannes Doerfert };
26289548b74aSJohannes Doerfert 
26299548b74aSJohannes Doerfert } // end anonymous namespace
26309548b74aSJohannes Doerfert 
2631e8039ad4SJohannes Doerfert void OpenMPInModule::identifyKernels(Module &M) {
2632e8039ad4SJohannes Doerfert 
2633e8039ad4SJohannes Doerfert   NamedMDNode *MD = M.getOrInsertNamedMetadata("nvvm.annotations");
2634e8039ad4SJohannes Doerfert   if (!MD)
2635e8039ad4SJohannes Doerfert     return;
2636e8039ad4SJohannes Doerfert 
2637e8039ad4SJohannes Doerfert   for (auto *Op : MD->operands()) {
2638e8039ad4SJohannes Doerfert     if (Op->getNumOperands() < 2)
2639e8039ad4SJohannes Doerfert       continue;
2640e8039ad4SJohannes Doerfert     MDString *KindID = dyn_cast<MDString>(Op->getOperand(1));
2641e8039ad4SJohannes Doerfert     if (!KindID || KindID->getString() != "kernel")
2642e8039ad4SJohannes Doerfert       continue;
2643e8039ad4SJohannes Doerfert 
2644e8039ad4SJohannes Doerfert     Function *KernelFn =
2645e8039ad4SJohannes Doerfert         mdconst::dyn_extract_or_null<Function>(Op->getOperand(0));
2646e8039ad4SJohannes Doerfert     if (!KernelFn)
2647e8039ad4SJohannes Doerfert       continue;
2648e8039ad4SJohannes Doerfert 
2649e8039ad4SJohannes Doerfert     ++NumOpenMPTargetRegionKernels;
2650e8039ad4SJohannes Doerfert 
2651e8039ad4SJohannes Doerfert     Kernels.insert(KernelFn);
2652e8039ad4SJohannes Doerfert   }
2653e8039ad4SJohannes Doerfert }
2654e8039ad4SJohannes Doerfert 
26559548b74aSJohannes Doerfert bool llvm::omp::containsOpenMP(Module &M, OpenMPInModule &OMPInModule) {
26569548b74aSJohannes Doerfert   if (OMPInModule.isKnown())
26579548b74aSJohannes Doerfert     return OMPInModule;
2658dce6bc18SJohannes Doerfert 
2659351d234dSRoman Lebedev   auto RecordFunctionsContainingUsesOf = [&](Function *F) {
2660351d234dSRoman Lebedev     for (User *U : F->users())
2661351d234dSRoman Lebedev       if (auto *I = dyn_cast<Instruction>(U))
2662351d234dSRoman Lebedev         OMPInModule.FuncsWithOMPRuntimeCalls.insert(I->getFunction());
2663351d234dSRoman Lebedev   };
2664351d234dSRoman Lebedev 
2665dce6bc18SJohannes Doerfert   // MSVC doesn't like long if-else chains for some reason and instead just
2666dce6bc18SJohannes Doerfert   // issues an error. Work around it..
2667dce6bc18SJohannes Doerfert   do {
26689548b74aSJohannes Doerfert #define OMP_RTL(_Enum, _Name, ...)                                             \
2669351d234dSRoman Lebedev   if (Function *F = M.getFunction(_Name)) {                                    \
2670351d234dSRoman Lebedev     RecordFunctionsContainingUsesOf(F);                                        \
2671dce6bc18SJohannes Doerfert     OMPInModule = true;                                                        \
2672dce6bc18SJohannes Doerfert   }
26739548b74aSJohannes Doerfert #include "llvm/Frontend/OpenMP/OMPKinds.def"
2674dce6bc18SJohannes Doerfert   } while (false);
2675e8039ad4SJohannes Doerfert 
2676e8039ad4SJohannes Doerfert   // Identify kernels once. TODO: We should split the OMPInformationCache into a
2677e8039ad4SJohannes Doerfert   // module and an SCC part. The kernel information, among other things, could
2678e8039ad4SJohannes Doerfert   // go into the module part.
2679e8039ad4SJohannes Doerfert   if (OMPInModule.isKnown() && OMPInModule) {
2680e8039ad4SJohannes Doerfert     OMPInModule.identifyKernels(M);
2681e8039ad4SJohannes Doerfert     return true;
2682e8039ad4SJohannes Doerfert   }
2683e8039ad4SJohannes Doerfert 
26849548b74aSJohannes Doerfert   return OMPInModule = false;
26859548b74aSJohannes Doerfert }
26869548b74aSJohannes Doerfert 
2687b2ad63d3SJoseph Huber char OpenMPOptCGSCCLegacyPass::ID = 0;
26889548b74aSJohannes Doerfert 
2689b2ad63d3SJoseph Huber INITIALIZE_PASS_BEGIN(OpenMPOptCGSCCLegacyPass, "openmp-opt-cgscc",
26909548b74aSJohannes Doerfert                       "OpenMP specific optimizations", false, false)
26919548b74aSJohannes Doerfert INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass)
2692b2ad63d3SJoseph Huber INITIALIZE_PASS_END(OpenMPOptCGSCCLegacyPass, "openmp-opt-cgscc",
26939548b74aSJohannes Doerfert                     "OpenMP specific optimizations", false, false)
26949548b74aSJohannes Doerfert 
2695b2ad63d3SJoseph Huber Pass *llvm::createOpenMPOptCGSCCLegacyPass() {
2696b2ad63d3SJoseph Huber   return new OpenMPOptCGSCCLegacyPass();
2697b2ad63d3SJoseph Huber }
2698