19548b74aSJohannes Doerfert //===-- IPO/OpenMPOpt.cpp - Collection of OpenMP specific optimizations ---===// 29548b74aSJohannes Doerfert // 39548b74aSJohannes Doerfert // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 49548b74aSJohannes Doerfert // See https://llvm.org/LICENSE.txt for license information. 59548b74aSJohannes Doerfert // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 69548b74aSJohannes Doerfert // 79548b74aSJohannes Doerfert //===----------------------------------------------------------------------===// 89548b74aSJohannes Doerfert // 99548b74aSJohannes Doerfert // OpenMP specific optimizations: 109548b74aSJohannes Doerfert // 119548b74aSJohannes Doerfert // - Deduplication of runtime calls, e.g., omp_get_thread_num. 12ca1560daSJoseph Huber // - Replacing globalized device memory with stack memory. 13ca1560daSJoseph Huber // - Replacing globalized device memory with shared memory. 14b910a109SJoseph Huber // - Parallel region merging. 15b910a109SJoseph Huber // - Transforming generic-mode device kernels to SPMD mode. 16b910a109SJoseph Huber // - Specializing the state machine for generic-mode device kernels. 179548b74aSJohannes Doerfert // 189548b74aSJohannes Doerfert //===----------------------------------------------------------------------===// 199548b74aSJohannes Doerfert 209548b74aSJohannes Doerfert #include "llvm/Transforms/IPO/OpenMPOpt.h" 219548b74aSJohannes Doerfert 229548b74aSJohannes Doerfert #include "llvm/ADT/EnumeratedArray.h" 2318283125SJoseph Huber #include "llvm/ADT/PostOrderIterator.h" 249548b74aSJohannes Doerfert #include "llvm/ADT/Statistic.h" 259548b74aSJohannes Doerfert #include "llvm/Analysis/CallGraph.h" 269548b74aSJohannes Doerfert #include "llvm/Analysis/CallGraphSCCPass.h" 274d4ea9acSHuber, Joseph #include "llvm/Analysis/OptimizationRemarkEmitter.h" 283a6bfcf2SGiorgis Georgakoudis #include "llvm/Analysis/ValueTracking.h" 299548b74aSJohannes Doerfert #include "llvm/Frontend/OpenMP/OMPConstants.h" 30e28936f6SJohannes Doerfert #include "llvm/Frontend/OpenMP/OMPIRBuilder.h" 31d9659bf6SJohannes Doerfert #include "llvm/IR/Assumptions.h" 32d9659bf6SJohannes Doerfert #include "llvm/IR/DiagnosticInfo.h" 33514c033dSJohannes Doerfert #include "llvm/IR/GlobalValue.h" 34d9659bf6SJohannes Doerfert #include "llvm/IR/Instruction.h" 3568abc3d2SJoseph Huber #include "llvm/IR/IntrinsicInst.h" 369548b74aSJohannes Doerfert #include "llvm/InitializePasses.h" 379548b74aSJohannes Doerfert #include "llvm/Support/CommandLine.h" 389548b74aSJohannes Doerfert #include "llvm/Transforms/IPO.h" 397cfd267cSsstefan1 #include "llvm/Transforms/IPO/Attributor.h" 403a6bfcf2SGiorgis Georgakoudis #include "llvm/Transforms/Utils/BasicBlockUtils.h" 419548b74aSJohannes Doerfert #include "llvm/Transforms/Utils/CallGraphUpdater.h" 4297517055SGiorgis Georgakoudis #include "llvm/Transforms/Utils/CodeExtractor.h" 439548b74aSJohannes Doerfert 449548b74aSJohannes Doerfert using namespace llvm; 459548b74aSJohannes Doerfert using namespace omp; 469548b74aSJohannes Doerfert 479548b74aSJohannes Doerfert #define DEBUG_TYPE "openmp-opt" 489548b74aSJohannes Doerfert 499548b74aSJohannes Doerfert static cl::opt<bool> DisableOpenMPOptimizations( 509548b74aSJohannes Doerfert "openmp-opt-disable", cl::ZeroOrMore, 519548b74aSJohannes Doerfert cl::desc("Disable OpenMP specific optimizations."), cl::Hidden, 529548b74aSJohannes Doerfert cl::init(false)); 539548b74aSJohannes Doerfert 543a6bfcf2SGiorgis Georgakoudis static cl::opt<bool> EnableParallelRegionMerging( 553a6bfcf2SGiorgis Georgakoudis "openmp-opt-enable-merging", cl::ZeroOrMore, 563a6bfcf2SGiorgis Georgakoudis cl::desc("Enable the OpenMP region merging optimization."), cl::Hidden, 573a6bfcf2SGiorgis Georgakoudis cl::init(false)); 583a6bfcf2SGiorgis Georgakoudis 594a668604SJoseph Huber static cl::opt<bool> 604a668604SJoseph Huber DisableInternalization("openmp-opt-disable-internalization", cl::ZeroOrMore, 614a668604SJoseph Huber cl::desc("Disable function internalization."), 624a668604SJoseph Huber cl::Hidden, cl::init(false)); 634a668604SJoseph Huber 640f426935Ssstefan1 static cl::opt<bool> PrintICVValues("openmp-print-icv-values", cl::init(false), 650f426935Ssstefan1 cl::Hidden); 66e8039ad4SJohannes Doerfert static cl::opt<bool> PrintOpenMPKernels("openmp-print-gpu-kernels", 67e8039ad4SJohannes Doerfert cl::init(false), cl::Hidden); 680f426935Ssstefan1 69496f8e5bSHamilton Tobon Mosquera static cl::opt<bool> HideMemoryTransferLatency( 70496f8e5bSHamilton Tobon Mosquera "openmp-hide-memory-transfer-latency", 71496f8e5bSHamilton Tobon Mosquera cl::desc("[WIP] Tries to hide the latency of host to device memory" 72496f8e5bSHamilton Tobon Mosquera " transfers"), 73496f8e5bSHamilton Tobon Mosquera cl::Hidden, cl::init(false)); 74496f8e5bSHamilton Tobon Mosquera 75cd0dd8ecSJoseph Huber static cl::opt<bool> DisableOpenMPOptDeglobalization( 76cd0dd8ecSJoseph Huber "openmp-opt-disable-deglobalization", cl::ZeroOrMore, 77cd0dd8ecSJoseph Huber cl::desc("Disable OpenMP optimizations involving deglobalization."), 78cd0dd8ecSJoseph Huber cl::Hidden, cl::init(false)); 79cd0dd8ecSJoseph Huber 80cd0dd8ecSJoseph Huber static cl::opt<bool> DisableOpenMPOptSPMDization( 81cd0dd8ecSJoseph Huber "openmp-opt-disable-spmdization", cl::ZeroOrMore, 82cd0dd8ecSJoseph Huber cl::desc("Disable OpenMP optimizations involving SPMD-ization."), 83cd0dd8ecSJoseph Huber cl::Hidden, cl::init(false)); 84cd0dd8ecSJoseph Huber 85cd0dd8ecSJoseph Huber static cl::opt<bool> DisableOpenMPOptFolding( 86cd0dd8ecSJoseph Huber "openmp-opt-disable-folding", cl::ZeroOrMore, 87cd0dd8ecSJoseph Huber cl::desc("Disable OpenMP optimizations involving folding."), cl::Hidden, 88cd0dd8ecSJoseph Huber cl::init(false)); 89cd0dd8ecSJoseph Huber 90cd0dd8ecSJoseph Huber static cl::opt<bool> DisableOpenMPOptStateMachineRewrite( 91cd0dd8ecSJoseph Huber "openmp-opt-disable-state-machine-rewrite", cl::ZeroOrMore, 92cd0dd8ecSJoseph Huber cl::desc("Disable OpenMP optimizations that replace the state machine."), 93cd0dd8ecSJoseph Huber cl::Hidden, cl::init(false)); 94cd0dd8ecSJoseph Huber 959548b74aSJohannes Doerfert STATISTIC(NumOpenMPRuntimeCallsDeduplicated, 969548b74aSJohannes Doerfert "Number of OpenMP runtime calls deduplicated"); 9755eb714aSRoman Lebedev STATISTIC(NumOpenMPParallelRegionsDeleted, 9855eb714aSRoman Lebedev "Number of OpenMP parallel regions deleted"); 999548b74aSJohannes Doerfert STATISTIC(NumOpenMPRuntimeFunctionsIdentified, 1009548b74aSJohannes Doerfert "Number of OpenMP runtime functions identified"); 1019548b74aSJohannes Doerfert STATISTIC(NumOpenMPRuntimeFunctionUsesIdentified, 1029548b74aSJohannes Doerfert "Number of OpenMP runtime function uses identified"); 103e8039ad4SJohannes Doerfert STATISTIC(NumOpenMPTargetRegionKernels, 104e8039ad4SJohannes Doerfert "Number of OpenMP target region entry points (=kernels) identified"); 105514c033dSJohannes Doerfert STATISTIC(NumOpenMPTargetRegionKernelsSPMD, 106514c033dSJohannes Doerfert "Number of OpenMP target region entry points (=kernels) executed in " 107514c033dSJohannes Doerfert "SPMD-mode instead of generic-mode"); 108d9659bf6SJohannes Doerfert STATISTIC(NumOpenMPTargetRegionKernelsWithoutStateMachine, 109d9659bf6SJohannes Doerfert "Number of OpenMP target region entry points (=kernels) executed in " 110d9659bf6SJohannes Doerfert "generic-mode without a state machines"); 111d9659bf6SJohannes Doerfert STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback, 112d9659bf6SJohannes Doerfert "Number of OpenMP target region entry points (=kernels) executed in " 113d9659bf6SJohannes Doerfert "generic-mode with customized state machines with fallback"); 114d9659bf6SJohannes Doerfert STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback, 115d9659bf6SJohannes Doerfert "Number of OpenMP target region entry points (=kernels) executed in " 116d9659bf6SJohannes Doerfert "generic-mode with customized state machines without fallback"); 1175b0581aeSJohannes Doerfert STATISTIC( 1185b0581aeSJohannes Doerfert NumOpenMPParallelRegionsReplacedInGPUStateMachine, 1195b0581aeSJohannes Doerfert "Number of OpenMP parallel regions replaced with ID in GPU state machines"); 1203a6bfcf2SGiorgis Georgakoudis STATISTIC(NumOpenMPParallelRegionsMerged, 1213a6bfcf2SGiorgis Georgakoudis "Number of OpenMP parallel regions merged"); 1226fc51c9fSJoseph Huber STATISTIC(NumBytesMovedToSharedMemory, 1236fc51c9fSJoseph Huber "Amount of memory pushed to shared memory"); 1249548b74aSJohannes Doerfert 125263c4a3cSrathod-sahaab #if !defined(NDEBUG) 1269548b74aSJohannes Doerfert static constexpr auto TAG = "[" DEBUG_TYPE "]"; 127a50c0b0dSMikael Holmen #endif 1289548b74aSJohannes Doerfert 1299548b74aSJohannes Doerfert namespace { 1309548b74aSJohannes Doerfert 1316fc51c9fSJoseph Huber enum class AddressSpace : unsigned { 1326fc51c9fSJoseph Huber Generic = 0, 1336fc51c9fSJoseph Huber Global = 1, 1346fc51c9fSJoseph Huber Shared = 3, 1356fc51c9fSJoseph Huber Constant = 4, 1366fc51c9fSJoseph Huber Local = 5, 1376fc51c9fSJoseph Huber }; 1386fc51c9fSJoseph Huber 1396fc51c9fSJoseph Huber struct AAHeapToShared; 1406fc51c9fSJoseph Huber 141b8235d2bSsstefan1 struct AAICVTracker; 142b8235d2bSsstefan1 1437cfd267cSsstefan1 /// OpenMP specific information. For now, stores RFIs and ICVs also needed for 1447cfd267cSsstefan1 /// Attributor runs. 1457cfd267cSsstefan1 struct OMPInformationCache : public InformationCache { 1467cfd267cSsstefan1 OMPInformationCache(Module &M, AnalysisGetter &AG, 147624d34afSJohannes Doerfert BumpPtrAllocator &Allocator, SetVector<Function *> &CGSCC, 148e8039ad4SJohannes Doerfert SmallPtrSetImpl<Kernel> &Kernels) 149624d34afSJohannes Doerfert : InformationCache(M, AG, Allocator, &CGSCC), OMPBuilder(M), 150624d34afSJohannes Doerfert Kernels(Kernels) { 151624d34afSJohannes Doerfert 15261238d26Ssstefan1 OMPBuilder.initialize(); 1539548b74aSJohannes Doerfert initializeRuntimeFunctions(); 1540f426935Ssstefan1 initializeInternalControlVars(); 1559548b74aSJohannes Doerfert } 1569548b74aSJohannes Doerfert 1570f426935Ssstefan1 /// Generic information that describes an internal control variable. 1580f426935Ssstefan1 struct InternalControlVarInfo { 1590f426935Ssstefan1 /// The kind, as described by InternalControlVar enum. 1600f426935Ssstefan1 InternalControlVar Kind; 1610f426935Ssstefan1 1620f426935Ssstefan1 /// The name of the ICV. 1630f426935Ssstefan1 StringRef Name; 1640f426935Ssstefan1 1650f426935Ssstefan1 /// Environment variable associated with this ICV. 1660f426935Ssstefan1 StringRef EnvVarName; 1670f426935Ssstefan1 1680f426935Ssstefan1 /// Initial value kind. 1690f426935Ssstefan1 ICVInitValue InitKind; 1700f426935Ssstefan1 1710f426935Ssstefan1 /// Initial value. 1720f426935Ssstefan1 ConstantInt *InitValue; 1730f426935Ssstefan1 1740f426935Ssstefan1 /// Setter RTL function associated with this ICV. 1750f426935Ssstefan1 RuntimeFunction Setter; 1760f426935Ssstefan1 1770f426935Ssstefan1 /// Getter RTL function associated with this ICV. 1780f426935Ssstefan1 RuntimeFunction Getter; 1790f426935Ssstefan1 1800f426935Ssstefan1 /// RTL Function corresponding to the override clause of this ICV 1810f426935Ssstefan1 RuntimeFunction Clause; 1820f426935Ssstefan1 }; 1830f426935Ssstefan1 1849548b74aSJohannes Doerfert /// Generic information that describes a runtime function 1859548b74aSJohannes Doerfert struct RuntimeFunctionInfo { 1868855fec3SJohannes Doerfert 1879548b74aSJohannes Doerfert /// The kind, as described by the RuntimeFunction enum. 1889548b74aSJohannes Doerfert RuntimeFunction Kind; 1899548b74aSJohannes Doerfert 1909548b74aSJohannes Doerfert /// The name of the function. 1919548b74aSJohannes Doerfert StringRef Name; 1929548b74aSJohannes Doerfert 1939548b74aSJohannes Doerfert /// Flag to indicate a variadic function. 1949548b74aSJohannes Doerfert bool IsVarArg; 1959548b74aSJohannes Doerfert 1969548b74aSJohannes Doerfert /// The return type of the function. 1979548b74aSJohannes Doerfert Type *ReturnType; 1989548b74aSJohannes Doerfert 1999548b74aSJohannes Doerfert /// The argument types of the function. 2009548b74aSJohannes Doerfert SmallVector<Type *, 8> ArgumentTypes; 2019548b74aSJohannes Doerfert 2029548b74aSJohannes Doerfert /// The declaration if available. 203f09f4b26SJohannes Doerfert Function *Declaration = nullptr; 2049548b74aSJohannes Doerfert 2059548b74aSJohannes Doerfert /// Uses of this runtime function per function containing the use. 2068855fec3SJohannes Doerfert using UseVector = SmallVector<Use *, 16>; 2078855fec3SJohannes Doerfert 208b8235d2bSsstefan1 /// Clear UsesMap for runtime function. 209b8235d2bSsstefan1 void clearUsesMap() { UsesMap.clear(); } 210b8235d2bSsstefan1 21154bd3751SJohannes Doerfert /// Boolean conversion that is true if the runtime function was found. 21254bd3751SJohannes Doerfert operator bool() const { return Declaration; } 21354bd3751SJohannes Doerfert 2148855fec3SJohannes Doerfert /// Return the vector of uses in function \p F. 2158855fec3SJohannes Doerfert UseVector &getOrCreateUseVector(Function *F) { 216b8235d2bSsstefan1 std::shared_ptr<UseVector> &UV = UsesMap[F]; 2178855fec3SJohannes Doerfert if (!UV) 218b8235d2bSsstefan1 UV = std::make_shared<UseVector>(); 2198855fec3SJohannes Doerfert return *UV; 2208855fec3SJohannes Doerfert } 2218855fec3SJohannes Doerfert 2228855fec3SJohannes Doerfert /// Return the vector of uses in function \p F or `nullptr` if there are 2238855fec3SJohannes Doerfert /// none. 2248855fec3SJohannes Doerfert const UseVector *getUseVector(Function &F) const { 22595e57072SDavid Blaikie auto I = UsesMap.find(&F); 22695e57072SDavid Blaikie if (I != UsesMap.end()) 22795e57072SDavid Blaikie return I->second.get(); 22895e57072SDavid Blaikie return nullptr; 2298855fec3SJohannes Doerfert } 2308855fec3SJohannes Doerfert 2318855fec3SJohannes Doerfert /// Return how many functions contain uses of this runtime function. 2328855fec3SJohannes Doerfert size_t getNumFunctionsWithUses() const { return UsesMap.size(); } 2339548b74aSJohannes Doerfert 2349548b74aSJohannes Doerfert /// Return the number of arguments (or the minimal number for variadic 2359548b74aSJohannes Doerfert /// functions). 2369548b74aSJohannes Doerfert size_t getNumArgs() const { return ArgumentTypes.size(); } 2379548b74aSJohannes Doerfert 2389548b74aSJohannes Doerfert /// Run the callback \p CB on each use and forget the use if the result is 2399548b74aSJohannes Doerfert /// true. The callback will be fed the function in which the use was 2409548b74aSJohannes Doerfert /// encountered as second argument. 241624d34afSJohannes Doerfert void foreachUse(SmallVectorImpl<Function *> &SCC, 242624d34afSJohannes Doerfert function_ref<bool(Use &, Function &)> CB) { 243624d34afSJohannes Doerfert for (Function *F : SCC) 244624d34afSJohannes Doerfert foreachUse(CB, F); 245e099c7b6Ssstefan1 } 246e099c7b6Ssstefan1 247e099c7b6Ssstefan1 /// Run the callback \p CB on each use within the function \p F and forget 248e099c7b6Ssstefan1 /// the use if the result is true. 249624d34afSJohannes Doerfert void foreachUse(function_ref<bool(Use &, Function &)> CB, Function *F) { 2508855fec3SJohannes Doerfert SmallVector<unsigned, 8> ToBeDeleted; 2519548b74aSJohannes Doerfert ToBeDeleted.clear(); 252e099c7b6Ssstefan1 2538855fec3SJohannes Doerfert unsigned Idx = 0; 254624d34afSJohannes Doerfert UseVector &UV = getOrCreateUseVector(F); 255e099c7b6Ssstefan1 2568855fec3SJohannes Doerfert for (Use *U : UV) { 257e099c7b6Ssstefan1 if (CB(*U, *F)) 2588855fec3SJohannes Doerfert ToBeDeleted.push_back(Idx); 2598855fec3SJohannes Doerfert ++Idx; 2608855fec3SJohannes Doerfert } 2618855fec3SJohannes Doerfert 2628855fec3SJohannes Doerfert // Remove the to-be-deleted indices in reverse order as prior 263b726c557SJohannes Doerfert // modifications will not modify the smaller indices. 2648855fec3SJohannes Doerfert while (!ToBeDeleted.empty()) { 2658855fec3SJohannes Doerfert unsigned Idx = ToBeDeleted.pop_back_val(); 2668855fec3SJohannes Doerfert UV[Idx] = UV.back(); 2678855fec3SJohannes Doerfert UV.pop_back(); 2689548b74aSJohannes Doerfert } 2699548b74aSJohannes Doerfert } 2708855fec3SJohannes Doerfert 2718855fec3SJohannes Doerfert private: 2728855fec3SJohannes Doerfert /// Map from functions to all uses of this runtime function contained in 2738855fec3SJohannes Doerfert /// them. 274b8235d2bSsstefan1 DenseMap<Function *, std::shared_ptr<UseVector>> UsesMap; 275d9659bf6SJohannes Doerfert 276d9659bf6SJohannes Doerfert public: 277d9659bf6SJohannes Doerfert /// Iterators for the uses of this runtime function. 278d9659bf6SJohannes Doerfert decltype(UsesMap)::iterator begin() { return UsesMap.begin(); } 279d9659bf6SJohannes Doerfert decltype(UsesMap)::iterator end() { return UsesMap.end(); } 2809548b74aSJohannes Doerfert }; 2819548b74aSJohannes Doerfert 2827cfd267cSsstefan1 /// An OpenMP-IR-Builder instance 2837cfd267cSsstefan1 OpenMPIRBuilder OMPBuilder; 2847cfd267cSsstefan1 2857cfd267cSsstefan1 /// Map from runtime function kind to the runtime function description. 2867cfd267cSsstefan1 EnumeratedArray<RuntimeFunctionInfo, RuntimeFunction, 2877cfd267cSsstefan1 RuntimeFunction::OMPRTL___last> 2887cfd267cSsstefan1 RFIs; 2897cfd267cSsstefan1 290d9659bf6SJohannes Doerfert /// Map from function declarations/definitions to their runtime enum type. 291d9659bf6SJohannes Doerfert DenseMap<Function *, RuntimeFunction> RuntimeFunctionIDMap; 292d9659bf6SJohannes Doerfert 2930f426935Ssstefan1 /// Map from ICV kind to the ICV description. 2940f426935Ssstefan1 EnumeratedArray<InternalControlVarInfo, InternalControlVar, 2950f426935Ssstefan1 InternalControlVar::ICV___last> 2960f426935Ssstefan1 ICVs; 2970f426935Ssstefan1 2980f426935Ssstefan1 /// Helper to initialize all internal control variable information for those 2990f426935Ssstefan1 /// defined in OMPKinds.def. 3000f426935Ssstefan1 void initializeInternalControlVars() { 3010f426935Ssstefan1 #define ICV_RT_SET(_Name, RTL) \ 3020f426935Ssstefan1 { \ 3030f426935Ssstefan1 auto &ICV = ICVs[_Name]; \ 3040f426935Ssstefan1 ICV.Setter = RTL; \ 3050f426935Ssstefan1 } 3060f426935Ssstefan1 #define ICV_RT_GET(Name, RTL) \ 3070f426935Ssstefan1 { \ 3080f426935Ssstefan1 auto &ICV = ICVs[Name]; \ 3090f426935Ssstefan1 ICV.Getter = RTL; \ 3100f426935Ssstefan1 } 3110f426935Ssstefan1 #define ICV_DATA_ENV(Enum, _Name, _EnvVarName, Init) \ 3120f426935Ssstefan1 { \ 3130f426935Ssstefan1 auto &ICV = ICVs[Enum]; \ 3140f426935Ssstefan1 ICV.Name = _Name; \ 3150f426935Ssstefan1 ICV.Kind = Enum; \ 3160f426935Ssstefan1 ICV.InitKind = Init; \ 3170f426935Ssstefan1 ICV.EnvVarName = _EnvVarName; \ 3180f426935Ssstefan1 switch (ICV.InitKind) { \ 319951e43f3Ssstefan1 case ICV_IMPLEMENTATION_DEFINED: \ 3200f426935Ssstefan1 ICV.InitValue = nullptr; \ 3210f426935Ssstefan1 break; \ 322951e43f3Ssstefan1 case ICV_ZERO: \ 3236aab27baSsstefan1 ICV.InitValue = ConstantInt::get( \ 3246aab27baSsstefan1 Type::getInt32Ty(OMPBuilder.Int32->getContext()), 0); \ 3250f426935Ssstefan1 break; \ 326951e43f3Ssstefan1 case ICV_FALSE: \ 3276aab27baSsstefan1 ICV.InitValue = ConstantInt::getFalse(OMPBuilder.Int1->getContext()); \ 3280f426935Ssstefan1 break; \ 329951e43f3Ssstefan1 case ICV_LAST: \ 3300f426935Ssstefan1 break; \ 3310f426935Ssstefan1 } \ 3320f426935Ssstefan1 } 3330f426935Ssstefan1 #include "llvm/Frontend/OpenMP/OMPKinds.def" 3340f426935Ssstefan1 } 3350f426935Ssstefan1 3367cfd267cSsstefan1 /// Returns true if the function declaration \p F matches the runtime 3377cfd267cSsstefan1 /// function types, that is, return type \p RTFRetType, and argument types 3387cfd267cSsstefan1 /// \p RTFArgTypes. 3397cfd267cSsstefan1 static bool declMatchesRTFTypes(Function *F, Type *RTFRetType, 3407cfd267cSsstefan1 SmallVector<Type *, 8> &RTFArgTypes) { 3417cfd267cSsstefan1 // TODO: We should output information to the user (under debug output 3427cfd267cSsstefan1 // and via remarks). 3437cfd267cSsstefan1 3447cfd267cSsstefan1 if (!F) 3457cfd267cSsstefan1 return false; 3467cfd267cSsstefan1 if (F->getReturnType() != RTFRetType) 3477cfd267cSsstefan1 return false; 3487cfd267cSsstefan1 if (F->arg_size() != RTFArgTypes.size()) 3497cfd267cSsstefan1 return false; 3507cfd267cSsstefan1 3517cfd267cSsstefan1 auto RTFTyIt = RTFArgTypes.begin(); 3527cfd267cSsstefan1 for (Argument &Arg : F->args()) { 3537cfd267cSsstefan1 if (Arg.getType() != *RTFTyIt) 3547cfd267cSsstefan1 return false; 3557cfd267cSsstefan1 3567cfd267cSsstefan1 ++RTFTyIt; 3577cfd267cSsstefan1 } 3587cfd267cSsstefan1 3597cfd267cSsstefan1 return true; 3607cfd267cSsstefan1 } 3617cfd267cSsstefan1 362b726c557SJohannes Doerfert // Helper to collect all uses of the declaration in the UsesMap. 363b8235d2bSsstefan1 unsigned collectUses(RuntimeFunctionInfo &RFI, bool CollectStats = true) { 3647cfd267cSsstefan1 unsigned NumUses = 0; 3657cfd267cSsstefan1 if (!RFI.Declaration) 3667cfd267cSsstefan1 return NumUses; 3677cfd267cSsstefan1 OMPBuilder.addAttributes(RFI.Kind, *RFI.Declaration); 3687cfd267cSsstefan1 369b8235d2bSsstefan1 if (CollectStats) { 3707cfd267cSsstefan1 NumOpenMPRuntimeFunctionsIdentified += 1; 3717cfd267cSsstefan1 NumOpenMPRuntimeFunctionUsesIdentified += RFI.Declaration->getNumUses(); 372b8235d2bSsstefan1 } 3737cfd267cSsstefan1 3747cfd267cSsstefan1 // TODO: We directly convert uses into proper calls and unknown uses. 3757cfd267cSsstefan1 for (Use &U : RFI.Declaration->uses()) { 3767cfd267cSsstefan1 if (Instruction *UserI = dyn_cast<Instruction>(U.getUser())) { 3777cfd267cSsstefan1 if (ModuleSlice.count(UserI->getFunction())) { 3787cfd267cSsstefan1 RFI.getOrCreateUseVector(UserI->getFunction()).push_back(&U); 3797cfd267cSsstefan1 ++NumUses; 3807cfd267cSsstefan1 } 3817cfd267cSsstefan1 } else { 3827cfd267cSsstefan1 RFI.getOrCreateUseVector(nullptr).push_back(&U); 3837cfd267cSsstefan1 ++NumUses; 3847cfd267cSsstefan1 } 3857cfd267cSsstefan1 } 3867cfd267cSsstefan1 return NumUses; 387b8235d2bSsstefan1 } 3887cfd267cSsstefan1 38997517055SGiorgis Georgakoudis // Helper function to recollect uses of a runtime function. 39097517055SGiorgis Georgakoudis void recollectUsesForFunction(RuntimeFunction RTF) { 39197517055SGiorgis Georgakoudis auto &RFI = RFIs[RTF]; 392b8235d2bSsstefan1 RFI.clearUsesMap(); 393b8235d2bSsstefan1 collectUses(RFI, /*CollectStats*/ false); 394b8235d2bSsstefan1 } 39597517055SGiorgis Georgakoudis 39697517055SGiorgis Georgakoudis // Helper function to recollect uses of all runtime functions. 39797517055SGiorgis Georgakoudis void recollectUses() { 39897517055SGiorgis Georgakoudis for (int Idx = 0; Idx < RFIs.size(); ++Idx) 39997517055SGiorgis Georgakoudis recollectUsesForFunction(static_cast<RuntimeFunction>(Idx)); 400b8235d2bSsstefan1 } 401b8235d2bSsstefan1 402b8235d2bSsstefan1 /// Helper to initialize all runtime function information for those defined 403b8235d2bSsstefan1 /// in OpenMPKinds.def. 404b8235d2bSsstefan1 void initializeRuntimeFunctions() { 4057cfd267cSsstefan1 Module &M = *((*ModuleSlice.begin())->getParent()); 4067cfd267cSsstefan1 4076aab27baSsstefan1 // Helper macros for handling __VA_ARGS__ in OMP_RTL 4086aab27baSsstefan1 #define OMP_TYPE(VarName, ...) \ 4096aab27baSsstefan1 Type *VarName = OMPBuilder.VarName; \ 4106aab27baSsstefan1 (void)VarName; 4116aab27baSsstefan1 4126aab27baSsstefan1 #define OMP_ARRAY_TYPE(VarName, ...) \ 4136aab27baSsstefan1 ArrayType *VarName##Ty = OMPBuilder.VarName##Ty; \ 4146aab27baSsstefan1 (void)VarName##Ty; \ 4156aab27baSsstefan1 PointerType *VarName##PtrTy = OMPBuilder.VarName##PtrTy; \ 4166aab27baSsstefan1 (void)VarName##PtrTy; 4176aab27baSsstefan1 4186aab27baSsstefan1 #define OMP_FUNCTION_TYPE(VarName, ...) \ 4196aab27baSsstefan1 FunctionType *VarName = OMPBuilder.VarName; \ 4206aab27baSsstefan1 (void)VarName; \ 4216aab27baSsstefan1 PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr; \ 4226aab27baSsstefan1 (void)VarName##Ptr; 4236aab27baSsstefan1 4246aab27baSsstefan1 #define OMP_STRUCT_TYPE(VarName, ...) \ 4256aab27baSsstefan1 StructType *VarName = OMPBuilder.VarName; \ 4266aab27baSsstefan1 (void)VarName; \ 4276aab27baSsstefan1 PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr; \ 4286aab27baSsstefan1 (void)VarName##Ptr; 4296aab27baSsstefan1 4307cfd267cSsstefan1 #define OMP_RTL(_Enum, _Name, _IsVarArg, _ReturnType, ...) \ 4317cfd267cSsstefan1 { \ 4327cfd267cSsstefan1 SmallVector<Type *, 8> ArgsTypes({__VA_ARGS__}); \ 4337cfd267cSsstefan1 Function *F = M.getFunction(_Name); \ 434eef6601bSJoseph Huber RTLFunctions.insert(F); \ 4356aab27baSsstefan1 if (declMatchesRTFTypes(F, OMPBuilder._ReturnType, ArgsTypes)) { \ 436d9659bf6SJohannes Doerfert RuntimeFunctionIDMap[F] = _Enum; \ 43716206d17SJoseph Huber F->removeFnAttr(Attribute::NoInline); \ 4387cfd267cSsstefan1 auto &RFI = RFIs[_Enum]; \ 4397cfd267cSsstefan1 RFI.Kind = _Enum; \ 4407cfd267cSsstefan1 RFI.Name = _Name; \ 4417cfd267cSsstefan1 RFI.IsVarArg = _IsVarArg; \ 4426aab27baSsstefan1 RFI.ReturnType = OMPBuilder._ReturnType; \ 4437cfd267cSsstefan1 RFI.ArgumentTypes = std::move(ArgsTypes); \ 4447cfd267cSsstefan1 RFI.Declaration = F; \ 445b8235d2bSsstefan1 unsigned NumUses = collectUses(RFI); \ 4467cfd267cSsstefan1 (void)NumUses; \ 4477cfd267cSsstefan1 LLVM_DEBUG({ \ 4487cfd267cSsstefan1 dbgs() << TAG << RFI.Name << (RFI.Declaration ? "" : " not") \ 4497cfd267cSsstefan1 << " found\n"; \ 4507cfd267cSsstefan1 if (RFI.Declaration) \ 4517cfd267cSsstefan1 dbgs() << TAG << "-> got " << NumUses << " uses in " \ 4527cfd267cSsstefan1 << RFI.getNumFunctionsWithUses() \ 4537cfd267cSsstefan1 << " different functions.\n"; \ 4547cfd267cSsstefan1 }); \ 4557cfd267cSsstefan1 } \ 4567cfd267cSsstefan1 } 4577cfd267cSsstefan1 #include "llvm/Frontend/OpenMP/OMPKinds.def" 4587cfd267cSsstefan1 4597cfd267cSsstefan1 // TODO: We should attach the attributes defined in OMPKinds.def. 4607cfd267cSsstefan1 } 461e8039ad4SJohannes Doerfert 462e8039ad4SJohannes Doerfert /// Collection of known kernels (\see Kernel) in the module. 463e8039ad4SJohannes Doerfert SmallPtrSetImpl<Kernel> &Kernels; 464eef6601bSJoseph Huber 465eef6601bSJoseph Huber /// Collection of known OpenMP runtime functions.. 466eef6601bSJoseph Huber DenseSet<const Function *> RTLFunctions; 4677cfd267cSsstefan1 }; 4687cfd267cSsstefan1 469d9659bf6SJohannes Doerfert template <typename Ty, bool InsertInvalidates = true> 4701a7f7790SShilei Tian struct BooleanStateWithSetVector : public BooleanState { 4711a7f7790SShilei Tian bool contains(const Ty &Elem) const { return Set.contains(Elem); } 4721a7f7790SShilei Tian bool insert(const Ty &Elem) { 473d9659bf6SJohannes Doerfert if (InsertInvalidates) 474d9659bf6SJohannes Doerfert BooleanState::indicatePessimisticFixpoint(); 475d9659bf6SJohannes Doerfert return Set.insert(Elem); 476d9659bf6SJohannes Doerfert } 477d9659bf6SJohannes Doerfert 4781a7f7790SShilei Tian const Ty &operator[](int Idx) const { return Set[Idx]; } 4791a7f7790SShilei Tian bool operator==(const BooleanStateWithSetVector &RHS) const { 480d9659bf6SJohannes Doerfert return BooleanState::operator==(RHS) && Set == RHS.Set; 481d9659bf6SJohannes Doerfert } 4821a7f7790SShilei Tian bool operator!=(const BooleanStateWithSetVector &RHS) const { 483d9659bf6SJohannes Doerfert return !(*this == RHS); 484d9659bf6SJohannes Doerfert } 485d9659bf6SJohannes Doerfert 486d9659bf6SJohannes Doerfert bool empty() const { return Set.empty(); } 487d9659bf6SJohannes Doerfert size_t size() const { return Set.size(); } 488d9659bf6SJohannes Doerfert 489d9659bf6SJohannes Doerfert /// "Clamp" this state with \p RHS. 4901a7f7790SShilei Tian BooleanStateWithSetVector &operator^=(const BooleanStateWithSetVector &RHS) { 491d9659bf6SJohannes Doerfert BooleanState::operator^=(RHS); 492d9659bf6SJohannes Doerfert Set.insert(RHS.Set.begin(), RHS.Set.end()); 493d9659bf6SJohannes Doerfert return *this; 494d9659bf6SJohannes Doerfert } 495d9659bf6SJohannes Doerfert 496d9659bf6SJohannes Doerfert private: 497d9659bf6SJohannes Doerfert /// A set to keep track of elements. 4981a7f7790SShilei Tian SetVector<Ty> Set; 499d9659bf6SJohannes Doerfert 500d9659bf6SJohannes Doerfert public: 501d9659bf6SJohannes Doerfert typename decltype(Set)::iterator begin() { return Set.begin(); } 502d9659bf6SJohannes Doerfert typename decltype(Set)::iterator end() { return Set.end(); } 503d9659bf6SJohannes Doerfert typename decltype(Set)::const_iterator begin() const { return Set.begin(); } 504d9659bf6SJohannes Doerfert typename decltype(Set)::const_iterator end() const { return Set.end(); } 505d9659bf6SJohannes Doerfert }; 506d9659bf6SJohannes Doerfert 5071a7f7790SShilei Tian template <typename Ty, bool InsertInvalidates = true> 5081a7f7790SShilei Tian using BooleanStateWithPtrSetVector = 5091a7f7790SShilei Tian BooleanStateWithSetVector<Ty *, InsertInvalidates>; 5101a7f7790SShilei Tian 511d9659bf6SJohannes Doerfert struct KernelInfoState : AbstractState { 512d9659bf6SJohannes Doerfert /// Flag to track if we reached a fixpoint. 513d9659bf6SJohannes Doerfert bool IsAtFixpoint = false; 514d9659bf6SJohannes Doerfert 515d9659bf6SJohannes Doerfert /// The parallel regions (identified by the outlined parallel functions) that 516d9659bf6SJohannes Doerfert /// can be reached from the associated function. 517d9659bf6SJohannes Doerfert BooleanStateWithPtrSetVector<Function, /* InsertInvalidates */ false> 518d9659bf6SJohannes Doerfert ReachedKnownParallelRegions; 519d9659bf6SJohannes Doerfert 520d9659bf6SJohannes Doerfert /// State to track what parallel region we might reach. 521d9659bf6SJohannes Doerfert BooleanStateWithPtrSetVector<CallBase> ReachedUnknownParallelRegions; 522d9659bf6SJohannes Doerfert 523514c033dSJohannes Doerfert /// State to track if we are in SPMD-mode, assumed or know, and why we decided 524e8439ec8SGiorgis Georgakoudis /// we cannot be. If it is assumed, then RequiresFullRuntime should also be 525e8439ec8SGiorgis Georgakoudis /// false. 52629a3e3ddSGiorgis Georgakoudis BooleanStateWithPtrSetVector<Instruction, false> SPMDCompatibilityTracker; 527514c033dSJohannes Doerfert 528d9659bf6SJohannes Doerfert /// The __kmpc_target_init call in this kernel, if any. If we find more than 529d9659bf6SJohannes Doerfert /// one we abort as the kernel is malformed. 530d9659bf6SJohannes Doerfert CallBase *KernelInitCB = nullptr; 531d9659bf6SJohannes Doerfert 532d9659bf6SJohannes Doerfert /// The __kmpc_target_deinit call in this kernel, if any. If we find more than 533d9659bf6SJohannes Doerfert /// one we abort as the kernel is malformed. 534d9659bf6SJohannes Doerfert CallBase *KernelDeinitCB = nullptr; 535d9659bf6SJohannes Doerfert 536ca662297SShilei Tian /// Flag to indicate if the associated function is a kernel entry. 537ca662297SShilei Tian bool IsKernelEntry = false; 538ca662297SShilei Tian 539ca662297SShilei Tian /// State to track what kernel entries can reach the associated function. 540ca662297SShilei Tian BooleanStateWithPtrSetVector<Function, false> ReachingKernelEntries; 541ca662297SShilei Tian 542e97e0a4fSShilei Tian /// State to indicate if we can track parallel level of the associated 543e97e0a4fSShilei Tian /// function. We will give up tracking if we encounter unknown caller or the 544e97e0a4fSShilei Tian /// caller is __kmpc_parallel_51. 545e97e0a4fSShilei Tian BooleanStateWithSetVector<uint8_t> ParallelLevels; 546e97e0a4fSShilei Tian 547d9659bf6SJohannes Doerfert /// Abstract State interface 548d9659bf6SJohannes Doerfert ///{ 549d9659bf6SJohannes Doerfert 550d9659bf6SJohannes Doerfert KernelInfoState() {} 551d9659bf6SJohannes Doerfert KernelInfoState(bool BestState) { 552d9659bf6SJohannes Doerfert if (!BestState) 553d9659bf6SJohannes Doerfert indicatePessimisticFixpoint(); 554d9659bf6SJohannes Doerfert } 555d9659bf6SJohannes Doerfert 556d9659bf6SJohannes Doerfert /// See AbstractState::isValidState(...) 557d9659bf6SJohannes Doerfert bool isValidState() const override { return true; } 558d9659bf6SJohannes Doerfert 559d9659bf6SJohannes Doerfert /// See AbstractState::isAtFixpoint(...) 560d9659bf6SJohannes Doerfert bool isAtFixpoint() const override { return IsAtFixpoint; } 561d9659bf6SJohannes Doerfert 562d9659bf6SJohannes Doerfert /// See AbstractState::indicatePessimisticFixpoint(...) 563d9659bf6SJohannes Doerfert ChangeStatus indicatePessimisticFixpoint() override { 564d9659bf6SJohannes Doerfert IsAtFixpoint = true; 565514c033dSJohannes Doerfert SPMDCompatibilityTracker.indicatePessimisticFixpoint(); 566d9659bf6SJohannes Doerfert ReachedUnknownParallelRegions.indicatePessimisticFixpoint(); 567d9659bf6SJohannes Doerfert return ChangeStatus::CHANGED; 568d9659bf6SJohannes Doerfert } 569d9659bf6SJohannes Doerfert 570d9659bf6SJohannes Doerfert /// See AbstractState::indicateOptimisticFixpoint(...) 571d9659bf6SJohannes Doerfert ChangeStatus indicateOptimisticFixpoint() override { 572d9659bf6SJohannes Doerfert IsAtFixpoint = true; 573d9659bf6SJohannes Doerfert return ChangeStatus::UNCHANGED; 574d9659bf6SJohannes Doerfert } 575d9659bf6SJohannes Doerfert 576d9659bf6SJohannes Doerfert /// Return the assumed state 577d9659bf6SJohannes Doerfert KernelInfoState &getAssumed() { return *this; } 578d9659bf6SJohannes Doerfert const KernelInfoState &getAssumed() const { return *this; } 579d9659bf6SJohannes Doerfert 580d9659bf6SJohannes Doerfert bool operator==(const KernelInfoState &RHS) const { 581514c033dSJohannes Doerfert if (SPMDCompatibilityTracker != RHS.SPMDCompatibilityTracker) 582514c033dSJohannes Doerfert return false; 583d9659bf6SJohannes Doerfert if (ReachedKnownParallelRegions != RHS.ReachedKnownParallelRegions) 584d9659bf6SJohannes Doerfert return false; 585d9659bf6SJohannes Doerfert if (ReachedUnknownParallelRegions != RHS.ReachedUnknownParallelRegions) 586d9659bf6SJohannes Doerfert return false; 587ca662297SShilei Tian if (ReachingKernelEntries != RHS.ReachingKernelEntries) 588ca662297SShilei Tian return false; 589d9659bf6SJohannes Doerfert return true; 590d9659bf6SJohannes Doerfert } 591d9659bf6SJohannes Doerfert 592d9659bf6SJohannes Doerfert /// Return empty set as the best state of potential values. 593d9659bf6SJohannes Doerfert static KernelInfoState getBestState() { return KernelInfoState(true); } 594d9659bf6SJohannes Doerfert 595d9659bf6SJohannes Doerfert static KernelInfoState getBestState(KernelInfoState &KIS) { 596d9659bf6SJohannes Doerfert return getBestState(); 597d9659bf6SJohannes Doerfert } 598d9659bf6SJohannes Doerfert 599d9659bf6SJohannes Doerfert /// Return full set as the worst state of potential values. 600d9659bf6SJohannes Doerfert static KernelInfoState getWorstState() { return KernelInfoState(false); } 601d9659bf6SJohannes Doerfert 602d9659bf6SJohannes Doerfert /// "Clamp" this state with \p KIS. 603d9659bf6SJohannes Doerfert KernelInfoState operator^=(const KernelInfoState &KIS) { 604d9659bf6SJohannes Doerfert // Do not merge two different _init and _deinit call sites. 605d9659bf6SJohannes Doerfert if (KIS.KernelInitCB) { 606d9659bf6SJohannes Doerfert if (KernelInitCB && KernelInitCB != KIS.KernelInitCB) 607d9659bf6SJohannes Doerfert indicatePessimisticFixpoint(); 608d9659bf6SJohannes Doerfert KernelInitCB = KIS.KernelInitCB; 609d9659bf6SJohannes Doerfert } 610d9659bf6SJohannes Doerfert if (KIS.KernelDeinitCB) { 611d9659bf6SJohannes Doerfert if (KernelDeinitCB && KernelDeinitCB != KIS.KernelDeinitCB) 612d9659bf6SJohannes Doerfert indicatePessimisticFixpoint(); 613d9659bf6SJohannes Doerfert KernelDeinitCB = KIS.KernelDeinitCB; 614d9659bf6SJohannes Doerfert } 615514c033dSJohannes Doerfert SPMDCompatibilityTracker ^= KIS.SPMDCompatibilityTracker; 616d9659bf6SJohannes Doerfert ReachedKnownParallelRegions ^= KIS.ReachedKnownParallelRegions; 617d9659bf6SJohannes Doerfert ReachedUnknownParallelRegions ^= KIS.ReachedUnknownParallelRegions; 618d9659bf6SJohannes Doerfert return *this; 619d9659bf6SJohannes Doerfert } 620d9659bf6SJohannes Doerfert 621d9659bf6SJohannes Doerfert KernelInfoState operator&=(const KernelInfoState &KIS) { 622d9659bf6SJohannes Doerfert return (*this ^= KIS); 623d9659bf6SJohannes Doerfert } 624d9659bf6SJohannes Doerfert 625d9659bf6SJohannes Doerfert ///} 626d9659bf6SJohannes Doerfert }; 627d9659bf6SJohannes Doerfert 6288931add6SHamilton Tobon Mosquera /// Used to map the values physically (in the IR) stored in an offload 6298931add6SHamilton Tobon Mosquera /// array, to a vector in memory. 6308931add6SHamilton Tobon Mosquera struct OffloadArray { 6318931add6SHamilton Tobon Mosquera /// Physical array (in the IR). 6328931add6SHamilton Tobon Mosquera AllocaInst *Array = nullptr; 6338931add6SHamilton Tobon Mosquera /// Mapped values. 6348931add6SHamilton Tobon Mosquera SmallVector<Value *, 8> StoredValues; 6358931add6SHamilton Tobon Mosquera /// Last stores made in the offload array. 6368931add6SHamilton Tobon Mosquera SmallVector<StoreInst *, 8> LastAccesses; 6378931add6SHamilton Tobon Mosquera 6388931add6SHamilton Tobon Mosquera OffloadArray() = default; 6398931add6SHamilton Tobon Mosquera 6408931add6SHamilton Tobon Mosquera /// Initializes the OffloadArray with the values stored in \p Array before 6418931add6SHamilton Tobon Mosquera /// instruction \p Before is reached. Returns false if the initialization 6428931add6SHamilton Tobon Mosquera /// fails. 6438931add6SHamilton Tobon Mosquera /// This MUST be used immediately after the construction of the object. 6448931add6SHamilton Tobon Mosquera bool initialize(AllocaInst &Array, Instruction &Before) { 6458931add6SHamilton Tobon Mosquera if (!Array.getAllocatedType()->isArrayTy()) 6468931add6SHamilton Tobon Mosquera return false; 6478931add6SHamilton Tobon Mosquera 6488931add6SHamilton Tobon Mosquera if (!getValues(Array, Before)) 6498931add6SHamilton Tobon Mosquera return false; 6508931add6SHamilton Tobon Mosquera 6518931add6SHamilton Tobon Mosquera this->Array = &Array; 6528931add6SHamilton Tobon Mosquera return true; 6538931add6SHamilton Tobon Mosquera } 6548931add6SHamilton Tobon Mosquera 655da8bec47SJoseph Huber static const unsigned DeviceIDArgNum = 1; 656da8bec47SJoseph Huber static const unsigned BasePtrsArgNum = 3; 657da8bec47SJoseph Huber static const unsigned PtrsArgNum = 4; 658da8bec47SJoseph Huber static const unsigned SizesArgNum = 5; 6591d3d9b9cSHamilton Tobon Mosquera 6608931add6SHamilton Tobon Mosquera private: 6618931add6SHamilton Tobon Mosquera /// Traverses the BasicBlock where \p Array is, collecting the stores made to 6628931add6SHamilton Tobon Mosquera /// \p Array, leaving StoredValues with the values stored before the 6638931add6SHamilton Tobon Mosquera /// instruction \p Before is reached. 6648931add6SHamilton Tobon Mosquera bool getValues(AllocaInst &Array, Instruction &Before) { 6658931add6SHamilton Tobon Mosquera // Initialize container. 666d08d490aSJohannes Doerfert const uint64_t NumValues = Array.getAllocatedType()->getArrayNumElements(); 6678931add6SHamilton Tobon Mosquera StoredValues.assign(NumValues, nullptr); 6688931add6SHamilton Tobon Mosquera LastAccesses.assign(NumValues, nullptr); 6698931add6SHamilton Tobon Mosquera 6708931add6SHamilton Tobon Mosquera // TODO: This assumes the instruction \p Before is in the same 6718931add6SHamilton Tobon Mosquera // BasicBlock as Array. Make it general, for any control flow graph. 6728931add6SHamilton Tobon Mosquera BasicBlock *BB = Array.getParent(); 6738931add6SHamilton Tobon Mosquera if (BB != Before.getParent()) 6748931add6SHamilton Tobon Mosquera return false; 6758931add6SHamilton Tobon Mosquera 6768931add6SHamilton Tobon Mosquera const DataLayout &DL = Array.getModule()->getDataLayout(); 6778931add6SHamilton Tobon Mosquera const unsigned int PointerSize = DL.getPointerSize(); 6788931add6SHamilton Tobon Mosquera 6798931add6SHamilton Tobon Mosquera for (Instruction &I : *BB) { 6808931add6SHamilton Tobon Mosquera if (&I == &Before) 6818931add6SHamilton Tobon Mosquera break; 6828931add6SHamilton Tobon Mosquera 6838931add6SHamilton Tobon Mosquera if (!isa<StoreInst>(&I)) 6848931add6SHamilton Tobon Mosquera continue; 6858931add6SHamilton Tobon Mosquera 6868931add6SHamilton Tobon Mosquera auto *S = cast<StoreInst>(&I); 6878931add6SHamilton Tobon Mosquera int64_t Offset = -1; 688d08d490aSJohannes Doerfert auto *Dst = 689d08d490aSJohannes Doerfert GetPointerBaseWithConstantOffset(S->getPointerOperand(), Offset, DL); 6908931add6SHamilton Tobon Mosquera if (Dst == &Array) { 6918931add6SHamilton Tobon Mosquera int64_t Idx = Offset / PointerSize; 6928931add6SHamilton Tobon Mosquera StoredValues[Idx] = getUnderlyingObject(S->getValueOperand()); 6938931add6SHamilton Tobon Mosquera LastAccesses[Idx] = S; 6948931add6SHamilton Tobon Mosquera } 6958931add6SHamilton Tobon Mosquera } 6968931add6SHamilton Tobon Mosquera 6978931add6SHamilton Tobon Mosquera return isFilled(); 6988931add6SHamilton Tobon Mosquera } 6998931add6SHamilton Tobon Mosquera 7008931add6SHamilton Tobon Mosquera /// Returns true if all values in StoredValues and 7018931add6SHamilton Tobon Mosquera /// LastAccesses are not nullptrs. 7028931add6SHamilton Tobon Mosquera bool isFilled() { 7038931add6SHamilton Tobon Mosquera const unsigned NumValues = StoredValues.size(); 7048931add6SHamilton Tobon Mosquera for (unsigned I = 0; I < NumValues; ++I) { 7058931add6SHamilton Tobon Mosquera if (!StoredValues[I] || !LastAccesses[I]) 7068931add6SHamilton Tobon Mosquera return false; 7078931add6SHamilton Tobon Mosquera } 7088931add6SHamilton Tobon Mosquera 7098931add6SHamilton Tobon Mosquera return true; 7108931add6SHamilton Tobon Mosquera } 7118931add6SHamilton Tobon Mosquera }; 7128931add6SHamilton Tobon Mosquera 7137cfd267cSsstefan1 struct OpenMPOpt { 7147cfd267cSsstefan1 7157cfd267cSsstefan1 using OptimizationRemarkGetter = 7167cfd267cSsstefan1 function_ref<OptimizationRemarkEmitter &(Function *)>; 7177cfd267cSsstefan1 7187cfd267cSsstefan1 OpenMPOpt(SmallVectorImpl<Function *> &SCC, CallGraphUpdater &CGUpdater, 7197cfd267cSsstefan1 OptimizationRemarkGetter OREGetter, 720b8235d2bSsstefan1 OMPInformationCache &OMPInfoCache, Attributor &A) 72177b79d79SMehdi Amini : M(*(*SCC.begin())->getParent()), SCC(SCC), CGUpdater(CGUpdater), 722b8235d2bSsstefan1 OREGetter(OREGetter), OMPInfoCache(OMPInfoCache), A(A) {} 7237cfd267cSsstefan1 724a2281419SJoseph Huber /// Check if any remarks are enabled for openmp-opt 725a2281419SJoseph Huber bool remarksEnabled() { 726a2281419SJoseph Huber auto &Ctx = M.getContext(); 727a2281419SJoseph Huber return Ctx.getDiagHandlerPtr()->isAnyRemarkEnabled(DEBUG_TYPE); 728a2281419SJoseph Huber } 729a2281419SJoseph Huber 7309548b74aSJohannes Doerfert /// Run all OpenMP optimizations on the underlying SCC/ModuleSlice. 731b2ad63d3SJoseph Huber bool run(bool IsModulePass) { 73254bd3751SJohannes Doerfert if (SCC.empty()) 73354bd3751SJohannes Doerfert return false; 73454bd3751SJohannes Doerfert 7359548b74aSJohannes Doerfert bool Changed = false; 7369548b74aSJohannes Doerfert 7379548b74aSJohannes Doerfert LLVM_DEBUG(dbgs() << TAG << "Run on SCC with " << SCC.size() 73877b79d79SMehdi Amini << " functions in a slice with " 73977b79d79SMehdi Amini << OMPInfoCache.ModuleSlice.size() << " functions\n"); 7409548b74aSJohannes Doerfert 741b2ad63d3SJoseph Huber if (IsModulePass) { 742d9659bf6SJohannes Doerfert Changed |= runAttributor(IsModulePass); 74318283125SJoseph Huber 7446fc51c9fSJoseph Huber // Recollect uses, in case Attributor deleted any. 7456fc51c9fSJoseph Huber OMPInfoCache.recollectUses(); 7466fc51c9fSJoseph Huber 747be2b5696SJohannes Doerfert // TODO: This should be folded into buildCustomStateMachine. 748be2b5696SJohannes Doerfert Changed |= rewriteDeviceCodeStateMachine(); 749be2b5696SJohannes Doerfert 750b2ad63d3SJoseph Huber if (remarksEnabled()) 751b2ad63d3SJoseph Huber analysisGlobalization(); 752b2ad63d3SJoseph Huber } else { 753e8039ad4SJohannes Doerfert if (PrintICVValues) 754e8039ad4SJohannes Doerfert printICVs(); 755e8039ad4SJohannes Doerfert if (PrintOpenMPKernels) 756e8039ad4SJohannes Doerfert printKernels(); 757e8039ad4SJohannes Doerfert 758d9659bf6SJohannes Doerfert Changed |= runAttributor(IsModulePass); 759e8039ad4SJohannes Doerfert 760e8039ad4SJohannes Doerfert // Recollect uses, in case Attributor deleted any. 761e8039ad4SJohannes Doerfert OMPInfoCache.recollectUses(); 762e8039ad4SJohannes Doerfert 763e8039ad4SJohannes Doerfert Changed |= deleteParallelRegions(); 764d9659bf6SJohannes Doerfert 765496f8e5bSHamilton Tobon Mosquera if (HideMemoryTransferLatency) 766496f8e5bSHamilton Tobon Mosquera Changed |= hideMemTransfersLatency(); 7673a6bfcf2SGiorgis Georgakoudis Changed |= deduplicateRuntimeCalls(); 7683a6bfcf2SGiorgis Georgakoudis if (EnableParallelRegionMerging) { 7693a6bfcf2SGiorgis Georgakoudis if (mergeParallelRegions()) { 7703a6bfcf2SGiorgis Georgakoudis deduplicateRuntimeCalls(); 7713a6bfcf2SGiorgis Georgakoudis Changed = true; 7723a6bfcf2SGiorgis Georgakoudis } 7733a6bfcf2SGiorgis Georgakoudis } 774b2ad63d3SJoseph Huber } 775e8039ad4SJohannes Doerfert 776e8039ad4SJohannes Doerfert return Changed; 777e8039ad4SJohannes Doerfert } 778e8039ad4SJohannes Doerfert 7790f426935Ssstefan1 /// Print initial ICV values for testing. 7800f426935Ssstefan1 /// FIXME: This should be done from the Attributor once it is added. 781e8039ad4SJohannes Doerfert void printICVs() const { 782cb9cfa0dSsstefan1 InternalControlVar ICVs[] = {ICV_nthreads, ICV_active_levels, ICV_cancel, 783cb9cfa0dSsstefan1 ICV_proc_bind}; 7840f426935Ssstefan1 7850f426935Ssstefan1 for (Function *F : OMPInfoCache.ModuleSlice) { 7860f426935Ssstefan1 for (auto ICV : ICVs) { 7870f426935Ssstefan1 auto ICVInfo = OMPInfoCache.ICVs[ICV]; 7882db182ffSJoseph Huber auto Remark = [&](OptimizationRemarkAnalysis ORA) { 7892db182ffSJoseph Huber return ORA << "OpenMP ICV " << ore::NV("OpenMPICV", ICVInfo.Name) 7900f426935Ssstefan1 << " Value: " 7910f426935Ssstefan1 << (ICVInfo.InitValue 79261cdaf66SSimon Pilgrim ? toString(ICVInfo.InitValue->getValue(), 10, true) 7930f426935Ssstefan1 : "IMPLEMENTATION_DEFINED"); 7940f426935Ssstefan1 }; 7950f426935Ssstefan1 7962db182ffSJoseph Huber emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPICVTracker", Remark); 7970f426935Ssstefan1 } 7980f426935Ssstefan1 } 7990f426935Ssstefan1 } 8000f426935Ssstefan1 801e8039ad4SJohannes Doerfert /// Print OpenMP GPU kernels for testing. 802e8039ad4SJohannes Doerfert void printKernels() const { 803e8039ad4SJohannes Doerfert for (Function *F : SCC) { 804e8039ad4SJohannes Doerfert if (!OMPInfoCache.Kernels.count(F)) 805e8039ad4SJohannes Doerfert continue; 806b8235d2bSsstefan1 8072db182ffSJoseph Huber auto Remark = [&](OptimizationRemarkAnalysis ORA) { 8082db182ffSJoseph Huber return ORA << "OpenMP GPU kernel " 809e8039ad4SJohannes Doerfert << ore::NV("OpenMPGPUKernel", F->getName()) << "\n"; 810e8039ad4SJohannes Doerfert }; 811b8235d2bSsstefan1 8122db182ffSJoseph Huber emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPGPU", Remark); 813e8039ad4SJohannes Doerfert } 8149548b74aSJohannes Doerfert } 8159548b74aSJohannes Doerfert 8167cfd267cSsstefan1 /// Return the call if \p U is a callee use in a regular call. If \p RFI is 8177cfd267cSsstefan1 /// given it has to be the callee or a nullptr is returned. 8187cfd267cSsstefan1 static CallInst *getCallIfRegularCall( 8197cfd267cSsstefan1 Use &U, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) { 8207cfd267cSsstefan1 CallInst *CI = dyn_cast<CallInst>(U.getUser()); 8217cfd267cSsstefan1 if (CI && CI->isCallee(&U) && !CI->hasOperandBundles() && 822c4b1fe05SJohannes Doerfert (!RFI || 823c4b1fe05SJohannes Doerfert (RFI->Declaration && CI->getCalledFunction() == RFI->Declaration))) 8247cfd267cSsstefan1 return CI; 8257cfd267cSsstefan1 return nullptr; 8267cfd267cSsstefan1 } 8277cfd267cSsstefan1 8287cfd267cSsstefan1 /// Return the call if \p V is a regular call. If \p RFI is given it has to be 8297cfd267cSsstefan1 /// the callee or a nullptr is returned. 8307cfd267cSsstefan1 static CallInst *getCallIfRegularCall( 8317cfd267cSsstefan1 Value &V, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) { 8327cfd267cSsstefan1 CallInst *CI = dyn_cast<CallInst>(&V); 8337cfd267cSsstefan1 if (CI && !CI->hasOperandBundles() && 834c4b1fe05SJohannes Doerfert (!RFI || 835c4b1fe05SJohannes Doerfert (RFI->Declaration && CI->getCalledFunction() == RFI->Declaration))) 8367cfd267cSsstefan1 return CI; 8377cfd267cSsstefan1 return nullptr; 8387cfd267cSsstefan1 } 8397cfd267cSsstefan1 8409548b74aSJohannes Doerfert private: 8413a6bfcf2SGiorgis Georgakoudis /// Merge parallel regions when it is safe. 8423a6bfcf2SGiorgis Georgakoudis bool mergeParallelRegions() { 8433a6bfcf2SGiorgis Georgakoudis const unsigned CallbackCalleeOperand = 2; 8443a6bfcf2SGiorgis Georgakoudis const unsigned CallbackFirstArgOperand = 3; 8453a6bfcf2SGiorgis Georgakoudis using InsertPointTy = OpenMPIRBuilder::InsertPointTy; 8463a6bfcf2SGiorgis Georgakoudis 8473a6bfcf2SGiorgis Georgakoudis // Check if there are any __kmpc_fork_call calls to merge. 8483a6bfcf2SGiorgis Georgakoudis OMPInformationCache::RuntimeFunctionInfo &RFI = 8493a6bfcf2SGiorgis Georgakoudis OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call]; 8503a6bfcf2SGiorgis Georgakoudis 8513a6bfcf2SGiorgis Georgakoudis if (!RFI.Declaration) 8523a6bfcf2SGiorgis Georgakoudis return false; 8533a6bfcf2SGiorgis Georgakoudis 85497517055SGiorgis Georgakoudis // Unmergable calls that prevent merging a parallel region. 85597517055SGiorgis Georgakoudis OMPInformationCache::RuntimeFunctionInfo UnmergableCallsInfo[] = { 85697517055SGiorgis Georgakoudis OMPInfoCache.RFIs[OMPRTL___kmpc_push_proc_bind], 85797517055SGiorgis Georgakoudis OMPInfoCache.RFIs[OMPRTL___kmpc_push_num_threads], 85897517055SGiorgis Georgakoudis }; 8593a6bfcf2SGiorgis Georgakoudis 8603a6bfcf2SGiorgis Georgakoudis bool Changed = false; 8613a6bfcf2SGiorgis Georgakoudis LoopInfo *LI = nullptr; 8623a6bfcf2SGiorgis Georgakoudis DominatorTree *DT = nullptr; 8633a6bfcf2SGiorgis Georgakoudis 8643a6bfcf2SGiorgis Georgakoudis SmallDenseMap<BasicBlock *, SmallPtrSet<Instruction *, 4>> BB2PRMap; 8653a6bfcf2SGiorgis Georgakoudis 8663a6bfcf2SGiorgis Georgakoudis BasicBlock *StartBB = nullptr, *EndBB = nullptr; 8673a6bfcf2SGiorgis Georgakoudis auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, 8683a6bfcf2SGiorgis Georgakoudis BasicBlock &ContinuationIP) { 8693a6bfcf2SGiorgis Georgakoudis BasicBlock *CGStartBB = CodeGenIP.getBlock(); 8703a6bfcf2SGiorgis Georgakoudis BasicBlock *CGEndBB = 8713a6bfcf2SGiorgis Georgakoudis SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI); 8723a6bfcf2SGiorgis Georgakoudis assert(StartBB != nullptr && "StartBB should not be null"); 8733a6bfcf2SGiorgis Georgakoudis CGStartBB->getTerminator()->setSuccessor(0, StartBB); 8743a6bfcf2SGiorgis Georgakoudis assert(EndBB != nullptr && "EndBB should not be null"); 8753a6bfcf2SGiorgis Georgakoudis EndBB->getTerminator()->setSuccessor(0, CGEndBB); 8763a6bfcf2SGiorgis Georgakoudis }; 8773a6bfcf2SGiorgis Georgakoudis 878240dd924SAlex Zinenko auto PrivCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, Value &, 879240dd924SAlex Zinenko Value &Inner, Value *&ReplacementValue) -> InsertPointTy { 880240dd924SAlex Zinenko ReplacementValue = &Inner; 8813a6bfcf2SGiorgis Georgakoudis return CodeGenIP; 8823a6bfcf2SGiorgis Georgakoudis }; 8833a6bfcf2SGiorgis Georgakoudis 8843a6bfcf2SGiorgis Georgakoudis auto FiniCB = [&](InsertPointTy CodeGenIP) {}; 8853a6bfcf2SGiorgis Georgakoudis 88697517055SGiorgis Georgakoudis /// Create a sequential execution region within a merged parallel region, 88797517055SGiorgis Georgakoudis /// encapsulated in a master construct with a barrier for synchronization. 88897517055SGiorgis Georgakoudis auto CreateSequentialRegion = [&](Function *OuterFn, 88997517055SGiorgis Georgakoudis BasicBlock *OuterPredBB, 89097517055SGiorgis Georgakoudis Instruction *SeqStartI, 89197517055SGiorgis Georgakoudis Instruction *SeqEndI) { 89297517055SGiorgis Georgakoudis // Isolate the instructions of the sequential region to a separate 89397517055SGiorgis Georgakoudis // block. 89497517055SGiorgis Georgakoudis BasicBlock *ParentBB = SeqStartI->getParent(); 89597517055SGiorgis Georgakoudis BasicBlock *SeqEndBB = 89697517055SGiorgis Georgakoudis SplitBlock(ParentBB, SeqEndI->getNextNode(), DT, LI); 89797517055SGiorgis Georgakoudis BasicBlock *SeqAfterBB = 89897517055SGiorgis Georgakoudis SplitBlock(SeqEndBB, &*SeqEndBB->getFirstInsertionPt(), DT, LI); 89997517055SGiorgis Georgakoudis BasicBlock *SeqStartBB = 90097517055SGiorgis Georgakoudis SplitBlock(ParentBB, SeqStartI, DT, LI, nullptr, "seq.par.merged"); 90197517055SGiorgis Georgakoudis 90297517055SGiorgis Georgakoudis assert(ParentBB->getUniqueSuccessor() == SeqStartBB && 90397517055SGiorgis Georgakoudis "Expected a different CFG"); 90497517055SGiorgis Georgakoudis const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc(); 90597517055SGiorgis Georgakoudis ParentBB->getTerminator()->eraseFromParent(); 90697517055SGiorgis Georgakoudis 90797517055SGiorgis Georgakoudis auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, 90897517055SGiorgis Georgakoudis BasicBlock &ContinuationIP) { 90997517055SGiorgis Georgakoudis BasicBlock *CGStartBB = CodeGenIP.getBlock(); 91097517055SGiorgis Georgakoudis BasicBlock *CGEndBB = 91197517055SGiorgis Georgakoudis SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI); 91297517055SGiorgis Georgakoudis assert(SeqStartBB != nullptr && "SeqStartBB should not be null"); 91397517055SGiorgis Georgakoudis CGStartBB->getTerminator()->setSuccessor(0, SeqStartBB); 91497517055SGiorgis Georgakoudis assert(SeqEndBB != nullptr && "SeqEndBB should not be null"); 91597517055SGiorgis Georgakoudis SeqEndBB->getTerminator()->setSuccessor(0, CGEndBB); 91697517055SGiorgis Georgakoudis }; 91797517055SGiorgis Georgakoudis auto FiniCB = [&](InsertPointTy CodeGenIP) {}; 91897517055SGiorgis Georgakoudis 91997517055SGiorgis Georgakoudis // Find outputs from the sequential region to outside users and 92097517055SGiorgis Georgakoudis // broadcast their values to them. 92197517055SGiorgis Georgakoudis for (Instruction &I : *SeqStartBB) { 92297517055SGiorgis Georgakoudis SmallPtrSet<Instruction *, 4> OutsideUsers; 92397517055SGiorgis Georgakoudis for (User *Usr : I.users()) { 92497517055SGiorgis Georgakoudis Instruction &UsrI = *cast<Instruction>(Usr); 92597517055SGiorgis Georgakoudis // Ignore outputs to LT intrinsics, code extraction for the merged 92697517055SGiorgis Georgakoudis // parallel region will fix them. 92797517055SGiorgis Georgakoudis if (UsrI.isLifetimeStartOrEnd()) 92897517055SGiorgis Georgakoudis continue; 92997517055SGiorgis Georgakoudis 93097517055SGiorgis Georgakoudis if (UsrI.getParent() != SeqStartBB) 93197517055SGiorgis Georgakoudis OutsideUsers.insert(&UsrI); 93297517055SGiorgis Georgakoudis } 93397517055SGiorgis Georgakoudis 93497517055SGiorgis Georgakoudis if (OutsideUsers.empty()) 93597517055SGiorgis Georgakoudis continue; 93697517055SGiorgis Georgakoudis 93797517055SGiorgis Georgakoudis // Emit an alloca in the outer region to store the broadcasted 93897517055SGiorgis Georgakoudis // value. 93997517055SGiorgis Georgakoudis const DataLayout &DL = M.getDataLayout(); 94097517055SGiorgis Georgakoudis AllocaInst *AllocaI = new AllocaInst( 94197517055SGiorgis Georgakoudis I.getType(), DL.getAllocaAddrSpace(), nullptr, 94297517055SGiorgis Georgakoudis I.getName() + ".seq.output.alloc", &OuterFn->front().front()); 94397517055SGiorgis Georgakoudis 94497517055SGiorgis Georgakoudis // Emit a store instruction in the sequential BB to update the 94597517055SGiorgis Georgakoudis // value. 94697517055SGiorgis Georgakoudis new StoreInst(&I, AllocaI, SeqStartBB->getTerminator()); 94797517055SGiorgis Georgakoudis 94897517055SGiorgis Georgakoudis // Emit a load instruction and replace the use of the output value 94997517055SGiorgis Georgakoudis // with it. 95097517055SGiorgis Georgakoudis for (Instruction *UsrI : OutsideUsers) { 9515b70c12fSJohannes Doerfert LoadInst *LoadI = new LoadInst( 9525b70c12fSJohannes Doerfert I.getType(), AllocaI, I.getName() + ".seq.output.load", UsrI); 95397517055SGiorgis Georgakoudis UsrI->replaceUsesOfWith(&I, LoadI); 95497517055SGiorgis Georgakoudis } 95597517055SGiorgis Georgakoudis } 95697517055SGiorgis Georgakoudis 95797517055SGiorgis Georgakoudis OpenMPIRBuilder::LocationDescription Loc( 95897517055SGiorgis Georgakoudis InsertPointTy(ParentBB, ParentBB->end()), DL); 95997517055SGiorgis Georgakoudis InsertPointTy SeqAfterIP = 96097517055SGiorgis Georgakoudis OMPInfoCache.OMPBuilder.createMaster(Loc, BodyGenCB, FiniCB); 96197517055SGiorgis Georgakoudis 96297517055SGiorgis Georgakoudis OMPInfoCache.OMPBuilder.createBarrier(SeqAfterIP, OMPD_parallel); 96397517055SGiorgis Georgakoudis 96497517055SGiorgis Georgakoudis BranchInst::Create(SeqAfterBB, SeqAfterIP.getBlock()); 96597517055SGiorgis Georgakoudis 96697517055SGiorgis Georgakoudis LLVM_DEBUG(dbgs() << TAG << "After sequential inlining " << *OuterFn 96797517055SGiorgis Georgakoudis << "\n"); 96897517055SGiorgis Georgakoudis }; 96997517055SGiorgis Georgakoudis 9703a6bfcf2SGiorgis Georgakoudis // Helper to merge the __kmpc_fork_call calls in MergableCIs. They are all 9713a6bfcf2SGiorgis Georgakoudis // contained in BB and only separated by instructions that can be 9723a6bfcf2SGiorgis Georgakoudis // redundantly executed in parallel. The block BB is split before the first 9733a6bfcf2SGiorgis Georgakoudis // call (in MergableCIs) and after the last so the entire region we merge 9743a6bfcf2SGiorgis Georgakoudis // into a single parallel region is contained in a single basic block 9753a6bfcf2SGiorgis Georgakoudis // without any other instructions. We use the OpenMPIRBuilder to outline 9763a6bfcf2SGiorgis Georgakoudis // that block and call the resulting function via __kmpc_fork_call. 9773a6bfcf2SGiorgis Georgakoudis auto Merge = [&](SmallVectorImpl<CallInst *> &MergableCIs, BasicBlock *BB) { 9783a6bfcf2SGiorgis Georgakoudis // TODO: Change the interface to allow single CIs expanded, e.g, to 9793a6bfcf2SGiorgis Georgakoudis // include an outer loop. 9803a6bfcf2SGiorgis Georgakoudis assert(MergableCIs.size() > 1 && "Assumed multiple mergable CIs"); 9813a6bfcf2SGiorgis Georgakoudis 9823a6bfcf2SGiorgis Georgakoudis auto Remark = [&](OptimizationRemark OR) { 983eef6601bSJoseph Huber OR << "Parallel region merged with parallel region" 984eef6601bSJoseph Huber << (MergableCIs.size() > 2 ? "s" : "") << " at "; 98523b0ab2aSKazu Hirata for (auto *CI : llvm::drop_begin(MergableCIs)) { 9863a6bfcf2SGiorgis Georgakoudis OR << ore::NV("OpenMPParallelMerge", CI->getDebugLoc()); 9873a6bfcf2SGiorgis Georgakoudis if (CI != MergableCIs.back()) 9883a6bfcf2SGiorgis Georgakoudis OR << ", "; 9893a6bfcf2SGiorgis Georgakoudis } 990eef6601bSJoseph Huber return OR << "."; 9913a6bfcf2SGiorgis Georgakoudis }; 9923a6bfcf2SGiorgis Georgakoudis 9932c31d5ebSJoseph Huber emitRemark<OptimizationRemark>(MergableCIs.front(), "OMP150", Remark); 9943a6bfcf2SGiorgis Georgakoudis 9953a6bfcf2SGiorgis Georgakoudis Function *OriginalFn = BB->getParent(); 9963a6bfcf2SGiorgis Georgakoudis LLVM_DEBUG(dbgs() << TAG << "Merge " << MergableCIs.size() 9973a6bfcf2SGiorgis Georgakoudis << " parallel regions in " << OriginalFn->getName() 9983a6bfcf2SGiorgis Georgakoudis << "\n"); 9993a6bfcf2SGiorgis Georgakoudis 10003a6bfcf2SGiorgis Georgakoudis // Isolate the calls to merge in a separate block. 10013a6bfcf2SGiorgis Georgakoudis EndBB = SplitBlock(BB, MergableCIs.back()->getNextNode(), DT, LI); 10023a6bfcf2SGiorgis Georgakoudis BasicBlock *AfterBB = 10033a6bfcf2SGiorgis Georgakoudis SplitBlock(EndBB, &*EndBB->getFirstInsertionPt(), DT, LI); 10043a6bfcf2SGiorgis Georgakoudis StartBB = SplitBlock(BB, MergableCIs.front(), DT, LI, nullptr, 10053a6bfcf2SGiorgis Georgakoudis "omp.par.merged"); 10063a6bfcf2SGiorgis Georgakoudis 10073a6bfcf2SGiorgis Georgakoudis assert(BB->getUniqueSuccessor() == StartBB && "Expected a different CFG"); 10083a6bfcf2SGiorgis Georgakoudis const DebugLoc DL = BB->getTerminator()->getDebugLoc(); 10093a6bfcf2SGiorgis Georgakoudis BB->getTerminator()->eraseFromParent(); 10103a6bfcf2SGiorgis Georgakoudis 101197517055SGiorgis Georgakoudis // Create sequential regions for sequential instructions that are 101297517055SGiorgis Georgakoudis // in-between mergable parallel regions. 101397517055SGiorgis Georgakoudis for (auto *It = MergableCIs.begin(), *End = MergableCIs.end() - 1; 101497517055SGiorgis Georgakoudis It != End; ++It) { 101597517055SGiorgis Georgakoudis Instruction *ForkCI = *It; 101697517055SGiorgis Georgakoudis Instruction *NextForkCI = *(It + 1); 101797517055SGiorgis Georgakoudis 101897517055SGiorgis Georgakoudis // Continue if there are not in-between instructions. 101997517055SGiorgis Georgakoudis if (ForkCI->getNextNode() == NextForkCI) 102097517055SGiorgis Georgakoudis continue; 102197517055SGiorgis Georgakoudis 102297517055SGiorgis Georgakoudis CreateSequentialRegion(OriginalFn, BB, ForkCI->getNextNode(), 102397517055SGiorgis Georgakoudis NextForkCI->getPrevNode()); 102497517055SGiorgis Georgakoudis } 102597517055SGiorgis Georgakoudis 10263a6bfcf2SGiorgis Georgakoudis OpenMPIRBuilder::LocationDescription Loc(InsertPointTy(BB, BB->end()), 10273a6bfcf2SGiorgis Georgakoudis DL); 10283a6bfcf2SGiorgis Georgakoudis IRBuilder<>::InsertPoint AllocaIP( 10293a6bfcf2SGiorgis Georgakoudis &OriginalFn->getEntryBlock(), 10303a6bfcf2SGiorgis Georgakoudis OriginalFn->getEntryBlock().getFirstInsertionPt()); 10313a6bfcf2SGiorgis Georgakoudis // Create the merged parallel region with default proc binding, to 10323a6bfcf2SGiorgis Georgakoudis // avoid overriding binding settings, and without explicit cancellation. 1033e5dba2d7SMichael Kruse InsertPointTy AfterIP = OMPInfoCache.OMPBuilder.createParallel( 10343a6bfcf2SGiorgis Georgakoudis Loc, AllocaIP, BodyGenCB, PrivCB, FiniCB, nullptr, nullptr, 10353a6bfcf2SGiorgis Georgakoudis OMP_PROC_BIND_default, /* IsCancellable */ false); 10363a6bfcf2SGiorgis Georgakoudis BranchInst::Create(AfterBB, AfterIP.getBlock()); 10373a6bfcf2SGiorgis Georgakoudis 10383a6bfcf2SGiorgis Georgakoudis // Perform the actual outlining. 1039b1191206SMichael Kruse OMPInfoCache.OMPBuilder.finalize(OriginalFn, 1040b1191206SMichael Kruse /* AllowExtractorSinking */ true); 10413a6bfcf2SGiorgis Georgakoudis 10423a6bfcf2SGiorgis Georgakoudis Function *OutlinedFn = MergableCIs.front()->getCaller(); 10433a6bfcf2SGiorgis Georgakoudis 10443a6bfcf2SGiorgis Georgakoudis // Replace the __kmpc_fork_call calls with direct calls to the outlined 10453a6bfcf2SGiorgis Georgakoudis // callbacks. 10463a6bfcf2SGiorgis Georgakoudis SmallVector<Value *, 8> Args; 10473a6bfcf2SGiorgis Georgakoudis for (auto *CI : MergableCIs) { 10483a6bfcf2SGiorgis Georgakoudis Value *Callee = 10493a6bfcf2SGiorgis Georgakoudis CI->getArgOperand(CallbackCalleeOperand)->stripPointerCasts(); 10503a6bfcf2SGiorgis Georgakoudis FunctionType *FT = 10513a6bfcf2SGiorgis Georgakoudis cast<FunctionType>(Callee->getType()->getPointerElementType()); 10523a6bfcf2SGiorgis Georgakoudis Args.clear(); 10533a6bfcf2SGiorgis Georgakoudis Args.push_back(OutlinedFn->getArg(0)); 10543a6bfcf2SGiorgis Georgakoudis Args.push_back(OutlinedFn->getArg(1)); 10553a6bfcf2SGiorgis Georgakoudis for (unsigned U = CallbackFirstArgOperand, E = CI->getNumArgOperands(); 10563a6bfcf2SGiorgis Georgakoudis U < E; ++U) 10573a6bfcf2SGiorgis Georgakoudis Args.push_back(CI->getArgOperand(U)); 10583a6bfcf2SGiorgis Georgakoudis 10593a6bfcf2SGiorgis Georgakoudis CallInst *NewCI = CallInst::Create(FT, Callee, Args, "", CI); 10603a6bfcf2SGiorgis Georgakoudis if (CI->getDebugLoc()) 10613a6bfcf2SGiorgis Georgakoudis NewCI->setDebugLoc(CI->getDebugLoc()); 10623a6bfcf2SGiorgis Georgakoudis 10633a6bfcf2SGiorgis Georgakoudis // Forward parameter attributes from the callback to the callee. 10643a6bfcf2SGiorgis Georgakoudis for (unsigned U = CallbackFirstArgOperand, E = CI->getNumArgOperands(); 10653a6bfcf2SGiorgis Georgakoudis U < E; ++U) 1066*80ea2bb5SArthur Eubanks for (const Attribute &A : CI->getAttributes().getParamAttrs(U)) 10673a6bfcf2SGiorgis Georgakoudis NewCI->addParamAttr( 10683a6bfcf2SGiorgis Georgakoudis U - (CallbackFirstArgOperand - CallbackCalleeOperand), A); 10693a6bfcf2SGiorgis Georgakoudis 10703a6bfcf2SGiorgis Georgakoudis // Emit an explicit barrier to replace the implicit fork-join barrier. 10713a6bfcf2SGiorgis Georgakoudis if (CI != MergableCIs.back()) { 10723a6bfcf2SGiorgis Georgakoudis // TODO: Remove barrier if the merged parallel region includes the 10733a6bfcf2SGiorgis Georgakoudis // 'nowait' clause. 1074e5dba2d7SMichael Kruse OMPInfoCache.OMPBuilder.createBarrier( 10753a6bfcf2SGiorgis Georgakoudis InsertPointTy(NewCI->getParent(), 10763a6bfcf2SGiorgis Georgakoudis NewCI->getNextNode()->getIterator()), 10773a6bfcf2SGiorgis Georgakoudis OMPD_parallel); 10783a6bfcf2SGiorgis Georgakoudis } 10793a6bfcf2SGiorgis Georgakoudis 10803a6bfcf2SGiorgis Georgakoudis CI->eraseFromParent(); 10813a6bfcf2SGiorgis Georgakoudis } 10823a6bfcf2SGiorgis Georgakoudis 10833a6bfcf2SGiorgis Georgakoudis assert(OutlinedFn != OriginalFn && "Outlining failed"); 10847fea561eSArthur Eubanks CGUpdater.registerOutlinedFunction(*OriginalFn, *OutlinedFn); 10853a6bfcf2SGiorgis Georgakoudis CGUpdater.reanalyzeFunction(*OriginalFn); 10863a6bfcf2SGiorgis Georgakoudis 10873a6bfcf2SGiorgis Georgakoudis NumOpenMPParallelRegionsMerged += MergableCIs.size(); 10883a6bfcf2SGiorgis Georgakoudis 10893a6bfcf2SGiorgis Georgakoudis return true; 10903a6bfcf2SGiorgis Georgakoudis }; 10913a6bfcf2SGiorgis Georgakoudis 10923a6bfcf2SGiorgis Georgakoudis // Helper function that identifes sequences of 10933a6bfcf2SGiorgis Georgakoudis // __kmpc_fork_call uses in a basic block. 10943a6bfcf2SGiorgis Georgakoudis auto DetectPRsCB = [&](Use &U, Function &F) { 10953a6bfcf2SGiorgis Georgakoudis CallInst *CI = getCallIfRegularCall(U, &RFI); 10963a6bfcf2SGiorgis Georgakoudis BB2PRMap[CI->getParent()].insert(CI); 10973a6bfcf2SGiorgis Georgakoudis 10983a6bfcf2SGiorgis Georgakoudis return false; 10993a6bfcf2SGiorgis Georgakoudis }; 11003a6bfcf2SGiorgis Georgakoudis 11013a6bfcf2SGiorgis Georgakoudis BB2PRMap.clear(); 11023a6bfcf2SGiorgis Georgakoudis RFI.foreachUse(SCC, DetectPRsCB); 11033a6bfcf2SGiorgis Georgakoudis SmallVector<SmallVector<CallInst *, 4>, 4> MergableCIsVector; 11043a6bfcf2SGiorgis Georgakoudis // Find mergable parallel regions within a basic block that are 11053a6bfcf2SGiorgis Georgakoudis // safe to merge, that is any in-between instructions can safely 11063a6bfcf2SGiorgis Georgakoudis // execute in parallel after merging. 11073a6bfcf2SGiorgis Georgakoudis // TODO: support merging across basic-blocks. 11083a6bfcf2SGiorgis Georgakoudis for (auto &It : BB2PRMap) { 11093a6bfcf2SGiorgis Georgakoudis auto &CIs = It.getSecond(); 11103a6bfcf2SGiorgis Georgakoudis if (CIs.size() < 2) 11113a6bfcf2SGiorgis Georgakoudis continue; 11123a6bfcf2SGiorgis Georgakoudis 11133a6bfcf2SGiorgis Georgakoudis BasicBlock *BB = It.getFirst(); 11143a6bfcf2SGiorgis Georgakoudis SmallVector<CallInst *, 4> MergableCIs; 11153a6bfcf2SGiorgis Georgakoudis 111697517055SGiorgis Georgakoudis /// Returns true if the instruction is mergable, false otherwise. 111797517055SGiorgis Georgakoudis /// A terminator instruction is unmergable by definition since merging 111897517055SGiorgis Georgakoudis /// works within a BB. Instructions before the mergable region are 111997517055SGiorgis Georgakoudis /// mergable if they are not calls to OpenMP runtime functions that may 112097517055SGiorgis Georgakoudis /// set different execution parameters for subsequent parallel regions. 112197517055SGiorgis Georgakoudis /// Instructions in-between parallel regions are mergable if they are not 112297517055SGiorgis Georgakoudis /// calls to any non-intrinsic function since that may call a non-mergable 112397517055SGiorgis Georgakoudis /// OpenMP runtime function. 112497517055SGiorgis Georgakoudis auto IsMergable = [&](Instruction &I, bool IsBeforeMergableRegion) { 112597517055SGiorgis Georgakoudis // We do not merge across BBs, hence return false (unmergable) if the 112697517055SGiorgis Georgakoudis // instruction is a terminator. 112797517055SGiorgis Georgakoudis if (I.isTerminator()) 112897517055SGiorgis Georgakoudis return false; 112997517055SGiorgis Georgakoudis 113097517055SGiorgis Georgakoudis if (!isa<CallInst>(&I)) 113197517055SGiorgis Georgakoudis return true; 113297517055SGiorgis Georgakoudis 113397517055SGiorgis Georgakoudis CallInst *CI = cast<CallInst>(&I); 113497517055SGiorgis Georgakoudis if (IsBeforeMergableRegion) { 113597517055SGiorgis Georgakoudis Function *CalledFunction = CI->getCalledFunction(); 113697517055SGiorgis Georgakoudis if (!CalledFunction) 113797517055SGiorgis Georgakoudis return false; 113897517055SGiorgis Georgakoudis // Return false (unmergable) if the call before the parallel 113997517055SGiorgis Georgakoudis // region calls an explicit affinity (proc_bind) or number of 114097517055SGiorgis Georgakoudis // threads (num_threads) compiler-generated function. Those settings 114197517055SGiorgis Georgakoudis // may be incompatible with following parallel regions. 114297517055SGiorgis Georgakoudis // TODO: ICV tracking to detect compatibility. 114397517055SGiorgis Georgakoudis for (const auto &RFI : UnmergableCallsInfo) { 114497517055SGiorgis Georgakoudis if (CalledFunction == RFI.Declaration) 114597517055SGiorgis Georgakoudis return false; 114697517055SGiorgis Georgakoudis } 114797517055SGiorgis Georgakoudis } else { 114897517055SGiorgis Georgakoudis // Return false (unmergable) if there is a call instruction 114997517055SGiorgis Georgakoudis // in-between parallel regions when it is not an intrinsic. It 115097517055SGiorgis Georgakoudis // may call an unmergable OpenMP runtime function in its callpath. 115197517055SGiorgis Georgakoudis // TODO: Keep track of possible OpenMP calls in the callpath. 115297517055SGiorgis Georgakoudis if (!isa<IntrinsicInst>(CI)) 115397517055SGiorgis Georgakoudis return false; 115497517055SGiorgis Georgakoudis } 115597517055SGiorgis Georgakoudis 115697517055SGiorgis Georgakoudis return true; 115797517055SGiorgis Georgakoudis }; 11583a6bfcf2SGiorgis Georgakoudis // Find maximal number of parallel region CIs that are safe to merge. 115997517055SGiorgis Georgakoudis for (auto It = BB->begin(), End = BB->end(); It != End;) { 116097517055SGiorgis Georgakoudis Instruction &I = *It; 116197517055SGiorgis Georgakoudis ++It; 116297517055SGiorgis Georgakoudis 11633a6bfcf2SGiorgis Georgakoudis if (CIs.count(&I)) { 11643a6bfcf2SGiorgis Georgakoudis MergableCIs.push_back(cast<CallInst>(&I)); 11653a6bfcf2SGiorgis Georgakoudis continue; 11663a6bfcf2SGiorgis Georgakoudis } 11673a6bfcf2SGiorgis Georgakoudis 116897517055SGiorgis Georgakoudis // Continue expanding if the instruction is mergable. 116997517055SGiorgis Georgakoudis if (IsMergable(I, MergableCIs.empty())) 11703a6bfcf2SGiorgis Georgakoudis continue; 11713a6bfcf2SGiorgis Georgakoudis 117297517055SGiorgis Georgakoudis // Forward the instruction iterator to skip the next parallel region 117397517055SGiorgis Georgakoudis // since there is an unmergable instruction which can affect it. 117497517055SGiorgis Georgakoudis for (; It != End; ++It) { 117597517055SGiorgis Georgakoudis Instruction &SkipI = *It; 117697517055SGiorgis Georgakoudis if (CIs.count(&SkipI)) { 117797517055SGiorgis Georgakoudis LLVM_DEBUG(dbgs() << TAG << "Skip parallel region " << SkipI 117897517055SGiorgis Georgakoudis << " due to " << I << "\n"); 117997517055SGiorgis Georgakoudis ++It; 118097517055SGiorgis Georgakoudis break; 118197517055SGiorgis Georgakoudis } 118297517055SGiorgis Georgakoudis } 118397517055SGiorgis Georgakoudis 118497517055SGiorgis Georgakoudis // Store mergable regions found. 11853a6bfcf2SGiorgis Georgakoudis if (MergableCIs.size() > 1) { 11863a6bfcf2SGiorgis Georgakoudis MergableCIsVector.push_back(MergableCIs); 11873a6bfcf2SGiorgis Georgakoudis LLVM_DEBUG(dbgs() << TAG << "Found " << MergableCIs.size() 11883a6bfcf2SGiorgis Georgakoudis << " parallel regions in block " << BB->getName() 11893a6bfcf2SGiorgis Georgakoudis << " of function " << BB->getParent()->getName() 11903a6bfcf2SGiorgis Georgakoudis << "\n";); 11913a6bfcf2SGiorgis Georgakoudis } 11923a6bfcf2SGiorgis Georgakoudis 11933a6bfcf2SGiorgis Georgakoudis MergableCIs.clear(); 11943a6bfcf2SGiorgis Georgakoudis } 11953a6bfcf2SGiorgis Georgakoudis 11963a6bfcf2SGiorgis Georgakoudis if (!MergableCIsVector.empty()) { 11973a6bfcf2SGiorgis Georgakoudis Changed = true; 11983a6bfcf2SGiorgis Georgakoudis 11993a6bfcf2SGiorgis Georgakoudis for (auto &MergableCIs : MergableCIsVector) 12003a6bfcf2SGiorgis Georgakoudis Merge(MergableCIs, BB); 1201b2ad63d3SJoseph Huber MergableCIsVector.clear(); 12023a6bfcf2SGiorgis Georgakoudis } 12033a6bfcf2SGiorgis Georgakoudis } 12043a6bfcf2SGiorgis Georgakoudis 12053a6bfcf2SGiorgis Georgakoudis if (Changed) { 120697517055SGiorgis Georgakoudis /// Re-collect use for fork calls, emitted barrier calls, and 120797517055SGiorgis Georgakoudis /// any emitted master/end_master calls. 120897517055SGiorgis Georgakoudis OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_fork_call); 120997517055SGiorgis Georgakoudis OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_barrier); 121097517055SGiorgis Georgakoudis OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_master); 121197517055SGiorgis Georgakoudis OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_end_master); 12123a6bfcf2SGiorgis Georgakoudis } 12133a6bfcf2SGiorgis Georgakoudis 12143a6bfcf2SGiorgis Georgakoudis return Changed; 12153a6bfcf2SGiorgis Georgakoudis } 12163a6bfcf2SGiorgis Georgakoudis 12179d38f98dSJohannes Doerfert /// Try to delete parallel regions if possible. 1218e565db49SJohannes Doerfert bool deleteParallelRegions() { 1219e565db49SJohannes Doerfert const unsigned CallbackCalleeOperand = 2; 1220e565db49SJohannes Doerfert 12217cfd267cSsstefan1 OMPInformationCache::RuntimeFunctionInfo &RFI = 12227cfd267cSsstefan1 OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call]; 12237cfd267cSsstefan1 1224e565db49SJohannes Doerfert if (!RFI.Declaration) 1225e565db49SJohannes Doerfert return false; 1226e565db49SJohannes Doerfert 1227e565db49SJohannes Doerfert bool Changed = false; 1228e565db49SJohannes Doerfert auto DeleteCallCB = [&](Use &U, Function &) { 1229e565db49SJohannes Doerfert CallInst *CI = getCallIfRegularCall(U); 1230e565db49SJohannes Doerfert if (!CI) 1231e565db49SJohannes Doerfert return false; 1232e565db49SJohannes Doerfert auto *Fn = dyn_cast<Function>( 1233e565db49SJohannes Doerfert CI->getArgOperand(CallbackCalleeOperand)->stripPointerCasts()); 1234e565db49SJohannes Doerfert if (!Fn) 1235e565db49SJohannes Doerfert return false; 1236e565db49SJohannes Doerfert if (!Fn->onlyReadsMemory()) 1237e565db49SJohannes Doerfert return false; 1238e565db49SJohannes Doerfert if (!Fn->hasFnAttribute(Attribute::WillReturn)) 1239e565db49SJohannes Doerfert return false; 1240e565db49SJohannes Doerfert 1241e565db49SJohannes Doerfert LLVM_DEBUG(dbgs() << TAG << "Delete read-only parallel region in " 1242e565db49SJohannes Doerfert << CI->getCaller()->getName() << "\n"); 12434d4ea9acSHuber, Joseph 12444d4ea9acSHuber, Joseph auto Remark = [&](OptimizationRemark OR) { 1245eef6601bSJoseph Huber return OR << "Removing parallel region with no side-effects."; 12464d4ea9acSHuber, Joseph }; 12472c31d5ebSJoseph Huber emitRemark<OptimizationRemark>(CI, "OMP160", Remark); 12484d4ea9acSHuber, Joseph 1249e565db49SJohannes Doerfert CGUpdater.removeCallSite(*CI); 1250e565db49SJohannes Doerfert CI->eraseFromParent(); 1251e565db49SJohannes Doerfert Changed = true; 125255eb714aSRoman Lebedev ++NumOpenMPParallelRegionsDeleted; 1253e565db49SJohannes Doerfert return true; 1254e565db49SJohannes Doerfert }; 1255e565db49SJohannes Doerfert 1256624d34afSJohannes Doerfert RFI.foreachUse(SCC, DeleteCallCB); 1257e565db49SJohannes Doerfert 1258e565db49SJohannes Doerfert return Changed; 1259e565db49SJohannes Doerfert } 1260e565db49SJohannes Doerfert 1261b726c557SJohannes Doerfert /// Try to eliminate runtime calls by reusing existing ones. 12629548b74aSJohannes Doerfert bool deduplicateRuntimeCalls() { 12639548b74aSJohannes Doerfert bool Changed = false; 12649548b74aSJohannes Doerfert 1265e28936f6SJohannes Doerfert RuntimeFunction DeduplicableRuntimeCallIDs[] = { 1266e28936f6SJohannes Doerfert OMPRTL_omp_get_num_threads, 1267e28936f6SJohannes Doerfert OMPRTL_omp_in_parallel, 1268e28936f6SJohannes Doerfert OMPRTL_omp_get_cancellation, 1269e28936f6SJohannes Doerfert OMPRTL_omp_get_thread_limit, 1270e28936f6SJohannes Doerfert OMPRTL_omp_get_supported_active_levels, 1271e28936f6SJohannes Doerfert OMPRTL_omp_get_level, 1272e28936f6SJohannes Doerfert OMPRTL_omp_get_ancestor_thread_num, 1273e28936f6SJohannes Doerfert OMPRTL_omp_get_team_size, 1274e28936f6SJohannes Doerfert OMPRTL_omp_get_active_level, 1275e28936f6SJohannes Doerfert OMPRTL_omp_in_final, 1276e28936f6SJohannes Doerfert OMPRTL_omp_get_proc_bind, 1277e28936f6SJohannes Doerfert OMPRTL_omp_get_num_places, 1278e28936f6SJohannes Doerfert OMPRTL_omp_get_num_procs, 1279e28936f6SJohannes Doerfert OMPRTL_omp_get_place_num, 1280e28936f6SJohannes Doerfert OMPRTL_omp_get_partition_num_places, 1281e28936f6SJohannes Doerfert OMPRTL_omp_get_partition_place_nums}; 1282e28936f6SJohannes Doerfert 1283bc93c2d7SMarek Kurdej // Global-tid is handled separately. 12849548b74aSJohannes Doerfert SmallSetVector<Value *, 16> GTIdArgs; 12859548b74aSJohannes Doerfert collectGlobalThreadIdArguments(GTIdArgs); 12869548b74aSJohannes Doerfert LLVM_DEBUG(dbgs() << TAG << "Found " << GTIdArgs.size() 12879548b74aSJohannes Doerfert << " global thread ID arguments\n"); 12889548b74aSJohannes Doerfert 12899548b74aSJohannes Doerfert for (Function *F : SCC) { 1290e28936f6SJohannes Doerfert for (auto DeduplicableRuntimeCallID : DeduplicableRuntimeCallIDs) 12914e29d256Sserge-sans-paille Changed |= deduplicateRuntimeCalls( 12924e29d256Sserge-sans-paille *F, OMPInfoCache.RFIs[DeduplicableRuntimeCallID]); 1293e28936f6SJohannes Doerfert 1294e28936f6SJohannes Doerfert // __kmpc_global_thread_num is special as we can replace it with an 1295e28936f6SJohannes Doerfert // argument in enough cases to make it worth trying. 12969548b74aSJohannes Doerfert Value *GTIdArg = nullptr; 12979548b74aSJohannes Doerfert for (Argument &Arg : F->args()) 12989548b74aSJohannes Doerfert if (GTIdArgs.count(&Arg)) { 12999548b74aSJohannes Doerfert GTIdArg = &Arg; 13009548b74aSJohannes Doerfert break; 13019548b74aSJohannes Doerfert } 13029548b74aSJohannes Doerfert Changed |= deduplicateRuntimeCalls( 13037cfd267cSsstefan1 *F, OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num], GTIdArg); 13049548b74aSJohannes Doerfert } 13059548b74aSJohannes Doerfert 13069548b74aSJohannes Doerfert return Changed; 13079548b74aSJohannes Doerfert } 13089548b74aSJohannes Doerfert 1309496f8e5bSHamilton Tobon Mosquera /// Tries to hide the latency of runtime calls that involve host to 1310496f8e5bSHamilton Tobon Mosquera /// device memory transfers by splitting them into their "issue" and "wait" 1311496f8e5bSHamilton Tobon Mosquera /// versions. The "issue" is moved upwards as much as possible. The "wait" is 1312496f8e5bSHamilton Tobon Mosquera /// moved downards as much as possible. The "issue" issues the memory transfer 1313496f8e5bSHamilton Tobon Mosquera /// asynchronously, returning a handle. The "wait" waits in the returned 1314496f8e5bSHamilton Tobon Mosquera /// handle for the memory transfer to finish. 1315496f8e5bSHamilton Tobon Mosquera bool hideMemTransfersLatency() { 1316496f8e5bSHamilton Tobon Mosquera auto &RFI = OMPInfoCache.RFIs[OMPRTL___tgt_target_data_begin_mapper]; 1317496f8e5bSHamilton Tobon Mosquera bool Changed = false; 1318496f8e5bSHamilton Tobon Mosquera auto SplitMemTransfers = [&](Use &U, Function &Decl) { 1319496f8e5bSHamilton Tobon Mosquera auto *RTCall = getCallIfRegularCall(U, &RFI); 1320496f8e5bSHamilton Tobon Mosquera if (!RTCall) 1321496f8e5bSHamilton Tobon Mosquera return false; 1322496f8e5bSHamilton Tobon Mosquera 13238931add6SHamilton Tobon Mosquera OffloadArray OffloadArrays[3]; 13248931add6SHamilton Tobon Mosquera if (!getValuesInOffloadArrays(*RTCall, OffloadArrays)) 13258931add6SHamilton Tobon Mosquera return false; 13268931add6SHamilton Tobon Mosquera 13278931add6SHamilton Tobon Mosquera LLVM_DEBUG(dumpValuesInOffloadArrays(OffloadArrays)); 13288931add6SHamilton Tobon Mosquera 1329bd2fa181SHamilton Tobon Mosquera // TODO: Check if can be moved upwards. 1330bd2fa181SHamilton Tobon Mosquera bool WasSplit = false; 1331bd2fa181SHamilton Tobon Mosquera Instruction *WaitMovementPoint = canBeMovedDownwards(*RTCall); 1332bd2fa181SHamilton Tobon Mosquera if (WaitMovementPoint) 1333bd2fa181SHamilton Tobon Mosquera WasSplit = splitTargetDataBeginRTC(*RTCall, *WaitMovementPoint); 1334bd2fa181SHamilton Tobon Mosquera 1335496f8e5bSHamilton Tobon Mosquera Changed |= WasSplit; 1336496f8e5bSHamilton Tobon Mosquera return WasSplit; 1337496f8e5bSHamilton Tobon Mosquera }; 1338496f8e5bSHamilton Tobon Mosquera RFI.foreachUse(SCC, SplitMemTransfers); 1339496f8e5bSHamilton Tobon Mosquera 1340496f8e5bSHamilton Tobon Mosquera return Changed; 1341496f8e5bSHamilton Tobon Mosquera } 1342496f8e5bSHamilton Tobon Mosquera 1343a2281419SJoseph Huber void analysisGlobalization() { 13446fc51c9fSJoseph Huber auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared]; 134582453e75SJoseph Huber 134682453e75SJoseph Huber auto CheckGlobalization = [&](Use &U, Function &Decl) { 1347a2281419SJoseph Huber if (CallInst *CI = getCallIfRegularCall(U, &RFI)) { 134844feacc7SJoseph Huber auto Remark = [&](OptimizationRemarkMissed ORM) { 134944feacc7SJoseph Huber return ORM 1350a2281419SJoseph Huber << "Found thread data sharing on the GPU. " 1351a2281419SJoseph Huber << "Expect degraded performance due to data globalization."; 1352a2281419SJoseph Huber }; 13532c31d5ebSJoseph Huber emitRemark<OptimizationRemarkMissed>(CI, "OMP112", Remark); 1354a2281419SJoseph Huber } 1355a2281419SJoseph Huber 1356a2281419SJoseph Huber return false; 1357a2281419SJoseph Huber }; 1358a2281419SJoseph Huber 135982453e75SJoseph Huber RFI.foreachUse(SCC, CheckGlobalization); 136082453e75SJoseph Huber } 1361a2281419SJoseph Huber 13628931add6SHamilton Tobon Mosquera /// Maps the values stored in the offload arrays passed as arguments to 13638931add6SHamilton Tobon Mosquera /// \p RuntimeCall into the offload arrays in \p OAs. 13648931add6SHamilton Tobon Mosquera bool getValuesInOffloadArrays(CallInst &RuntimeCall, 13658931add6SHamilton Tobon Mosquera MutableArrayRef<OffloadArray> OAs) { 13668931add6SHamilton Tobon Mosquera assert(OAs.size() == 3 && "Need space for three offload arrays!"); 13678931add6SHamilton Tobon Mosquera 13688931add6SHamilton Tobon Mosquera // A runtime call that involves memory offloading looks something like: 13698931add6SHamilton Tobon Mosquera // call void @__tgt_target_data_begin_mapper(arg0, arg1, 13708931add6SHamilton Tobon Mosquera // i8** %offload_baseptrs, i8** %offload_ptrs, i64* %offload_sizes, 13718931add6SHamilton Tobon Mosquera // ...) 13728931add6SHamilton Tobon Mosquera // So, the idea is to access the allocas that allocate space for these 13738931add6SHamilton Tobon Mosquera // offload arrays, offload_baseptrs, offload_ptrs, offload_sizes. 13748931add6SHamilton Tobon Mosquera // Therefore: 13758931add6SHamilton Tobon Mosquera // i8** %offload_baseptrs. 13761d3d9b9cSHamilton Tobon Mosquera Value *BasePtrsArg = 13771d3d9b9cSHamilton Tobon Mosquera RuntimeCall.getArgOperand(OffloadArray::BasePtrsArgNum); 13788931add6SHamilton Tobon Mosquera // i8** %offload_ptrs. 13791d3d9b9cSHamilton Tobon Mosquera Value *PtrsArg = RuntimeCall.getArgOperand(OffloadArray::PtrsArgNum); 13808931add6SHamilton Tobon Mosquera // i8** %offload_sizes. 13811d3d9b9cSHamilton Tobon Mosquera Value *SizesArg = RuntimeCall.getArgOperand(OffloadArray::SizesArgNum); 13828931add6SHamilton Tobon Mosquera 13838931add6SHamilton Tobon Mosquera // Get values stored in **offload_baseptrs. 13848931add6SHamilton Tobon Mosquera auto *V = getUnderlyingObject(BasePtrsArg); 13858931add6SHamilton Tobon Mosquera if (!isa<AllocaInst>(V)) 13868931add6SHamilton Tobon Mosquera return false; 13878931add6SHamilton Tobon Mosquera auto *BasePtrsArray = cast<AllocaInst>(V); 13888931add6SHamilton Tobon Mosquera if (!OAs[0].initialize(*BasePtrsArray, RuntimeCall)) 13898931add6SHamilton Tobon Mosquera return false; 13908931add6SHamilton Tobon Mosquera 13918931add6SHamilton Tobon Mosquera // Get values stored in **offload_baseptrs. 13928931add6SHamilton Tobon Mosquera V = getUnderlyingObject(PtrsArg); 13938931add6SHamilton Tobon Mosquera if (!isa<AllocaInst>(V)) 13948931add6SHamilton Tobon Mosquera return false; 13958931add6SHamilton Tobon Mosquera auto *PtrsArray = cast<AllocaInst>(V); 13968931add6SHamilton Tobon Mosquera if (!OAs[1].initialize(*PtrsArray, RuntimeCall)) 13978931add6SHamilton Tobon Mosquera return false; 13988931add6SHamilton Tobon Mosquera 13998931add6SHamilton Tobon Mosquera // Get values stored in **offload_sizes. 14008931add6SHamilton Tobon Mosquera V = getUnderlyingObject(SizesArg); 14018931add6SHamilton Tobon Mosquera // If it's a [constant] global array don't analyze it. 14028931add6SHamilton Tobon Mosquera if (isa<GlobalValue>(V)) 14038931add6SHamilton Tobon Mosquera return isa<Constant>(V); 14048931add6SHamilton Tobon Mosquera if (!isa<AllocaInst>(V)) 14058931add6SHamilton Tobon Mosquera return false; 14068931add6SHamilton Tobon Mosquera 14078931add6SHamilton Tobon Mosquera auto *SizesArray = cast<AllocaInst>(V); 14088931add6SHamilton Tobon Mosquera if (!OAs[2].initialize(*SizesArray, RuntimeCall)) 14098931add6SHamilton Tobon Mosquera return false; 14108931add6SHamilton Tobon Mosquera 14118931add6SHamilton Tobon Mosquera return true; 14128931add6SHamilton Tobon Mosquera } 14138931add6SHamilton Tobon Mosquera 14148931add6SHamilton Tobon Mosquera /// Prints the values in the OffloadArrays \p OAs using LLVM_DEBUG. 14158931add6SHamilton Tobon Mosquera /// For now this is a way to test that the function getValuesInOffloadArrays 14168931add6SHamilton Tobon Mosquera /// is working properly. 14178931add6SHamilton Tobon Mosquera /// TODO: Move this to a unittest when unittests are available for OpenMPOpt. 14188931add6SHamilton Tobon Mosquera void dumpValuesInOffloadArrays(ArrayRef<OffloadArray> OAs) { 14198931add6SHamilton Tobon Mosquera assert(OAs.size() == 3 && "There are three offload arrays to debug!"); 14208931add6SHamilton Tobon Mosquera 14218931add6SHamilton Tobon Mosquera LLVM_DEBUG(dbgs() << TAG << " Successfully got offload values:\n"); 14228931add6SHamilton Tobon Mosquera std::string ValuesStr; 14238931add6SHamilton Tobon Mosquera raw_string_ostream Printer(ValuesStr); 14248931add6SHamilton Tobon Mosquera std::string Separator = " --- "; 14258931add6SHamilton Tobon Mosquera 14268931add6SHamilton Tobon Mosquera for (auto *BP : OAs[0].StoredValues) { 14278931add6SHamilton Tobon Mosquera BP->print(Printer); 14288931add6SHamilton Tobon Mosquera Printer << Separator; 14298931add6SHamilton Tobon Mosquera } 14308931add6SHamilton Tobon Mosquera LLVM_DEBUG(dbgs() << "\t\toffload_baseptrs: " << Printer.str() << "\n"); 14318931add6SHamilton Tobon Mosquera ValuesStr.clear(); 14328931add6SHamilton Tobon Mosquera 14338931add6SHamilton Tobon Mosquera for (auto *P : OAs[1].StoredValues) { 14348931add6SHamilton Tobon Mosquera P->print(Printer); 14358931add6SHamilton Tobon Mosquera Printer << Separator; 14368931add6SHamilton Tobon Mosquera } 14378931add6SHamilton Tobon Mosquera LLVM_DEBUG(dbgs() << "\t\toffload_ptrs: " << Printer.str() << "\n"); 14388931add6SHamilton Tobon Mosquera ValuesStr.clear(); 14398931add6SHamilton Tobon Mosquera 14408931add6SHamilton Tobon Mosquera for (auto *S : OAs[2].StoredValues) { 14418931add6SHamilton Tobon Mosquera S->print(Printer); 14428931add6SHamilton Tobon Mosquera Printer << Separator; 14438931add6SHamilton Tobon Mosquera } 14448931add6SHamilton Tobon Mosquera LLVM_DEBUG(dbgs() << "\t\toffload_sizes: " << Printer.str() << "\n"); 14458931add6SHamilton Tobon Mosquera } 14468931add6SHamilton Tobon Mosquera 1447bd2fa181SHamilton Tobon Mosquera /// Returns the instruction where the "wait" counterpart \p RuntimeCall can be 1448bd2fa181SHamilton Tobon Mosquera /// moved. Returns nullptr if the movement is not possible, or not worth it. 1449bd2fa181SHamilton Tobon Mosquera Instruction *canBeMovedDownwards(CallInst &RuntimeCall) { 1450bd2fa181SHamilton Tobon Mosquera // FIXME: This traverses only the BasicBlock where RuntimeCall is. 1451bd2fa181SHamilton Tobon Mosquera // Make it traverse the CFG. 1452bd2fa181SHamilton Tobon Mosquera 1453bd2fa181SHamilton Tobon Mosquera Instruction *CurrentI = &RuntimeCall; 1454bd2fa181SHamilton Tobon Mosquera bool IsWorthIt = false; 1455bd2fa181SHamilton Tobon Mosquera while ((CurrentI = CurrentI->getNextNode())) { 1456bd2fa181SHamilton Tobon Mosquera 1457bd2fa181SHamilton Tobon Mosquera // TODO: Once we detect the regions to be offloaded we should use the 1458bd2fa181SHamilton Tobon Mosquera // alias analysis manager to check if CurrentI may modify one of 1459bd2fa181SHamilton Tobon Mosquera // the offloaded regions. 1460bd2fa181SHamilton Tobon Mosquera if (CurrentI->mayHaveSideEffects() || CurrentI->mayReadFromMemory()) { 1461bd2fa181SHamilton Tobon Mosquera if (IsWorthIt) 1462bd2fa181SHamilton Tobon Mosquera return CurrentI; 1463bd2fa181SHamilton Tobon Mosquera 1464bd2fa181SHamilton Tobon Mosquera return nullptr; 1465bd2fa181SHamilton Tobon Mosquera } 1466bd2fa181SHamilton Tobon Mosquera 1467bd2fa181SHamilton Tobon Mosquera // FIXME: For now if we move it over anything without side effect 1468bd2fa181SHamilton Tobon Mosquera // is worth it. 1469bd2fa181SHamilton Tobon Mosquera IsWorthIt = true; 1470bd2fa181SHamilton Tobon Mosquera } 1471bd2fa181SHamilton Tobon Mosquera 1472bd2fa181SHamilton Tobon Mosquera // Return end of BasicBlock. 1473bd2fa181SHamilton Tobon Mosquera return RuntimeCall.getParent()->getTerminator(); 1474bd2fa181SHamilton Tobon Mosquera } 1475bd2fa181SHamilton Tobon Mosquera 1476496f8e5bSHamilton Tobon Mosquera /// Splits \p RuntimeCall into its "issue" and "wait" counterparts. 1477bd2fa181SHamilton Tobon Mosquera bool splitTargetDataBeginRTC(CallInst &RuntimeCall, 1478bd2fa181SHamilton Tobon Mosquera Instruction &WaitMovementPoint) { 1479bd31abc1SHamilton Tobon Mosquera // Create stack allocated handle (__tgt_async_info) at the beginning of the 1480bd31abc1SHamilton Tobon Mosquera // function. Used for storing information of the async transfer, allowing to 1481bd31abc1SHamilton Tobon Mosquera // wait on it later. 1482496f8e5bSHamilton Tobon Mosquera auto &IRBuilder = OMPInfoCache.OMPBuilder; 1483bd31abc1SHamilton Tobon Mosquera auto *F = RuntimeCall.getCaller(); 1484bd31abc1SHamilton Tobon Mosquera Instruction *FirstInst = &(F->getEntryBlock().front()); 1485bd31abc1SHamilton Tobon Mosquera AllocaInst *Handle = new AllocaInst( 1486bd31abc1SHamilton Tobon Mosquera IRBuilder.AsyncInfo, F->getAddressSpace(), "handle", FirstInst); 1487bd31abc1SHamilton Tobon Mosquera 1488496f8e5bSHamilton Tobon Mosquera // Add "issue" runtime call declaration: 1489496f8e5bSHamilton Tobon Mosquera // declare %struct.tgt_async_info @__tgt_target_data_begin_issue(i64, i32, 1490496f8e5bSHamilton Tobon Mosquera // i8**, i8**, i64*, i64*) 1491496f8e5bSHamilton Tobon Mosquera FunctionCallee IssueDecl = IRBuilder.getOrCreateRuntimeFunction( 1492496f8e5bSHamilton Tobon Mosquera M, OMPRTL___tgt_target_data_begin_mapper_issue); 1493496f8e5bSHamilton Tobon Mosquera 1494496f8e5bSHamilton Tobon Mosquera // Change RuntimeCall call site for its asynchronous version. 149597e55cfeSJoseph Huber SmallVector<Value *, 16> Args; 1496bd2fa181SHamilton Tobon Mosquera for (auto &Arg : RuntimeCall.args()) 1497496f8e5bSHamilton Tobon Mosquera Args.push_back(Arg.get()); 1498bd31abc1SHamilton Tobon Mosquera Args.push_back(Handle); 1499496f8e5bSHamilton Tobon Mosquera 1500496f8e5bSHamilton Tobon Mosquera CallInst *IssueCallsite = 1501bd31abc1SHamilton Tobon Mosquera CallInst::Create(IssueDecl, Args, /*NameStr=*/"", &RuntimeCall); 1502bd2fa181SHamilton Tobon Mosquera RuntimeCall.eraseFromParent(); 1503496f8e5bSHamilton Tobon Mosquera 1504496f8e5bSHamilton Tobon Mosquera // Add "wait" runtime call declaration: 1505496f8e5bSHamilton Tobon Mosquera // declare void @__tgt_target_data_begin_wait(i64, %struct.__tgt_async_info) 1506496f8e5bSHamilton Tobon Mosquera FunctionCallee WaitDecl = IRBuilder.getOrCreateRuntimeFunction( 1507496f8e5bSHamilton Tobon Mosquera M, OMPRTL___tgt_target_data_begin_mapper_wait); 1508496f8e5bSHamilton Tobon Mosquera 1509496f8e5bSHamilton Tobon Mosquera Value *WaitParams[2] = { 1510da8bec47SJoseph Huber IssueCallsite->getArgOperand( 1511da8bec47SJoseph Huber OffloadArray::DeviceIDArgNum), // device_id. 1512bd31abc1SHamilton Tobon Mosquera Handle // handle to wait on. 1513496f8e5bSHamilton Tobon Mosquera }; 1514bd2fa181SHamilton Tobon Mosquera CallInst::Create(WaitDecl, WaitParams, /*NameStr=*/"", &WaitMovementPoint); 1515496f8e5bSHamilton Tobon Mosquera 1516496f8e5bSHamilton Tobon Mosquera return true; 1517496f8e5bSHamilton Tobon Mosquera } 1518496f8e5bSHamilton Tobon Mosquera 1519dc3b5b00SJohannes Doerfert static Value *combinedIdentStruct(Value *CurrentIdent, Value *NextIdent, 1520dc3b5b00SJohannes Doerfert bool GlobalOnly, bool &SingleChoice) { 1521dc3b5b00SJohannes Doerfert if (CurrentIdent == NextIdent) 1522dc3b5b00SJohannes Doerfert return CurrentIdent; 1523dc3b5b00SJohannes Doerfert 1524396b7253SJohannes Doerfert // TODO: Figure out how to actually combine multiple debug locations. For 1525dc3b5b00SJohannes Doerfert // now we just keep an existing one if there is a single choice. 1526dc3b5b00SJohannes Doerfert if (!GlobalOnly || isa<GlobalValue>(NextIdent)) { 1527dc3b5b00SJohannes Doerfert SingleChoice = !CurrentIdent; 1528dc3b5b00SJohannes Doerfert return NextIdent; 1529dc3b5b00SJohannes Doerfert } 1530396b7253SJohannes Doerfert return nullptr; 1531396b7253SJohannes Doerfert } 1532396b7253SJohannes Doerfert 1533396b7253SJohannes Doerfert /// Return an `struct ident_t*` value that represents the ones used in the 1534396b7253SJohannes Doerfert /// calls of \p RFI inside of \p F. If \p GlobalOnly is true, we will not 1535396b7253SJohannes Doerfert /// return a local `struct ident_t*`. For now, if we cannot find a suitable 1536396b7253SJohannes Doerfert /// return value we create one from scratch. We also do not yet combine 1537396b7253SJohannes Doerfert /// information, e.g., the source locations, see combinedIdentStruct. 15387cfd267cSsstefan1 Value * 15397cfd267cSsstefan1 getCombinedIdentFromCallUsesIn(OMPInformationCache::RuntimeFunctionInfo &RFI, 15407cfd267cSsstefan1 Function &F, bool GlobalOnly) { 1541dc3b5b00SJohannes Doerfert bool SingleChoice = true; 1542396b7253SJohannes Doerfert Value *Ident = nullptr; 1543396b7253SJohannes Doerfert auto CombineIdentStruct = [&](Use &U, Function &Caller) { 1544396b7253SJohannes Doerfert CallInst *CI = getCallIfRegularCall(U, &RFI); 1545396b7253SJohannes Doerfert if (!CI || &F != &Caller) 1546396b7253SJohannes Doerfert return false; 1547396b7253SJohannes Doerfert Ident = combinedIdentStruct(Ident, CI->getArgOperand(0), 1548dc3b5b00SJohannes Doerfert /* GlobalOnly */ true, SingleChoice); 1549396b7253SJohannes Doerfert return false; 1550396b7253SJohannes Doerfert }; 1551624d34afSJohannes Doerfert RFI.foreachUse(SCC, CombineIdentStruct); 1552396b7253SJohannes Doerfert 1553dc3b5b00SJohannes Doerfert if (!Ident || !SingleChoice) { 1554396b7253SJohannes Doerfert // The IRBuilder uses the insertion block to get to the module, this is 1555396b7253SJohannes Doerfert // unfortunate but we work around it for now. 15567cfd267cSsstefan1 if (!OMPInfoCache.OMPBuilder.getInsertionPoint().getBlock()) 15577cfd267cSsstefan1 OMPInfoCache.OMPBuilder.updateToLocation(OpenMPIRBuilder::InsertPointTy( 1558396b7253SJohannes Doerfert &F.getEntryBlock(), F.getEntryBlock().begin())); 1559396b7253SJohannes Doerfert // Create a fallback location if non was found. 1560396b7253SJohannes Doerfert // TODO: Use the debug locations of the calls instead. 15617cfd267cSsstefan1 Constant *Loc = OMPInfoCache.OMPBuilder.getOrCreateDefaultSrcLocStr(); 15627cfd267cSsstefan1 Ident = OMPInfoCache.OMPBuilder.getOrCreateIdent(Loc); 1563396b7253SJohannes Doerfert } 1564396b7253SJohannes Doerfert return Ident; 1565396b7253SJohannes Doerfert } 1566396b7253SJohannes Doerfert 1567b726c557SJohannes Doerfert /// Try to eliminate calls of \p RFI in \p F by reusing an existing one or 15689548b74aSJohannes Doerfert /// \p ReplVal if given. 15697cfd267cSsstefan1 bool deduplicateRuntimeCalls(Function &F, 15707cfd267cSsstefan1 OMPInformationCache::RuntimeFunctionInfo &RFI, 15719548b74aSJohannes Doerfert Value *ReplVal = nullptr) { 15728855fec3SJohannes Doerfert auto *UV = RFI.getUseVector(F); 15738855fec3SJohannes Doerfert if (!UV || UV->size() + (ReplVal != nullptr) < 2) 1574b1fbf438SRoman Lebedev return false; 1575b1fbf438SRoman Lebedev 15767cfd267cSsstefan1 LLVM_DEBUG( 15777cfd267cSsstefan1 dbgs() << TAG << "Deduplicate " << UV->size() << " uses of " << RFI.Name 15787cfd267cSsstefan1 << (ReplVal ? " with an existing value\n" : "\n") << "\n"); 15797cfd267cSsstefan1 1580ab3da5ddSMichael Liao assert((!ReplVal || (isa<Argument>(ReplVal) && 1581ab3da5ddSMichael Liao cast<Argument>(ReplVal)->getParent() == &F)) && 15829548b74aSJohannes Doerfert "Unexpected replacement value!"); 1583396b7253SJohannes Doerfert 1584396b7253SJohannes Doerfert // TODO: Use dominance to find a good position instead. 15856aab27baSsstefan1 auto CanBeMoved = [this](CallBase &CB) { 1586396b7253SJohannes Doerfert unsigned NumArgs = CB.getNumArgOperands(); 1587396b7253SJohannes Doerfert if (NumArgs == 0) 1588396b7253SJohannes Doerfert return true; 15896aab27baSsstefan1 if (CB.getArgOperand(0)->getType() != OMPInfoCache.OMPBuilder.IdentPtr) 1590396b7253SJohannes Doerfert return false; 1591396b7253SJohannes Doerfert for (unsigned u = 1; u < NumArgs; ++u) 1592396b7253SJohannes Doerfert if (isa<Instruction>(CB.getArgOperand(u))) 1593396b7253SJohannes Doerfert return false; 1594396b7253SJohannes Doerfert return true; 1595396b7253SJohannes Doerfert }; 1596396b7253SJohannes Doerfert 15979548b74aSJohannes Doerfert if (!ReplVal) { 15988855fec3SJohannes Doerfert for (Use *U : *UV) 15999548b74aSJohannes Doerfert if (CallInst *CI = getCallIfRegularCall(*U, &RFI)) { 1600396b7253SJohannes Doerfert if (!CanBeMoved(*CI)) 1601396b7253SJohannes Doerfert continue; 16024d4ea9acSHuber, Joseph 1603f97de4cbSGiorgis Georgakoudis // If the function is a kernel, dedup will move 1604f97de4cbSGiorgis Georgakoudis // the runtime call right after the kernel init callsite. Otherwise, 1605f97de4cbSGiorgis Georgakoudis // it will move it to the beginning of the caller function. 1606f97de4cbSGiorgis Georgakoudis if (isKernel(F)) { 1607f97de4cbSGiorgis Georgakoudis auto &KernelInitRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init]; 1608f97de4cbSGiorgis Georgakoudis auto *KernelInitUV = KernelInitRFI.getUseVector(F); 1609f97de4cbSGiorgis Georgakoudis 1610f97de4cbSGiorgis Georgakoudis if (KernelInitUV->empty()) 1611f97de4cbSGiorgis Georgakoudis continue; 1612f97de4cbSGiorgis Georgakoudis 1613f97de4cbSGiorgis Georgakoudis assert(KernelInitUV->size() == 1 && 1614f97de4cbSGiorgis Georgakoudis "Expected a single __kmpc_target_init in kernel\n"); 1615f97de4cbSGiorgis Georgakoudis 1616f97de4cbSGiorgis Georgakoudis CallInst *KernelInitCI = 1617f97de4cbSGiorgis Georgakoudis getCallIfRegularCall(*KernelInitUV->front(), &KernelInitRFI); 1618f97de4cbSGiorgis Georgakoudis assert(KernelInitCI && 1619f97de4cbSGiorgis Georgakoudis "Expected a call to __kmpc_target_init in kernel\n"); 1620f97de4cbSGiorgis Georgakoudis 1621f97de4cbSGiorgis Georgakoudis CI->moveAfter(KernelInitCI); 1622f97de4cbSGiorgis Georgakoudis } else 16239548b74aSJohannes Doerfert CI->moveBefore(&*F.getEntryBlock().getFirstInsertionPt()); 16249548b74aSJohannes Doerfert ReplVal = CI; 16259548b74aSJohannes Doerfert break; 16269548b74aSJohannes Doerfert } 16279548b74aSJohannes Doerfert if (!ReplVal) 16289548b74aSJohannes Doerfert return false; 16299548b74aSJohannes Doerfert } 16309548b74aSJohannes Doerfert 1631396b7253SJohannes Doerfert // If we use a call as a replacement value we need to make sure the ident is 1632396b7253SJohannes Doerfert // valid at the new location. For now we just pick a global one, either 1633396b7253SJohannes Doerfert // existing and used by one of the calls, or created from scratch. 1634396b7253SJohannes Doerfert if (CallBase *CI = dyn_cast<CallBase>(ReplVal)) { 1635396b7253SJohannes Doerfert if (CI->getNumArgOperands() > 0 && 16366aab27baSsstefan1 CI->getArgOperand(0)->getType() == OMPInfoCache.OMPBuilder.IdentPtr) { 1637396b7253SJohannes Doerfert Value *Ident = getCombinedIdentFromCallUsesIn(RFI, F, 1638396b7253SJohannes Doerfert /* GlobalOnly */ true); 1639396b7253SJohannes Doerfert CI->setArgOperand(0, Ident); 1640396b7253SJohannes Doerfert } 1641396b7253SJohannes Doerfert } 1642396b7253SJohannes Doerfert 16439548b74aSJohannes Doerfert bool Changed = false; 16449548b74aSJohannes Doerfert auto ReplaceAndDeleteCB = [&](Use &U, Function &Caller) { 16459548b74aSJohannes Doerfert CallInst *CI = getCallIfRegularCall(U, &RFI); 16469548b74aSJohannes Doerfert if (!CI || CI == ReplVal || &F != &Caller) 16479548b74aSJohannes Doerfert return false; 16489548b74aSJohannes Doerfert assert(CI->getCaller() == &F && "Unexpected call!"); 16494d4ea9acSHuber, Joseph 16504d4ea9acSHuber, Joseph auto Remark = [&](OptimizationRemark OR) { 16514d4ea9acSHuber, Joseph return OR << "OpenMP runtime call " 1652eef6601bSJoseph Huber << ore::NV("OpenMPOptRuntime", RFI.Name) << " deduplicated."; 16534d4ea9acSHuber, Joseph }; 1654eef6601bSJoseph Huber if (CI->getDebugLoc()) 16552c31d5ebSJoseph Huber emitRemark<OptimizationRemark>(CI, "OMP170", Remark); 1656eef6601bSJoseph Huber else 16572c31d5ebSJoseph Huber emitRemark<OptimizationRemark>(&F, "OMP170", Remark); 16584d4ea9acSHuber, Joseph 16599548b74aSJohannes Doerfert CGUpdater.removeCallSite(*CI); 16609548b74aSJohannes Doerfert CI->replaceAllUsesWith(ReplVal); 16619548b74aSJohannes Doerfert CI->eraseFromParent(); 16629548b74aSJohannes Doerfert ++NumOpenMPRuntimeCallsDeduplicated; 16639548b74aSJohannes Doerfert Changed = true; 16649548b74aSJohannes Doerfert return true; 16659548b74aSJohannes Doerfert }; 1666624d34afSJohannes Doerfert RFI.foreachUse(SCC, ReplaceAndDeleteCB); 16679548b74aSJohannes Doerfert 16689548b74aSJohannes Doerfert return Changed; 16699548b74aSJohannes Doerfert } 16709548b74aSJohannes Doerfert 16719548b74aSJohannes Doerfert /// Collect arguments that represent the global thread id in \p GTIdArgs. 16729548b74aSJohannes Doerfert void collectGlobalThreadIdArguments(SmallSetVector<Value *, 16> >IdArgs) { 16739548b74aSJohannes Doerfert // TODO: Below we basically perform a fixpoint iteration with a pessimistic 16749548b74aSJohannes Doerfert // initialization. We could define an AbstractAttribute instead and 16759548b74aSJohannes Doerfert // run the Attributor here once it can be run as an SCC pass. 16769548b74aSJohannes Doerfert 16779548b74aSJohannes Doerfert // Helper to check the argument \p ArgNo at all call sites of \p F for 16789548b74aSJohannes Doerfert // a GTId. 16799548b74aSJohannes Doerfert auto CallArgOpIsGTId = [&](Function &F, unsigned ArgNo, CallInst &RefCI) { 16809548b74aSJohannes Doerfert if (!F.hasLocalLinkage()) 16819548b74aSJohannes Doerfert return false; 16829548b74aSJohannes Doerfert for (Use &U : F.uses()) { 16839548b74aSJohannes Doerfert if (CallInst *CI = getCallIfRegularCall(U)) { 16849548b74aSJohannes Doerfert Value *ArgOp = CI->getArgOperand(ArgNo); 16859548b74aSJohannes Doerfert if (CI == &RefCI || GTIdArgs.count(ArgOp) || 16867cfd267cSsstefan1 getCallIfRegularCall( 16877cfd267cSsstefan1 *ArgOp, &OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num])) 16889548b74aSJohannes Doerfert continue; 16899548b74aSJohannes Doerfert } 16909548b74aSJohannes Doerfert return false; 16919548b74aSJohannes Doerfert } 16929548b74aSJohannes Doerfert return true; 16939548b74aSJohannes Doerfert }; 16949548b74aSJohannes Doerfert 16959548b74aSJohannes Doerfert // Helper to identify uses of a GTId as GTId arguments. 16969548b74aSJohannes Doerfert auto AddUserArgs = [&](Value >Id) { 16979548b74aSJohannes Doerfert for (Use &U : GTId.uses()) 16989548b74aSJohannes Doerfert if (CallInst *CI = dyn_cast<CallInst>(U.getUser())) 16999548b74aSJohannes Doerfert if (CI->isArgOperand(&U)) 17009548b74aSJohannes Doerfert if (Function *Callee = CI->getCalledFunction()) 17019548b74aSJohannes Doerfert if (CallArgOpIsGTId(*Callee, U.getOperandNo(), *CI)) 17029548b74aSJohannes Doerfert GTIdArgs.insert(Callee->getArg(U.getOperandNo())); 17039548b74aSJohannes Doerfert }; 17049548b74aSJohannes Doerfert 17059548b74aSJohannes Doerfert // The argument users of __kmpc_global_thread_num calls are GTIds. 17067cfd267cSsstefan1 OMPInformationCache::RuntimeFunctionInfo &GlobThreadNumRFI = 17077cfd267cSsstefan1 OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num]; 17087cfd267cSsstefan1 1709624d34afSJohannes Doerfert GlobThreadNumRFI.foreachUse(SCC, [&](Use &U, Function &F) { 17108855fec3SJohannes Doerfert if (CallInst *CI = getCallIfRegularCall(U, &GlobThreadNumRFI)) 17119548b74aSJohannes Doerfert AddUserArgs(*CI); 17128855fec3SJohannes Doerfert return false; 17138855fec3SJohannes Doerfert }); 17149548b74aSJohannes Doerfert 17159548b74aSJohannes Doerfert // Transitively search for more arguments by looking at the users of the 17169548b74aSJohannes Doerfert // ones we know already. During the search the GTIdArgs vector is extended 17179548b74aSJohannes Doerfert // so we cannot cache the size nor can we use a range based for. 17189548b74aSJohannes Doerfert for (unsigned u = 0; u < GTIdArgs.size(); ++u) 17199548b74aSJohannes Doerfert AddUserArgs(*GTIdArgs[u]); 17209548b74aSJohannes Doerfert } 17219548b74aSJohannes Doerfert 17225b0581aeSJohannes Doerfert /// Kernel (=GPU) optimizations and utility functions 17235b0581aeSJohannes Doerfert /// 17245b0581aeSJohannes Doerfert ///{{ 17255b0581aeSJohannes Doerfert 17265b0581aeSJohannes Doerfert /// Check if \p F is a kernel, hence entry point for target offloading. 17275b0581aeSJohannes Doerfert bool isKernel(Function &F) { return OMPInfoCache.Kernels.count(&F); } 17285b0581aeSJohannes Doerfert 17295b0581aeSJohannes Doerfert /// Cache to remember the unique kernel for a function. 17305b0581aeSJohannes Doerfert DenseMap<Function *, Optional<Kernel>> UniqueKernelMap; 17315b0581aeSJohannes Doerfert 17325b0581aeSJohannes Doerfert /// Find the unique kernel that will execute \p F, if any. 17335b0581aeSJohannes Doerfert Kernel getUniqueKernelFor(Function &F); 17345b0581aeSJohannes Doerfert 17355b0581aeSJohannes Doerfert /// Find the unique kernel that will execute \p I, if any. 17365b0581aeSJohannes Doerfert Kernel getUniqueKernelFor(Instruction &I) { 17375b0581aeSJohannes Doerfert return getUniqueKernelFor(*I.getFunction()); 17385b0581aeSJohannes Doerfert } 17395b0581aeSJohannes Doerfert 17405b0581aeSJohannes Doerfert /// Rewrite the device (=GPU) code state machine create in non-SPMD mode in 17415b0581aeSJohannes Doerfert /// the cases we can avoid taking the address of a function. 17425b0581aeSJohannes Doerfert bool rewriteDeviceCodeStateMachine(); 17435b0581aeSJohannes Doerfert 17445b0581aeSJohannes Doerfert /// 17455b0581aeSJohannes Doerfert ///}} 17465b0581aeSJohannes Doerfert 17474d4ea9acSHuber, Joseph /// Emit a remark generically 17484d4ea9acSHuber, Joseph /// 17494d4ea9acSHuber, Joseph /// This template function can be used to generically emit a remark. The 17504d4ea9acSHuber, Joseph /// RemarkKind should be one of the following: 17514d4ea9acSHuber, Joseph /// - OptimizationRemark to indicate a successful optimization attempt 17524d4ea9acSHuber, Joseph /// - OptimizationRemarkMissed to report a failed optimization attempt 17534d4ea9acSHuber, Joseph /// - OptimizationRemarkAnalysis to provide additional information about an 17544d4ea9acSHuber, Joseph /// optimization attempt 17554d4ea9acSHuber, Joseph /// 17564d4ea9acSHuber, Joseph /// The remark is built using a callback function provided by the caller that 17574d4ea9acSHuber, Joseph /// takes a RemarkKind as input and returns a RemarkKind. 17582db182ffSJoseph Huber template <typename RemarkKind, typename RemarkCallBack> 17592db182ffSJoseph Huber void emitRemark(Instruction *I, StringRef RemarkName, 1760e8039ad4SJohannes Doerfert RemarkCallBack &&RemarkCB) const { 17612db182ffSJoseph Huber Function *F = I->getParent()->getParent(); 17624d4ea9acSHuber, Joseph auto &ORE = OREGetter(F); 17634d4ea9acSHuber, Joseph 17642c31d5ebSJoseph Huber if (RemarkName.startswith("OMP")) 17652c31d5ebSJoseph Huber ORE.emit([&]() { 17662c31d5ebSJoseph Huber return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I)) 17672c31d5ebSJoseph Huber << " [" << RemarkName << "]"; 17682c31d5ebSJoseph Huber }); 17692c31d5ebSJoseph Huber else 17702c31d5ebSJoseph Huber ORE.emit( 17712c31d5ebSJoseph Huber [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I)); }); 17724d4ea9acSHuber, Joseph } 17734d4ea9acSHuber, Joseph 17742db182ffSJoseph Huber /// Emit a remark on a function. 17752db182ffSJoseph Huber template <typename RemarkKind, typename RemarkCallBack> 17762db182ffSJoseph Huber void emitRemark(Function *F, StringRef RemarkName, 17772db182ffSJoseph Huber RemarkCallBack &&RemarkCB) const { 17780f426935Ssstefan1 auto &ORE = OREGetter(F); 17790f426935Ssstefan1 17802c31d5ebSJoseph Huber if (RemarkName.startswith("OMP")) 17812c31d5ebSJoseph Huber ORE.emit([&]() { 17822c31d5ebSJoseph Huber return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F)) 17832c31d5ebSJoseph Huber << " [" << RemarkName << "]"; 17842c31d5ebSJoseph Huber }); 17852c31d5ebSJoseph Huber else 17862c31d5ebSJoseph Huber ORE.emit( 17872c31d5ebSJoseph Huber [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F)); }); 17880f426935Ssstefan1 } 17890f426935Ssstefan1 179058725c12SJoseph Huber /// RAII struct to temporarily change an RTL function's linkage to external. 179158725c12SJoseph Huber /// This prevents it from being mistakenly removed by other optimizations. 179258725c12SJoseph Huber struct ExternalizationRAII { 179358725c12SJoseph Huber ExternalizationRAII(OMPInformationCache &OMPInfoCache, 179458725c12SJoseph Huber RuntimeFunction RFKind) 1795e757a3b0SJoseph Huber : Declaration(OMPInfoCache.RFIs[RFKind].Declaration) { 179658725c12SJoseph Huber if (!Declaration) 179758725c12SJoseph Huber return; 179858725c12SJoseph Huber 179958725c12SJoseph Huber LinkageType = Declaration->getLinkage(); 180058725c12SJoseph Huber Declaration->setLinkage(GlobalValue::ExternalLinkage); 180158725c12SJoseph Huber } 180258725c12SJoseph Huber 180358725c12SJoseph Huber ~ExternalizationRAII() { 180458725c12SJoseph Huber if (!Declaration) 180558725c12SJoseph Huber return; 180658725c12SJoseph Huber 180758725c12SJoseph Huber Declaration->setLinkage(LinkageType); 180858725c12SJoseph Huber } 180958725c12SJoseph Huber 181058725c12SJoseph Huber Function *Declaration; 181158725c12SJoseph Huber GlobalValue::LinkageTypes LinkageType; 181258725c12SJoseph Huber }; 181358725c12SJoseph Huber 1814b726c557SJohannes Doerfert /// The underlying module. 18159548b74aSJohannes Doerfert Module &M; 18169548b74aSJohannes Doerfert 18179548b74aSJohannes Doerfert /// The SCC we are operating on. 1818ee17263aSJohannes Doerfert SmallVectorImpl<Function *> &SCC; 18199548b74aSJohannes Doerfert 18209548b74aSJohannes Doerfert /// Callback to update the call graph, the first argument is a removed call, 18219548b74aSJohannes Doerfert /// the second an optional replacement call. 18229548b74aSJohannes Doerfert CallGraphUpdater &CGUpdater; 18239548b74aSJohannes Doerfert 18244d4ea9acSHuber, Joseph /// Callback to get an OptimizationRemarkEmitter from a Function * 18254d4ea9acSHuber, Joseph OptimizationRemarkGetter OREGetter; 18264d4ea9acSHuber, Joseph 18277cfd267cSsstefan1 /// OpenMP-specific information cache. Also Used for Attributor runs. 18287cfd267cSsstefan1 OMPInformationCache &OMPInfoCache; 1829b8235d2bSsstefan1 1830b8235d2bSsstefan1 /// Attributor instance. 1831b8235d2bSsstefan1 Attributor &A; 1832b8235d2bSsstefan1 1833b8235d2bSsstefan1 /// Helper function to run Attributor on SCC. 1834d9659bf6SJohannes Doerfert bool runAttributor(bool IsModulePass) { 1835b8235d2bSsstefan1 if (SCC.empty()) 1836b8235d2bSsstefan1 return false; 1837b8235d2bSsstefan1 183858725c12SJoseph Huber // Temporarily make these function have external linkage so the Attributor 183958725c12SJoseph Huber // doesn't remove them when we try to look them up later. 184058725c12SJoseph Huber ExternalizationRAII Parallel(OMPInfoCache, OMPRTL___kmpc_kernel_parallel); 184158725c12SJoseph Huber ExternalizationRAII EndParallel(OMPInfoCache, 184258725c12SJoseph Huber OMPRTL___kmpc_kernel_end_parallel); 184358725c12SJoseph Huber ExternalizationRAII BarrierSPMD(OMPInfoCache, 184458725c12SJoseph Huber OMPRTL___kmpc_barrier_simple_spmd); 184558725c12SJoseph Huber 1846d9659bf6SJohannes Doerfert registerAAs(IsModulePass); 1847b8235d2bSsstefan1 1848b8235d2bSsstefan1 ChangeStatus Changed = A.run(); 1849b8235d2bSsstefan1 1850b8235d2bSsstefan1 LLVM_DEBUG(dbgs() << "[Attributor] Done with " << SCC.size() 1851b8235d2bSsstefan1 << " functions, result: " << Changed << ".\n"); 1852b8235d2bSsstefan1 1853b8235d2bSsstefan1 return Changed == ChangeStatus::CHANGED; 1854b8235d2bSsstefan1 } 1855b8235d2bSsstefan1 18565ab6aeddSJose M Monsalve Diaz void registerFoldRuntimeCall(RuntimeFunction RF); 18575ab6aeddSJose M Monsalve Diaz 1858b8235d2bSsstefan1 /// Populate the Attributor with abstract attribute opportunities in the 1859b8235d2bSsstefan1 /// function. 1860d9659bf6SJohannes Doerfert void registerAAs(bool IsModulePass); 1861b8235d2bSsstefan1 }; 1862b8235d2bSsstefan1 18635b0581aeSJohannes Doerfert Kernel OpenMPOpt::getUniqueKernelFor(Function &F) { 18645b0581aeSJohannes Doerfert if (!OMPInfoCache.ModuleSlice.count(&F)) 18655b0581aeSJohannes Doerfert return nullptr; 18665b0581aeSJohannes Doerfert 18675b0581aeSJohannes Doerfert // Use a scope to keep the lifetime of the CachedKernel short. 18685b0581aeSJohannes Doerfert { 18695b0581aeSJohannes Doerfert Optional<Kernel> &CachedKernel = UniqueKernelMap[&F]; 18705b0581aeSJohannes Doerfert if (CachedKernel) 18715b0581aeSJohannes Doerfert return *CachedKernel; 18725b0581aeSJohannes Doerfert 18735b0581aeSJohannes Doerfert // TODO: We should use an AA to create an (optimistic and callback 18745b0581aeSJohannes Doerfert // call-aware) call graph. For now we stick to simple patterns that 18755b0581aeSJohannes Doerfert // are less powerful, basically the worst fixpoint. 18765b0581aeSJohannes Doerfert if (isKernel(F)) { 18775b0581aeSJohannes Doerfert CachedKernel = Kernel(&F); 18785b0581aeSJohannes Doerfert return *CachedKernel; 18795b0581aeSJohannes Doerfert } 18805b0581aeSJohannes Doerfert 18815b0581aeSJohannes Doerfert CachedKernel = nullptr; 1882994bb6ebSJohannes Doerfert if (!F.hasLocalLinkage()) { 1883994bb6ebSJohannes Doerfert 1884994bb6ebSJohannes Doerfert // See https://openmp.llvm.org/remarks/OptimizationRemarks.html 18852db182ffSJoseph Huber auto Remark = [&](OptimizationRemarkAnalysis ORA) { 1886eef6601bSJoseph Huber return ORA << "Potentially unknown OpenMP target region caller."; 1887994bb6ebSJohannes Doerfert }; 18882db182ffSJoseph Huber emitRemark<OptimizationRemarkAnalysis>(&F, "OMP100", Remark); 1889994bb6ebSJohannes Doerfert 18905b0581aeSJohannes Doerfert return nullptr; 18915b0581aeSJohannes Doerfert } 1892994bb6ebSJohannes Doerfert } 18935b0581aeSJohannes Doerfert 18945b0581aeSJohannes Doerfert auto GetUniqueKernelForUse = [&](const Use &U) -> Kernel { 18955b0581aeSJohannes Doerfert if (auto *Cmp = dyn_cast<ICmpInst>(U.getUser())) { 18965b0581aeSJohannes Doerfert // Allow use in equality comparisons. 18975b0581aeSJohannes Doerfert if (Cmp->isEquality()) 18985b0581aeSJohannes Doerfert return getUniqueKernelFor(*Cmp); 18995b0581aeSJohannes Doerfert return nullptr; 19005b0581aeSJohannes Doerfert } 19015b0581aeSJohannes Doerfert if (auto *CB = dyn_cast<CallBase>(U.getUser())) { 19025b0581aeSJohannes Doerfert // Allow direct calls. 19035b0581aeSJohannes Doerfert if (CB->isCallee(&U)) 19045b0581aeSJohannes Doerfert return getUniqueKernelFor(*CB); 1905a2dbfb6bSGiorgis Georgakoudis 1906a2dbfb6bSGiorgis Georgakoudis OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI = 1907a2dbfb6bSGiorgis Georgakoudis OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51]; 1908a2dbfb6bSGiorgis Georgakoudis // Allow the use in __kmpc_parallel_51 calls. 1909a2dbfb6bSGiorgis Georgakoudis if (OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI)) 19105b0581aeSJohannes Doerfert return getUniqueKernelFor(*CB); 19115b0581aeSJohannes Doerfert return nullptr; 19125b0581aeSJohannes Doerfert } 19135b0581aeSJohannes Doerfert // Disallow every other use. 19145b0581aeSJohannes Doerfert return nullptr; 19155b0581aeSJohannes Doerfert }; 19165b0581aeSJohannes Doerfert 19175b0581aeSJohannes Doerfert // TODO: In the future we want to track more than just a unique kernel. 19185b0581aeSJohannes Doerfert SmallPtrSet<Kernel, 2> PotentialKernels; 19198d8ce85bSsstefan1 OMPInformationCache::foreachUse(F, [&](const Use &U) { 19205b0581aeSJohannes Doerfert PotentialKernels.insert(GetUniqueKernelForUse(U)); 19215b0581aeSJohannes Doerfert }); 19225b0581aeSJohannes Doerfert 19235b0581aeSJohannes Doerfert Kernel K = nullptr; 19245b0581aeSJohannes Doerfert if (PotentialKernels.size() == 1) 19255b0581aeSJohannes Doerfert K = *PotentialKernels.begin(); 19265b0581aeSJohannes Doerfert 19275b0581aeSJohannes Doerfert // Cache the result. 19285b0581aeSJohannes Doerfert UniqueKernelMap[&F] = K; 19295b0581aeSJohannes Doerfert 19305b0581aeSJohannes Doerfert return K; 19315b0581aeSJohannes Doerfert } 19325b0581aeSJohannes Doerfert 19335b0581aeSJohannes Doerfert bool OpenMPOpt::rewriteDeviceCodeStateMachine() { 1934a2dbfb6bSGiorgis Georgakoudis OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI = 1935a2dbfb6bSGiorgis Georgakoudis OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51]; 19365b0581aeSJohannes Doerfert 19375b0581aeSJohannes Doerfert bool Changed = false; 1938a2dbfb6bSGiorgis Georgakoudis if (!KernelParallelRFI) 19395b0581aeSJohannes Doerfert return Changed; 19405b0581aeSJohannes Doerfert 1941cd0dd8ecSJoseph Huber // If we have disabled state machine changes, exit 1942cd0dd8ecSJoseph Huber if (DisableOpenMPOptStateMachineRewrite) 1943cd0dd8ecSJoseph Huber return Changed; 1944cd0dd8ecSJoseph Huber 19455b0581aeSJohannes Doerfert for (Function *F : SCC) { 19465b0581aeSJohannes Doerfert 1947a2dbfb6bSGiorgis Georgakoudis // Check if the function is a use in a __kmpc_parallel_51 call at 19485b0581aeSJohannes Doerfert // all. 19495b0581aeSJohannes Doerfert bool UnknownUse = false; 1950a2dbfb6bSGiorgis Georgakoudis bool KernelParallelUse = false; 19515b0581aeSJohannes Doerfert unsigned NumDirectCalls = 0; 19525b0581aeSJohannes Doerfert 19535b0581aeSJohannes Doerfert SmallVector<Use *, 2> ToBeReplacedStateMachineUses; 19548d8ce85bSsstefan1 OMPInformationCache::foreachUse(*F, [&](Use &U) { 19555b0581aeSJohannes Doerfert if (auto *CB = dyn_cast<CallBase>(U.getUser())) 19565b0581aeSJohannes Doerfert if (CB->isCallee(&U)) { 19575b0581aeSJohannes Doerfert ++NumDirectCalls; 19585b0581aeSJohannes Doerfert return; 19595b0581aeSJohannes Doerfert } 19605b0581aeSJohannes Doerfert 196181db6144SMichael Liao if (isa<ICmpInst>(U.getUser())) { 19625b0581aeSJohannes Doerfert ToBeReplacedStateMachineUses.push_back(&U); 19635b0581aeSJohannes Doerfert return; 19645b0581aeSJohannes Doerfert } 1965a2dbfb6bSGiorgis Georgakoudis 1966a2dbfb6bSGiorgis Georgakoudis // Find wrapper functions that represent parallel kernels. 1967a2dbfb6bSGiorgis Georgakoudis CallInst *CI = 1968a2dbfb6bSGiorgis Georgakoudis OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI); 1969a2dbfb6bSGiorgis Georgakoudis const unsigned int WrapperFunctionArgNo = 6; 1970a2dbfb6bSGiorgis Georgakoudis if (!KernelParallelUse && CI && 1971a2dbfb6bSGiorgis Georgakoudis CI->getArgOperandNo(&U) == WrapperFunctionArgNo) { 1972a2dbfb6bSGiorgis Georgakoudis KernelParallelUse = true; 19735b0581aeSJohannes Doerfert ToBeReplacedStateMachineUses.push_back(&U); 19745b0581aeSJohannes Doerfert return; 19755b0581aeSJohannes Doerfert } 19765b0581aeSJohannes Doerfert UnknownUse = true; 19775b0581aeSJohannes Doerfert }); 19785b0581aeSJohannes Doerfert 1979a2dbfb6bSGiorgis Georgakoudis // Do not emit a remark if we haven't seen a __kmpc_parallel_51 1980fec1f210SJohannes Doerfert // use. 1981a2dbfb6bSGiorgis Georgakoudis if (!KernelParallelUse) 19825b0581aeSJohannes Doerfert continue; 19835b0581aeSJohannes Doerfert 1984fec1f210SJohannes Doerfert // If this ever hits, we should investigate. 1985fec1f210SJohannes Doerfert // TODO: Checking the number of uses is not a necessary restriction and 1986fec1f210SJohannes Doerfert // should be lifted. 1987fec1f210SJohannes Doerfert if (UnknownUse || NumDirectCalls != 1 || 1988d9659bf6SJohannes Doerfert ToBeReplacedStateMachineUses.size() > 2) { 19892db182ffSJoseph Huber auto Remark = [&](OptimizationRemarkAnalysis ORA) { 19902db182ffSJoseph Huber return ORA << "Parallel region is used in " 1991fec1f210SJohannes Doerfert << (UnknownUse ? "unknown" : "unexpected") 1992eef6601bSJoseph Huber << " ways. Will not attempt to rewrite the state machine."; 1993fec1f210SJohannes Doerfert }; 19942c31d5ebSJoseph Huber emitRemark<OptimizationRemarkAnalysis>(F, "OMP101", Remark); 19955b0581aeSJohannes Doerfert continue; 1996fec1f210SJohannes Doerfert } 19975b0581aeSJohannes Doerfert 1998a2dbfb6bSGiorgis Georgakoudis // Even if we have __kmpc_parallel_51 calls, we (for now) give 19995b0581aeSJohannes Doerfert // up if the function is not called from a unique kernel. 20005b0581aeSJohannes Doerfert Kernel K = getUniqueKernelFor(*F); 2001fec1f210SJohannes Doerfert if (!K) { 20022db182ffSJoseph Huber auto Remark = [&](OptimizationRemarkAnalysis ORA) { 2003eef6601bSJoseph Huber return ORA << "Parallel region is not called from a unique kernel. " 2004eef6601bSJoseph Huber "Will not attempt to rewrite the state machine."; 2005fec1f210SJohannes Doerfert }; 20062c31d5ebSJoseph Huber emitRemark<OptimizationRemarkAnalysis>(F, "OMP102", Remark); 20075b0581aeSJohannes Doerfert continue; 2008fec1f210SJohannes Doerfert } 20095b0581aeSJohannes Doerfert 20105b0581aeSJohannes Doerfert // We now know F is a parallel body function called only from the kernel K. 20115b0581aeSJohannes Doerfert // We also identified the state machine uses in which we replace the 20125b0581aeSJohannes Doerfert // function pointer by a new global symbol for identification purposes. This 20135b0581aeSJohannes Doerfert // ensures only direct calls to the function are left. 20145b0581aeSJohannes Doerfert 20155b0581aeSJohannes Doerfert Module &M = *F->getParent(); 20165b0581aeSJohannes Doerfert Type *Int8Ty = Type::getInt8Ty(M.getContext()); 20175b0581aeSJohannes Doerfert 20185b0581aeSJohannes Doerfert auto *ID = new GlobalVariable( 20195b0581aeSJohannes Doerfert M, Int8Ty, /* isConstant */ true, GlobalValue::PrivateLinkage, 20205b0581aeSJohannes Doerfert UndefValue::get(Int8Ty), F->getName() + ".ID"); 20215b0581aeSJohannes Doerfert 20225b0581aeSJohannes Doerfert for (Use *U : ToBeReplacedStateMachineUses) 20235b0581aeSJohannes Doerfert U->set(ConstantExpr::getBitCast(ID, U->get()->getType())); 20245b0581aeSJohannes Doerfert 20255b0581aeSJohannes Doerfert ++NumOpenMPParallelRegionsReplacedInGPUStateMachine; 20265b0581aeSJohannes Doerfert 20275b0581aeSJohannes Doerfert Changed = true; 20285b0581aeSJohannes Doerfert } 20295b0581aeSJohannes Doerfert 20305b0581aeSJohannes Doerfert return Changed; 20315b0581aeSJohannes Doerfert } 20325b0581aeSJohannes Doerfert 2033b8235d2bSsstefan1 /// Abstract Attribute for tracking ICV values. 2034b8235d2bSsstefan1 struct AAICVTracker : public StateWrapper<BooleanState, AbstractAttribute> { 2035b8235d2bSsstefan1 using Base = StateWrapper<BooleanState, AbstractAttribute>; 2036b8235d2bSsstefan1 AAICVTracker(const IRPosition &IRP, Attributor &A) : Base(IRP) {} 2037b8235d2bSsstefan1 20385dfd7cc4Ssstefan1 void initialize(Attributor &A) override { 20395dfd7cc4Ssstefan1 Function *F = getAnchorScope(); 20405dfd7cc4Ssstefan1 if (!F || !A.isFunctionIPOAmendable(*F)) 20415dfd7cc4Ssstefan1 indicatePessimisticFixpoint(); 20425dfd7cc4Ssstefan1 } 20435dfd7cc4Ssstefan1 2044b8235d2bSsstefan1 /// Returns true if value is assumed to be tracked. 2045b8235d2bSsstefan1 bool isAssumedTracked() const { return getAssumed(); } 2046b8235d2bSsstefan1 2047b8235d2bSsstefan1 /// Returns true if value is known to be tracked. 2048b8235d2bSsstefan1 bool isKnownTracked() const { return getAssumed(); } 2049b8235d2bSsstefan1 2050b8235d2bSsstefan1 /// Create an abstract attribute biew for the position \p IRP. 2051b8235d2bSsstefan1 static AAICVTracker &createForPosition(const IRPosition &IRP, Attributor &A); 2052b8235d2bSsstefan1 2053b8235d2bSsstefan1 /// Return the value with which \p I can be replaced for specific \p ICV. 20545dfd7cc4Ssstefan1 virtual Optional<Value *> getReplacementValue(InternalControlVar ICV, 20555dfd7cc4Ssstefan1 const Instruction *I, 20565dfd7cc4Ssstefan1 Attributor &A) const { 20575dfd7cc4Ssstefan1 return None; 20585dfd7cc4Ssstefan1 } 20595dfd7cc4Ssstefan1 20605dfd7cc4Ssstefan1 /// Return an assumed unique ICV value if a single candidate is found. If 20615dfd7cc4Ssstefan1 /// there cannot be one, return a nullptr. If it is not clear yet, return the 20625dfd7cc4Ssstefan1 /// Optional::NoneType. 20635dfd7cc4Ssstefan1 virtual Optional<Value *> 20645dfd7cc4Ssstefan1 getUniqueReplacementValue(InternalControlVar ICV) const = 0; 20655dfd7cc4Ssstefan1 20665dfd7cc4Ssstefan1 // Currently only nthreads is being tracked. 20675dfd7cc4Ssstefan1 // this array will only grow with time. 20685dfd7cc4Ssstefan1 InternalControlVar TrackableICVs[1] = {ICV_nthreads}; 2069b8235d2bSsstefan1 2070b8235d2bSsstefan1 /// See AbstractAttribute::getName() 2071b8235d2bSsstefan1 const std::string getName() const override { return "AAICVTracker"; } 2072b8235d2bSsstefan1 2073233af895SLuofan Chen /// See AbstractAttribute::getIdAddr() 2074233af895SLuofan Chen const char *getIdAddr() const override { return &ID; } 2075233af895SLuofan Chen 2076233af895SLuofan Chen /// This function should return true if the type of the \p AA is AAICVTracker 2077233af895SLuofan Chen static bool classof(const AbstractAttribute *AA) { 2078233af895SLuofan Chen return (AA->getIdAddr() == &ID); 2079233af895SLuofan Chen } 2080233af895SLuofan Chen 2081b8235d2bSsstefan1 static const char ID; 2082b8235d2bSsstefan1 }; 2083b8235d2bSsstefan1 2084b8235d2bSsstefan1 struct AAICVTrackerFunction : public AAICVTracker { 2085b8235d2bSsstefan1 AAICVTrackerFunction(const IRPosition &IRP, Attributor &A) 2086b8235d2bSsstefan1 : AAICVTracker(IRP, A) {} 2087b8235d2bSsstefan1 2088b8235d2bSsstefan1 // FIXME: come up with better string. 20895dfd7cc4Ssstefan1 const std::string getAsStr() const override { return "ICVTrackerFunction"; } 2090b8235d2bSsstefan1 2091b8235d2bSsstefan1 // FIXME: come up with some stats. 2092b8235d2bSsstefan1 void trackStatistics() const override {} 2093b8235d2bSsstefan1 20945dfd7cc4Ssstefan1 /// We don't manifest anything for this AA. 2095b8235d2bSsstefan1 ChangeStatus manifest(Attributor &A) override { 20965dfd7cc4Ssstefan1 return ChangeStatus::UNCHANGED; 2097b8235d2bSsstefan1 } 2098b8235d2bSsstefan1 2099b8235d2bSsstefan1 // Map of ICV to their values at specific program point. 21005dfd7cc4Ssstefan1 EnumeratedArray<DenseMap<Instruction *, Value *>, InternalControlVar, 2101b8235d2bSsstefan1 InternalControlVar::ICV___last> 21025dfd7cc4Ssstefan1 ICVReplacementValuesMap; 2103b8235d2bSsstefan1 2104b8235d2bSsstefan1 ChangeStatus updateImpl(Attributor &A) override { 2105b8235d2bSsstefan1 ChangeStatus HasChanged = ChangeStatus::UNCHANGED; 2106b8235d2bSsstefan1 2107b8235d2bSsstefan1 Function *F = getAnchorScope(); 2108b8235d2bSsstefan1 2109b8235d2bSsstefan1 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); 2110b8235d2bSsstefan1 2111b8235d2bSsstefan1 for (InternalControlVar ICV : TrackableICVs) { 2112b8235d2bSsstefan1 auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter]; 2113b8235d2bSsstefan1 21145dfd7cc4Ssstefan1 auto &ValuesMap = ICVReplacementValuesMap[ICV]; 2115b8235d2bSsstefan1 auto TrackValues = [&](Use &U, Function &) { 2116b8235d2bSsstefan1 CallInst *CI = OpenMPOpt::getCallIfRegularCall(U); 2117b8235d2bSsstefan1 if (!CI) 2118b8235d2bSsstefan1 return false; 2119b8235d2bSsstefan1 2120b8235d2bSsstefan1 // FIXME: handle setters with more that 1 arguments. 2121b8235d2bSsstefan1 /// Track new value. 21225dfd7cc4Ssstefan1 if (ValuesMap.insert(std::make_pair(CI, CI->getArgOperand(0))).second) 2123b8235d2bSsstefan1 HasChanged = ChangeStatus::CHANGED; 2124b8235d2bSsstefan1 2125b8235d2bSsstefan1 return false; 2126b8235d2bSsstefan1 }; 2127b8235d2bSsstefan1 21285dfd7cc4Ssstefan1 auto CallCheck = [&](Instruction &I) { 21295dfd7cc4Ssstefan1 Optional<Value *> ReplVal = getValueForCall(A, &I, ICV); 21305dfd7cc4Ssstefan1 if (ReplVal.hasValue() && 21315dfd7cc4Ssstefan1 ValuesMap.insert(std::make_pair(&I, *ReplVal)).second) 21325dfd7cc4Ssstefan1 HasChanged = ChangeStatus::CHANGED; 21335dfd7cc4Ssstefan1 21345dfd7cc4Ssstefan1 return true; 21355dfd7cc4Ssstefan1 }; 21365dfd7cc4Ssstefan1 21375dfd7cc4Ssstefan1 // Track all changes of an ICV. 2138b8235d2bSsstefan1 SetterRFI.foreachUse(TrackValues, F); 21395dfd7cc4Ssstefan1 2140792aac98SJohannes Doerfert bool UsedAssumedInformation = false; 21415dfd7cc4Ssstefan1 A.checkForAllInstructions(CallCheck, *this, {Instruction::Call}, 2142792aac98SJohannes Doerfert UsedAssumedInformation, 21435dfd7cc4Ssstefan1 /* CheckBBLivenessOnly */ true); 21445dfd7cc4Ssstefan1 21455dfd7cc4Ssstefan1 /// TODO: Figure out a way to avoid adding entry in 21465dfd7cc4Ssstefan1 /// ICVReplacementValuesMap 21475dfd7cc4Ssstefan1 Instruction *Entry = &F->getEntryBlock().front(); 21485dfd7cc4Ssstefan1 if (HasChanged == ChangeStatus::CHANGED && !ValuesMap.count(Entry)) 21495dfd7cc4Ssstefan1 ValuesMap.insert(std::make_pair(Entry, nullptr)); 2150b8235d2bSsstefan1 } 2151b8235d2bSsstefan1 2152b8235d2bSsstefan1 return HasChanged; 2153b8235d2bSsstefan1 } 2154b8235d2bSsstefan1 21555dfd7cc4Ssstefan1 /// Hepler to check if \p I is a call and get the value for it if it is 21565dfd7cc4Ssstefan1 /// unique. 21575dfd7cc4Ssstefan1 Optional<Value *> getValueForCall(Attributor &A, const Instruction *I, 21585dfd7cc4Ssstefan1 InternalControlVar &ICV) const { 2159b8235d2bSsstefan1 21605dfd7cc4Ssstefan1 const auto *CB = dyn_cast<CallBase>(I); 2161dcaec812SJohannes Doerfert if (!CB || CB->hasFnAttr("no_openmp") || 2162dcaec812SJohannes Doerfert CB->hasFnAttr("no_openmp_routines")) 21635dfd7cc4Ssstefan1 return None; 21645dfd7cc4Ssstefan1 2165b8235d2bSsstefan1 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); 2166b8235d2bSsstefan1 auto &GetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Getter]; 21675dfd7cc4Ssstefan1 auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter]; 21685dfd7cc4Ssstefan1 Function *CalledFunction = CB->getCalledFunction(); 2169b8235d2bSsstefan1 21704eef14f9SWei Wang // Indirect call, assume ICV changes. 21714eef14f9SWei Wang if (CalledFunction == nullptr) 21724eef14f9SWei Wang return nullptr; 21735dfd7cc4Ssstefan1 if (CalledFunction == GetterRFI.Declaration) 21745dfd7cc4Ssstefan1 return None; 21755dfd7cc4Ssstefan1 if (CalledFunction == SetterRFI.Declaration) { 21765dfd7cc4Ssstefan1 if (ICVReplacementValuesMap[ICV].count(I)) 21775dfd7cc4Ssstefan1 return ICVReplacementValuesMap[ICV].lookup(I); 21785dfd7cc4Ssstefan1 21795dfd7cc4Ssstefan1 return nullptr; 21805dfd7cc4Ssstefan1 } 21815dfd7cc4Ssstefan1 21825dfd7cc4Ssstefan1 // Since we don't know, assume it changes the ICV. 21835dfd7cc4Ssstefan1 if (CalledFunction->isDeclaration()) 21845dfd7cc4Ssstefan1 return nullptr; 21855dfd7cc4Ssstefan1 21865b70c12fSJohannes Doerfert const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>( 21875b70c12fSJohannes Doerfert *this, IRPosition::callsite_returned(*CB), DepClassTy::REQUIRED); 21885dfd7cc4Ssstefan1 21895dfd7cc4Ssstefan1 if (ICVTrackingAA.isAssumedTracked()) 21905dfd7cc4Ssstefan1 return ICVTrackingAA.getUniqueReplacementValue(ICV); 21915dfd7cc4Ssstefan1 21925dfd7cc4Ssstefan1 // If we don't know, assume it changes. 21935dfd7cc4Ssstefan1 return nullptr; 21945dfd7cc4Ssstefan1 } 21955dfd7cc4Ssstefan1 21965dfd7cc4Ssstefan1 // We don't check unique value for a function, so return None. 21975dfd7cc4Ssstefan1 Optional<Value *> 21985dfd7cc4Ssstefan1 getUniqueReplacementValue(InternalControlVar ICV) const override { 21995dfd7cc4Ssstefan1 return None; 22005dfd7cc4Ssstefan1 } 22015dfd7cc4Ssstefan1 22025dfd7cc4Ssstefan1 /// Return the value with which \p I can be replaced for specific \p ICV. 22035dfd7cc4Ssstefan1 Optional<Value *> getReplacementValue(InternalControlVar ICV, 22045dfd7cc4Ssstefan1 const Instruction *I, 22055dfd7cc4Ssstefan1 Attributor &A) const override { 22065dfd7cc4Ssstefan1 const auto &ValuesMap = ICVReplacementValuesMap[ICV]; 22075dfd7cc4Ssstefan1 if (ValuesMap.count(I)) 22085dfd7cc4Ssstefan1 return ValuesMap.lookup(I); 22095dfd7cc4Ssstefan1 22105dfd7cc4Ssstefan1 SmallVector<const Instruction *, 16> Worklist; 22115dfd7cc4Ssstefan1 SmallPtrSet<const Instruction *, 16> Visited; 22125dfd7cc4Ssstefan1 Worklist.push_back(I); 22135dfd7cc4Ssstefan1 22145dfd7cc4Ssstefan1 Optional<Value *> ReplVal; 22155dfd7cc4Ssstefan1 22165dfd7cc4Ssstefan1 while (!Worklist.empty()) { 22175dfd7cc4Ssstefan1 const Instruction *CurrInst = Worklist.pop_back_val(); 22185dfd7cc4Ssstefan1 if (!Visited.insert(CurrInst).second) 2219b8235d2bSsstefan1 continue; 2220b8235d2bSsstefan1 22215dfd7cc4Ssstefan1 const BasicBlock *CurrBB = CurrInst->getParent(); 22225dfd7cc4Ssstefan1 22235dfd7cc4Ssstefan1 // Go up and look for all potential setters/calls that might change the 22245dfd7cc4Ssstefan1 // ICV. 22255dfd7cc4Ssstefan1 while ((CurrInst = CurrInst->getPrevNode())) { 22265dfd7cc4Ssstefan1 if (ValuesMap.count(CurrInst)) { 22275dfd7cc4Ssstefan1 Optional<Value *> NewReplVal = ValuesMap.lookup(CurrInst); 22285dfd7cc4Ssstefan1 // Unknown value, track new. 22295dfd7cc4Ssstefan1 if (!ReplVal.hasValue()) { 22305dfd7cc4Ssstefan1 ReplVal = NewReplVal; 22315dfd7cc4Ssstefan1 break; 22325dfd7cc4Ssstefan1 } 22335dfd7cc4Ssstefan1 22345dfd7cc4Ssstefan1 // If we found a new value, we can't know the icv value anymore. 22355dfd7cc4Ssstefan1 if (NewReplVal.hasValue()) 22365dfd7cc4Ssstefan1 if (ReplVal != NewReplVal) 2237b8235d2bSsstefan1 return nullptr; 2238b8235d2bSsstefan1 22395dfd7cc4Ssstefan1 break; 2240b8235d2bSsstefan1 } 2241b8235d2bSsstefan1 22425dfd7cc4Ssstefan1 Optional<Value *> NewReplVal = getValueForCall(A, CurrInst, ICV); 22435dfd7cc4Ssstefan1 if (!NewReplVal.hasValue()) 22445dfd7cc4Ssstefan1 continue; 22455dfd7cc4Ssstefan1 22465dfd7cc4Ssstefan1 // Unknown value, track new. 22475dfd7cc4Ssstefan1 if (!ReplVal.hasValue()) { 22485dfd7cc4Ssstefan1 ReplVal = NewReplVal; 22495dfd7cc4Ssstefan1 break; 2250b8235d2bSsstefan1 } 2251b8235d2bSsstefan1 22525dfd7cc4Ssstefan1 // if (NewReplVal.hasValue()) 22535dfd7cc4Ssstefan1 // We found a new value, we can't know the icv value anymore. 22545dfd7cc4Ssstefan1 if (ReplVal != NewReplVal) 2255b8235d2bSsstefan1 return nullptr; 2256b8235d2bSsstefan1 } 22575dfd7cc4Ssstefan1 22585dfd7cc4Ssstefan1 // If we are in the same BB and we have a value, we are done. 22595dfd7cc4Ssstefan1 if (CurrBB == I->getParent() && ReplVal.hasValue()) 22605dfd7cc4Ssstefan1 return ReplVal; 22615dfd7cc4Ssstefan1 22625dfd7cc4Ssstefan1 // Go through all predecessors and add terminators for analysis. 22635dfd7cc4Ssstefan1 for (const BasicBlock *Pred : predecessors(CurrBB)) 22645dfd7cc4Ssstefan1 if (const Instruction *Terminator = Pred->getTerminator()) 22655dfd7cc4Ssstefan1 Worklist.push_back(Terminator); 22665dfd7cc4Ssstefan1 } 22675dfd7cc4Ssstefan1 22685dfd7cc4Ssstefan1 return ReplVal; 22695dfd7cc4Ssstefan1 } 22705dfd7cc4Ssstefan1 }; 22715dfd7cc4Ssstefan1 22725dfd7cc4Ssstefan1 struct AAICVTrackerFunctionReturned : AAICVTracker { 22735dfd7cc4Ssstefan1 AAICVTrackerFunctionReturned(const IRPosition &IRP, Attributor &A) 22745dfd7cc4Ssstefan1 : AAICVTracker(IRP, A) {} 22755dfd7cc4Ssstefan1 22765dfd7cc4Ssstefan1 // FIXME: come up with better string. 22775dfd7cc4Ssstefan1 const std::string getAsStr() const override { 22785dfd7cc4Ssstefan1 return "ICVTrackerFunctionReturned"; 22795dfd7cc4Ssstefan1 } 22805dfd7cc4Ssstefan1 22815dfd7cc4Ssstefan1 // FIXME: come up with some stats. 22825dfd7cc4Ssstefan1 void trackStatistics() const override {} 22835dfd7cc4Ssstefan1 22845dfd7cc4Ssstefan1 /// We don't manifest anything for this AA. 22855dfd7cc4Ssstefan1 ChangeStatus manifest(Attributor &A) override { 22865dfd7cc4Ssstefan1 return ChangeStatus::UNCHANGED; 22875dfd7cc4Ssstefan1 } 22885dfd7cc4Ssstefan1 22895dfd7cc4Ssstefan1 // Map of ICV to their values at specific program point. 22905dfd7cc4Ssstefan1 EnumeratedArray<Optional<Value *>, InternalControlVar, 22915dfd7cc4Ssstefan1 InternalControlVar::ICV___last> 22925dfd7cc4Ssstefan1 ICVReplacementValuesMap; 22935dfd7cc4Ssstefan1 22945dfd7cc4Ssstefan1 /// Return the value with which \p I can be replaced for specific \p ICV. 22955dfd7cc4Ssstefan1 Optional<Value *> 22965dfd7cc4Ssstefan1 getUniqueReplacementValue(InternalControlVar ICV) const override { 22975dfd7cc4Ssstefan1 return ICVReplacementValuesMap[ICV]; 22985dfd7cc4Ssstefan1 } 22995dfd7cc4Ssstefan1 23005dfd7cc4Ssstefan1 ChangeStatus updateImpl(Attributor &A) override { 23015dfd7cc4Ssstefan1 ChangeStatus Changed = ChangeStatus::UNCHANGED; 23025dfd7cc4Ssstefan1 const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>( 23035b70c12fSJohannes Doerfert *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED); 23045dfd7cc4Ssstefan1 23055dfd7cc4Ssstefan1 if (!ICVTrackingAA.isAssumedTracked()) 23065dfd7cc4Ssstefan1 return indicatePessimisticFixpoint(); 23075dfd7cc4Ssstefan1 23085dfd7cc4Ssstefan1 for (InternalControlVar ICV : TrackableICVs) { 23095dfd7cc4Ssstefan1 Optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV]; 23105dfd7cc4Ssstefan1 Optional<Value *> UniqueICVValue; 23115dfd7cc4Ssstefan1 23125dfd7cc4Ssstefan1 auto CheckReturnInst = [&](Instruction &I) { 23135dfd7cc4Ssstefan1 Optional<Value *> NewReplVal = 23145dfd7cc4Ssstefan1 ICVTrackingAA.getReplacementValue(ICV, &I, A); 23155dfd7cc4Ssstefan1 23165dfd7cc4Ssstefan1 // If we found a second ICV value there is no unique returned value. 23175dfd7cc4Ssstefan1 if (UniqueICVValue.hasValue() && UniqueICVValue != NewReplVal) 23185dfd7cc4Ssstefan1 return false; 23195dfd7cc4Ssstefan1 23205dfd7cc4Ssstefan1 UniqueICVValue = NewReplVal; 23215dfd7cc4Ssstefan1 23225dfd7cc4Ssstefan1 return true; 23235dfd7cc4Ssstefan1 }; 23245dfd7cc4Ssstefan1 2325792aac98SJohannes Doerfert bool UsedAssumedInformation = false; 23265dfd7cc4Ssstefan1 if (!A.checkForAllInstructions(CheckReturnInst, *this, {Instruction::Ret}, 2327792aac98SJohannes Doerfert UsedAssumedInformation, 23285dfd7cc4Ssstefan1 /* CheckBBLivenessOnly */ true)) 23295dfd7cc4Ssstefan1 UniqueICVValue = nullptr; 23305dfd7cc4Ssstefan1 23315dfd7cc4Ssstefan1 if (UniqueICVValue == ReplVal) 23325dfd7cc4Ssstefan1 continue; 23335dfd7cc4Ssstefan1 23345dfd7cc4Ssstefan1 ReplVal = UniqueICVValue; 23355dfd7cc4Ssstefan1 Changed = ChangeStatus::CHANGED; 23365dfd7cc4Ssstefan1 } 23375dfd7cc4Ssstefan1 23385dfd7cc4Ssstefan1 return Changed; 23395dfd7cc4Ssstefan1 } 23405dfd7cc4Ssstefan1 }; 23415dfd7cc4Ssstefan1 23425dfd7cc4Ssstefan1 struct AAICVTrackerCallSite : AAICVTracker { 23435dfd7cc4Ssstefan1 AAICVTrackerCallSite(const IRPosition &IRP, Attributor &A) 23445dfd7cc4Ssstefan1 : AAICVTracker(IRP, A) {} 23455dfd7cc4Ssstefan1 23465dfd7cc4Ssstefan1 void initialize(Attributor &A) override { 23475dfd7cc4Ssstefan1 Function *F = getAnchorScope(); 23485dfd7cc4Ssstefan1 if (!F || !A.isFunctionIPOAmendable(*F)) 23495dfd7cc4Ssstefan1 indicatePessimisticFixpoint(); 23505dfd7cc4Ssstefan1 23515dfd7cc4Ssstefan1 // We only initialize this AA for getters, so we need to know which ICV it 23525dfd7cc4Ssstefan1 // gets. 23535dfd7cc4Ssstefan1 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); 23545dfd7cc4Ssstefan1 for (InternalControlVar ICV : TrackableICVs) { 23555dfd7cc4Ssstefan1 auto ICVInfo = OMPInfoCache.ICVs[ICV]; 23565dfd7cc4Ssstefan1 auto &Getter = OMPInfoCache.RFIs[ICVInfo.Getter]; 23575dfd7cc4Ssstefan1 if (Getter.Declaration == getAssociatedFunction()) { 23585dfd7cc4Ssstefan1 AssociatedICV = ICVInfo.Kind; 23595dfd7cc4Ssstefan1 return; 23605dfd7cc4Ssstefan1 } 23615dfd7cc4Ssstefan1 } 23625dfd7cc4Ssstefan1 23635dfd7cc4Ssstefan1 /// Unknown ICV. 23645dfd7cc4Ssstefan1 indicatePessimisticFixpoint(); 23655dfd7cc4Ssstefan1 } 23665dfd7cc4Ssstefan1 23675dfd7cc4Ssstefan1 ChangeStatus manifest(Attributor &A) override { 23685dfd7cc4Ssstefan1 if (!ReplVal.hasValue() || !ReplVal.getValue()) 23695dfd7cc4Ssstefan1 return ChangeStatus::UNCHANGED; 23705dfd7cc4Ssstefan1 23715dfd7cc4Ssstefan1 A.changeValueAfterManifest(*getCtxI(), **ReplVal); 23725dfd7cc4Ssstefan1 A.deleteAfterManifest(*getCtxI()); 23735dfd7cc4Ssstefan1 23745dfd7cc4Ssstefan1 return ChangeStatus::CHANGED; 23755dfd7cc4Ssstefan1 } 23765dfd7cc4Ssstefan1 23775dfd7cc4Ssstefan1 // FIXME: come up with better string. 23785dfd7cc4Ssstefan1 const std::string getAsStr() const override { return "ICVTrackerCallSite"; } 23795dfd7cc4Ssstefan1 23805dfd7cc4Ssstefan1 // FIXME: come up with some stats. 23815dfd7cc4Ssstefan1 void trackStatistics() const override {} 23825dfd7cc4Ssstefan1 23835dfd7cc4Ssstefan1 InternalControlVar AssociatedICV; 23845dfd7cc4Ssstefan1 Optional<Value *> ReplVal; 23855dfd7cc4Ssstefan1 23865dfd7cc4Ssstefan1 ChangeStatus updateImpl(Attributor &A) override { 23875dfd7cc4Ssstefan1 const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>( 23885b70c12fSJohannes Doerfert *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED); 23895dfd7cc4Ssstefan1 23905dfd7cc4Ssstefan1 // We don't have any information, so we assume it changes the ICV. 23915dfd7cc4Ssstefan1 if (!ICVTrackingAA.isAssumedTracked()) 23925dfd7cc4Ssstefan1 return indicatePessimisticFixpoint(); 23935dfd7cc4Ssstefan1 23945dfd7cc4Ssstefan1 Optional<Value *> NewReplVal = 23955dfd7cc4Ssstefan1 ICVTrackingAA.getReplacementValue(AssociatedICV, getCtxI(), A); 23965dfd7cc4Ssstefan1 23975dfd7cc4Ssstefan1 if (ReplVal == NewReplVal) 23985dfd7cc4Ssstefan1 return ChangeStatus::UNCHANGED; 23995dfd7cc4Ssstefan1 24005dfd7cc4Ssstefan1 ReplVal = NewReplVal; 24015dfd7cc4Ssstefan1 return ChangeStatus::CHANGED; 24025dfd7cc4Ssstefan1 } 24035dfd7cc4Ssstefan1 24045dfd7cc4Ssstefan1 // Return the value with which associated value can be replaced for specific 24055dfd7cc4Ssstefan1 // \p ICV. 24065dfd7cc4Ssstefan1 Optional<Value *> 24075dfd7cc4Ssstefan1 getUniqueReplacementValue(InternalControlVar ICV) const override { 24085dfd7cc4Ssstefan1 return ReplVal; 24095dfd7cc4Ssstefan1 } 24105dfd7cc4Ssstefan1 }; 24115dfd7cc4Ssstefan1 24125dfd7cc4Ssstefan1 struct AAICVTrackerCallSiteReturned : AAICVTracker { 24135dfd7cc4Ssstefan1 AAICVTrackerCallSiteReturned(const IRPosition &IRP, Attributor &A) 24145dfd7cc4Ssstefan1 : AAICVTracker(IRP, A) {} 24155dfd7cc4Ssstefan1 24165dfd7cc4Ssstefan1 // FIXME: come up with better string. 24175dfd7cc4Ssstefan1 const std::string getAsStr() const override { 24185dfd7cc4Ssstefan1 return "ICVTrackerCallSiteReturned"; 24195dfd7cc4Ssstefan1 } 24205dfd7cc4Ssstefan1 24215dfd7cc4Ssstefan1 // FIXME: come up with some stats. 24225dfd7cc4Ssstefan1 void trackStatistics() const override {} 24235dfd7cc4Ssstefan1 24245dfd7cc4Ssstefan1 /// We don't manifest anything for this AA. 24255dfd7cc4Ssstefan1 ChangeStatus manifest(Attributor &A) override { 24265dfd7cc4Ssstefan1 return ChangeStatus::UNCHANGED; 24275dfd7cc4Ssstefan1 } 24285dfd7cc4Ssstefan1 24295dfd7cc4Ssstefan1 // Map of ICV to their values at specific program point. 24305dfd7cc4Ssstefan1 EnumeratedArray<Optional<Value *>, InternalControlVar, 24315dfd7cc4Ssstefan1 InternalControlVar::ICV___last> 24325dfd7cc4Ssstefan1 ICVReplacementValuesMap; 24335dfd7cc4Ssstefan1 24345dfd7cc4Ssstefan1 /// Return the value with which associated value can be replaced for specific 24355dfd7cc4Ssstefan1 /// \p ICV. 24365dfd7cc4Ssstefan1 Optional<Value *> 24375dfd7cc4Ssstefan1 getUniqueReplacementValue(InternalControlVar ICV) const override { 24385dfd7cc4Ssstefan1 return ICVReplacementValuesMap[ICV]; 24395dfd7cc4Ssstefan1 } 24405dfd7cc4Ssstefan1 24415dfd7cc4Ssstefan1 ChangeStatus updateImpl(Attributor &A) override { 24425dfd7cc4Ssstefan1 ChangeStatus Changed = ChangeStatus::UNCHANGED; 24435dfd7cc4Ssstefan1 const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>( 24445b70c12fSJohannes Doerfert *this, IRPosition::returned(*getAssociatedFunction()), 24455b70c12fSJohannes Doerfert DepClassTy::REQUIRED); 24465dfd7cc4Ssstefan1 24475dfd7cc4Ssstefan1 // We don't have any information, so we assume it changes the ICV. 24485dfd7cc4Ssstefan1 if (!ICVTrackingAA.isAssumedTracked()) 24495dfd7cc4Ssstefan1 return indicatePessimisticFixpoint(); 24505dfd7cc4Ssstefan1 24515dfd7cc4Ssstefan1 for (InternalControlVar ICV : TrackableICVs) { 24525dfd7cc4Ssstefan1 Optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV]; 24535dfd7cc4Ssstefan1 Optional<Value *> NewReplVal = 24545dfd7cc4Ssstefan1 ICVTrackingAA.getUniqueReplacementValue(ICV); 24555dfd7cc4Ssstefan1 24565dfd7cc4Ssstefan1 if (ReplVal == NewReplVal) 24575dfd7cc4Ssstefan1 continue; 24585dfd7cc4Ssstefan1 24595dfd7cc4Ssstefan1 ReplVal = NewReplVal; 24605dfd7cc4Ssstefan1 Changed = ChangeStatus::CHANGED; 24615dfd7cc4Ssstefan1 } 24625dfd7cc4Ssstefan1 return Changed; 24635dfd7cc4Ssstefan1 } 24649548b74aSJohannes Doerfert }; 246518283125SJoseph Huber 246618283125SJoseph Huber struct AAExecutionDomainFunction : public AAExecutionDomain { 246718283125SJoseph Huber AAExecutionDomainFunction(const IRPosition &IRP, Attributor &A) 246818283125SJoseph Huber : AAExecutionDomain(IRP, A) {} 246918283125SJoseph Huber 247018283125SJoseph Huber const std::string getAsStr() const override { 247118283125SJoseph Huber return "[AAExecutionDomain] " + std::to_string(SingleThreadedBBs.size()) + 247218283125SJoseph Huber "/" + std::to_string(NumBBs) + " BBs thread 0 only."; 247318283125SJoseph Huber } 247418283125SJoseph Huber 247518283125SJoseph Huber /// See AbstractAttribute::trackStatistics(). 247618283125SJoseph Huber void trackStatistics() const override {} 247718283125SJoseph Huber 247818283125SJoseph Huber void initialize(Attributor &A) override { 247918283125SJoseph Huber Function *F = getAnchorScope(); 248018283125SJoseph Huber for (const auto &BB : *F) 248118283125SJoseph Huber SingleThreadedBBs.insert(&BB); 248218283125SJoseph Huber NumBBs = SingleThreadedBBs.size(); 248318283125SJoseph Huber } 248418283125SJoseph Huber 248518283125SJoseph Huber ChangeStatus manifest(Attributor &A) override { 248618283125SJoseph Huber LLVM_DEBUG({ 248718283125SJoseph Huber for (const BasicBlock *BB : SingleThreadedBBs) 248818283125SJoseph Huber dbgs() << TAG << " Basic block @" << getAnchorScope()->getName() << " " 248918283125SJoseph Huber << BB->getName() << " is executed by a single thread.\n"; 249018283125SJoseph Huber }); 249118283125SJoseph Huber return ChangeStatus::UNCHANGED; 249218283125SJoseph Huber } 249318283125SJoseph Huber 249418283125SJoseph Huber ChangeStatus updateImpl(Attributor &A) override; 249518283125SJoseph Huber 249618283125SJoseph Huber /// Check if an instruction is executed by a single thread. 24979a23e673SJohannes Doerfert bool isExecutedByInitialThreadOnly(const Instruction &I) const override { 24989a23e673SJohannes Doerfert return isExecutedByInitialThreadOnly(*I.getParent()); 249918283125SJoseph Huber } 250018283125SJoseph Huber 25019a23e673SJohannes Doerfert bool isExecutedByInitialThreadOnly(const BasicBlock &BB) const override { 25021cfdcae6SJoseph Huber return isValidState() && SingleThreadedBBs.contains(&BB); 250318283125SJoseph Huber } 250418283125SJoseph Huber 250518283125SJoseph Huber /// Set of basic blocks that are executed by a single thread. 250618283125SJoseph Huber DenseSet<const BasicBlock *> SingleThreadedBBs; 250718283125SJoseph Huber 250818283125SJoseph Huber /// Total number of basic blocks in this function. 250918283125SJoseph Huber long unsigned NumBBs; 251018283125SJoseph Huber }; 251118283125SJoseph Huber 251218283125SJoseph Huber ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) { 251318283125SJoseph Huber Function *F = getAnchorScope(); 251418283125SJoseph Huber ReversePostOrderTraversal<Function *> RPOT(F); 251518283125SJoseph Huber auto NumSingleThreadedBBs = SingleThreadedBBs.size(); 251618283125SJoseph Huber 251718283125SJoseph Huber bool AllCallSitesKnown; 251818283125SJoseph Huber auto PredForCallSite = [&](AbstractCallSite ACS) { 251918283125SJoseph Huber const auto &ExecutionDomainAA = A.getAAFor<AAExecutionDomain>( 252018283125SJoseph Huber *this, IRPosition::function(*ACS.getInstruction()->getFunction()), 252118283125SJoseph Huber DepClassTy::REQUIRED); 25221cfdcae6SJoseph Huber return ACS.isDirectCall() && 25231cfdcae6SJoseph Huber ExecutionDomainAA.isExecutedByInitialThreadOnly( 25249a23e673SJohannes Doerfert *ACS.getInstruction()); 252518283125SJoseph Huber }; 252618283125SJoseph Huber 252718283125SJoseph Huber if (!A.checkForAllCallSites(PredForCallSite, *this, 252818283125SJoseph Huber /* RequiresAllCallSites */ true, 252918283125SJoseph Huber AllCallSitesKnown)) 253018283125SJoseph Huber SingleThreadedBBs.erase(&F->getEntryBlock()); 253118283125SJoseph Huber 2532e2cfbfccSJohannes Doerfert auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); 2533e2cfbfccSJohannes Doerfert auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init]; 2534e2cfbfccSJohannes Doerfert 2535e2cfbfccSJohannes Doerfert // Check if the edge into the successor block compares the __kmpc_target_init 2536e2cfbfccSJohannes Doerfert // result with -1. If we are in non-SPMD-mode that signals only the main 2537e2cfbfccSJohannes Doerfert // thread will execute the edge. 25386fc51c9fSJoseph Huber auto IsInitialThreadOnly = [&](BranchInst *Edge, BasicBlock *SuccessorBB) { 253918283125SJoseph Huber if (!Edge || !Edge->isConditional()) 254018283125SJoseph Huber return false; 254118283125SJoseph Huber if (Edge->getSuccessor(0) != SuccessorBB) 254218283125SJoseph Huber return false; 254318283125SJoseph Huber 254418283125SJoseph Huber auto *Cmp = dyn_cast<CmpInst>(Edge->getCondition()); 254518283125SJoseph Huber if (!Cmp || !Cmp->isTrueWhenEqual() || !Cmp->isEquality()) 254618283125SJoseph Huber return false; 254718283125SJoseph Huber 254818283125SJoseph Huber ConstantInt *C = dyn_cast<ConstantInt>(Cmp->getOperand(1)); 2549e2cfbfccSJohannes Doerfert if (!C) 255018283125SJoseph Huber return false; 255118283125SJoseph Huber 2552e2cfbfccSJohannes Doerfert // Match: -1 == __kmpc_target_init (for non-SPMD kernels only!) 2553e2cfbfccSJohannes Doerfert if (C->isAllOnesValue()) { 2554e2cfbfccSJohannes Doerfert auto *CB = dyn_cast<CallBase>(Cmp->getOperand(0)); 2555c4b1fe05SJohannes Doerfert CB = CB ? OpenMPOpt::getCallIfRegularCall(*CB, &RFI) : nullptr; 2556c4b1fe05SJohannes Doerfert if (!CB) 2557e2cfbfccSJohannes Doerfert return false; 2558e2cfbfccSJohannes Doerfert const int InitIsSPMDArgNo = 1; 2559e2cfbfccSJohannes Doerfert auto *IsSPMDModeCI = 2560e2cfbfccSJohannes Doerfert dyn_cast<ConstantInt>(CB->getOperand(InitIsSPMDArgNo)); 2561e2cfbfccSJohannes Doerfert return IsSPMDModeCI && IsSPMDModeCI->isZero(); 2562e2cfbfccSJohannes Doerfert } 256318283125SJoseph Huber 256418283125SJoseph Huber return false; 256518283125SJoseph Huber }; 256618283125SJoseph Huber 256718283125SJoseph Huber // Merge all the predecessor states into the current basic block. A basic 256818283125SJoseph Huber // block is executed by a single thread if all of its predecessors are. 256918283125SJoseph Huber auto MergePredecessorStates = [&](BasicBlock *BB) { 257018283125SJoseph Huber if (pred_begin(BB) == pred_end(BB)) 257118283125SJoseph Huber return SingleThreadedBBs.contains(BB); 257218283125SJoseph Huber 25736fc51c9fSJoseph Huber bool IsInitialThread = true; 257418283125SJoseph Huber for (auto PredBB = pred_begin(BB), PredEndBB = pred_end(BB); 257518283125SJoseph Huber PredBB != PredEndBB; ++PredBB) { 25766fc51c9fSJoseph Huber if (!IsInitialThreadOnly(dyn_cast<BranchInst>((*PredBB)->getTerminator()), 257718283125SJoseph Huber BB)) 25786fc51c9fSJoseph Huber IsInitialThread &= SingleThreadedBBs.contains(*PredBB); 257918283125SJoseph Huber } 258018283125SJoseph Huber 25816fc51c9fSJoseph Huber return IsInitialThread; 258218283125SJoseph Huber }; 258318283125SJoseph Huber 258418283125SJoseph Huber for (auto *BB : RPOT) { 258518283125SJoseph Huber if (!MergePredecessorStates(BB)) 258618283125SJoseph Huber SingleThreadedBBs.erase(BB); 258718283125SJoseph Huber } 258818283125SJoseph Huber 258918283125SJoseph Huber return (NumSingleThreadedBBs == SingleThreadedBBs.size()) 259018283125SJoseph Huber ? ChangeStatus::UNCHANGED 259118283125SJoseph Huber : ChangeStatus::CHANGED; 259218283125SJoseph Huber } 259318283125SJoseph Huber 25946fc51c9fSJoseph Huber /// Try to replace memory allocation calls called by a single thread with a 25956fc51c9fSJoseph Huber /// static buffer of shared memory. 25966fc51c9fSJoseph Huber struct AAHeapToShared : public StateWrapper<BooleanState, AbstractAttribute> { 25976fc51c9fSJoseph Huber using Base = StateWrapper<BooleanState, AbstractAttribute>; 25986fc51c9fSJoseph Huber AAHeapToShared(const IRPosition &IRP, Attributor &A) : Base(IRP) {} 25996fc51c9fSJoseph Huber 26006fc51c9fSJoseph Huber /// Create an abstract attribute view for the position \p IRP. 26016fc51c9fSJoseph Huber static AAHeapToShared &createForPosition(const IRPosition &IRP, 26026fc51c9fSJoseph Huber Attributor &A); 26036fc51c9fSJoseph Huber 2604f8c40ed8SGiorgis Georgakoudis /// Returns true if HeapToShared conversion is assumed to be possible. 2605f8c40ed8SGiorgis Georgakoudis virtual bool isAssumedHeapToShared(CallBase &CB) const = 0; 2606f8c40ed8SGiorgis Georgakoudis 2607f8c40ed8SGiorgis Georgakoudis /// Returns true if HeapToShared conversion is assumed and the CB is a 2608f8c40ed8SGiorgis Georgakoudis /// callsite to a free operation to be removed. 2609f8c40ed8SGiorgis Georgakoudis virtual bool isAssumedHeapToSharedRemovedFree(CallBase &CB) const = 0; 2610f8c40ed8SGiorgis Georgakoudis 26116fc51c9fSJoseph Huber /// See AbstractAttribute::getName(). 26126fc51c9fSJoseph Huber const std::string getName() const override { return "AAHeapToShared"; } 26136fc51c9fSJoseph Huber 26146fc51c9fSJoseph Huber /// See AbstractAttribute::getIdAddr(). 26156fc51c9fSJoseph Huber const char *getIdAddr() const override { return &ID; } 26166fc51c9fSJoseph Huber 26176fc51c9fSJoseph Huber /// This function should return true if the type of the \p AA is 26186fc51c9fSJoseph Huber /// AAHeapToShared. 26196fc51c9fSJoseph Huber static bool classof(const AbstractAttribute *AA) { 26206fc51c9fSJoseph Huber return (AA->getIdAddr() == &ID); 26216fc51c9fSJoseph Huber } 26226fc51c9fSJoseph Huber 26236fc51c9fSJoseph Huber /// Unique ID (due to the unique address) 26246fc51c9fSJoseph Huber static const char ID; 26256fc51c9fSJoseph Huber }; 26266fc51c9fSJoseph Huber 26276fc51c9fSJoseph Huber struct AAHeapToSharedFunction : public AAHeapToShared { 26286fc51c9fSJoseph Huber AAHeapToSharedFunction(const IRPosition &IRP, Attributor &A) 26296fc51c9fSJoseph Huber : AAHeapToShared(IRP, A) {} 26306fc51c9fSJoseph Huber 26316fc51c9fSJoseph Huber const std::string getAsStr() const override { 26326fc51c9fSJoseph Huber return "[AAHeapToShared] " + std::to_string(MallocCalls.size()) + 26336fc51c9fSJoseph Huber " malloc calls eligible."; 26346fc51c9fSJoseph Huber } 26356fc51c9fSJoseph Huber 26366fc51c9fSJoseph Huber /// See AbstractAttribute::trackStatistics(). 26376fc51c9fSJoseph Huber void trackStatistics() const override {} 26386fc51c9fSJoseph Huber 2639f8c40ed8SGiorgis Georgakoudis /// This functions finds free calls that will be removed by the 2640f8c40ed8SGiorgis Georgakoudis /// HeapToShared transformation. 2641f8c40ed8SGiorgis Georgakoudis void findPotentialRemovedFreeCalls(Attributor &A) { 2642f8c40ed8SGiorgis Georgakoudis auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); 2643f8c40ed8SGiorgis Georgakoudis auto &FreeRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared]; 2644f8c40ed8SGiorgis Georgakoudis 2645f8c40ed8SGiorgis Georgakoudis PotentialRemovedFreeCalls.clear(); 2646f8c40ed8SGiorgis Georgakoudis // Update free call users of found malloc calls. 2647f8c40ed8SGiorgis Georgakoudis for (CallBase *CB : MallocCalls) { 2648f8c40ed8SGiorgis Georgakoudis SmallVector<CallBase *, 4> FreeCalls; 2649f8c40ed8SGiorgis Georgakoudis for (auto *U : CB->users()) { 2650f8c40ed8SGiorgis Georgakoudis CallBase *C = dyn_cast<CallBase>(U); 2651f8c40ed8SGiorgis Georgakoudis if (C && C->getCalledFunction() == FreeRFI.Declaration) 2652f8c40ed8SGiorgis Georgakoudis FreeCalls.push_back(C); 2653f8c40ed8SGiorgis Georgakoudis } 2654f8c40ed8SGiorgis Georgakoudis 2655f8c40ed8SGiorgis Georgakoudis if (FreeCalls.size() != 1) 2656f8c40ed8SGiorgis Georgakoudis continue; 2657f8c40ed8SGiorgis Georgakoudis 2658f8c40ed8SGiorgis Georgakoudis PotentialRemovedFreeCalls.insert(FreeCalls.front()); 2659f8c40ed8SGiorgis Georgakoudis } 2660f8c40ed8SGiorgis Georgakoudis } 2661f8c40ed8SGiorgis Georgakoudis 26626fc51c9fSJoseph Huber void initialize(Attributor &A) override { 26636fc51c9fSJoseph Huber auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); 26646fc51c9fSJoseph Huber auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared]; 26656fc51c9fSJoseph Huber 26666fc51c9fSJoseph Huber for (User *U : RFI.Declaration->users()) 26676fc51c9fSJoseph Huber if (CallBase *CB = dyn_cast<CallBase>(U)) 26686fc51c9fSJoseph Huber MallocCalls.insert(CB); 2669f8c40ed8SGiorgis Georgakoudis 2670f8c40ed8SGiorgis Georgakoudis findPotentialRemovedFreeCalls(A); 2671f8c40ed8SGiorgis Georgakoudis } 2672f8c40ed8SGiorgis Georgakoudis 2673eaab880eSGiorgis Georgakoudis bool isAssumedHeapToShared(CallBase &CB) const override { 2674f8c40ed8SGiorgis Georgakoudis return isValidState() && MallocCalls.count(&CB); 2675f8c40ed8SGiorgis Georgakoudis } 2676f8c40ed8SGiorgis Georgakoudis 2677eaab880eSGiorgis Georgakoudis bool isAssumedHeapToSharedRemovedFree(CallBase &CB) const override { 2678f8c40ed8SGiorgis Georgakoudis return isValidState() && PotentialRemovedFreeCalls.count(&CB); 26796fc51c9fSJoseph Huber } 26806fc51c9fSJoseph Huber 26816fc51c9fSJoseph Huber ChangeStatus manifest(Attributor &A) override { 26826fc51c9fSJoseph Huber if (MallocCalls.empty()) 26836fc51c9fSJoseph Huber return ChangeStatus::UNCHANGED; 26846fc51c9fSJoseph Huber 26856fc51c9fSJoseph Huber auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); 26866fc51c9fSJoseph Huber auto &FreeCall = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared]; 26876fc51c9fSJoseph Huber 26886fc51c9fSJoseph Huber Function *F = getAnchorScope(); 26896fc51c9fSJoseph Huber auto *HS = A.lookupAAFor<AAHeapToStack>(IRPosition::function(*F), this, 26906fc51c9fSJoseph Huber DepClassTy::OPTIONAL); 26916fc51c9fSJoseph Huber 26926fc51c9fSJoseph Huber ChangeStatus Changed = ChangeStatus::UNCHANGED; 26936fc51c9fSJoseph Huber for (CallBase *CB : MallocCalls) { 26946fc51c9fSJoseph Huber // Skip replacing this if HeapToStack has already claimed it. 2695c1c1fe93SJohannes Doerfert if (HS && HS->isAssumedHeapToStack(*CB)) 26966fc51c9fSJoseph Huber continue; 26976fc51c9fSJoseph Huber 26986fc51c9fSJoseph Huber // Find the unique free call to remove it. 26996fc51c9fSJoseph Huber SmallVector<CallBase *, 4> FreeCalls; 27006fc51c9fSJoseph Huber for (auto *U : CB->users()) { 27016fc51c9fSJoseph Huber CallBase *C = dyn_cast<CallBase>(U); 27026fc51c9fSJoseph Huber if (C && C->getCalledFunction() == FreeCall.Declaration) 27036fc51c9fSJoseph Huber FreeCalls.push_back(C); 27046fc51c9fSJoseph Huber } 27056fc51c9fSJoseph Huber if (FreeCalls.size() != 1) 27066fc51c9fSJoseph Huber continue; 27076fc51c9fSJoseph Huber 27086fc51c9fSJoseph Huber ConstantInt *AllocSize = dyn_cast<ConstantInt>(CB->getArgOperand(0)); 27096fc51c9fSJoseph Huber 27106fc51c9fSJoseph Huber LLVM_DEBUG(dbgs() << TAG << "Replace globalization call in " 27116fc51c9fSJoseph Huber << CB->getCaller()->getName() << " with " 27126fc51c9fSJoseph Huber << AllocSize->getZExtValue() 27136fc51c9fSJoseph Huber << " bytes of shared memory\n"); 27146fc51c9fSJoseph Huber 27156fc51c9fSJoseph Huber // Create a new shared memory buffer of the same size as the allocation 27166fc51c9fSJoseph Huber // and replace all the uses of the original allocation with it. 27176fc51c9fSJoseph Huber Module *M = CB->getModule(); 27186fc51c9fSJoseph Huber Type *Int8Ty = Type::getInt8Ty(M->getContext()); 27196fc51c9fSJoseph Huber Type *Int8ArrTy = ArrayType::get(Int8Ty, AllocSize->getZExtValue()); 27206fc51c9fSJoseph Huber auto *SharedMem = new GlobalVariable( 27216fc51c9fSJoseph Huber *M, Int8ArrTy, /* IsConstant */ false, GlobalValue::InternalLinkage, 27226fc51c9fSJoseph Huber UndefValue::get(Int8ArrTy), CB->getName(), nullptr, 27236fc51c9fSJoseph Huber GlobalValue::NotThreadLocal, 27246fc51c9fSJoseph Huber static_cast<unsigned>(AddressSpace::Shared)); 27256fc51c9fSJoseph Huber auto *NewBuffer = 27266fc51c9fSJoseph Huber ConstantExpr::getPointerCast(SharedMem, Int8Ty->getPointerTo()); 27276fc51c9fSJoseph Huber 272830e36c9bSJoseph Huber auto Remark = [&](OptimizationRemark OR) { 272930e36c9bSJoseph Huber return OR << "Replaced globalized variable with " 273030e36c9bSJoseph Huber << ore::NV("SharedMemory", AllocSize->getZExtValue()) 273130e36c9bSJoseph Huber << ((AllocSize->getZExtValue() != 1) ? " bytes " : " byte ") 2732eef6601bSJoseph Huber << "of shared memory."; 273330e36c9bSJoseph Huber }; 27342c31d5ebSJoseph Huber A.emitRemark<OptimizationRemark>(CB, "OMP111", Remark); 273530e36c9bSJoseph Huber 27366fc51c9fSJoseph Huber SharedMem->setAlignment(MaybeAlign(32)); 27376fc51c9fSJoseph Huber 27386fc51c9fSJoseph Huber A.changeValueAfterManifest(*CB, *NewBuffer); 27396fc51c9fSJoseph Huber A.deleteAfterManifest(*CB); 27406fc51c9fSJoseph Huber A.deleteAfterManifest(*FreeCalls.front()); 27416fc51c9fSJoseph Huber 27426fc51c9fSJoseph Huber NumBytesMovedToSharedMemory += AllocSize->getZExtValue(); 27436fc51c9fSJoseph Huber Changed = ChangeStatus::CHANGED; 27446fc51c9fSJoseph Huber } 27456fc51c9fSJoseph Huber 27466fc51c9fSJoseph Huber return Changed; 27476fc51c9fSJoseph Huber } 27486fc51c9fSJoseph Huber 27496fc51c9fSJoseph Huber ChangeStatus updateImpl(Attributor &A) override { 27506fc51c9fSJoseph Huber auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); 27516fc51c9fSJoseph Huber auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared]; 27526fc51c9fSJoseph Huber Function *F = getAnchorScope(); 27536fc51c9fSJoseph Huber 27546fc51c9fSJoseph Huber auto NumMallocCalls = MallocCalls.size(); 27556fc51c9fSJoseph Huber 27566fc51c9fSJoseph Huber // Only consider malloc calls executed by a single thread with a constant. 27576fc51c9fSJoseph Huber for (User *U : RFI.Declaration->users()) { 27586fc51c9fSJoseph Huber const auto &ED = A.getAAFor<AAExecutionDomain>( 27596fc51c9fSJoseph Huber *this, IRPosition::function(*F), DepClassTy::REQUIRED); 27606fc51c9fSJoseph Huber if (CallBase *CB = dyn_cast<CallBase>(U)) 27616fc51c9fSJoseph Huber if (!dyn_cast<ConstantInt>(CB->getArgOperand(0)) || 27626fc51c9fSJoseph Huber !ED.isExecutedByInitialThreadOnly(*CB)) 27636fc51c9fSJoseph Huber MallocCalls.erase(CB); 27646fc51c9fSJoseph Huber } 27656fc51c9fSJoseph Huber 2766f8c40ed8SGiorgis Georgakoudis findPotentialRemovedFreeCalls(A); 2767f8c40ed8SGiorgis Georgakoudis 27686fc51c9fSJoseph Huber if (NumMallocCalls != MallocCalls.size()) 27696fc51c9fSJoseph Huber return ChangeStatus::CHANGED; 27706fc51c9fSJoseph Huber 27716fc51c9fSJoseph Huber return ChangeStatus::UNCHANGED; 27726fc51c9fSJoseph Huber } 27736fc51c9fSJoseph Huber 27746fc51c9fSJoseph Huber /// Collection of all malloc calls in a function. 27756fc51c9fSJoseph Huber SmallPtrSet<CallBase *, 4> MallocCalls; 2776f8c40ed8SGiorgis Georgakoudis /// Collection of potentially removed free calls in a function. 2777f8c40ed8SGiorgis Georgakoudis SmallPtrSet<CallBase *, 4> PotentialRemovedFreeCalls; 27786fc51c9fSJoseph Huber }; 27796fc51c9fSJoseph Huber 2780d9659bf6SJohannes Doerfert struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> { 2781d9659bf6SJohannes Doerfert using Base = StateWrapper<KernelInfoState, AbstractAttribute>; 2782d9659bf6SJohannes Doerfert AAKernelInfo(const IRPosition &IRP, Attributor &A) : Base(IRP) {} 2783d9659bf6SJohannes Doerfert 2784d9659bf6SJohannes Doerfert /// Statistics are tracked as part of manifest for now. 2785d9659bf6SJohannes Doerfert void trackStatistics() const override {} 2786d9659bf6SJohannes Doerfert 2787d9659bf6SJohannes Doerfert /// See AbstractAttribute::getAsStr() 2788d9659bf6SJohannes Doerfert const std::string getAsStr() const override { 2789d9659bf6SJohannes Doerfert if (!isValidState()) 2790d9659bf6SJohannes Doerfert return "<invalid>"; 2791514c033dSJohannes Doerfert return std::string(SPMDCompatibilityTracker.isAssumed() ? "SPMD" 2792514c033dSJohannes Doerfert : "generic") + 2793514c033dSJohannes Doerfert std::string(SPMDCompatibilityTracker.isAtFixpoint() ? " [FIX]" 2794514c033dSJohannes Doerfert : "") + 2795d9659bf6SJohannes Doerfert std::string(" #PRs: ") + 2796d9659bf6SJohannes Doerfert std::to_string(ReachedKnownParallelRegions.size()) + 2797d9659bf6SJohannes Doerfert ", #Unknown PRs: " + 2798d9659bf6SJohannes Doerfert std::to_string(ReachedUnknownParallelRegions.size()); 2799d9659bf6SJohannes Doerfert } 2800d9659bf6SJohannes Doerfert 2801d9659bf6SJohannes Doerfert /// Create an abstract attribute biew for the position \p IRP. 2802d9659bf6SJohannes Doerfert static AAKernelInfo &createForPosition(const IRPosition &IRP, Attributor &A); 2803d9659bf6SJohannes Doerfert 2804d9659bf6SJohannes Doerfert /// See AbstractAttribute::getName() 2805d9659bf6SJohannes Doerfert const std::string getName() const override { return "AAKernelInfo"; } 2806d9659bf6SJohannes Doerfert 2807d9659bf6SJohannes Doerfert /// See AbstractAttribute::getIdAddr() 2808d9659bf6SJohannes Doerfert const char *getIdAddr() const override { return &ID; } 2809d9659bf6SJohannes Doerfert 2810d9659bf6SJohannes Doerfert /// This function should return true if the type of the \p AA is AAKernelInfo 2811d9659bf6SJohannes Doerfert static bool classof(const AbstractAttribute *AA) { 2812d9659bf6SJohannes Doerfert return (AA->getIdAddr() == &ID); 2813d9659bf6SJohannes Doerfert } 2814d9659bf6SJohannes Doerfert 2815d9659bf6SJohannes Doerfert static const char ID; 2816d9659bf6SJohannes Doerfert }; 2817d9659bf6SJohannes Doerfert 2818d9659bf6SJohannes Doerfert /// The function kernel info abstract attribute, basically, what can we say 2819d9659bf6SJohannes Doerfert /// about a function with regards to the KernelInfoState. 2820d9659bf6SJohannes Doerfert struct AAKernelInfoFunction : AAKernelInfo { 2821d9659bf6SJohannes Doerfert AAKernelInfoFunction(const IRPosition &IRP, Attributor &A) 2822d9659bf6SJohannes Doerfert : AAKernelInfo(IRP, A) {} 2823d9659bf6SJohannes Doerfert 282429a3e3ddSGiorgis Georgakoudis SmallPtrSet<Instruction *, 4> GuardedInstructions; 282529a3e3ddSGiorgis Georgakoudis 282629a3e3ddSGiorgis Georgakoudis SmallPtrSetImpl<Instruction *> &getGuardedInstructions() { 282729a3e3ddSGiorgis Georgakoudis return GuardedInstructions; 282829a3e3ddSGiorgis Georgakoudis } 282929a3e3ddSGiorgis Georgakoudis 2830d9659bf6SJohannes Doerfert /// See AbstractAttribute::initialize(...). 2831d9659bf6SJohannes Doerfert void initialize(Attributor &A) override { 2832d9659bf6SJohannes Doerfert // This is a high-level transform that might change the constant arguments 2833d9659bf6SJohannes Doerfert // of the init and dinit calls. We need to tell the Attributor about this 2834d9659bf6SJohannes Doerfert // to avoid other parts using the current constant value for simpliication. 2835d9659bf6SJohannes Doerfert auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); 2836d9659bf6SJohannes Doerfert 2837d9659bf6SJohannes Doerfert Function *Fn = getAnchorScope(); 2838d9659bf6SJohannes Doerfert if (!OMPInfoCache.Kernels.count(Fn)) 2839d9659bf6SJohannes Doerfert return; 2840d9659bf6SJohannes Doerfert 2841ca662297SShilei Tian // Add itself to the reaching kernel and set IsKernelEntry. 2842ca662297SShilei Tian ReachingKernelEntries.insert(Fn); 2843ca662297SShilei Tian IsKernelEntry = true; 2844ca662297SShilei Tian 2845d9659bf6SJohannes Doerfert OMPInformationCache::RuntimeFunctionInfo &InitRFI = 2846d9659bf6SJohannes Doerfert OMPInfoCache.RFIs[OMPRTL___kmpc_target_init]; 2847d9659bf6SJohannes Doerfert OMPInformationCache::RuntimeFunctionInfo &DeinitRFI = 2848d9659bf6SJohannes Doerfert OMPInfoCache.RFIs[OMPRTL___kmpc_target_deinit]; 2849d9659bf6SJohannes Doerfert 2850d9659bf6SJohannes Doerfert // For kernels we perform more initialization work, first we find the init 2851d9659bf6SJohannes Doerfert // and deinit calls. 2852d9659bf6SJohannes Doerfert auto StoreCallBase = [](Use &U, 2853d9659bf6SJohannes Doerfert OMPInformationCache::RuntimeFunctionInfo &RFI, 2854d9659bf6SJohannes Doerfert CallBase *&Storage) { 2855d9659bf6SJohannes Doerfert CallBase *CB = OpenMPOpt::getCallIfRegularCall(U, &RFI); 2856d9659bf6SJohannes Doerfert assert(CB && 2857d9659bf6SJohannes Doerfert "Unexpected use of __kmpc_target_init or __kmpc_target_deinit!"); 2858d9659bf6SJohannes Doerfert assert(!Storage && 2859d9659bf6SJohannes Doerfert "Multiple uses of __kmpc_target_init or __kmpc_target_deinit!"); 2860d9659bf6SJohannes Doerfert Storage = CB; 2861d9659bf6SJohannes Doerfert return false; 2862d9659bf6SJohannes Doerfert }; 2863d9659bf6SJohannes Doerfert InitRFI.foreachUse( 2864d9659bf6SJohannes Doerfert [&](Use &U, Function &) { 2865d9659bf6SJohannes Doerfert StoreCallBase(U, InitRFI, KernelInitCB); 2866d9659bf6SJohannes Doerfert return false; 2867d9659bf6SJohannes Doerfert }, 2868d9659bf6SJohannes Doerfert Fn); 2869d9659bf6SJohannes Doerfert DeinitRFI.foreachUse( 2870d9659bf6SJohannes Doerfert [&](Use &U, Function &) { 2871d9659bf6SJohannes Doerfert StoreCallBase(U, DeinitRFI, KernelDeinitCB); 2872d9659bf6SJohannes Doerfert return false; 2873d9659bf6SJohannes Doerfert }, 2874d9659bf6SJohannes Doerfert Fn); 2875d9659bf6SJohannes Doerfert 2876d9659bf6SJohannes Doerfert assert((KernelInitCB && KernelDeinitCB) && 2877d9659bf6SJohannes Doerfert "Kernel without __kmpc_target_init or __kmpc_target_deinit!"); 2878d9659bf6SJohannes Doerfert 2879514c033dSJohannes Doerfert // For kernels we might need to initialize/finalize the IsSPMD state and 2880514c033dSJohannes Doerfert // we need to register a simplification callback so that the Attributor 2881514c033dSJohannes Doerfert // knows the constant arguments to __kmpc_target_init and 2882d9659bf6SJohannes Doerfert // __kmpc_target_deinit might actually change. 2883d9659bf6SJohannes Doerfert 2884d9659bf6SJohannes Doerfert Attributor::SimplifictionCallbackTy StateMachineSimplifyCB = 2885d9659bf6SJohannes Doerfert [&](const IRPosition &IRP, const AbstractAttribute *AA, 2886d9659bf6SJohannes Doerfert bool &UsedAssumedInformation) -> Optional<Value *> { 2887d9659bf6SJohannes Doerfert // IRP represents the "use generic state machine" argument of an 2888d9659bf6SJohannes Doerfert // __kmpc_target_init call. We will answer this one with the internal 2889d9659bf6SJohannes Doerfert // state. As long as we are not in an invalid state, we will create a 2890d9659bf6SJohannes Doerfert // custom state machine so the value should be a `i1 false`. If we are 2891d9659bf6SJohannes Doerfert // in an invalid state, we won't change the value that is in the IR. 2892d9659bf6SJohannes Doerfert if (!isValidState()) 2893d9659bf6SJohannes Doerfert return nullptr; 2894e0c5d83aSJohannes Doerfert // If we have disabled state machine rewrites, don't make a custom one. 2895e0c5d83aSJohannes Doerfert if (DisableOpenMPOptStateMachineRewrite) 2896e0c5d83aSJohannes Doerfert return nullptr; 2897d9659bf6SJohannes Doerfert if (AA) 2898d9659bf6SJohannes Doerfert A.recordDependence(*this, *AA, DepClassTy::OPTIONAL); 2899d9659bf6SJohannes Doerfert UsedAssumedInformation = !isAtFixpoint(); 2900d9659bf6SJohannes Doerfert auto *FalseVal = 2901d9659bf6SJohannes Doerfert ConstantInt::getBool(IRP.getAnchorValue().getContext(), 0); 2902d9659bf6SJohannes Doerfert return FalseVal; 2903d9659bf6SJohannes Doerfert }; 2904d9659bf6SJohannes Doerfert 2905514c033dSJohannes Doerfert Attributor::SimplifictionCallbackTy IsSPMDModeSimplifyCB = 2906514c033dSJohannes Doerfert [&](const IRPosition &IRP, const AbstractAttribute *AA, 2907514c033dSJohannes Doerfert bool &UsedAssumedInformation) -> Optional<Value *> { 2908514c033dSJohannes Doerfert // IRP represents the "SPMDCompatibilityTracker" argument of an 2909514c033dSJohannes Doerfert // __kmpc_target_init or 2910514c033dSJohannes Doerfert // __kmpc_target_deinit call. We will answer this one with the internal 2911514c033dSJohannes Doerfert // state. 291297387fdfSJohannes Doerfert if (!SPMDCompatibilityTracker.isValidState()) 2913514c033dSJohannes Doerfert return nullptr; 2914514c033dSJohannes Doerfert if (!SPMDCompatibilityTracker.isAtFixpoint()) { 2915514c033dSJohannes Doerfert if (AA) 2916514c033dSJohannes Doerfert A.recordDependence(*this, *AA, DepClassTy::OPTIONAL); 2917514c033dSJohannes Doerfert UsedAssumedInformation = true; 2918514c033dSJohannes Doerfert } else { 2919514c033dSJohannes Doerfert UsedAssumedInformation = false; 2920514c033dSJohannes Doerfert } 2921514c033dSJohannes Doerfert auto *Val = ConstantInt::getBool(IRP.getAnchorValue().getContext(), 2922514c033dSJohannes Doerfert SPMDCompatibilityTracker.isAssumed()); 2923514c033dSJohannes Doerfert return Val; 2924514c033dSJohannes Doerfert }; 2925514c033dSJohannes Doerfert 2926e8439ec8SGiorgis Georgakoudis Attributor::SimplifictionCallbackTy IsGenericModeSimplifyCB = 2927e8439ec8SGiorgis Georgakoudis [&](const IRPosition &IRP, const AbstractAttribute *AA, 2928e8439ec8SGiorgis Georgakoudis bool &UsedAssumedInformation) -> Optional<Value *> { 2929e8439ec8SGiorgis Georgakoudis // IRP represents the "RequiresFullRuntime" argument of an 2930e8439ec8SGiorgis Georgakoudis // __kmpc_target_init or __kmpc_target_deinit call. We will answer this 2931e8439ec8SGiorgis Georgakoudis // one with the internal state of the SPMDCompatibilityTracker, so if 2932e8439ec8SGiorgis Georgakoudis // generic then true, if SPMD then false. 2933e8439ec8SGiorgis Georgakoudis if (!SPMDCompatibilityTracker.isValidState()) 2934e8439ec8SGiorgis Georgakoudis return nullptr; 2935e8439ec8SGiorgis Georgakoudis if (!SPMDCompatibilityTracker.isAtFixpoint()) { 2936e8439ec8SGiorgis Georgakoudis if (AA) 2937e8439ec8SGiorgis Georgakoudis A.recordDependence(*this, *AA, DepClassTy::OPTIONAL); 2938e8439ec8SGiorgis Georgakoudis UsedAssumedInformation = true; 2939e8439ec8SGiorgis Georgakoudis } else { 2940e8439ec8SGiorgis Georgakoudis UsedAssumedInformation = false; 2941e8439ec8SGiorgis Georgakoudis } 2942e8439ec8SGiorgis Georgakoudis auto *Val = ConstantInt::getBool(IRP.getAnchorValue().getContext(), 2943e8439ec8SGiorgis Georgakoudis !SPMDCompatibilityTracker.isAssumed()); 2944e8439ec8SGiorgis Georgakoudis return Val; 2945e8439ec8SGiorgis Georgakoudis }; 2946e8439ec8SGiorgis Georgakoudis 2947514c033dSJohannes Doerfert constexpr const int InitIsSPMDArgNo = 1; 2948514c033dSJohannes Doerfert constexpr const int DeinitIsSPMDArgNo = 1; 2949d9659bf6SJohannes Doerfert constexpr const int InitUseStateMachineArgNo = 2; 2950e8439ec8SGiorgis Georgakoudis constexpr const int InitRequiresFullRuntimeArgNo = 3; 2951e8439ec8SGiorgis Georgakoudis constexpr const int DeinitRequiresFullRuntimeArgNo = 2; 2952d9659bf6SJohannes Doerfert A.registerSimplificationCallback( 2953d9659bf6SJohannes Doerfert IRPosition::callsite_argument(*KernelInitCB, InitUseStateMachineArgNo), 2954d9659bf6SJohannes Doerfert StateMachineSimplifyCB); 2955514c033dSJohannes Doerfert A.registerSimplificationCallback( 2956514c033dSJohannes Doerfert IRPosition::callsite_argument(*KernelInitCB, InitIsSPMDArgNo), 2957514c033dSJohannes Doerfert IsSPMDModeSimplifyCB); 2958514c033dSJohannes Doerfert A.registerSimplificationCallback( 2959514c033dSJohannes Doerfert IRPosition::callsite_argument(*KernelDeinitCB, DeinitIsSPMDArgNo), 2960514c033dSJohannes Doerfert IsSPMDModeSimplifyCB); 2961e8439ec8SGiorgis Georgakoudis A.registerSimplificationCallback( 2962e8439ec8SGiorgis Georgakoudis IRPosition::callsite_argument(*KernelInitCB, 2963e8439ec8SGiorgis Georgakoudis InitRequiresFullRuntimeArgNo), 2964e8439ec8SGiorgis Georgakoudis IsGenericModeSimplifyCB); 2965e8439ec8SGiorgis Georgakoudis A.registerSimplificationCallback( 2966e8439ec8SGiorgis Georgakoudis IRPosition::callsite_argument(*KernelDeinitCB, 2967e8439ec8SGiorgis Georgakoudis DeinitRequiresFullRuntimeArgNo), 2968e8439ec8SGiorgis Georgakoudis IsGenericModeSimplifyCB); 2969514c033dSJohannes Doerfert 2970514c033dSJohannes Doerfert // Check if we know we are in SPMD-mode already. 2971514c033dSJohannes Doerfert ConstantInt *IsSPMDArg = 2972514c033dSJohannes Doerfert dyn_cast<ConstantInt>(KernelInitCB->getArgOperand(InitIsSPMDArgNo)); 2973514c033dSJohannes Doerfert if (IsSPMDArg && !IsSPMDArg->isZero()) 2974514c033dSJohannes Doerfert SPMDCompatibilityTracker.indicateOptimisticFixpoint(); 297560e643feSGiorgis Georgakoudis // This is a generic region but SPMDization is disabled so stop tracking. 297660e643feSGiorgis Georgakoudis else if (DisableOpenMPOptSPMDization) 297760e643feSGiorgis Georgakoudis SPMDCompatibilityTracker.indicatePessimisticFixpoint(); 2978d9659bf6SJohannes Doerfert } 2979d9659bf6SJohannes Doerfert 2980d9659bf6SJohannes Doerfert /// Modify the IR based on the KernelInfoState as the fixpoint iteration is 2981d9659bf6SJohannes Doerfert /// finished now. 2982d9659bf6SJohannes Doerfert ChangeStatus manifest(Attributor &A) override { 2983d9659bf6SJohannes Doerfert // If we are not looking at a kernel with __kmpc_target_init and 2984d9659bf6SJohannes Doerfert // __kmpc_target_deinit call we cannot actually manifest the information. 2985d9659bf6SJohannes Doerfert if (!KernelInitCB || !KernelDeinitCB) 2986d9659bf6SJohannes Doerfert return ChangeStatus::UNCHANGED; 2987d9659bf6SJohannes Doerfert 2988514c033dSJohannes Doerfert // Known SPMD-mode kernels need no manifest changes. 2989514c033dSJohannes Doerfert if (SPMDCompatibilityTracker.isKnown()) 2990514c033dSJohannes Doerfert return ChangeStatus::UNCHANGED; 2991514c033dSJohannes Doerfert 2992514c033dSJohannes Doerfert // If we can we change the execution mode to SPMD-mode otherwise we build a 2993514c033dSJohannes Doerfert // custom state machine. 2994514c033dSJohannes Doerfert if (!changeToSPMDMode(A)) 2995d9659bf6SJohannes Doerfert buildCustomStateMachine(A); 2996d9659bf6SJohannes Doerfert 2997d9659bf6SJohannes Doerfert return ChangeStatus::CHANGED; 2998d9659bf6SJohannes Doerfert } 2999d9659bf6SJohannes Doerfert 3000514c033dSJohannes Doerfert bool changeToSPMDMode(Attributor &A) { 3001eef6601bSJoseph Huber auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); 3002eef6601bSJoseph Huber 3003514c033dSJohannes Doerfert if (!SPMDCompatibilityTracker.isAssumed()) { 3004514c033dSJohannes Doerfert for (Instruction *NonCompatibleI : SPMDCompatibilityTracker) { 3005514c033dSJohannes Doerfert if (!NonCompatibleI) 3006514c033dSJohannes Doerfert continue; 3007eef6601bSJoseph Huber 3008eef6601bSJoseph Huber // Skip diagnostics on calls to known OpenMP runtime functions for now. 3009eef6601bSJoseph Huber if (auto *CB = dyn_cast<CallBase>(NonCompatibleI)) 3010eef6601bSJoseph Huber if (OMPInfoCache.RTLFunctions.contains(CB->getCalledFunction())) 3011eef6601bSJoseph Huber continue; 3012eef6601bSJoseph Huber 3013514c033dSJohannes Doerfert auto Remark = [&](OptimizationRemarkAnalysis ORA) { 3014eef6601bSJoseph Huber ORA << "Value has potential side effects preventing SPMD-mode " 3015eef6601bSJoseph Huber "execution"; 3016eef6601bSJoseph Huber if (isa<CallBase>(NonCompatibleI)) { 3017eef6601bSJoseph Huber ORA << ". Add `__attribute__((assume(\"ompx_spmd_amenable\")))` to " 3018eef6601bSJoseph Huber "the called function to override"; 3019514c033dSJohannes Doerfert } 3020514c033dSJohannes Doerfert return ORA << "."; 3021514c033dSJohannes Doerfert }; 30222c31d5ebSJoseph Huber A.emitRemark<OptimizationRemarkAnalysis>(NonCompatibleI, "OMP121", 30232c31d5ebSJoseph Huber Remark); 3024514c033dSJohannes Doerfert 3025514c033dSJohannes Doerfert LLVM_DEBUG(dbgs() << TAG << "SPMD-incompatible side-effect: " 3026514c033dSJohannes Doerfert << *NonCompatibleI << "\n"); 3027514c033dSJohannes Doerfert } 3028514c033dSJohannes Doerfert 3029514c033dSJohannes Doerfert return false; 3030514c033dSJohannes Doerfert } 3031514c033dSJohannes Doerfert 303229a3e3ddSGiorgis Georgakoudis auto CreateGuardedRegion = [&](Instruction *RegionStartI, 303329a3e3ddSGiorgis Georgakoudis Instruction *RegionEndI) { 303429a3e3ddSGiorgis Georgakoudis LoopInfo *LI = nullptr; 303529a3e3ddSGiorgis Georgakoudis DominatorTree *DT = nullptr; 303629a3e3ddSGiorgis Georgakoudis MemorySSAUpdater *MSU = nullptr; 303729a3e3ddSGiorgis Georgakoudis using InsertPointTy = OpenMPIRBuilder::InsertPointTy; 303829a3e3ddSGiorgis Georgakoudis 303929a3e3ddSGiorgis Georgakoudis BasicBlock *ParentBB = RegionStartI->getParent(); 304029a3e3ddSGiorgis Georgakoudis Function *Fn = ParentBB->getParent(); 304129a3e3ddSGiorgis Georgakoudis Module &M = *Fn->getParent(); 304229a3e3ddSGiorgis Georgakoudis 304329a3e3ddSGiorgis Georgakoudis // Create all the blocks and logic. 304429a3e3ddSGiorgis Georgakoudis // ParentBB: 304529a3e3ddSGiorgis Georgakoudis // goto RegionCheckTidBB 304629a3e3ddSGiorgis Georgakoudis // RegionCheckTidBB: 304729a3e3ddSGiorgis Georgakoudis // Tid = __kmpc_hardware_thread_id() 304829a3e3ddSGiorgis Georgakoudis // if (Tid != 0) 304929a3e3ddSGiorgis Georgakoudis // goto RegionBarrierBB 305029a3e3ddSGiorgis Georgakoudis // RegionStartBB: 305129a3e3ddSGiorgis Georgakoudis // <execute instructions guarded> 305229a3e3ddSGiorgis Georgakoudis // goto RegionEndBB 305329a3e3ddSGiorgis Georgakoudis // RegionEndBB: 305429a3e3ddSGiorgis Georgakoudis // <store escaping values to shared mem> 305529a3e3ddSGiorgis Georgakoudis // goto RegionBarrierBB 305629a3e3ddSGiorgis Georgakoudis // RegionBarrierBB: 305729a3e3ddSGiorgis Georgakoudis // __kmpc_simple_barrier_spmd() 305829a3e3ddSGiorgis Georgakoudis // // second barrier is omitted if lacking escaping values. 305929a3e3ddSGiorgis Georgakoudis // <load escaping values from shared mem> 306029a3e3ddSGiorgis Georgakoudis // __kmpc_simple_barrier_spmd() 306129a3e3ddSGiorgis Georgakoudis // goto RegionExitBB 306229a3e3ddSGiorgis Georgakoudis // RegionExitBB: 306329a3e3ddSGiorgis Georgakoudis // <execute rest of instructions> 306429a3e3ddSGiorgis Georgakoudis 306529a3e3ddSGiorgis Georgakoudis BasicBlock *RegionEndBB = SplitBlock(ParentBB, RegionEndI->getNextNode(), 306629a3e3ddSGiorgis Georgakoudis DT, LI, MSU, "region.guarded.end"); 306729a3e3ddSGiorgis Georgakoudis BasicBlock *RegionBarrierBB = 306829a3e3ddSGiorgis Georgakoudis SplitBlock(RegionEndBB, &*RegionEndBB->getFirstInsertionPt(), DT, LI, 306929a3e3ddSGiorgis Georgakoudis MSU, "region.barrier"); 307029a3e3ddSGiorgis Georgakoudis BasicBlock *RegionExitBB = 307129a3e3ddSGiorgis Georgakoudis SplitBlock(RegionBarrierBB, &*RegionBarrierBB->getFirstInsertionPt(), 307229a3e3ddSGiorgis Georgakoudis DT, LI, MSU, "region.exit"); 307329a3e3ddSGiorgis Georgakoudis BasicBlock *RegionStartBB = 307429a3e3ddSGiorgis Georgakoudis SplitBlock(ParentBB, RegionStartI, DT, LI, MSU, "region.guarded"); 307529a3e3ddSGiorgis Georgakoudis 307629a3e3ddSGiorgis Georgakoudis assert(ParentBB->getUniqueSuccessor() == RegionStartBB && 307729a3e3ddSGiorgis Georgakoudis "Expected a different CFG"); 307829a3e3ddSGiorgis Georgakoudis 307929a3e3ddSGiorgis Georgakoudis BasicBlock *RegionCheckTidBB = SplitBlock( 308029a3e3ddSGiorgis Georgakoudis ParentBB, ParentBB->getTerminator(), DT, LI, MSU, "region.check.tid"); 308129a3e3ddSGiorgis Georgakoudis 308229a3e3ddSGiorgis Georgakoudis // Register basic blocks with the Attributor. 308329a3e3ddSGiorgis Georgakoudis A.registerManifestAddedBasicBlock(*RegionEndBB); 308429a3e3ddSGiorgis Georgakoudis A.registerManifestAddedBasicBlock(*RegionBarrierBB); 308529a3e3ddSGiorgis Georgakoudis A.registerManifestAddedBasicBlock(*RegionExitBB); 308629a3e3ddSGiorgis Georgakoudis A.registerManifestAddedBasicBlock(*RegionStartBB); 308729a3e3ddSGiorgis Georgakoudis A.registerManifestAddedBasicBlock(*RegionCheckTidBB); 308829a3e3ddSGiorgis Georgakoudis 308929a3e3ddSGiorgis Georgakoudis bool HasBroadcastValues = false; 309029a3e3ddSGiorgis Georgakoudis // Find escaping outputs from the guarded region to outside users and 309129a3e3ddSGiorgis Georgakoudis // broadcast their values to them. 309229a3e3ddSGiorgis Georgakoudis for (Instruction &I : *RegionStartBB) { 309329a3e3ddSGiorgis Georgakoudis SmallPtrSet<Instruction *, 4> OutsideUsers; 309429a3e3ddSGiorgis Georgakoudis for (User *Usr : I.users()) { 309529a3e3ddSGiorgis Georgakoudis Instruction &UsrI = *cast<Instruction>(Usr); 309629a3e3ddSGiorgis Georgakoudis if (UsrI.getParent() != RegionStartBB) 309729a3e3ddSGiorgis Georgakoudis OutsideUsers.insert(&UsrI); 309829a3e3ddSGiorgis Georgakoudis } 309929a3e3ddSGiorgis Georgakoudis 310029a3e3ddSGiorgis Georgakoudis if (OutsideUsers.empty()) 310129a3e3ddSGiorgis Georgakoudis continue; 310229a3e3ddSGiorgis Georgakoudis 310329a3e3ddSGiorgis Georgakoudis HasBroadcastValues = true; 310429a3e3ddSGiorgis Georgakoudis 310529a3e3ddSGiorgis Georgakoudis // Emit a global variable in shared memory to store the broadcasted 310629a3e3ddSGiorgis Georgakoudis // value. 310729a3e3ddSGiorgis Georgakoudis auto *SharedMem = new GlobalVariable( 310829a3e3ddSGiorgis Georgakoudis M, I.getType(), /* IsConstant */ false, 310929a3e3ddSGiorgis Georgakoudis GlobalValue::InternalLinkage, UndefValue::get(I.getType()), 311029a3e3ddSGiorgis Georgakoudis I.getName() + ".guarded.output.alloc", nullptr, 311129a3e3ddSGiorgis Georgakoudis GlobalValue::NotThreadLocal, 311229a3e3ddSGiorgis Georgakoudis static_cast<unsigned>(AddressSpace::Shared)); 311329a3e3ddSGiorgis Georgakoudis 311429a3e3ddSGiorgis Georgakoudis // Emit a store instruction to update the value. 311529a3e3ddSGiorgis Georgakoudis new StoreInst(&I, SharedMem, RegionEndBB->getTerminator()); 311629a3e3ddSGiorgis Georgakoudis 311729a3e3ddSGiorgis Georgakoudis LoadInst *LoadI = new LoadInst(I.getType(), SharedMem, 311829a3e3ddSGiorgis Georgakoudis I.getName() + ".guarded.output.load", 311929a3e3ddSGiorgis Georgakoudis RegionBarrierBB->getTerminator()); 312029a3e3ddSGiorgis Georgakoudis 312129a3e3ddSGiorgis Georgakoudis // Emit a load instruction and replace uses of the output value. 312229a3e3ddSGiorgis Georgakoudis for (Instruction *UsrI : OutsideUsers) { 312329a3e3ddSGiorgis Georgakoudis assert(UsrI->getParent() == RegionExitBB && 312429a3e3ddSGiorgis Georgakoudis "Expected escaping users in exit region"); 312529a3e3ddSGiorgis Georgakoudis UsrI->replaceUsesOfWith(&I, LoadI); 312629a3e3ddSGiorgis Georgakoudis } 312729a3e3ddSGiorgis Georgakoudis } 312829a3e3ddSGiorgis Georgakoudis 312929a3e3ddSGiorgis Georgakoudis auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); 313029a3e3ddSGiorgis Georgakoudis 313129a3e3ddSGiorgis Georgakoudis // Go to tid check BB in ParentBB. 313229a3e3ddSGiorgis Georgakoudis const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc(); 313329a3e3ddSGiorgis Georgakoudis ParentBB->getTerminator()->eraseFromParent(); 313429a3e3ddSGiorgis Georgakoudis OpenMPIRBuilder::LocationDescription Loc( 313529a3e3ddSGiorgis Georgakoudis InsertPointTy(ParentBB, ParentBB->end()), DL); 313629a3e3ddSGiorgis Georgakoudis OMPInfoCache.OMPBuilder.updateToLocation(Loc); 313729a3e3ddSGiorgis Georgakoudis auto *SrcLocStr = OMPInfoCache.OMPBuilder.getOrCreateSrcLocStr(Loc); 313829a3e3ddSGiorgis Georgakoudis Value *Ident = OMPInfoCache.OMPBuilder.getOrCreateIdent(SrcLocStr); 313929a3e3ddSGiorgis Georgakoudis BranchInst::Create(RegionCheckTidBB, ParentBB)->setDebugLoc(DL); 314029a3e3ddSGiorgis Georgakoudis 314129a3e3ddSGiorgis Georgakoudis // Add check for Tid in RegionCheckTidBB 314229a3e3ddSGiorgis Georgakoudis RegionCheckTidBB->getTerminator()->eraseFromParent(); 314329a3e3ddSGiorgis Georgakoudis OpenMPIRBuilder::LocationDescription LocRegionCheckTid( 314429a3e3ddSGiorgis Georgakoudis InsertPointTy(RegionCheckTidBB, RegionCheckTidBB->end()), DL); 314529a3e3ddSGiorgis Georgakoudis OMPInfoCache.OMPBuilder.updateToLocation(LocRegionCheckTid); 314629a3e3ddSGiorgis Georgakoudis FunctionCallee HardwareTidFn = 314729a3e3ddSGiorgis Georgakoudis OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction( 314829a3e3ddSGiorgis Georgakoudis M, OMPRTL___kmpc_get_hardware_thread_id_in_block); 314929a3e3ddSGiorgis Georgakoudis Value *Tid = 315029a3e3ddSGiorgis Georgakoudis OMPInfoCache.OMPBuilder.Builder.CreateCall(HardwareTidFn, {}); 315129a3e3ddSGiorgis Georgakoudis Value *TidCheck = OMPInfoCache.OMPBuilder.Builder.CreateIsNull(Tid); 315229a3e3ddSGiorgis Georgakoudis OMPInfoCache.OMPBuilder.Builder 315329a3e3ddSGiorgis Georgakoudis .CreateCondBr(TidCheck, RegionStartBB, RegionBarrierBB) 315429a3e3ddSGiorgis Georgakoudis ->setDebugLoc(DL); 315529a3e3ddSGiorgis Georgakoudis 315629a3e3ddSGiorgis Georgakoudis // First barrier for synchronization, ensures main thread has updated 315729a3e3ddSGiorgis Georgakoudis // values. 315829a3e3ddSGiorgis Georgakoudis FunctionCallee BarrierFn = 315929a3e3ddSGiorgis Georgakoudis OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction( 316029a3e3ddSGiorgis Georgakoudis M, OMPRTL___kmpc_barrier_simple_spmd); 316129a3e3ddSGiorgis Georgakoudis OMPInfoCache.OMPBuilder.updateToLocation(InsertPointTy( 316229a3e3ddSGiorgis Georgakoudis RegionBarrierBB, RegionBarrierBB->getFirstInsertionPt())); 316329a3e3ddSGiorgis Georgakoudis OMPInfoCache.OMPBuilder.Builder.CreateCall(BarrierFn, {Ident, Tid}) 316429a3e3ddSGiorgis Georgakoudis ->setDebugLoc(DL); 316529a3e3ddSGiorgis Georgakoudis 316629a3e3ddSGiorgis Georgakoudis // Second barrier ensures workers have read broadcast values. 316729a3e3ddSGiorgis Georgakoudis if (HasBroadcastValues) 316829a3e3ddSGiorgis Georgakoudis CallInst::Create(BarrierFn, {Ident, Tid}, "", 316929a3e3ddSGiorgis Georgakoudis RegionBarrierBB->getTerminator()) 317029a3e3ddSGiorgis Georgakoudis ->setDebugLoc(DL); 317129a3e3ddSGiorgis Georgakoudis }; 317229a3e3ddSGiorgis Georgakoudis 317329a3e3ddSGiorgis Georgakoudis SmallVector<std::pair<Instruction *, Instruction *>, 4> GuardedRegions; 317429a3e3ddSGiorgis Georgakoudis 317529a3e3ddSGiorgis Georgakoudis for (Instruction *GuardedI : SPMDCompatibilityTracker) { 317629a3e3ddSGiorgis Georgakoudis BasicBlock *BB = GuardedI->getParent(); 317729a3e3ddSGiorgis Georgakoudis auto *CalleeAA = A.lookupAAFor<AAKernelInfo>( 317829a3e3ddSGiorgis Georgakoudis IRPosition::function(*GuardedI->getFunction()), nullptr, 317929a3e3ddSGiorgis Georgakoudis DepClassTy::NONE); 318029a3e3ddSGiorgis Georgakoudis assert(CalleeAA != nullptr && "Expected Callee AAKernelInfo"); 318129a3e3ddSGiorgis Georgakoudis auto &CalleeAAFunction = *cast<AAKernelInfoFunction>(CalleeAA); 318229a3e3ddSGiorgis Georgakoudis // Continue if instruction is already guarded. 318329a3e3ddSGiorgis Georgakoudis if (CalleeAAFunction.getGuardedInstructions().contains(GuardedI)) 318429a3e3ddSGiorgis Georgakoudis continue; 318529a3e3ddSGiorgis Georgakoudis 318629a3e3ddSGiorgis Georgakoudis Instruction *GuardedRegionStart = nullptr, *GuardedRegionEnd = nullptr; 318729a3e3ddSGiorgis Georgakoudis for (Instruction &I : *BB) { 318829a3e3ddSGiorgis Georgakoudis // If instruction I needs to be guarded update the guarded region 318929a3e3ddSGiorgis Georgakoudis // bounds. 319029a3e3ddSGiorgis Georgakoudis if (SPMDCompatibilityTracker.contains(&I)) { 319129a3e3ddSGiorgis Georgakoudis CalleeAAFunction.getGuardedInstructions().insert(&I); 319229a3e3ddSGiorgis Georgakoudis if (GuardedRegionStart) 319329a3e3ddSGiorgis Georgakoudis GuardedRegionEnd = &I; 319429a3e3ddSGiorgis Georgakoudis else 319529a3e3ddSGiorgis Georgakoudis GuardedRegionStart = GuardedRegionEnd = &I; 319629a3e3ddSGiorgis Georgakoudis 319729a3e3ddSGiorgis Georgakoudis continue; 319829a3e3ddSGiorgis Georgakoudis } 319929a3e3ddSGiorgis Georgakoudis 320029a3e3ddSGiorgis Georgakoudis // Instruction I does not need guarding, store 320129a3e3ddSGiorgis Georgakoudis // any region found and reset bounds. 320229a3e3ddSGiorgis Georgakoudis if (GuardedRegionStart) { 320329a3e3ddSGiorgis Georgakoudis GuardedRegions.push_back( 320429a3e3ddSGiorgis Georgakoudis std::make_pair(GuardedRegionStart, GuardedRegionEnd)); 320529a3e3ddSGiorgis Georgakoudis GuardedRegionStart = nullptr; 320629a3e3ddSGiorgis Georgakoudis GuardedRegionEnd = nullptr; 320729a3e3ddSGiorgis Georgakoudis } 320829a3e3ddSGiorgis Georgakoudis } 320929a3e3ddSGiorgis Georgakoudis } 321029a3e3ddSGiorgis Georgakoudis 321129a3e3ddSGiorgis Georgakoudis for (auto &GR : GuardedRegions) 321229a3e3ddSGiorgis Georgakoudis CreateGuardedRegion(GR.first, GR.second); 321329a3e3ddSGiorgis Georgakoudis 3214514c033dSJohannes Doerfert // Adjust the global exec mode flag that tells the runtime what mode this 3215514c033dSJohannes Doerfert // kernel is executed in. 3216514c033dSJohannes Doerfert Function *Kernel = getAnchorScope(); 3217514c033dSJohannes Doerfert GlobalVariable *ExecMode = Kernel->getParent()->getGlobalVariable( 3218514c033dSJohannes Doerfert (Kernel->getName() + "_exec_mode").str()); 3219514c033dSJohannes Doerfert assert(ExecMode && "Kernel without exec mode?"); 3220514c033dSJohannes Doerfert assert(ExecMode->getInitializer() && 3221514c033dSJohannes Doerfert ExecMode->getInitializer()->isOneValue() && 3222514c033dSJohannes Doerfert "Initially non-SPMD kernel has SPMD exec mode!"); 32237d576392SJoseph Huber 32247d576392SJoseph Huber // Set the global exec mode flag to indicate SPMD-Generic mode. 32257d576392SJoseph Huber constexpr int SPMDGeneric = 2; 32267d576392SJoseph Huber if (!ExecMode->getInitializer()->isZeroValue()) 3227514c033dSJohannes Doerfert ExecMode->setInitializer( 32287d576392SJoseph Huber ConstantInt::get(ExecMode->getInitializer()->getType(), SPMDGeneric)); 3229514c033dSJohannes Doerfert 3230514c033dSJohannes Doerfert // Next rewrite the init and deinit calls to indicate we use SPMD-mode now. 3231514c033dSJohannes Doerfert const int InitIsSPMDArgNo = 1; 3232514c033dSJohannes Doerfert const int DeinitIsSPMDArgNo = 1; 3233514c033dSJohannes Doerfert const int InitUseStateMachineArgNo = 2; 3234e8439ec8SGiorgis Georgakoudis const int InitRequiresFullRuntimeArgNo = 3; 3235e8439ec8SGiorgis Georgakoudis const int DeinitRequiresFullRuntimeArgNo = 2; 3236514c033dSJohannes Doerfert 3237514c033dSJohannes Doerfert auto &Ctx = getAnchorValue().getContext(); 3238514c033dSJohannes Doerfert A.changeUseAfterManifest(KernelInitCB->getArgOperandUse(InitIsSPMDArgNo), 3239514c033dSJohannes Doerfert *ConstantInt::getBool(Ctx, 1)); 3240514c033dSJohannes Doerfert A.changeUseAfterManifest( 3241514c033dSJohannes Doerfert KernelInitCB->getArgOperandUse(InitUseStateMachineArgNo), 3242514c033dSJohannes Doerfert *ConstantInt::getBool(Ctx, 0)); 3243514c033dSJohannes Doerfert A.changeUseAfterManifest( 3244514c033dSJohannes Doerfert KernelDeinitCB->getArgOperandUse(DeinitIsSPMDArgNo), 3245514c033dSJohannes Doerfert *ConstantInt::getBool(Ctx, 1)); 3246e8439ec8SGiorgis Georgakoudis A.changeUseAfterManifest( 3247e8439ec8SGiorgis Georgakoudis KernelInitCB->getArgOperandUse(InitRequiresFullRuntimeArgNo), 3248e8439ec8SGiorgis Georgakoudis *ConstantInt::getBool(Ctx, 0)); 3249e8439ec8SGiorgis Georgakoudis A.changeUseAfterManifest( 3250e8439ec8SGiorgis Georgakoudis KernelDeinitCB->getArgOperandUse(DeinitRequiresFullRuntimeArgNo), 3251e8439ec8SGiorgis Georgakoudis *ConstantInt::getBool(Ctx, 0)); 3252e8439ec8SGiorgis Georgakoudis 3253514c033dSJohannes Doerfert ++NumOpenMPTargetRegionKernelsSPMD; 3254514c033dSJohannes Doerfert 3255514c033dSJohannes Doerfert auto Remark = [&](OptimizationRemark OR) { 3256eef6601bSJoseph Huber return OR << "Transformed generic-mode kernel to SPMD-mode."; 3257514c033dSJohannes Doerfert }; 32582c31d5ebSJoseph Huber A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP120", Remark); 3259514c033dSJohannes Doerfert return true; 3260514c033dSJohannes Doerfert }; 3261514c033dSJohannes Doerfert 3262d9659bf6SJohannes Doerfert ChangeStatus buildCustomStateMachine(Attributor &A) { 3263cd0dd8ecSJoseph Huber // If we have disabled state machine rewrites, don't make a custom one 3264cd0dd8ecSJoseph Huber if (DisableOpenMPOptStateMachineRewrite) 3265cd0dd8ecSJoseph Huber return indicatePessimisticFixpoint(); 3266cd0dd8ecSJoseph Huber 3267d9659bf6SJohannes Doerfert assert(ReachedKnownParallelRegions.isValidState() && 3268d9659bf6SJohannes Doerfert "Custom state machine with invalid parallel region states?"); 3269d9659bf6SJohannes Doerfert 3270d9659bf6SJohannes Doerfert const int InitIsSPMDArgNo = 1; 3271d9659bf6SJohannes Doerfert const int InitUseStateMachineArgNo = 2; 3272d9659bf6SJohannes Doerfert 3273d9659bf6SJohannes Doerfert // Check if the current configuration is non-SPMD and generic state machine. 3274d9659bf6SJohannes Doerfert // If we already have SPMD mode or a custom state machine we do not need to 3275d9659bf6SJohannes Doerfert // go any further. If it is anything but a constant something is weird and 3276d9659bf6SJohannes Doerfert // we give up. 3277d9659bf6SJohannes Doerfert ConstantInt *UseStateMachine = dyn_cast<ConstantInt>( 3278d9659bf6SJohannes Doerfert KernelInitCB->getArgOperand(InitUseStateMachineArgNo)); 3279d9659bf6SJohannes Doerfert ConstantInt *IsSPMD = 3280d9659bf6SJohannes Doerfert dyn_cast<ConstantInt>(KernelInitCB->getArgOperand(InitIsSPMDArgNo)); 3281d9659bf6SJohannes Doerfert 3282d9659bf6SJohannes Doerfert // If we are stuck with generic mode, try to create a custom device (=GPU) 3283d9659bf6SJohannes Doerfert // state machine which is specialized for the parallel regions that are 3284d9659bf6SJohannes Doerfert // reachable by the kernel. 3285d9659bf6SJohannes Doerfert if (!UseStateMachine || UseStateMachine->isZero() || !IsSPMD || 3286d9659bf6SJohannes Doerfert !IsSPMD->isZero()) 3287d9659bf6SJohannes Doerfert return ChangeStatus::UNCHANGED; 3288d9659bf6SJohannes Doerfert 3289514c033dSJohannes Doerfert // If not SPMD mode, indicate we use a custom state machine now. 3290d9659bf6SJohannes Doerfert auto &Ctx = getAnchorValue().getContext(); 3291d9659bf6SJohannes Doerfert auto *FalseVal = ConstantInt::getBool(Ctx, 0); 3292d9659bf6SJohannes Doerfert A.changeUseAfterManifest( 3293d9659bf6SJohannes Doerfert KernelInitCB->getArgOperandUse(InitUseStateMachineArgNo), *FalseVal); 3294d9659bf6SJohannes Doerfert 3295d9659bf6SJohannes Doerfert // If we don't actually need a state machine we are done here. This can 3296d9659bf6SJohannes Doerfert // happen if there simply are no parallel regions. In the resulting kernel 3297d9659bf6SJohannes Doerfert // all worker threads will simply exit right away, leaving the main thread 3298d9659bf6SJohannes Doerfert // to do the work alone. 3299d9659bf6SJohannes Doerfert if (ReachedKnownParallelRegions.empty() && 3300d9659bf6SJohannes Doerfert ReachedUnknownParallelRegions.empty()) { 3301d9659bf6SJohannes Doerfert ++NumOpenMPTargetRegionKernelsWithoutStateMachine; 3302d9659bf6SJohannes Doerfert 3303d9659bf6SJohannes Doerfert auto Remark = [&](OptimizationRemark OR) { 3304eef6601bSJoseph Huber return OR << "Removing unused state machine from generic-mode kernel."; 3305d9659bf6SJohannes Doerfert }; 33062c31d5ebSJoseph Huber A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP130", Remark); 3307d9659bf6SJohannes Doerfert 3308d9659bf6SJohannes Doerfert return ChangeStatus::CHANGED; 3309d9659bf6SJohannes Doerfert } 3310d9659bf6SJohannes Doerfert 3311d9659bf6SJohannes Doerfert // Keep track in the statistics of our new shiny custom state machine. 3312d9659bf6SJohannes Doerfert if (ReachedUnknownParallelRegions.empty()) { 3313d9659bf6SJohannes Doerfert ++NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback; 3314d9659bf6SJohannes Doerfert 3315d9659bf6SJohannes Doerfert auto Remark = [&](OptimizationRemark OR) { 3316eef6601bSJoseph Huber return OR << "Rewriting generic-mode kernel with a customized state " 3317eef6601bSJoseph Huber "machine."; 3318d9659bf6SJohannes Doerfert }; 33192c31d5ebSJoseph Huber A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP131", Remark); 3320d9659bf6SJohannes Doerfert } else { 3321d9659bf6SJohannes Doerfert ++NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback; 3322d9659bf6SJohannes Doerfert 3323eef6601bSJoseph Huber auto Remark = [&](OptimizationRemarkAnalysis OR) { 3324d9659bf6SJohannes Doerfert return OR << "Generic-mode kernel is executed with a customized state " 3325eef6601bSJoseph Huber "machine that requires a fallback."; 3326d9659bf6SJohannes Doerfert }; 33272c31d5ebSJoseph Huber A.emitRemark<OptimizationRemarkAnalysis>(KernelInitCB, "OMP132", Remark); 3328d9659bf6SJohannes Doerfert 3329d9659bf6SJohannes Doerfert // Tell the user why we ended up with a fallback. 3330d9659bf6SJohannes Doerfert for (CallBase *UnknownParallelRegionCB : ReachedUnknownParallelRegions) { 3331d9659bf6SJohannes Doerfert if (!UnknownParallelRegionCB) 3332d9659bf6SJohannes Doerfert continue; 3333d9659bf6SJohannes Doerfert auto Remark = [&](OptimizationRemarkAnalysis ORA) { 3334eef6601bSJoseph Huber return ORA << "Call may contain unknown parallel regions. Use " 3335eef6601bSJoseph Huber << "`__attribute__((assume(\"omp_no_parallelism\")))` to " 3336eef6601bSJoseph Huber "override."; 3337d9659bf6SJohannes Doerfert }; 33382c31d5ebSJoseph Huber A.emitRemark<OptimizationRemarkAnalysis>(UnknownParallelRegionCB, 33392c31d5ebSJoseph Huber "OMP133", Remark); 3340d9659bf6SJohannes Doerfert } 3341d9659bf6SJohannes Doerfert } 3342d9659bf6SJohannes Doerfert 3343d9659bf6SJohannes Doerfert // Create all the blocks: 3344d9659bf6SJohannes Doerfert // 3345d9659bf6SJohannes Doerfert // InitCB = __kmpc_target_init(...) 3346d9659bf6SJohannes Doerfert // bool IsWorker = InitCB >= 0; 3347d9659bf6SJohannes Doerfert // if (IsWorker) { 3348d9659bf6SJohannes Doerfert // SMBeginBB: __kmpc_barrier_simple_spmd(...); 3349d9659bf6SJohannes Doerfert // void *WorkFn; 3350d9659bf6SJohannes Doerfert // bool Active = __kmpc_kernel_parallel(&WorkFn); 3351d9659bf6SJohannes Doerfert // if (!WorkFn) return; 3352d9659bf6SJohannes Doerfert // SMIsActiveCheckBB: if (Active) { 3353d9659bf6SJohannes Doerfert // SMIfCascadeCurrentBB: if (WorkFn == <ParFn0>) 3354d9659bf6SJohannes Doerfert // ParFn0(...); 3355d9659bf6SJohannes Doerfert // SMIfCascadeCurrentBB: else if (WorkFn == <ParFn1>) 3356d9659bf6SJohannes Doerfert // ParFn1(...); 3357d9659bf6SJohannes Doerfert // ... 3358d9659bf6SJohannes Doerfert // SMIfCascadeCurrentBB: else 3359d9659bf6SJohannes Doerfert // ((WorkFnTy*)WorkFn)(...); 3360d9659bf6SJohannes Doerfert // SMEndParallelBB: __kmpc_kernel_end_parallel(...); 3361d9659bf6SJohannes Doerfert // } 3362d9659bf6SJohannes Doerfert // SMDoneBB: __kmpc_barrier_simple_spmd(...); 3363d9659bf6SJohannes Doerfert // goto SMBeginBB; 3364d9659bf6SJohannes Doerfert // } 3365d9659bf6SJohannes Doerfert // UserCodeEntryBB: // user code 3366d9659bf6SJohannes Doerfert // __kmpc_target_deinit(...) 3367d9659bf6SJohannes Doerfert // 3368d9659bf6SJohannes Doerfert Function *Kernel = getAssociatedFunction(); 3369d9659bf6SJohannes Doerfert assert(Kernel && "Expected an associated function!"); 3370d9659bf6SJohannes Doerfert 3371d9659bf6SJohannes Doerfert BasicBlock *InitBB = KernelInitCB->getParent(); 3372d9659bf6SJohannes Doerfert BasicBlock *UserCodeEntryBB = InitBB->splitBasicBlock( 3373d9659bf6SJohannes Doerfert KernelInitCB->getNextNode(), "thread.user_code.check"); 3374d9659bf6SJohannes Doerfert BasicBlock *StateMachineBeginBB = BasicBlock::Create( 3375d9659bf6SJohannes Doerfert Ctx, "worker_state_machine.begin", Kernel, UserCodeEntryBB); 3376d9659bf6SJohannes Doerfert BasicBlock *StateMachineFinishedBB = BasicBlock::Create( 3377d9659bf6SJohannes Doerfert Ctx, "worker_state_machine.finished", Kernel, UserCodeEntryBB); 3378d9659bf6SJohannes Doerfert BasicBlock *StateMachineIsActiveCheckBB = BasicBlock::Create( 3379d9659bf6SJohannes Doerfert Ctx, "worker_state_machine.is_active.check", Kernel, UserCodeEntryBB); 3380d9659bf6SJohannes Doerfert BasicBlock *StateMachineIfCascadeCurrentBB = 3381d9659bf6SJohannes Doerfert BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.check", 3382d9659bf6SJohannes Doerfert Kernel, UserCodeEntryBB); 3383d9659bf6SJohannes Doerfert BasicBlock *StateMachineEndParallelBB = 3384d9659bf6SJohannes Doerfert BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.end", 3385d9659bf6SJohannes Doerfert Kernel, UserCodeEntryBB); 3386d9659bf6SJohannes Doerfert BasicBlock *StateMachineDoneBarrierBB = BasicBlock::Create( 3387d9659bf6SJohannes Doerfert Ctx, "worker_state_machine.done.barrier", Kernel, UserCodeEntryBB); 33883f71b425SGiorgis Georgakoudis A.registerManifestAddedBasicBlock(*InitBB); 33893f71b425SGiorgis Georgakoudis A.registerManifestAddedBasicBlock(*UserCodeEntryBB); 33903f71b425SGiorgis Georgakoudis A.registerManifestAddedBasicBlock(*StateMachineBeginBB); 33913f71b425SGiorgis Georgakoudis A.registerManifestAddedBasicBlock(*StateMachineFinishedBB); 33923f71b425SGiorgis Georgakoudis A.registerManifestAddedBasicBlock(*StateMachineIsActiveCheckBB); 33933f71b425SGiorgis Georgakoudis A.registerManifestAddedBasicBlock(*StateMachineIfCascadeCurrentBB); 33943f71b425SGiorgis Georgakoudis A.registerManifestAddedBasicBlock(*StateMachineEndParallelBB); 33953f71b425SGiorgis Georgakoudis A.registerManifestAddedBasicBlock(*StateMachineDoneBarrierBB); 3396d9659bf6SJohannes Doerfert 3397d9659bf6SJohannes Doerfert const DebugLoc &DLoc = KernelInitCB->getDebugLoc(); 3398d9659bf6SJohannes Doerfert ReturnInst::Create(Ctx, StateMachineFinishedBB)->setDebugLoc(DLoc); 3399d9659bf6SJohannes Doerfert 3400d9659bf6SJohannes Doerfert InitBB->getTerminator()->eraseFromParent(); 3401d9659bf6SJohannes Doerfert Instruction *IsWorker = 3402d9659bf6SJohannes Doerfert ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_NE, KernelInitCB, 3403d9659bf6SJohannes Doerfert ConstantInt::get(KernelInitCB->getType(), -1), 3404d9659bf6SJohannes Doerfert "thread.is_worker", InitBB); 3405d9659bf6SJohannes Doerfert IsWorker->setDebugLoc(DLoc); 3406d9659bf6SJohannes Doerfert BranchInst::Create(StateMachineBeginBB, UserCodeEntryBB, IsWorker, InitBB); 3407d9659bf6SJohannes Doerfert 3408d9659bf6SJohannes Doerfert // Create local storage for the work function pointer. 3409d9659bf6SJohannes Doerfert Type *VoidPtrTy = Type::getInt8PtrTy(Ctx); 3410d9659bf6SJohannes Doerfert AllocaInst *WorkFnAI = new AllocaInst(VoidPtrTy, 0, "worker.work_fn.addr", 3411d9659bf6SJohannes Doerfert &Kernel->getEntryBlock().front()); 3412d9659bf6SJohannes Doerfert WorkFnAI->setDebugLoc(DLoc); 3413d9659bf6SJohannes Doerfert 3414d9659bf6SJohannes Doerfert auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); 3415d9659bf6SJohannes Doerfert OMPInfoCache.OMPBuilder.updateToLocation( 3416d9659bf6SJohannes Doerfert OpenMPIRBuilder::LocationDescription( 3417d9659bf6SJohannes Doerfert IRBuilder<>::InsertPoint(StateMachineBeginBB, 3418d9659bf6SJohannes Doerfert StateMachineBeginBB->end()), 3419d9659bf6SJohannes Doerfert DLoc)); 3420d9659bf6SJohannes Doerfert 3421d9659bf6SJohannes Doerfert Value *Ident = KernelInitCB->getArgOperand(0); 3422d9659bf6SJohannes Doerfert Value *GTid = KernelInitCB; 3423d9659bf6SJohannes Doerfert 3424d9659bf6SJohannes Doerfert Module &M = *Kernel->getParent(); 3425d9659bf6SJohannes Doerfert FunctionCallee BarrierFn = 3426d9659bf6SJohannes Doerfert OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction( 3427d9659bf6SJohannes Doerfert M, OMPRTL___kmpc_barrier_simple_spmd); 3428d9659bf6SJohannes Doerfert CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineBeginBB) 3429d9659bf6SJohannes Doerfert ->setDebugLoc(DLoc); 3430d9659bf6SJohannes Doerfert 3431d9659bf6SJohannes Doerfert FunctionCallee KernelParallelFn = 3432d9659bf6SJohannes Doerfert OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction( 3433d9659bf6SJohannes Doerfert M, OMPRTL___kmpc_kernel_parallel); 3434d9659bf6SJohannes Doerfert Instruction *IsActiveWorker = CallInst::Create( 3435d9659bf6SJohannes Doerfert KernelParallelFn, {WorkFnAI}, "worker.is_active", StateMachineBeginBB); 3436d9659bf6SJohannes Doerfert IsActiveWorker->setDebugLoc(DLoc); 3437d9659bf6SJohannes Doerfert Instruction *WorkFn = new LoadInst(VoidPtrTy, WorkFnAI, "worker.work_fn", 3438d9659bf6SJohannes Doerfert StateMachineBeginBB); 3439d9659bf6SJohannes Doerfert WorkFn->setDebugLoc(DLoc); 3440d9659bf6SJohannes Doerfert 3441d9659bf6SJohannes Doerfert FunctionType *ParallelRegionFnTy = FunctionType::get( 3442d9659bf6SJohannes Doerfert Type::getVoidTy(Ctx), {Type::getInt16Ty(Ctx), Type::getInt32Ty(Ctx)}, 3443d9659bf6SJohannes Doerfert false); 3444d9659bf6SJohannes Doerfert Value *WorkFnCast = BitCastInst::CreatePointerBitCastOrAddrSpaceCast( 3445d9659bf6SJohannes Doerfert WorkFn, ParallelRegionFnTy->getPointerTo(), "worker.work_fn.addr_cast", 3446d9659bf6SJohannes Doerfert StateMachineBeginBB); 3447d9659bf6SJohannes Doerfert 3448d9659bf6SJohannes Doerfert Instruction *IsDone = 3449d9659bf6SJohannes Doerfert ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFn, 3450d9659bf6SJohannes Doerfert Constant::getNullValue(VoidPtrTy), "worker.is_done", 3451d9659bf6SJohannes Doerfert StateMachineBeginBB); 3452d9659bf6SJohannes Doerfert IsDone->setDebugLoc(DLoc); 3453d9659bf6SJohannes Doerfert BranchInst::Create(StateMachineFinishedBB, StateMachineIsActiveCheckBB, 3454d9659bf6SJohannes Doerfert IsDone, StateMachineBeginBB) 3455d9659bf6SJohannes Doerfert ->setDebugLoc(DLoc); 3456d9659bf6SJohannes Doerfert 3457d9659bf6SJohannes Doerfert BranchInst::Create(StateMachineIfCascadeCurrentBB, 3458d9659bf6SJohannes Doerfert StateMachineDoneBarrierBB, IsActiveWorker, 3459d9659bf6SJohannes Doerfert StateMachineIsActiveCheckBB) 3460d9659bf6SJohannes Doerfert ->setDebugLoc(DLoc); 3461d9659bf6SJohannes Doerfert 3462d9659bf6SJohannes Doerfert Value *ZeroArg = 3463d9659bf6SJohannes Doerfert Constant::getNullValue(ParallelRegionFnTy->getParamType(0)); 3464d9659bf6SJohannes Doerfert 3465d9659bf6SJohannes Doerfert // Now that we have most of the CFG skeleton it is time for the if-cascade 3466d9659bf6SJohannes Doerfert // that checks the function pointer we got from the runtime against the 3467d9659bf6SJohannes Doerfert // parallel regions we expect, if there are any. 3468d9659bf6SJohannes Doerfert for (int i = 0, e = ReachedKnownParallelRegions.size(); i < e; ++i) { 3469d9659bf6SJohannes Doerfert auto *ParallelRegion = ReachedKnownParallelRegions[i]; 3470d9659bf6SJohannes Doerfert BasicBlock *PRExecuteBB = BasicBlock::Create( 3471d9659bf6SJohannes Doerfert Ctx, "worker_state_machine.parallel_region.execute", Kernel, 3472d9659bf6SJohannes Doerfert StateMachineEndParallelBB); 3473d9659bf6SJohannes Doerfert CallInst::Create(ParallelRegion, {ZeroArg, GTid}, "", PRExecuteBB) 3474d9659bf6SJohannes Doerfert ->setDebugLoc(DLoc); 3475d9659bf6SJohannes Doerfert BranchInst::Create(StateMachineEndParallelBB, PRExecuteBB) 3476d9659bf6SJohannes Doerfert ->setDebugLoc(DLoc); 3477d9659bf6SJohannes Doerfert 3478d9659bf6SJohannes Doerfert BasicBlock *PRNextBB = 3479d9659bf6SJohannes Doerfert BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.check", 3480d9659bf6SJohannes Doerfert Kernel, StateMachineEndParallelBB); 3481d9659bf6SJohannes Doerfert 3482d9659bf6SJohannes Doerfert // Check if we need to compare the pointer at all or if we can just 3483d9659bf6SJohannes Doerfert // call the parallel region function. 3484d9659bf6SJohannes Doerfert Value *IsPR; 3485d9659bf6SJohannes Doerfert if (i + 1 < e || !ReachedUnknownParallelRegions.empty()) { 3486d9659bf6SJohannes Doerfert Instruction *CmpI = ICmpInst::Create( 3487d9659bf6SJohannes Doerfert ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFnCast, ParallelRegion, 3488d9659bf6SJohannes Doerfert "worker.check_parallel_region", StateMachineIfCascadeCurrentBB); 3489d9659bf6SJohannes Doerfert CmpI->setDebugLoc(DLoc); 3490d9659bf6SJohannes Doerfert IsPR = CmpI; 3491d9659bf6SJohannes Doerfert } else { 3492d9659bf6SJohannes Doerfert IsPR = ConstantInt::getTrue(Ctx); 3493d9659bf6SJohannes Doerfert } 3494d9659bf6SJohannes Doerfert 3495d9659bf6SJohannes Doerfert BranchInst::Create(PRExecuteBB, PRNextBB, IsPR, 3496d9659bf6SJohannes Doerfert StateMachineIfCascadeCurrentBB) 3497d9659bf6SJohannes Doerfert ->setDebugLoc(DLoc); 3498d9659bf6SJohannes Doerfert StateMachineIfCascadeCurrentBB = PRNextBB; 3499d9659bf6SJohannes Doerfert } 3500d9659bf6SJohannes Doerfert 3501d9659bf6SJohannes Doerfert // At the end of the if-cascade we place the indirect function pointer call 3502d9659bf6SJohannes Doerfert // in case we might need it, that is if there can be parallel regions we 3503d9659bf6SJohannes Doerfert // have not handled in the if-cascade above. 3504d9659bf6SJohannes Doerfert if (!ReachedUnknownParallelRegions.empty()) { 3505d9659bf6SJohannes Doerfert StateMachineIfCascadeCurrentBB->setName( 3506d9659bf6SJohannes Doerfert "worker_state_machine.parallel_region.fallback.execute"); 3507d9659bf6SJohannes Doerfert CallInst::Create(ParallelRegionFnTy, WorkFnCast, {ZeroArg, GTid}, "", 3508d9659bf6SJohannes Doerfert StateMachineIfCascadeCurrentBB) 3509d9659bf6SJohannes Doerfert ->setDebugLoc(DLoc); 3510d9659bf6SJohannes Doerfert } 3511d9659bf6SJohannes Doerfert BranchInst::Create(StateMachineEndParallelBB, 3512d9659bf6SJohannes Doerfert StateMachineIfCascadeCurrentBB) 3513d9659bf6SJohannes Doerfert ->setDebugLoc(DLoc); 3514d9659bf6SJohannes Doerfert 3515d9659bf6SJohannes Doerfert CallInst::Create(OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction( 3516d9659bf6SJohannes Doerfert M, OMPRTL___kmpc_kernel_end_parallel), 3517d9659bf6SJohannes Doerfert {}, "", StateMachineEndParallelBB) 3518d9659bf6SJohannes Doerfert ->setDebugLoc(DLoc); 3519d9659bf6SJohannes Doerfert BranchInst::Create(StateMachineDoneBarrierBB, StateMachineEndParallelBB) 3520d9659bf6SJohannes Doerfert ->setDebugLoc(DLoc); 3521d9659bf6SJohannes Doerfert 3522d9659bf6SJohannes Doerfert CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineDoneBarrierBB) 3523d9659bf6SJohannes Doerfert ->setDebugLoc(DLoc); 3524d9659bf6SJohannes Doerfert BranchInst::Create(StateMachineBeginBB, StateMachineDoneBarrierBB) 3525d9659bf6SJohannes Doerfert ->setDebugLoc(DLoc); 3526d9659bf6SJohannes Doerfert 3527d9659bf6SJohannes Doerfert return ChangeStatus::CHANGED; 3528d9659bf6SJohannes Doerfert } 3529d9659bf6SJohannes Doerfert 3530d9659bf6SJohannes Doerfert /// Fixpoint iteration update function. Will be called every time a dependence 3531d9659bf6SJohannes Doerfert /// changed its state (and in the beginning). 3532d9659bf6SJohannes Doerfert ChangeStatus updateImpl(Attributor &A) override { 3533d9659bf6SJohannes Doerfert KernelInfoState StateBefore = getState(); 3534d9659bf6SJohannes Doerfert 3535514c033dSJohannes Doerfert // Callback to check a read/write instruction. 3536514c033dSJohannes Doerfert auto CheckRWInst = [&](Instruction &I) { 3537514c033dSJohannes Doerfert // We handle calls later. 3538514c033dSJohannes Doerfert if (isa<CallBase>(I)) 3539514c033dSJohannes Doerfert return true; 3540514c033dSJohannes Doerfert // We only care about write effects. 3541514c033dSJohannes Doerfert if (!I.mayWriteToMemory()) 3542514c033dSJohannes Doerfert return true; 3543514c033dSJohannes Doerfert if (auto *SI = dyn_cast<StoreInst>(&I)) { 3544514c033dSJohannes Doerfert SmallVector<const Value *> Objects; 3545514c033dSJohannes Doerfert getUnderlyingObjects(SI->getPointerOperand(), Objects); 3546514c033dSJohannes Doerfert if (llvm::all_of(Objects, 3547514c033dSJohannes Doerfert [](const Value *Obj) { return isa<AllocaInst>(Obj); })) 3548514c033dSJohannes Doerfert return true; 354929a3e3ddSGiorgis Georgakoudis // Check for AAHeapToStack moved objects which must not be guarded. 355029a3e3ddSGiorgis Georgakoudis auto &HS = A.getAAFor<AAHeapToStack>( 355129a3e3ddSGiorgis Georgakoudis *this, IRPosition::function(*I.getFunction()), 355229a3e3ddSGiorgis Georgakoudis DepClassTy::REQUIRED); 355329a3e3ddSGiorgis Georgakoudis if (llvm::all_of(Objects, [&HS](const Value *Obj) { 355429a3e3ddSGiorgis Georgakoudis auto *CB = dyn_cast<CallBase>(Obj); 355529a3e3ddSGiorgis Georgakoudis if (!CB) 355629a3e3ddSGiorgis Georgakoudis return false; 355729a3e3ddSGiorgis Georgakoudis return HS.isAssumedHeapToStack(*CB); 355829a3e3ddSGiorgis Georgakoudis })) { 355929a3e3ddSGiorgis Georgakoudis return true; 3560514c033dSJohannes Doerfert } 356129a3e3ddSGiorgis Georgakoudis } 356229a3e3ddSGiorgis Georgakoudis 356329a3e3ddSGiorgis Georgakoudis // Insert instruction that needs guarding. 3564514c033dSJohannes Doerfert SPMDCompatibilityTracker.insert(&I); 3565514c033dSJohannes Doerfert return true; 3566514c033dSJohannes Doerfert }; 3567792aac98SJohannes Doerfert 3568792aac98SJohannes Doerfert bool UsedAssumedInformationInCheckRWInst = false; 356997387fdfSJohannes Doerfert if (!SPMDCompatibilityTracker.isAtFixpoint()) 3570792aac98SJohannes Doerfert if (!A.checkForAllReadWriteInstructions( 3571792aac98SJohannes Doerfert CheckRWInst, *this, UsedAssumedInformationInCheckRWInst)) 3572514c033dSJohannes Doerfert SPMDCompatibilityTracker.indicatePessimisticFixpoint(); 3573514c033dSJohannes Doerfert 3574e97e0a4fSShilei Tian if (!IsKernelEntry) { 3575ca662297SShilei Tian updateReachingKernelEntries(A); 3576e97e0a4fSShilei Tian updateParallelLevels(A); 357729a3e3ddSGiorgis Georgakoudis 357829a3e3ddSGiorgis Georgakoudis if (!ParallelLevels.isValidState()) 357929a3e3ddSGiorgis Georgakoudis SPMDCompatibilityTracker.indicatePessimisticFixpoint(); 3580e97e0a4fSShilei Tian } 3581ca662297SShilei Tian 3582d9659bf6SJohannes Doerfert // Callback to check a call instruction. 358397387fdfSJohannes Doerfert bool AllSPMDStatesWereFixed = true; 3584d9659bf6SJohannes Doerfert auto CheckCallInst = [&](Instruction &I) { 3585d9659bf6SJohannes Doerfert auto &CB = cast<CallBase>(I); 3586d9659bf6SJohannes Doerfert auto &CBAA = A.getAAFor<AAKernelInfo>( 3587d9659bf6SJohannes Doerfert *this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL); 3588d9659bf6SJohannes Doerfert getState() ^= CBAA.getState(); 358997387fdfSJohannes Doerfert AllSPMDStatesWereFixed &= CBAA.SPMDCompatibilityTracker.isAtFixpoint(); 3590d9659bf6SJohannes Doerfert return true; 3591d9659bf6SJohannes Doerfert }; 3592d9659bf6SJohannes Doerfert 3593792aac98SJohannes Doerfert bool UsedAssumedInformationInCheckCallInst = false; 3594792aac98SJohannes Doerfert if (!A.checkForAllCallLikeInstructions( 3595792aac98SJohannes Doerfert CheckCallInst, *this, UsedAssumedInformationInCheckCallInst)) 3596d9659bf6SJohannes Doerfert return indicatePessimisticFixpoint(); 3597d9659bf6SJohannes Doerfert 359897387fdfSJohannes Doerfert // If we haven't used any assumed information for the SPMD state we can fix 359997387fdfSJohannes Doerfert // it. 360097387fdfSJohannes Doerfert if (!UsedAssumedInformationInCheckRWInst && 360197387fdfSJohannes Doerfert !UsedAssumedInformationInCheckCallInst && AllSPMDStatesWereFixed) 360297387fdfSJohannes Doerfert SPMDCompatibilityTracker.indicateOptimisticFixpoint(); 360397387fdfSJohannes Doerfert 3604d9659bf6SJohannes Doerfert return StateBefore == getState() ? ChangeStatus::UNCHANGED 3605d9659bf6SJohannes Doerfert : ChangeStatus::CHANGED; 3606d9659bf6SJohannes Doerfert } 3607ca662297SShilei Tian 3608ca662297SShilei Tian private: 3609ca662297SShilei Tian /// Update info regarding reaching kernels. 3610ca662297SShilei Tian void updateReachingKernelEntries(Attributor &A) { 3611ca662297SShilei Tian auto PredCallSite = [&](AbstractCallSite ACS) { 3612ca662297SShilei Tian Function *Caller = ACS.getInstruction()->getFunction(); 3613ca662297SShilei Tian 3614ca662297SShilei Tian assert(Caller && "Caller is nullptr"); 3615ca662297SShilei Tian 3616d3454ee8SShilei Tian auto &CAA = A.getOrCreateAAFor<AAKernelInfo>( 3617d3454ee8SShilei Tian IRPosition::function(*Caller), this, DepClassTy::REQUIRED); 3618ca662297SShilei Tian if (CAA.ReachingKernelEntries.isValidState()) { 3619ca662297SShilei Tian ReachingKernelEntries ^= CAA.ReachingKernelEntries; 3620ca662297SShilei Tian return true; 3621ca662297SShilei Tian } 3622ca662297SShilei Tian 3623ca662297SShilei Tian // We lost track of the caller of the associated function, any kernel 3624ca662297SShilei Tian // could reach now. 3625ca662297SShilei Tian ReachingKernelEntries.indicatePessimisticFixpoint(); 3626ca662297SShilei Tian 3627ca662297SShilei Tian return true; 3628ca662297SShilei Tian }; 3629ca662297SShilei Tian 3630ca662297SShilei Tian bool AllCallSitesKnown; 3631ca662297SShilei Tian if (!A.checkForAllCallSites(PredCallSite, *this, 3632ca662297SShilei Tian true /* RequireAllCallSites */, 3633ca662297SShilei Tian AllCallSitesKnown)) 3634ca662297SShilei Tian ReachingKernelEntries.indicatePessimisticFixpoint(); 3635ca662297SShilei Tian } 3636e97e0a4fSShilei Tian 3637e97e0a4fSShilei Tian /// Update info regarding parallel levels. 3638e97e0a4fSShilei Tian void updateParallelLevels(Attributor &A) { 3639e97e0a4fSShilei Tian auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); 3640e97e0a4fSShilei Tian OMPInformationCache::RuntimeFunctionInfo &Parallel51RFI = 3641e97e0a4fSShilei Tian OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51]; 3642e97e0a4fSShilei Tian 3643e97e0a4fSShilei Tian auto PredCallSite = [&](AbstractCallSite ACS) { 3644e97e0a4fSShilei Tian Function *Caller = ACS.getInstruction()->getFunction(); 3645e97e0a4fSShilei Tian 3646e97e0a4fSShilei Tian assert(Caller && "Caller is nullptr"); 3647e97e0a4fSShilei Tian 3648e97e0a4fSShilei Tian auto &CAA = 3649e97e0a4fSShilei Tian A.getOrCreateAAFor<AAKernelInfo>(IRPosition::function(*Caller)); 3650e97e0a4fSShilei Tian if (CAA.ParallelLevels.isValidState()) { 3651e97e0a4fSShilei Tian // Any function that is called by `__kmpc_parallel_51` will not be 3652e97e0a4fSShilei Tian // folded as the parallel level in the function is updated. In order to 3653e97e0a4fSShilei Tian // get it right, all the analysis would depend on the implentation. That 3654e97e0a4fSShilei Tian // said, if in the future any change to the implementation, the analysis 3655e97e0a4fSShilei Tian // could be wrong. As a consequence, we are just conservative here. 3656e97e0a4fSShilei Tian if (Caller == Parallel51RFI.Declaration) { 3657e97e0a4fSShilei Tian ParallelLevels.indicatePessimisticFixpoint(); 3658e97e0a4fSShilei Tian return true; 3659e97e0a4fSShilei Tian } 3660e97e0a4fSShilei Tian 3661e97e0a4fSShilei Tian ParallelLevels ^= CAA.ParallelLevels; 3662e97e0a4fSShilei Tian 3663e97e0a4fSShilei Tian return true; 3664e97e0a4fSShilei Tian } 3665e97e0a4fSShilei Tian 3666e97e0a4fSShilei Tian // We lost track of the caller of the associated function, any kernel 3667e97e0a4fSShilei Tian // could reach now. 3668e97e0a4fSShilei Tian ParallelLevels.indicatePessimisticFixpoint(); 3669e97e0a4fSShilei Tian 3670e97e0a4fSShilei Tian return true; 3671e97e0a4fSShilei Tian }; 3672e97e0a4fSShilei Tian 3673e97e0a4fSShilei Tian bool AllCallSitesKnown = true; 3674e97e0a4fSShilei Tian if (!A.checkForAllCallSites(PredCallSite, *this, 3675e97e0a4fSShilei Tian true /* RequireAllCallSites */, 3676e97e0a4fSShilei Tian AllCallSitesKnown)) 3677e97e0a4fSShilei Tian ParallelLevels.indicatePessimisticFixpoint(); 3678e97e0a4fSShilei Tian } 3679d9659bf6SJohannes Doerfert }; 3680d9659bf6SJohannes Doerfert 3681d9659bf6SJohannes Doerfert /// The call site kernel info abstract attribute, basically, what can we say 3682d9659bf6SJohannes Doerfert /// about a call site with regards to the KernelInfoState. For now this simply 3683d9659bf6SJohannes Doerfert /// forwards the information from the callee. 3684d9659bf6SJohannes Doerfert struct AAKernelInfoCallSite : AAKernelInfo { 3685d9659bf6SJohannes Doerfert AAKernelInfoCallSite(const IRPosition &IRP, Attributor &A) 3686d9659bf6SJohannes Doerfert : AAKernelInfo(IRP, A) {} 3687d9659bf6SJohannes Doerfert 3688d9659bf6SJohannes Doerfert /// See AbstractAttribute::initialize(...). 3689d9659bf6SJohannes Doerfert void initialize(Attributor &A) override { 3690d9659bf6SJohannes Doerfert AAKernelInfo::initialize(A); 3691d9659bf6SJohannes Doerfert 3692d9659bf6SJohannes Doerfert CallBase &CB = cast<CallBase>(getAssociatedValue()); 3693d9659bf6SJohannes Doerfert Function *Callee = getAssociatedFunction(); 3694d9659bf6SJohannes Doerfert 3695d9659bf6SJohannes Doerfert // Helper to lookup an assumption string. 3696d9659bf6SJohannes Doerfert auto HasAssumption = [](Function *Fn, StringRef AssumptionStr) { 3697d9659bf6SJohannes Doerfert return Fn && hasAssumption(*Fn, AssumptionStr); 3698d9659bf6SJohannes Doerfert }; 3699d9659bf6SJohannes Doerfert 3700514c033dSJohannes Doerfert // Check for SPMD-mode assumptions. 3701514c033dSJohannes Doerfert if (HasAssumption(Callee, "ompx_spmd_amenable")) 3702514c033dSJohannes Doerfert SPMDCompatibilityTracker.indicateOptimisticFixpoint(); 3703514c033dSJohannes Doerfert 3704d9659bf6SJohannes Doerfert // First weed out calls we do not care about, that is readonly/readnone 3705d9659bf6SJohannes Doerfert // calls, intrinsics, and "no_openmp" calls. Neither of these can reach a 3706d9659bf6SJohannes Doerfert // parallel region or anything else we are looking for. 3707d9659bf6SJohannes Doerfert if (!CB.mayWriteToMemory() || isa<IntrinsicInst>(CB)) { 3708d9659bf6SJohannes Doerfert indicateOptimisticFixpoint(); 3709d9659bf6SJohannes Doerfert return; 3710d9659bf6SJohannes Doerfert } 3711d9659bf6SJohannes Doerfert 3712d9659bf6SJohannes Doerfert // Next we check if we know the callee. If it is a known OpenMP function 3713d9659bf6SJohannes Doerfert // we will handle them explicitly in the switch below. If it is not, we 3714d9659bf6SJohannes Doerfert // will use an AAKernelInfo object on the callee to gather information and 3715d9659bf6SJohannes Doerfert // merge that into the current state. The latter happens in the updateImpl. 3716d9659bf6SJohannes Doerfert auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); 3717d9659bf6SJohannes Doerfert const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee); 3718d9659bf6SJohannes Doerfert if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) { 3719d9659bf6SJohannes Doerfert // Unknown caller or declarations are not analyzable, we give up. 3720d9659bf6SJohannes Doerfert if (!Callee || !A.isFunctionIPOAmendable(*Callee)) { 3721d9659bf6SJohannes Doerfert 3722d9659bf6SJohannes Doerfert // Unknown callees might contain parallel regions, except if they have 3723d9659bf6SJohannes Doerfert // an appropriate assumption attached. 3724d9659bf6SJohannes Doerfert if (!(HasAssumption(Callee, "omp_no_openmp") || 3725d9659bf6SJohannes Doerfert HasAssumption(Callee, "omp_no_parallelism"))) 3726d9659bf6SJohannes Doerfert ReachedUnknownParallelRegions.insert(&CB); 3727d9659bf6SJohannes Doerfert 3728514c033dSJohannes Doerfert // If SPMDCompatibilityTracker is not fixed, we need to give up on the 3729514c033dSJohannes Doerfert // idea we can run something unknown in SPMD-mode. 373029a3e3ddSGiorgis Georgakoudis if (!SPMDCompatibilityTracker.isAtFixpoint()) { 373129a3e3ddSGiorgis Georgakoudis SPMDCompatibilityTracker.indicatePessimisticFixpoint(); 3732514c033dSJohannes Doerfert SPMDCompatibilityTracker.insert(&CB); 373329a3e3ddSGiorgis Georgakoudis } 3734514c033dSJohannes Doerfert 3735d9659bf6SJohannes Doerfert // We have updated the state for this unknown call properly, there won't 3736d9659bf6SJohannes Doerfert // be any change so we indicate a fixpoint. 3737d9659bf6SJohannes Doerfert indicateOptimisticFixpoint(); 3738d9659bf6SJohannes Doerfert } 3739d9659bf6SJohannes Doerfert // If the callee is known and can be used in IPO, we will update the state 3740d9659bf6SJohannes Doerfert // based on the callee state in updateImpl. 3741d9659bf6SJohannes Doerfert return; 3742d9659bf6SJohannes Doerfert } 3743d9659bf6SJohannes Doerfert 3744d9659bf6SJohannes Doerfert const unsigned int WrapperFunctionArgNo = 6; 3745d9659bf6SJohannes Doerfert RuntimeFunction RF = It->getSecond(); 3746d9659bf6SJohannes Doerfert switch (RF) { 3747514c033dSJohannes Doerfert // All the functions we know are compatible with SPMD mode. 3748514c033dSJohannes Doerfert case OMPRTL___kmpc_is_spmd_exec_mode: 3749514c033dSJohannes Doerfert case OMPRTL___kmpc_for_static_fini: 3750514c033dSJohannes Doerfert case OMPRTL___kmpc_global_thread_num: 37515ab6aeddSJose M Monsalve Diaz case OMPRTL___kmpc_get_hardware_num_threads_in_block: 37525ab6aeddSJose M Monsalve Diaz case OMPRTL___kmpc_get_hardware_num_blocks: 3753514c033dSJohannes Doerfert case OMPRTL___kmpc_single: 3754514c033dSJohannes Doerfert case OMPRTL___kmpc_end_single: 3755514c033dSJohannes Doerfert case OMPRTL___kmpc_master: 3756514c033dSJohannes Doerfert case OMPRTL___kmpc_end_master: 3757514c033dSJohannes Doerfert case OMPRTL___kmpc_barrier: 3758514c033dSJohannes Doerfert break; 3759514c033dSJohannes Doerfert case OMPRTL___kmpc_for_static_init_4: 3760514c033dSJohannes Doerfert case OMPRTL___kmpc_for_static_init_4u: 3761514c033dSJohannes Doerfert case OMPRTL___kmpc_for_static_init_8: 3762514c033dSJohannes Doerfert case OMPRTL___kmpc_for_static_init_8u: { 3763514c033dSJohannes Doerfert // Check the schedule and allow static schedule in SPMD mode. 3764514c033dSJohannes Doerfert unsigned ScheduleArgOpNo = 2; 3765514c033dSJohannes Doerfert auto *ScheduleTypeCI = 3766514c033dSJohannes Doerfert dyn_cast<ConstantInt>(CB.getArgOperand(ScheduleArgOpNo)); 3767514c033dSJohannes Doerfert unsigned ScheduleTypeVal = 3768514c033dSJohannes Doerfert ScheduleTypeCI ? ScheduleTypeCI->getZExtValue() : 0; 3769514c033dSJohannes Doerfert switch (OMPScheduleType(ScheduleTypeVal)) { 3770514c033dSJohannes Doerfert case OMPScheduleType::Static: 3771514c033dSJohannes Doerfert case OMPScheduleType::StaticChunked: 3772514c033dSJohannes Doerfert case OMPScheduleType::Distribute: 3773514c033dSJohannes Doerfert case OMPScheduleType::DistributeChunked: 3774514c033dSJohannes Doerfert break; 3775514c033dSJohannes Doerfert default: 377629a3e3ddSGiorgis Georgakoudis SPMDCompatibilityTracker.indicatePessimisticFixpoint(); 3777514c033dSJohannes Doerfert SPMDCompatibilityTracker.insert(&CB); 3778514c033dSJohannes Doerfert break; 3779514c033dSJohannes Doerfert }; 3780514c033dSJohannes Doerfert } break; 3781d9659bf6SJohannes Doerfert case OMPRTL___kmpc_target_init: 3782d9659bf6SJohannes Doerfert KernelInitCB = &CB; 3783d9659bf6SJohannes Doerfert break; 3784d9659bf6SJohannes Doerfert case OMPRTL___kmpc_target_deinit: 3785d9659bf6SJohannes Doerfert KernelDeinitCB = &CB; 3786d9659bf6SJohannes Doerfert break; 3787d9659bf6SJohannes Doerfert case OMPRTL___kmpc_parallel_51: 3788d9659bf6SJohannes Doerfert if (auto *ParallelRegion = dyn_cast<Function>( 3789d9659bf6SJohannes Doerfert CB.getArgOperand(WrapperFunctionArgNo)->stripPointerCasts())) { 3790d9659bf6SJohannes Doerfert ReachedKnownParallelRegions.insert(ParallelRegion); 3791d9659bf6SJohannes Doerfert break; 3792d9659bf6SJohannes Doerfert } 3793d9659bf6SJohannes Doerfert // The condition above should usually get the parallel region function 3794d9659bf6SJohannes Doerfert // pointer and record it. In the off chance it doesn't we assume the 3795d9659bf6SJohannes Doerfert // worst. 3796d9659bf6SJohannes Doerfert ReachedUnknownParallelRegions.insert(&CB); 3797d9659bf6SJohannes Doerfert break; 3798d9659bf6SJohannes Doerfert case OMPRTL___kmpc_omp_task: 3799d9659bf6SJohannes Doerfert // We do not look into tasks right now, just give up. 3800514c033dSJohannes Doerfert SPMDCompatibilityTracker.insert(&CB); 3801d9659bf6SJohannes Doerfert ReachedUnknownParallelRegions.insert(&CB); 380229a3e3ddSGiorgis Georgakoudis indicatePessimisticFixpoint(); 380329a3e3ddSGiorgis Georgakoudis return; 3804f8c40ed8SGiorgis Georgakoudis case OMPRTL___kmpc_alloc_shared: 3805f8c40ed8SGiorgis Georgakoudis case OMPRTL___kmpc_free_shared: 3806f8c40ed8SGiorgis Georgakoudis // Return without setting a fixpoint, to be resolved in updateImpl. 3807f8c40ed8SGiorgis Georgakoudis return; 3808d9659bf6SJohannes Doerfert default: 3809514c033dSJohannes Doerfert // Unknown OpenMP runtime calls cannot be executed in SPMD-mode, 3810514c033dSJohannes Doerfert // generally. 3811514c033dSJohannes Doerfert SPMDCompatibilityTracker.insert(&CB); 381229a3e3ddSGiorgis Georgakoudis indicatePessimisticFixpoint(); 381329a3e3ddSGiorgis Georgakoudis return; 3814d9659bf6SJohannes Doerfert } 3815d9659bf6SJohannes Doerfert // All other OpenMP runtime calls will not reach parallel regions so they 3816d9659bf6SJohannes Doerfert // can be safely ignored for now. Since it is a known OpenMP runtime call we 3817d9659bf6SJohannes Doerfert // have now modeled all effects and there is no need for any update. 3818d9659bf6SJohannes Doerfert indicateOptimisticFixpoint(); 3819d9659bf6SJohannes Doerfert } 3820d9659bf6SJohannes Doerfert 3821d9659bf6SJohannes Doerfert ChangeStatus updateImpl(Attributor &A) override { 3822d9659bf6SJohannes Doerfert // TODO: Once we have call site specific value information we can provide 3823d9659bf6SJohannes Doerfert // call site specific liveness information and then it makes 3824d9659bf6SJohannes Doerfert // sense to specialize attributes for call sites arguments instead of 3825d9659bf6SJohannes Doerfert // redirecting requests to the callee argument. 3826d9659bf6SJohannes Doerfert Function *F = getAssociatedFunction(); 3827f8c40ed8SGiorgis Georgakoudis 3828f8c40ed8SGiorgis Georgakoudis auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); 3829f8c40ed8SGiorgis Georgakoudis const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(F); 3830f8c40ed8SGiorgis Georgakoudis 3831f8c40ed8SGiorgis Georgakoudis // If F is not a runtime function, propagate the AAKernelInfo of the callee. 3832f8c40ed8SGiorgis Georgakoudis if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) { 3833d9659bf6SJohannes Doerfert const IRPosition &FnPos = IRPosition::function(*F); 3834d9659bf6SJohannes Doerfert auto &FnAA = A.getAAFor<AAKernelInfo>(*this, FnPos, DepClassTy::REQUIRED); 3835d9659bf6SJohannes Doerfert if (getState() == FnAA.getState()) 3836d9659bf6SJohannes Doerfert return ChangeStatus::UNCHANGED; 3837d9659bf6SJohannes Doerfert getState() = FnAA.getState(); 3838d9659bf6SJohannes Doerfert return ChangeStatus::CHANGED; 3839d9659bf6SJohannes Doerfert } 3840f8c40ed8SGiorgis Georgakoudis 3841f8c40ed8SGiorgis Georgakoudis // F is a runtime function that allocates or frees memory, check 3842f8c40ed8SGiorgis Georgakoudis // AAHeapToStack and AAHeapToShared. 3843f8c40ed8SGiorgis Georgakoudis KernelInfoState StateBefore = getState(); 3844f8c40ed8SGiorgis Georgakoudis assert((It->getSecond() == OMPRTL___kmpc_alloc_shared || 3845f8c40ed8SGiorgis Georgakoudis It->getSecond() == OMPRTL___kmpc_free_shared) && 3846f8c40ed8SGiorgis Georgakoudis "Expected a __kmpc_alloc_shared or __kmpc_free_shared runtime call"); 3847f8c40ed8SGiorgis Georgakoudis 3848f8c40ed8SGiorgis Georgakoudis CallBase &CB = cast<CallBase>(getAssociatedValue()); 3849f8c40ed8SGiorgis Georgakoudis 3850f8c40ed8SGiorgis Georgakoudis auto &HeapToStackAA = A.getAAFor<AAHeapToStack>( 3851f8c40ed8SGiorgis Georgakoudis *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL); 3852f8c40ed8SGiorgis Georgakoudis auto &HeapToSharedAA = A.getAAFor<AAHeapToShared>( 3853f8c40ed8SGiorgis Georgakoudis *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL); 3854f8c40ed8SGiorgis Georgakoudis 3855f8c40ed8SGiorgis Georgakoudis RuntimeFunction RF = It->getSecond(); 3856f8c40ed8SGiorgis Georgakoudis 3857f8c40ed8SGiorgis Georgakoudis switch (RF) { 3858f8c40ed8SGiorgis Georgakoudis // If neither HeapToStack nor HeapToShared assume the call is removed, 3859f8c40ed8SGiorgis Georgakoudis // assume SPMD incompatibility. 3860f8c40ed8SGiorgis Georgakoudis case OMPRTL___kmpc_alloc_shared: 3861f8c40ed8SGiorgis Georgakoudis if (!HeapToStackAA.isAssumedHeapToStack(CB) && 3862f8c40ed8SGiorgis Georgakoudis !HeapToSharedAA.isAssumedHeapToShared(CB)) 3863f8c40ed8SGiorgis Georgakoudis SPMDCompatibilityTracker.insert(&CB); 3864f8c40ed8SGiorgis Georgakoudis break; 3865f8c40ed8SGiorgis Georgakoudis case OMPRTL___kmpc_free_shared: 3866f8c40ed8SGiorgis Georgakoudis if (!HeapToStackAA.isAssumedHeapToStackRemovedFree(CB) && 3867f8c40ed8SGiorgis Georgakoudis !HeapToSharedAA.isAssumedHeapToSharedRemovedFree(CB)) 3868f8c40ed8SGiorgis Georgakoudis SPMDCompatibilityTracker.insert(&CB); 3869f8c40ed8SGiorgis Georgakoudis break; 3870f8c40ed8SGiorgis Georgakoudis default: 3871f8c40ed8SGiorgis Georgakoudis SPMDCompatibilityTracker.insert(&CB); 3872f8c40ed8SGiorgis Georgakoudis } 3873f8c40ed8SGiorgis Georgakoudis 3874f8c40ed8SGiorgis Georgakoudis return StateBefore == getState() ? ChangeStatus::UNCHANGED 3875f8c40ed8SGiorgis Georgakoudis : ChangeStatus::CHANGED; 3876f8c40ed8SGiorgis Georgakoudis } 3877d9659bf6SJohannes Doerfert }; 3878d9659bf6SJohannes Doerfert 3879ca662297SShilei Tian struct AAFoldRuntimeCall 3880ca662297SShilei Tian : public StateWrapper<BooleanState, AbstractAttribute> { 3881ca662297SShilei Tian using Base = StateWrapper<BooleanState, AbstractAttribute>; 3882ca662297SShilei Tian 3883ca662297SShilei Tian AAFoldRuntimeCall(const IRPosition &IRP, Attributor &A) : Base(IRP) {} 3884ca662297SShilei Tian 3885ca662297SShilei Tian /// Statistics are tracked as part of manifest for now. 3886ca662297SShilei Tian void trackStatistics() const override {} 3887ca662297SShilei Tian 3888ca662297SShilei Tian /// Create an abstract attribute biew for the position \p IRP. 3889ca662297SShilei Tian static AAFoldRuntimeCall &createForPosition(const IRPosition &IRP, 3890ca662297SShilei Tian Attributor &A); 3891ca662297SShilei Tian 3892ca662297SShilei Tian /// See AbstractAttribute::getName() 3893ca662297SShilei Tian const std::string getName() const override { return "AAFoldRuntimeCall"; } 3894ca662297SShilei Tian 3895ca662297SShilei Tian /// See AbstractAttribute::getIdAddr() 3896ca662297SShilei Tian const char *getIdAddr() const override { return &ID; } 3897ca662297SShilei Tian 3898ca662297SShilei Tian /// This function should return true if the type of the \p AA is 3899ca662297SShilei Tian /// AAFoldRuntimeCall 3900ca662297SShilei Tian static bool classof(const AbstractAttribute *AA) { 3901ca662297SShilei Tian return (AA->getIdAddr() == &ID); 3902ca662297SShilei Tian } 3903ca662297SShilei Tian 3904ca662297SShilei Tian static const char ID; 3905ca662297SShilei Tian }; 3906ca662297SShilei Tian 3907ca662297SShilei Tian struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall { 3908ca662297SShilei Tian AAFoldRuntimeCallCallSiteReturned(const IRPosition &IRP, Attributor &A) 3909ca662297SShilei Tian : AAFoldRuntimeCall(IRP, A) {} 3910ca662297SShilei Tian 3911ca662297SShilei Tian /// See AbstractAttribute::getAsStr() 3912ca662297SShilei Tian const std::string getAsStr() const override { 3913ca662297SShilei Tian if (!isValidState()) 3914ca662297SShilei Tian return "<invalid>"; 3915ca662297SShilei Tian 3916ca662297SShilei Tian std::string Str("simplified value: "); 3917ca662297SShilei Tian 3918ca662297SShilei Tian if (!SimplifiedValue.hasValue()) 3919ca662297SShilei Tian return Str + std::string("none"); 3920ca662297SShilei Tian 3921ca662297SShilei Tian if (!SimplifiedValue.getValue()) 3922ca662297SShilei Tian return Str + std::string("nullptr"); 3923ca662297SShilei Tian 3924ca662297SShilei Tian if (ConstantInt *CI = dyn_cast<ConstantInt>(SimplifiedValue.getValue())) 3925ca662297SShilei Tian return Str + std::to_string(CI->getSExtValue()); 3926ca662297SShilei Tian 3927ca662297SShilei Tian return Str + std::string("unknown"); 3928ca662297SShilei Tian } 3929ca662297SShilei Tian 3930ca662297SShilei Tian void initialize(Attributor &A) override { 3931cd0dd8ecSJoseph Huber if (DisableOpenMPOptFolding) 3932cd0dd8ecSJoseph Huber indicatePessimisticFixpoint(); 3933cd0dd8ecSJoseph Huber 3934ca662297SShilei Tian Function *Callee = getAssociatedFunction(); 3935ca662297SShilei Tian 3936ca662297SShilei Tian auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); 3937ca662297SShilei Tian const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee); 3938ca662297SShilei Tian assert(It != OMPInfoCache.RuntimeFunctionIDMap.end() && 3939ca662297SShilei Tian "Expected a known OpenMP runtime function"); 3940ca662297SShilei Tian 3941ca662297SShilei Tian RFKind = It->getSecond(); 3942ca662297SShilei Tian 3943ca662297SShilei Tian CallBase &CB = cast<CallBase>(getAssociatedValue()); 3944ca662297SShilei Tian A.registerSimplificationCallback( 3945ca662297SShilei Tian IRPosition::callsite_returned(CB), 3946ca662297SShilei Tian [&](const IRPosition &IRP, const AbstractAttribute *AA, 3947ca662297SShilei Tian bool &UsedAssumedInformation) -> Optional<Value *> { 3948ca662297SShilei Tian assert((isValidState() || (SimplifiedValue.hasValue() && 3949ca662297SShilei Tian SimplifiedValue.getValue() == nullptr)) && 3950ca662297SShilei Tian "Unexpected invalid state!"); 3951ca662297SShilei Tian 3952ca662297SShilei Tian if (!isAtFixpoint()) { 3953ca662297SShilei Tian UsedAssumedInformation = true; 3954ca662297SShilei Tian if (AA) 3955ca662297SShilei Tian A.recordDependence(*this, *AA, DepClassTy::OPTIONAL); 3956ca662297SShilei Tian } 3957ca662297SShilei Tian return SimplifiedValue; 3958ca662297SShilei Tian }); 3959ca662297SShilei Tian } 3960ca662297SShilei Tian 3961ca662297SShilei Tian ChangeStatus updateImpl(Attributor &A) override { 3962ca662297SShilei Tian ChangeStatus Changed = ChangeStatus::UNCHANGED; 3963ca662297SShilei Tian switch (RFKind) { 3964ca662297SShilei Tian case OMPRTL___kmpc_is_spmd_exec_mode: 3965c23da666SShilei Tian Changed |= foldIsSPMDExecMode(A); 3966ca662297SShilei Tian break; 3967196fe994SJoseph Huber case OMPRTL___kmpc_is_generic_main_thread_id: 3968196fe994SJoseph Huber Changed |= foldIsGenericMainThread(A); 3969196fe994SJoseph Huber break; 3970e97e0a4fSShilei Tian case OMPRTL___kmpc_parallel_level: 3971e97e0a4fSShilei Tian Changed |= foldParallelLevel(A); 3972e97e0a4fSShilei Tian break; 39735ab6aeddSJose M Monsalve Diaz case OMPRTL___kmpc_get_hardware_num_threads_in_block: 39745ab6aeddSJose M Monsalve Diaz Changed = Changed | foldKernelFnAttribute(A, "omp_target_thread_limit"); 39755ab6aeddSJose M Monsalve Diaz break; 39765ab6aeddSJose M Monsalve Diaz case OMPRTL___kmpc_get_hardware_num_blocks: 39775ab6aeddSJose M Monsalve Diaz Changed = Changed | foldKernelFnAttribute(A, "omp_target_num_teams"); 39785ab6aeddSJose M Monsalve Diaz break; 3979ca662297SShilei Tian default: 3980ca662297SShilei Tian llvm_unreachable("Unhandled OpenMP runtime function!"); 3981ca662297SShilei Tian } 3982ca662297SShilei Tian 3983ca662297SShilei Tian return Changed; 3984ca662297SShilei Tian } 3985ca662297SShilei Tian 3986ca662297SShilei Tian ChangeStatus manifest(Attributor &A) override { 3987ca662297SShilei Tian ChangeStatus Changed = ChangeStatus::UNCHANGED; 3988ca662297SShilei Tian 3989ca662297SShilei Tian if (SimplifiedValue.hasValue() && SimplifiedValue.getValue()) { 3990ca662297SShilei Tian Instruction &CB = *getCtxI(); 3991ca662297SShilei Tian A.changeValueAfterManifest(CB, **SimplifiedValue); 3992ca662297SShilei Tian A.deleteAfterManifest(CB); 3993196fe994SJoseph Huber 3994196fe994SJoseph Huber LLVM_DEBUG(dbgs() << TAG << "Folding runtime call: " << CB << " with " 3995196fe994SJoseph Huber << **SimplifiedValue << "\n"); 3996196fe994SJoseph Huber 3997ca662297SShilei Tian Changed = ChangeStatus::CHANGED; 3998ca662297SShilei Tian } 3999ca662297SShilei Tian 4000ca662297SShilei Tian return Changed; 4001ca662297SShilei Tian } 4002ca662297SShilei Tian 4003ca662297SShilei Tian ChangeStatus indicatePessimisticFixpoint() override { 4004ca662297SShilei Tian SimplifiedValue = nullptr; 4005ca662297SShilei Tian return AAFoldRuntimeCall::indicatePessimisticFixpoint(); 4006ca662297SShilei Tian } 4007ca662297SShilei Tian 4008ca662297SShilei Tian private: 4009ca662297SShilei Tian /// Fold __kmpc_is_spmd_exec_mode into a constant if possible. 4010ca662297SShilei Tian ChangeStatus foldIsSPMDExecMode(Attributor &A) { 4011ca662297SShilei Tian Optional<Value *> SimplifiedValueBefore = SimplifiedValue; 4012ca662297SShilei Tian 4013ca662297SShilei Tian unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0; 4014ca662297SShilei Tian unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0; 4015ca662297SShilei Tian auto &CallerKernelInfoAA = A.getAAFor<AAKernelInfo>( 4016ca662297SShilei Tian *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED); 4017ca662297SShilei Tian 4018ca662297SShilei Tian if (!CallerKernelInfoAA.ReachingKernelEntries.isValidState()) 4019ca662297SShilei Tian return indicatePessimisticFixpoint(); 4020ca662297SShilei Tian 4021ca662297SShilei Tian for (Kernel K : CallerKernelInfoAA.ReachingKernelEntries) { 4022ca662297SShilei Tian auto &AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K), 4023ca662297SShilei Tian DepClassTy::REQUIRED); 4024ca662297SShilei Tian 4025ca662297SShilei Tian if (!AA.isValidState()) { 4026ca662297SShilei Tian SimplifiedValue = nullptr; 4027ca662297SShilei Tian return indicatePessimisticFixpoint(); 4028ca662297SShilei Tian } 4029ca662297SShilei Tian 4030ca662297SShilei Tian if (AA.SPMDCompatibilityTracker.isAssumed()) { 4031ca662297SShilei Tian if (AA.SPMDCompatibilityTracker.isAtFixpoint()) 4032ca662297SShilei Tian ++KnownSPMDCount; 4033ca662297SShilei Tian else 4034ca662297SShilei Tian ++AssumedSPMDCount; 4035ca662297SShilei Tian } else { 4036ca662297SShilei Tian if (AA.SPMDCompatibilityTracker.isAtFixpoint()) 4037ca662297SShilei Tian ++KnownNonSPMDCount; 4038ca662297SShilei Tian else 4039ca662297SShilei Tian ++AssumedNonSPMDCount; 4040ca662297SShilei Tian } 4041ca662297SShilei Tian } 4042ca662297SShilei Tian 4043ae69f468SShilei Tian if ((AssumedSPMDCount + KnownSPMDCount) && 4044ae69f468SShilei Tian (AssumedNonSPMDCount + KnownNonSPMDCount)) 4045ca662297SShilei Tian return indicatePessimisticFixpoint(); 4046ca662297SShilei Tian 4047ca662297SShilei Tian auto &Ctx = getAnchorValue().getContext(); 4048ca662297SShilei Tian if (KnownSPMDCount || AssumedSPMDCount) { 4049ca662297SShilei Tian assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 && 4050ca662297SShilei Tian "Expected only SPMD kernels!"); 4051ca662297SShilei Tian // All reaching kernels are in SPMD mode. Update all function calls to 4052ca662297SShilei Tian // __kmpc_is_spmd_exec_mode to 1. 4053ca662297SShilei Tian SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), true); 4054d3454ee8SShilei Tian } else if (KnownNonSPMDCount || AssumedNonSPMDCount) { 4055ca662297SShilei Tian assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 && 4056ca662297SShilei Tian "Expected only non-SPMD kernels!"); 4057ca662297SShilei Tian // All reaching kernels are in non-SPMD mode. Update all function 4058ca662297SShilei Tian // calls to __kmpc_is_spmd_exec_mode to 0. 4059ca662297SShilei Tian SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), false); 4060d3454ee8SShilei Tian } else { 4061d3454ee8SShilei Tian // We have empty reaching kernels, therefore we cannot tell if the 4062d3454ee8SShilei Tian // associated call site can be folded. At this moment, SimplifiedValue 4063d3454ee8SShilei Tian // must be none. 4064d3454ee8SShilei Tian assert(!SimplifiedValue.hasValue() && "SimplifiedValue should be none"); 4065ca662297SShilei Tian } 4066ca662297SShilei Tian 4067ca662297SShilei Tian return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED 4068ca662297SShilei Tian : ChangeStatus::CHANGED; 4069ca662297SShilei Tian } 4070ca662297SShilei Tian 4071196fe994SJoseph Huber /// Fold __kmpc_is_generic_main_thread_id into a constant if possible. 4072196fe994SJoseph Huber ChangeStatus foldIsGenericMainThread(Attributor &A) { 4073196fe994SJoseph Huber Optional<Value *> SimplifiedValueBefore = SimplifiedValue; 4074196fe994SJoseph Huber 4075196fe994SJoseph Huber CallBase &CB = cast<CallBase>(getAssociatedValue()); 4076196fe994SJoseph Huber Function *F = CB.getFunction(); 4077196fe994SJoseph Huber const auto &ExecutionDomainAA = A.getAAFor<AAExecutionDomain>( 4078196fe994SJoseph Huber *this, IRPosition::function(*F), DepClassTy::REQUIRED); 4079196fe994SJoseph Huber 4080196fe994SJoseph Huber if (!ExecutionDomainAA.isValidState()) 4081196fe994SJoseph Huber return indicatePessimisticFixpoint(); 4082196fe994SJoseph Huber 4083196fe994SJoseph Huber auto &Ctx = getAnchorValue().getContext(); 4084196fe994SJoseph Huber if (ExecutionDomainAA.isExecutedByInitialThreadOnly(CB)) 4085196fe994SJoseph Huber SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), true); 4086196fe994SJoseph Huber else 4087196fe994SJoseph Huber return indicatePessimisticFixpoint(); 4088196fe994SJoseph Huber 4089196fe994SJoseph Huber return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED 4090196fe994SJoseph Huber : ChangeStatus::CHANGED; 4091196fe994SJoseph Huber } 4092196fe994SJoseph Huber 4093e97e0a4fSShilei Tian /// Fold __kmpc_parallel_level into a constant if possible. 4094e97e0a4fSShilei Tian ChangeStatus foldParallelLevel(Attributor &A) { 4095e97e0a4fSShilei Tian Optional<Value *> SimplifiedValueBefore = SimplifiedValue; 4096e97e0a4fSShilei Tian 4097e97e0a4fSShilei Tian auto &CallerKernelInfoAA = A.getAAFor<AAKernelInfo>( 4098e97e0a4fSShilei Tian *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED); 4099e97e0a4fSShilei Tian 4100e97e0a4fSShilei Tian if (!CallerKernelInfoAA.ParallelLevels.isValidState()) 4101e97e0a4fSShilei Tian return indicatePessimisticFixpoint(); 4102e97e0a4fSShilei Tian 4103e97e0a4fSShilei Tian if (!CallerKernelInfoAA.ReachingKernelEntries.isValidState()) 4104e97e0a4fSShilei Tian return indicatePessimisticFixpoint(); 4105e97e0a4fSShilei Tian 4106e97e0a4fSShilei Tian if (CallerKernelInfoAA.ReachingKernelEntries.empty()) { 4107e97e0a4fSShilei Tian assert(!SimplifiedValue.hasValue() && 4108e97e0a4fSShilei Tian "SimplifiedValue should keep none at this point"); 4109e97e0a4fSShilei Tian return ChangeStatus::UNCHANGED; 4110e97e0a4fSShilei Tian } 4111e97e0a4fSShilei Tian 4112e97e0a4fSShilei Tian unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0; 4113e97e0a4fSShilei Tian unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0; 4114e97e0a4fSShilei Tian for (Kernel K : CallerKernelInfoAA.ReachingKernelEntries) { 4115e97e0a4fSShilei Tian auto &AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K), 4116e97e0a4fSShilei Tian DepClassTy::REQUIRED); 4117e97e0a4fSShilei Tian if (!AA.SPMDCompatibilityTracker.isValidState()) 4118e97e0a4fSShilei Tian return indicatePessimisticFixpoint(); 4119e97e0a4fSShilei Tian 4120e97e0a4fSShilei Tian if (AA.SPMDCompatibilityTracker.isAssumed()) { 4121e97e0a4fSShilei Tian if (AA.SPMDCompatibilityTracker.isAtFixpoint()) 4122e97e0a4fSShilei Tian ++KnownSPMDCount; 4123e97e0a4fSShilei Tian else 4124e97e0a4fSShilei Tian ++AssumedSPMDCount; 4125e97e0a4fSShilei Tian } else { 4126e97e0a4fSShilei Tian if (AA.SPMDCompatibilityTracker.isAtFixpoint()) 4127e97e0a4fSShilei Tian ++KnownNonSPMDCount; 4128e97e0a4fSShilei Tian else 4129e97e0a4fSShilei Tian ++AssumedNonSPMDCount; 4130e97e0a4fSShilei Tian } 4131e97e0a4fSShilei Tian } 4132e97e0a4fSShilei Tian 4133e97e0a4fSShilei Tian if ((AssumedSPMDCount + KnownSPMDCount) && 4134e97e0a4fSShilei Tian (AssumedNonSPMDCount + KnownNonSPMDCount)) 4135e97e0a4fSShilei Tian return indicatePessimisticFixpoint(); 4136e97e0a4fSShilei Tian 4137e97e0a4fSShilei Tian auto &Ctx = getAnchorValue().getContext(); 4138e97e0a4fSShilei Tian // If the caller can only be reached by SPMD kernel entries, the parallel 4139e97e0a4fSShilei Tian // level is 1. Similarly, if the caller can only be reached by non-SPMD 4140e97e0a4fSShilei Tian // kernel entries, it is 0. 4141e97e0a4fSShilei Tian if (AssumedSPMDCount || KnownSPMDCount) { 4142e97e0a4fSShilei Tian assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 && 4143e97e0a4fSShilei Tian "Expected only SPMD kernels!"); 4144e97e0a4fSShilei Tian SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 1); 4145e97e0a4fSShilei Tian } else { 4146e97e0a4fSShilei Tian assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 && 4147e97e0a4fSShilei Tian "Expected only non-SPMD kernels!"); 4148e97e0a4fSShilei Tian SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 0); 4149e97e0a4fSShilei Tian } 41505ab6aeddSJose M Monsalve Diaz return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED 41515ab6aeddSJose M Monsalve Diaz : ChangeStatus::CHANGED; 41525ab6aeddSJose M Monsalve Diaz } 4153e97e0a4fSShilei Tian 41545ab6aeddSJose M Monsalve Diaz ChangeStatus foldKernelFnAttribute(Attributor &A, llvm::StringRef Attr) { 41555ab6aeddSJose M Monsalve Diaz // Specialize only if all the calls agree with the attribute constant value 41565ab6aeddSJose M Monsalve Diaz int32_t CurrentAttrValue = -1; 41575ab6aeddSJose M Monsalve Diaz Optional<Value *> SimplifiedValueBefore = SimplifiedValue; 41585ab6aeddSJose M Monsalve Diaz 41595ab6aeddSJose M Monsalve Diaz auto &CallerKernelInfoAA = A.getAAFor<AAKernelInfo>( 41605ab6aeddSJose M Monsalve Diaz *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED); 41615ab6aeddSJose M Monsalve Diaz 41625ab6aeddSJose M Monsalve Diaz if (!CallerKernelInfoAA.ReachingKernelEntries.isValidState()) 41635ab6aeddSJose M Monsalve Diaz return indicatePessimisticFixpoint(); 41645ab6aeddSJose M Monsalve Diaz 41655ab6aeddSJose M Monsalve Diaz // Iterate over the kernels that reach this function 41665ab6aeddSJose M Monsalve Diaz for (Kernel K : CallerKernelInfoAA.ReachingKernelEntries) { 41675ab6aeddSJose M Monsalve Diaz int32_t NextAttrVal = -1; 41685ab6aeddSJose M Monsalve Diaz if (K->hasFnAttribute(Attr)) 41695ab6aeddSJose M Monsalve Diaz NextAttrVal = 41705ab6aeddSJose M Monsalve Diaz std::stoi(K->getFnAttribute(Attr).getValueAsString().str()); 41715ab6aeddSJose M Monsalve Diaz 41725ab6aeddSJose M Monsalve Diaz if (NextAttrVal == -1 || 41735ab6aeddSJose M Monsalve Diaz (CurrentAttrValue != -1 && CurrentAttrValue != NextAttrVal)) 41745ab6aeddSJose M Monsalve Diaz return indicatePessimisticFixpoint(); 41755ab6aeddSJose M Monsalve Diaz CurrentAttrValue = NextAttrVal; 41765ab6aeddSJose M Monsalve Diaz } 41775ab6aeddSJose M Monsalve Diaz 41785ab6aeddSJose M Monsalve Diaz if (CurrentAttrValue != -1) { 41795ab6aeddSJose M Monsalve Diaz auto &Ctx = getAnchorValue().getContext(); 41805ab6aeddSJose M Monsalve Diaz SimplifiedValue = 41815ab6aeddSJose M Monsalve Diaz ConstantInt::get(Type::getInt32Ty(Ctx), CurrentAttrValue); 41825ab6aeddSJose M Monsalve Diaz } 4183e97e0a4fSShilei Tian return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED 4184e97e0a4fSShilei Tian : ChangeStatus::CHANGED; 4185e97e0a4fSShilei Tian } 4186e97e0a4fSShilei Tian 4187ca662297SShilei Tian /// An optional value the associated value is assumed to fold to. That is, we 4188ca662297SShilei Tian /// assume the associated value (which is a call) can be replaced by this 4189ca662297SShilei Tian /// simplified value. 4190ca662297SShilei Tian Optional<Value *> SimplifiedValue; 4191ca662297SShilei Tian 4192ca662297SShilei Tian /// The runtime function kind of the callee of the associated call site. 4193ca662297SShilei Tian RuntimeFunction RFKind; 4194ca662297SShilei Tian }; 4195ca662297SShilei Tian 41969548b74aSJohannes Doerfert } // namespace 41979548b74aSJohannes Doerfert 41985ab6aeddSJose M Monsalve Diaz /// Register folding callsite 41995ab6aeddSJose M Monsalve Diaz void OpenMPOpt::registerFoldRuntimeCall(RuntimeFunction RF) { 42005ab6aeddSJose M Monsalve Diaz auto &RFI = OMPInfoCache.RFIs[RF]; 42015ab6aeddSJose M Monsalve Diaz RFI.foreachUse(SCC, [&](Use &U, Function &F) { 42025ab6aeddSJose M Monsalve Diaz CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &RFI); 42035ab6aeddSJose M Monsalve Diaz if (!CI) 42045ab6aeddSJose M Monsalve Diaz return false; 42055ab6aeddSJose M Monsalve Diaz A.getOrCreateAAFor<AAFoldRuntimeCall>( 42065ab6aeddSJose M Monsalve Diaz IRPosition::callsite_returned(*CI), /* QueryingAA */ nullptr, 42075ab6aeddSJose M Monsalve Diaz DepClassTy::NONE, /* ForceUpdate */ false, 42085ab6aeddSJose M Monsalve Diaz /* UpdateAfterInit */ false); 42095ab6aeddSJose M Monsalve Diaz return false; 42105ab6aeddSJose M Monsalve Diaz }); 42115ab6aeddSJose M Monsalve Diaz } 42125ab6aeddSJose M Monsalve Diaz 4213d9659bf6SJohannes Doerfert void OpenMPOpt::registerAAs(bool IsModulePass) { 4214d9659bf6SJohannes Doerfert if (SCC.empty()) 4215d9659bf6SJohannes Doerfert 4216d9659bf6SJohannes Doerfert return; 4217d9659bf6SJohannes Doerfert if (IsModulePass) { 4218d9659bf6SJohannes Doerfert // Ensure we create the AAKernelInfo AAs first and without triggering an 4219d9659bf6SJohannes Doerfert // update. This will make sure we register all value simplification 4220d9659bf6SJohannes Doerfert // callbacks before any other AA has the chance to create an AAValueSimplify 4221d9659bf6SJohannes Doerfert // or similar. 4222d9659bf6SJohannes Doerfert for (Function *Kernel : OMPInfoCache.Kernels) 4223d9659bf6SJohannes Doerfert A.getOrCreateAAFor<AAKernelInfo>( 4224d9659bf6SJohannes Doerfert IRPosition::function(*Kernel), /* QueryingAA */ nullptr, 4225d9659bf6SJohannes Doerfert DepClassTy::NONE, /* ForceUpdate */ false, 4226d9659bf6SJohannes Doerfert /* UpdateAfterInit */ false); 4227ca662297SShilei Tian 4228196fe994SJoseph Huber 42295ab6aeddSJose M Monsalve Diaz registerFoldRuntimeCall(OMPRTL___kmpc_is_generic_main_thread_id); 42305ab6aeddSJose M Monsalve Diaz registerFoldRuntimeCall(OMPRTL___kmpc_is_spmd_exec_mode); 42315ab6aeddSJose M Monsalve Diaz registerFoldRuntimeCall(OMPRTL___kmpc_parallel_level); 42325ab6aeddSJose M Monsalve Diaz registerFoldRuntimeCall(OMPRTL___kmpc_get_hardware_num_threads_in_block); 42335ab6aeddSJose M Monsalve Diaz registerFoldRuntimeCall(OMPRTL___kmpc_get_hardware_num_blocks); 4234d9659bf6SJohannes Doerfert } 4235d9659bf6SJohannes Doerfert 4236d9659bf6SJohannes Doerfert // Create CallSite AA for all Getters. 4237d9659bf6SJohannes Doerfert for (int Idx = 0; Idx < OMPInfoCache.ICVs.size() - 1; ++Idx) { 4238d9659bf6SJohannes Doerfert auto ICVInfo = OMPInfoCache.ICVs[static_cast<InternalControlVar>(Idx)]; 4239d9659bf6SJohannes Doerfert 4240d9659bf6SJohannes Doerfert auto &GetterRFI = OMPInfoCache.RFIs[ICVInfo.Getter]; 4241d9659bf6SJohannes Doerfert 4242d9659bf6SJohannes Doerfert auto CreateAA = [&](Use &U, Function &Caller) { 4243d9659bf6SJohannes Doerfert CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &GetterRFI); 4244d9659bf6SJohannes Doerfert if (!CI) 4245d9659bf6SJohannes Doerfert return false; 4246d9659bf6SJohannes Doerfert 4247d9659bf6SJohannes Doerfert auto &CB = cast<CallBase>(*CI); 4248d9659bf6SJohannes Doerfert 4249d9659bf6SJohannes Doerfert IRPosition CBPos = IRPosition::callsite_function(CB); 4250d9659bf6SJohannes Doerfert A.getOrCreateAAFor<AAICVTracker>(CBPos); 4251d9659bf6SJohannes Doerfert return false; 4252d9659bf6SJohannes Doerfert }; 4253d9659bf6SJohannes Doerfert 4254d9659bf6SJohannes Doerfert GetterRFI.foreachUse(SCC, CreateAA); 4255d9659bf6SJohannes Doerfert } 4256d9659bf6SJohannes Doerfert auto &GlobalizationRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared]; 4257d9659bf6SJohannes Doerfert auto CreateAA = [&](Use &U, Function &F) { 4258d9659bf6SJohannes Doerfert A.getOrCreateAAFor<AAHeapToShared>(IRPosition::function(F)); 4259d9659bf6SJohannes Doerfert return false; 4260d9659bf6SJohannes Doerfert }; 4261cd0dd8ecSJoseph Huber if (!DisableOpenMPOptDeglobalization) 4262d9659bf6SJohannes Doerfert GlobalizationRFI.foreachUse(SCC, CreateAA); 4263d9659bf6SJohannes Doerfert 4264d9659bf6SJohannes Doerfert // Create an ExecutionDomain AA for every function and a HeapToStack AA for 4265d9659bf6SJohannes Doerfert // every function if there is a device kernel. 426670b75f62SJohannes Doerfert if (!isOpenMPDevice(M)) 426770b75f62SJohannes Doerfert return; 426870b75f62SJohannes Doerfert 4269d9659bf6SJohannes Doerfert for (auto *F : SCC) { 427070b75f62SJohannes Doerfert if (F->isDeclaration()) 427170b75f62SJohannes Doerfert continue; 427270b75f62SJohannes Doerfert 4273d9659bf6SJohannes Doerfert A.getOrCreateAAFor<AAExecutionDomain>(IRPosition::function(*F)); 4274cd0dd8ecSJoseph Huber if (!DisableOpenMPOptDeglobalization) 4275d9659bf6SJohannes Doerfert A.getOrCreateAAFor<AAHeapToStack>(IRPosition::function(*F)); 427670b75f62SJohannes Doerfert 427770b75f62SJohannes Doerfert for (auto &I : instructions(*F)) { 427870b75f62SJohannes Doerfert if (auto *LI = dyn_cast<LoadInst>(&I)) { 427970b75f62SJohannes Doerfert bool UsedAssumedInformation = false; 428070b75f62SJohannes Doerfert A.getAssumedSimplified(IRPosition::value(*LI), /* AA */ nullptr, 428170b75f62SJohannes Doerfert UsedAssumedInformation); 428270b75f62SJohannes Doerfert } 428370b75f62SJohannes Doerfert } 4284d9659bf6SJohannes Doerfert } 4285d9659bf6SJohannes Doerfert } 4286d9659bf6SJohannes Doerfert 4287b8235d2bSsstefan1 const char AAICVTracker::ID = 0; 4288d9659bf6SJohannes Doerfert const char AAKernelInfo::ID = 0; 428918283125SJoseph Huber const char AAExecutionDomain::ID = 0; 42906fc51c9fSJoseph Huber const char AAHeapToShared::ID = 0; 4291ca662297SShilei Tian const char AAFoldRuntimeCall::ID = 0; 4292b8235d2bSsstefan1 4293b8235d2bSsstefan1 AAICVTracker &AAICVTracker::createForPosition(const IRPosition &IRP, 4294b8235d2bSsstefan1 Attributor &A) { 4295b8235d2bSsstefan1 AAICVTracker *AA = nullptr; 4296b8235d2bSsstefan1 switch (IRP.getPositionKind()) { 4297b8235d2bSsstefan1 case IRPosition::IRP_INVALID: 4298b8235d2bSsstefan1 case IRPosition::IRP_FLOAT: 4299b8235d2bSsstefan1 case IRPosition::IRP_ARGUMENT: 4300b8235d2bSsstefan1 case IRPosition::IRP_CALL_SITE_ARGUMENT: 43011de70a72SJohannes Doerfert llvm_unreachable("ICVTracker can only be created for function position!"); 43025dfd7cc4Ssstefan1 case IRPosition::IRP_RETURNED: 43035dfd7cc4Ssstefan1 AA = new (A.Allocator) AAICVTrackerFunctionReturned(IRP, A); 43045dfd7cc4Ssstefan1 break; 43055dfd7cc4Ssstefan1 case IRPosition::IRP_CALL_SITE_RETURNED: 43065dfd7cc4Ssstefan1 AA = new (A.Allocator) AAICVTrackerCallSiteReturned(IRP, A); 43075dfd7cc4Ssstefan1 break; 43085dfd7cc4Ssstefan1 case IRPosition::IRP_CALL_SITE: 43095dfd7cc4Ssstefan1 AA = new (A.Allocator) AAICVTrackerCallSite(IRP, A); 43105dfd7cc4Ssstefan1 break; 4311b8235d2bSsstefan1 case IRPosition::IRP_FUNCTION: 4312b8235d2bSsstefan1 AA = new (A.Allocator) AAICVTrackerFunction(IRP, A); 4313b8235d2bSsstefan1 break; 4314b8235d2bSsstefan1 } 4315b8235d2bSsstefan1 4316b8235d2bSsstefan1 return *AA; 4317b8235d2bSsstefan1 } 4318b8235d2bSsstefan1 431918283125SJoseph Huber AAExecutionDomain &AAExecutionDomain::createForPosition(const IRPosition &IRP, 432018283125SJoseph Huber Attributor &A) { 432118283125SJoseph Huber AAExecutionDomainFunction *AA = nullptr; 432218283125SJoseph Huber switch (IRP.getPositionKind()) { 432318283125SJoseph Huber case IRPosition::IRP_INVALID: 432418283125SJoseph Huber case IRPosition::IRP_FLOAT: 432518283125SJoseph Huber case IRPosition::IRP_ARGUMENT: 432618283125SJoseph Huber case IRPosition::IRP_CALL_SITE_ARGUMENT: 432718283125SJoseph Huber case IRPosition::IRP_RETURNED: 432818283125SJoseph Huber case IRPosition::IRP_CALL_SITE_RETURNED: 432918283125SJoseph Huber case IRPosition::IRP_CALL_SITE: 433018283125SJoseph Huber llvm_unreachable( 433118283125SJoseph Huber "AAExecutionDomain can only be created for function position!"); 433218283125SJoseph Huber case IRPosition::IRP_FUNCTION: 433318283125SJoseph Huber AA = new (A.Allocator) AAExecutionDomainFunction(IRP, A); 433418283125SJoseph Huber break; 433518283125SJoseph Huber } 433618283125SJoseph Huber 433718283125SJoseph Huber return *AA; 433818283125SJoseph Huber } 433918283125SJoseph Huber 43406fc51c9fSJoseph Huber AAHeapToShared &AAHeapToShared::createForPosition(const IRPosition &IRP, 43416fc51c9fSJoseph Huber Attributor &A) { 43426fc51c9fSJoseph Huber AAHeapToSharedFunction *AA = nullptr; 43436fc51c9fSJoseph Huber switch (IRP.getPositionKind()) { 43446fc51c9fSJoseph Huber case IRPosition::IRP_INVALID: 43456fc51c9fSJoseph Huber case IRPosition::IRP_FLOAT: 43466fc51c9fSJoseph Huber case IRPosition::IRP_ARGUMENT: 43476fc51c9fSJoseph Huber case IRPosition::IRP_CALL_SITE_ARGUMENT: 43486fc51c9fSJoseph Huber case IRPosition::IRP_RETURNED: 43496fc51c9fSJoseph Huber case IRPosition::IRP_CALL_SITE_RETURNED: 43506fc51c9fSJoseph Huber case IRPosition::IRP_CALL_SITE: 43516fc51c9fSJoseph Huber llvm_unreachable( 43526fc51c9fSJoseph Huber "AAHeapToShared can only be created for function position!"); 43536fc51c9fSJoseph Huber case IRPosition::IRP_FUNCTION: 43546fc51c9fSJoseph Huber AA = new (A.Allocator) AAHeapToSharedFunction(IRP, A); 43556fc51c9fSJoseph Huber break; 43566fc51c9fSJoseph Huber } 43576fc51c9fSJoseph Huber 43586fc51c9fSJoseph Huber return *AA; 43596fc51c9fSJoseph Huber } 43606fc51c9fSJoseph Huber 4361d9659bf6SJohannes Doerfert AAKernelInfo &AAKernelInfo::createForPosition(const IRPosition &IRP, 4362d9659bf6SJohannes Doerfert Attributor &A) { 4363d9659bf6SJohannes Doerfert AAKernelInfo *AA = nullptr; 4364d9659bf6SJohannes Doerfert switch (IRP.getPositionKind()) { 4365d9659bf6SJohannes Doerfert case IRPosition::IRP_INVALID: 4366d9659bf6SJohannes Doerfert case IRPosition::IRP_FLOAT: 4367d9659bf6SJohannes Doerfert case IRPosition::IRP_ARGUMENT: 4368d9659bf6SJohannes Doerfert case IRPosition::IRP_RETURNED: 4369d9659bf6SJohannes Doerfert case IRPosition::IRP_CALL_SITE_RETURNED: 4370d9659bf6SJohannes Doerfert case IRPosition::IRP_CALL_SITE_ARGUMENT: 4371d9659bf6SJohannes Doerfert llvm_unreachable("KernelInfo can only be created for function position!"); 4372d9659bf6SJohannes Doerfert case IRPosition::IRP_CALL_SITE: 4373d9659bf6SJohannes Doerfert AA = new (A.Allocator) AAKernelInfoCallSite(IRP, A); 4374d9659bf6SJohannes Doerfert break; 4375d9659bf6SJohannes Doerfert case IRPosition::IRP_FUNCTION: 4376d9659bf6SJohannes Doerfert AA = new (A.Allocator) AAKernelInfoFunction(IRP, A); 4377d9659bf6SJohannes Doerfert break; 4378d9659bf6SJohannes Doerfert } 4379d9659bf6SJohannes Doerfert 4380d9659bf6SJohannes Doerfert return *AA; 4381d9659bf6SJohannes Doerfert } 4382d9659bf6SJohannes Doerfert 4383ca662297SShilei Tian AAFoldRuntimeCall &AAFoldRuntimeCall::createForPosition(const IRPosition &IRP, 4384ca662297SShilei Tian Attributor &A) { 4385ca662297SShilei Tian AAFoldRuntimeCall *AA = nullptr; 4386ca662297SShilei Tian switch (IRP.getPositionKind()) { 4387ca662297SShilei Tian case IRPosition::IRP_INVALID: 4388ca662297SShilei Tian case IRPosition::IRP_FLOAT: 4389ca662297SShilei Tian case IRPosition::IRP_ARGUMENT: 4390ca662297SShilei Tian case IRPosition::IRP_RETURNED: 4391ca662297SShilei Tian case IRPosition::IRP_FUNCTION: 4392ca662297SShilei Tian case IRPosition::IRP_CALL_SITE: 4393ca662297SShilei Tian case IRPosition::IRP_CALL_SITE_ARGUMENT: 4394ca662297SShilei Tian llvm_unreachable("KernelInfo can only be created for call site position!"); 4395ca662297SShilei Tian case IRPosition::IRP_CALL_SITE_RETURNED: 4396ca662297SShilei Tian AA = new (A.Allocator) AAFoldRuntimeCallCallSiteReturned(IRP, A); 4397ca662297SShilei Tian break; 4398ca662297SShilei Tian } 4399ca662297SShilei Tian 4400ca662297SShilei Tian return *AA; 4401ca662297SShilei Tian } 4402ca662297SShilei Tian 4403b2ad63d3SJoseph Huber PreservedAnalyses OpenMPOptPass::run(Module &M, ModuleAnalysisManager &AM) { 44045ccb7424SJoseph Huber if (!containsOpenMP(M)) 4405b2ad63d3SJoseph Huber return PreservedAnalyses::all(); 4406b2ad63d3SJoseph Huber if (DisableOpenMPOptimizations) 4407b2ad63d3SJoseph Huber return PreservedAnalyses::all(); 4408b2ad63d3SJoseph Huber 44090edb8777SJoseph Huber FunctionAnalysisManager &FAM = 44100edb8777SJoseph Huber AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); 44115ccb7424SJoseph Huber KernelSet Kernels = getDeviceKernels(M); 44125ccb7424SJoseph Huber 441357ad2e10SJoseph Huber auto IsCalled = [&](Function &F) { 441457ad2e10SJoseph Huber if (Kernels.contains(&F)) 441557ad2e10SJoseph Huber return true; 441657ad2e10SJoseph Huber for (const User *U : F.users()) 441757ad2e10SJoseph Huber if (!isa<BlockAddress>(U)) 441857ad2e10SJoseph Huber return true; 441957ad2e10SJoseph Huber return false; 442057ad2e10SJoseph Huber }; 442157ad2e10SJoseph Huber 44220edb8777SJoseph Huber auto EmitRemark = [&](Function &F) { 44230edb8777SJoseph Huber auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(F); 44240edb8777SJoseph Huber ORE.emit([&]() { 44252c31d5ebSJoseph Huber OptimizationRemarkAnalysis ORA(DEBUG_TYPE, "OMP140", &F); 4426ecabc668SJoseph Huber return ORA << "Could not internalize function. " 4427adbaa39dSJoseph Huber << "Some optimizations may not be possible. [OMP140]"; 44280edb8777SJoseph Huber }); 44290edb8777SJoseph Huber }; 44300edb8777SJoseph Huber 443157ad2e10SJoseph Huber // Create internal copies of each function if this is a kernel Module. This 443257ad2e10SJoseph Huber // allows iterprocedural passes to see every call edge. 4433adbaa39dSJoseph Huber DenseMap<Function *, Function *> InternalizedMap; 4434adbaa39dSJoseph Huber if (isOpenMPDevice(M)) { 4435adbaa39dSJoseph Huber SmallPtrSet<Function *, 16> InternalizeFns; 443603d7e61cSJoseph Huber for (Function &F : M) 44374a668604SJoseph Huber if (!F.isDeclaration() && !Kernels.contains(&F) && IsCalled(F) && 44384a668604SJoseph Huber !DisableInternalization) { 4439adbaa39dSJoseph Huber if (Attributor::isInternalizable(F)) { 4440adbaa39dSJoseph Huber InternalizeFns.insert(&F); 4441ecabc668SJoseph Huber } else if (!F.hasLocalLinkage() && !F.hasFnAttribute(Attribute::Cold)) { 44420edb8777SJoseph Huber EmitRemark(F); 44430edb8777SJoseph Huber } 44440edb8777SJoseph Huber } 444503d7e61cSJoseph Huber 4446adbaa39dSJoseph Huber Attributor::internalizeFunctions(InternalizeFns, InternalizedMap); 4447adbaa39dSJoseph Huber } 4448adbaa39dSJoseph Huber 444957ad2e10SJoseph Huber // Look at every function in the Module unless it was internalized. 4450b2ad63d3SJoseph Huber SmallVector<Function *, 16> SCC; 445103d7e61cSJoseph Huber for (Function &F : M) 4452adbaa39dSJoseph Huber if (!F.isDeclaration() && !InternalizedMap.lookup(&F)) 445303d7e61cSJoseph Huber SCC.push_back(&F); 4454b2ad63d3SJoseph Huber 4455b2ad63d3SJoseph Huber if (SCC.empty()) 4456b2ad63d3SJoseph Huber return PreservedAnalyses::all(); 4457b2ad63d3SJoseph Huber 4458b2ad63d3SJoseph Huber AnalysisGetter AG(FAM); 4459b2ad63d3SJoseph Huber 4460b2ad63d3SJoseph Huber auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & { 4461b2ad63d3SJoseph Huber return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F); 4462b2ad63d3SJoseph Huber }; 4463b2ad63d3SJoseph Huber 4464b2ad63d3SJoseph Huber BumpPtrAllocator Allocator; 4465b2ad63d3SJoseph Huber CallGraphUpdater CGUpdater; 4466b2ad63d3SJoseph Huber 4467b2ad63d3SJoseph Huber SetVector<Function *> Functions(SCC.begin(), SCC.end()); 44685ccb7424SJoseph Huber OMPInformationCache InfoCache(M, AG, Allocator, /*CGSCC*/ Functions, Kernels); 4469b2ad63d3SJoseph Huber 447013b2fba2SJoseph Huber unsigned MaxFixpointIterations = (isOpenMPDevice(M)) ? 128 : 32; 44714a6bd8e3SJoseph Huber Attributor A(Functions, InfoCache, CGUpdater, nullptr, true, false, 447213b2fba2SJoseph Huber MaxFixpointIterations, OREGetter, DEBUG_TYPE); 4473b2ad63d3SJoseph Huber 4474b2ad63d3SJoseph Huber OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A); 4475b2ad63d3SJoseph Huber bool Changed = OMPOpt.run(true); 4476b2ad63d3SJoseph Huber if (Changed) 4477b2ad63d3SJoseph Huber return PreservedAnalyses::none(); 4478b2ad63d3SJoseph Huber 4479b2ad63d3SJoseph Huber return PreservedAnalyses::all(); 4480b2ad63d3SJoseph Huber } 4481b2ad63d3SJoseph Huber 4482b2ad63d3SJoseph Huber PreservedAnalyses OpenMPOptCGSCCPass::run(LazyCallGraph::SCC &C, 44839548b74aSJohannes Doerfert CGSCCAnalysisManager &AM, 4484b2ad63d3SJoseph Huber LazyCallGraph &CG, 4485b2ad63d3SJoseph Huber CGSCCUpdateResult &UR) { 44865ccb7424SJoseph Huber if (!containsOpenMP(*C.begin()->getFunction().getParent())) 44879548b74aSJohannes Doerfert return PreservedAnalyses::all(); 44889548b74aSJohannes Doerfert if (DisableOpenMPOptimizations) 44899548b74aSJohannes Doerfert return PreservedAnalyses::all(); 44909548b74aSJohannes Doerfert 4491ee17263aSJohannes Doerfert SmallVector<Function *, 16> SCC; 4492351d234dSRoman Lebedev // If there are kernels in the module, we have to run on all SCC's. 4493351d234dSRoman Lebedev for (LazyCallGraph::Node &N : C) { 4494351d234dSRoman Lebedev Function *Fn = &N.getFunction(); 4495351d234dSRoman Lebedev SCC.push_back(Fn); 4496351d234dSRoman Lebedev } 4497351d234dSRoman Lebedev 44985ccb7424SJoseph Huber if (SCC.empty()) 44999548b74aSJohannes Doerfert return PreservedAnalyses::all(); 45009548b74aSJohannes Doerfert 45015ccb7424SJoseph Huber Module &M = *C.begin()->getFunction().getParent(); 45025ccb7424SJoseph Huber 45035ccb7424SJoseph Huber KernelSet Kernels = getDeviceKernels(M); 45045ccb7424SJoseph Huber 45054d4ea9acSHuber, Joseph FunctionAnalysisManager &FAM = 45064d4ea9acSHuber, Joseph AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager(); 45077cfd267cSsstefan1 45087cfd267cSsstefan1 AnalysisGetter AG(FAM); 45097cfd267cSsstefan1 45107cfd267cSsstefan1 auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & { 45114d4ea9acSHuber, Joseph return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F); 45124d4ea9acSHuber, Joseph }; 45134d4ea9acSHuber, Joseph 4514b2ad63d3SJoseph Huber BumpPtrAllocator Allocator; 45159548b74aSJohannes Doerfert CallGraphUpdater CGUpdater; 45169548b74aSJohannes Doerfert CGUpdater.initialize(CG, C, AM, UR); 45177cfd267cSsstefan1 45187cfd267cSsstefan1 SetVector<Function *> Functions(SCC.begin(), SCC.end()); 45197cfd267cSsstefan1 OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG, Allocator, 45205ccb7424SJoseph Huber /*CGSCC*/ Functions, Kernels); 45217cfd267cSsstefan1 452213b2fba2SJoseph Huber unsigned MaxFixpointIterations = (isOpenMPDevice(M)) ? 128 : 32; 45234a6bd8e3SJoseph Huber Attributor A(Functions, InfoCache, CGUpdater, nullptr, false, true, 452413b2fba2SJoseph Huber MaxFixpointIterations, OREGetter, DEBUG_TYPE); 4525b8235d2bSsstefan1 4526b8235d2bSsstefan1 OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A); 4527b2ad63d3SJoseph Huber bool Changed = OMPOpt.run(false); 4528694ded37SGiorgis Georgakoudis if (Changed) 4529694ded37SGiorgis Georgakoudis return PreservedAnalyses::none(); 4530694ded37SGiorgis Georgakoudis 45319548b74aSJohannes Doerfert return PreservedAnalyses::all(); 45329548b74aSJohannes Doerfert } 45338b57ed09SJoseph Huber 45349548b74aSJohannes Doerfert namespace { 45359548b74aSJohannes Doerfert 4536b2ad63d3SJoseph Huber struct OpenMPOptCGSCCLegacyPass : public CallGraphSCCPass { 45379548b74aSJohannes Doerfert CallGraphUpdater CGUpdater; 45389548b74aSJohannes Doerfert static char ID; 45399548b74aSJohannes Doerfert 4540b2ad63d3SJoseph Huber OpenMPOptCGSCCLegacyPass() : CallGraphSCCPass(ID) { 4541b2ad63d3SJoseph Huber initializeOpenMPOptCGSCCLegacyPassPass(*PassRegistry::getPassRegistry()); 45429548b74aSJohannes Doerfert } 45439548b74aSJohannes Doerfert 45449548b74aSJohannes Doerfert void getAnalysisUsage(AnalysisUsage &AU) const override { 45459548b74aSJohannes Doerfert CallGraphSCCPass::getAnalysisUsage(AU); 45469548b74aSJohannes Doerfert } 45479548b74aSJohannes Doerfert 45489548b74aSJohannes Doerfert bool runOnSCC(CallGraphSCC &CGSCC) override { 45495ccb7424SJoseph Huber if (!containsOpenMP(CGSCC.getCallGraph().getModule())) 45509548b74aSJohannes Doerfert return false; 45519548b74aSJohannes Doerfert if (DisableOpenMPOptimizations || skipSCC(CGSCC)) 45529548b74aSJohannes Doerfert return false; 45539548b74aSJohannes Doerfert 4554ee17263aSJohannes Doerfert SmallVector<Function *, 16> SCC; 4555351d234dSRoman Lebedev // If there are kernels in the module, we have to run on all SCC's. 4556351d234dSRoman Lebedev for (CallGraphNode *CGN : CGSCC) { 4557351d234dSRoman Lebedev Function *Fn = CGN->getFunction(); 4558351d234dSRoman Lebedev if (!Fn || Fn->isDeclaration()) 4559351d234dSRoman Lebedev continue; 4560ee17263aSJohannes Doerfert SCC.push_back(Fn); 4561351d234dSRoman Lebedev } 4562351d234dSRoman Lebedev 45635ccb7424SJoseph Huber if (SCC.empty()) 45649548b74aSJohannes Doerfert return false; 45659548b74aSJohannes Doerfert 45665ccb7424SJoseph Huber Module &M = CGSCC.getCallGraph().getModule(); 45675ccb7424SJoseph Huber KernelSet Kernels = getDeviceKernels(M); 45685ccb7424SJoseph Huber 45699548b74aSJohannes Doerfert CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph(); 45709548b74aSJohannes Doerfert CGUpdater.initialize(CG, CGSCC); 45719548b74aSJohannes Doerfert 45724d4ea9acSHuber, Joseph // Maintain a map of functions to avoid rebuilding the ORE 45734d4ea9acSHuber, Joseph DenseMap<Function *, std::unique_ptr<OptimizationRemarkEmitter>> OREMap; 45744d4ea9acSHuber, Joseph auto OREGetter = [&OREMap](Function *F) -> OptimizationRemarkEmitter & { 45754d4ea9acSHuber, Joseph std::unique_ptr<OptimizationRemarkEmitter> &ORE = OREMap[F]; 45764d4ea9acSHuber, Joseph if (!ORE) 45774d4ea9acSHuber, Joseph ORE = std::make_unique<OptimizationRemarkEmitter>(F); 45784d4ea9acSHuber, Joseph return *ORE; 45794d4ea9acSHuber, Joseph }; 45804d4ea9acSHuber, Joseph 45817cfd267cSsstefan1 AnalysisGetter AG; 45827cfd267cSsstefan1 SetVector<Function *> Functions(SCC.begin(), SCC.end()); 45837cfd267cSsstefan1 BumpPtrAllocator Allocator; 45845ccb7424SJoseph Huber OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG, 45855ccb7424SJoseph Huber Allocator, 45865ccb7424SJoseph Huber /*CGSCC*/ Functions, Kernels); 45877cfd267cSsstefan1 458813b2fba2SJoseph Huber unsigned MaxFixpointIterations = (isOpenMPDevice(M)) ? 128 : 32; 458930e36c9bSJoseph Huber Attributor A(Functions, InfoCache, CGUpdater, nullptr, false, true, 459013b2fba2SJoseph Huber MaxFixpointIterations, OREGetter, DEBUG_TYPE); 4591b8235d2bSsstefan1 4592b8235d2bSsstefan1 OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A); 4593b2ad63d3SJoseph Huber return OMPOpt.run(false); 45949548b74aSJohannes Doerfert } 45959548b74aSJohannes Doerfert 45969548b74aSJohannes Doerfert bool doFinalization(CallGraph &CG) override { return CGUpdater.finalize(); } 45979548b74aSJohannes Doerfert }; 45989548b74aSJohannes Doerfert 45999548b74aSJohannes Doerfert } // end anonymous namespace 46009548b74aSJohannes Doerfert 46015ccb7424SJoseph Huber KernelSet llvm::omp::getDeviceKernels(Module &M) { 46025ccb7424SJoseph Huber // TODO: Create a more cross-platform way of determining device kernels. 4603e8039ad4SJohannes Doerfert NamedMDNode *MD = M.getOrInsertNamedMetadata("nvvm.annotations"); 46045ccb7424SJoseph Huber KernelSet Kernels; 46055ccb7424SJoseph Huber 4606e8039ad4SJohannes Doerfert if (!MD) 46075ccb7424SJoseph Huber return Kernels; 4608e8039ad4SJohannes Doerfert 4609e8039ad4SJohannes Doerfert for (auto *Op : MD->operands()) { 4610e8039ad4SJohannes Doerfert if (Op->getNumOperands() < 2) 4611e8039ad4SJohannes Doerfert continue; 4612e8039ad4SJohannes Doerfert MDString *KindID = dyn_cast<MDString>(Op->getOperand(1)); 4613e8039ad4SJohannes Doerfert if (!KindID || KindID->getString() != "kernel") 4614e8039ad4SJohannes Doerfert continue; 4615e8039ad4SJohannes Doerfert 4616e8039ad4SJohannes Doerfert Function *KernelFn = 4617e8039ad4SJohannes Doerfert mdconst::dyn_extract_or_null<Function>(Op->getOperand(0)); 4618e8039ad4SJohannes Doerfert if (!KernelFn) 4619e8039ad4SJohannes Doerfert continue; 4620e8039ad4SJohannes Doerfert 4621e8039ad4SJohannes Doerfert ++NumOpenMPTargetRegionKernels; 4622e8039ad4SJohannes Doerfert 4623e8039ad4SJohannes Doerfert Kernels.insert(KernelFn); 4624e8039ad4SJohannes Doerfert } 46255ccb7424SJoseph Huber 46265ccb7424SJoseph Huber return Kernels; 4627e8039ad4SJohannes Doerfert } 4628e8039ad4SJohannes Doerfert 46295ccb7424SJoseph Huber bool llvm::omp::containsOpenMP(Module &M) { 46305ccb7424SJoseph Huber Metadata *MD = M.getModuleFlag("openmp"); 46315ccb7424SJoseph Huber if (!MD) 46325ccb7424SJoseph Huber return false; 4633dce6bc18SJohannes Doerfert 4634e8039ad4SJohannes Doerfert return true; 4635e8039ad4SJohannes Doerfert } 4636e8039ad4SJohannes Doerfert 46375ccb7424SJoseph Huber bool llvm::omp::isOpenMPDevice(Module &M) { 46385ccb7424SJoseph Huber Metadata *MD = M.getModuleFlag("openmp-device"); 46395ccb7424SJoseph Huber if (!MD) 46405ccb7424SJoseph Huber return false; 46415ccb7424SJoseph Huber 46425ccb7424SJoseph Huber return true; 46439548b74aSJohannes Doerfert } 46449548b74aSJohannes Doerfert 4645b2ad63d3SJoseph Huber char OpenMPOptCGSCCLegacyPass::ID = 0; 46469548b74aSJohannes Doerfert 4647b2ad63d3SJoseph Huber INITIALIZE_PASS_BEGIN(OpenMPOptCGSCCLegacyPass, "openmp-opt-cgscc", 46489548b74aSJohannes Doerfert "OpenMP specific optimizations", false, false) 46499548b74aSJohannes Doerfert INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass) 4650b2ad63d3SJoseph Huber INITIALIZE_PASS_END(OpenMPOptCGSCCLegacyPass, "openmp-opt-cgscc", 46519548b74aSJohannes Doerfert "OpenMP specific optimizations", false, false) 46529548b74aSJohannes Doerfert 4653b2ad63d3SJoseph Huber Pass *llvm::createOpenMPOptCGSCCLegacyPass() { 4654b2ad63d3SJoseph Huber return new OpenMPOptCGSCCLegacyPass(); 4655b2ad63d3SJoseph Huber } 4656