1 //===-- IPO/OpenMPOpt.cpp - Collection of OpenMP specific optimizations ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // OpenMP specific optimizations: 10 // 11 // - Deduplication of runtime calls, e.g., omp_get_thread_num. 12 // - Replacing globalized device memory with stack memory. 13 // - Replacing globalized device memory with shared memory. 14 // - Parallel region merging. 15 // - Transforming generic-mode device kernels to SPMD mode. 16 // - Specializing the state machine for generic-mode device kernels. 17 // 18 //===----------------------------------------------------------------------===// 19 20 #include "llvm/Transforms/IPO/OpenMPOpt.h" 21 22 #include "llvm/ADT/EnumeratedArray.h" 23 #include "llvm/ADT/PostOrderIterator.h" 24 #include "llvm/ADT/Statistic.h" 25 #include "llvm/Analysis/CallGraph.h" 26 #include "llvm/Analysis/CallGraphSCCPass.h" 27 #include "llvm/Analysis/OptimizationRemarkEmitter.h" 28 #include "llvm/Analysis/ValueTracking.h" 29 #include "llvm/Frontend/OpenMP/OMPConstants.h" 30 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h" 31 #include "llvm/IR/Assumptions.h" 32 #include "llvm/IR/DiagnosticInfo.h" 33 #include "llvm/IR/GlobalValue.h" 34 #include "llvm/IR/Instruction.h" 35 #include "llvm/IR/IntrinsicInst.h" 36 #include "llvm/InitializePasses.h" 37 #include "llvm/Support/CommandLine.h" 38 #include "llvm/Transforms/IPO.h" 39 #include "llvm/Transforms/IPO/Attributor.h" 40 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 41 #include "llvm/Transforms/Utils/CallGraphUpdater.h" 42 #include "llvm/Transforms/Utils/CodeExtractor.h" 43 44 using namespace llvm; 45 using namespace omp; 46 47 #define DEBUG_TYPE "openmp-opt" 48 49 static cl::opt<bool> DisableOpenMPOptimizations( 50 "openmp-opt-disable", cl::ZeroOrMore, 51 cl::desc("Disable OpenMP specific optimizations."), cl::Hidden, 52 cl::init(false)); 53 54 static cl::opt<bool> EnableParallelRegionMerging( 55 "openmp-opt-enable-merging", cl::ZeroOrMore, 56 cl::desc("Enable the OpenMP region merging optimization."), cl::Hidden, 57 cl::init(false)); 58 59 static cl::opt<bool> 60 DisableInternalization("openmp-opt-disable-internalization", cl::ZeroOrMore, 61 cl::desc("Disable function internalization."), 62 cl::Hidden, cl::init(false)); 63 64 static cl::opt<bool> PrintICVValues("openmp-print-icv-values", cl::init(false), 65 cl::Hidden); 66 static cl::opt<bool> PrintOpenMPKernels("openmp-print-gpu-kernels", 67 cl::init(false), cl::Hidden); 68 69 static cl::opt<bool> HideMemoryTransferLatency( 70 "openmp-hide-memory-transfer-latency", 71 cl::desc("[WIP] Tries to hide the latency of host to device memory" 72 " transfers"), 73 cl::Hidden, cl::init(false)); 74 75 STATISTIC(NumOpenMPRuntimeCallsDeduplicated, 76 "Number of OpenMP runtime calls deduplicated"); 77 STATISTIC(NumOpenMPParallelRegionsDeleted, 78 "Number of OpenMP parallel regions deleted"); 79 STATISTIC(NumOpenMPRuntimeFunctionsIdentified, 80 "Number of OpenMP runtime functions identified"); 81 STATISTIC(NumOpenMPRuntimeFunctionUsesIdentified, 82 "Number of OpenMP runtime function uses identified"); 83 STATISTIC(NumOpenMPTargetRegionKernels, 84 "Number of OpenMP target region entry points (=kernels) identified"); 85 STATISTIC(NumOpenMPTargetRegionKernelsSPMD, 86 "Number of OpenMP target region entry points (=kernels) executed in " 87 "SPMD-mode instead of generic-mode"); 88 STATISTIC(NumOpenMPTargetRegionKernelsWithoutStateMachine, 89 "Number of OpenMP target region entry points (=kernels) executed in " 90 "generic-mode without a state machines"); 91 STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback, 92 "Number of OpenMP target region entry points (=kernels) executed in " 93 "generic-mode with customized state machines with fallback"); 94 STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback, 95 "Number of OpenMP target region entry points (=kernels) executed in " 96 "generic-mode with customized state machines without fallback"); 97 STATISTIC( 98 NumOpenMPParallelRegionsReplacedInGPUStateMachine, 99 "Number of OpenMP parallel regions replaced with ID in GPU state machines"); 100 STATISTIC(NumOpenMPParallelRegionsMerged, 101 "Number of OpenMP parallel regions merged"); 102 STATISTIC(NumBytesMovedToSharedMemory, 103 "Amount of memory pushed to shared memory"); 104 105 #if !defined(NDEBUG) 106 static constexpr auto TAG = "[" DEBUG_TYPE "]"; 107 #endif 108 109 namespace { 110 111 enum class AddressSpace : unsigned { 112 Generic = 0, 113 Global = 1, 114 Shared = 3, 115 Constant = 4, 116 Local = 5, 117 }; 118 119 struct AAHeapToShared; 120 121 struct AAICVTracker; 122 123 /// OpenMP specific information. For now, stores RFIs and ICVs also needed for 124 /// Attributor runs. 125 struct OMPInformationCache : public InformationCache { 126 OMPInformationCache(Module &M, AnalysisGetter &AG, 127 BumpPtrAllocator &Allocator, SetVector<Function *> &CGSCC, 128 SmallPtrSetImpl<Kernel> &Kernels) 129 : InformationCache(M, AG, Allocator, &CGSCC), OMPBuilder(M), 130 Kernels(Kernels) { 131 132 OMPBuilder.initialize(); 133 initializeRuntimeFunctions(); 134 initializeInternalControlVars(); 135 } 136 137 /// Generic information that describes an internal control variable. 138 struct InternalControlVarInfo { 139 /// The kind, as described by InternalControlVar enum. 140 InternalControlVar Kind; 141 142 /// The name of the ICV. 143 StringRef Name; 144 145 /// Environment variable associated with this ICV. 146 StringRef EnvVarName; 147 148 /// Initial value kind. 149 ICVInitValue InitKind; 150 151 /// Initial value. 152 ConstantInt *InitValue; 153 154 /// Setter RTL function associated with this ICV. 155 RuntimeFunction Setter; 156 157 /// Getter RTL function associated with this ICV. 158 RuntimeFunction Getter; 159 160 /// RTL Function corresponding to the override clause of this ICV 161 RuntimeFunction Clause; 162 }; 163 164 /// Generic information that describes a runtime function 165 struct RuntimeFunctionInfo { 166 167 /// The kind, as described by the RuntimeFunction enum. 168 RuntimeFunction Kind; 169 170 /// The name of the function. 171 StringRef Name; 172 173 /// Flag to indicate a variadic function. 174 bool IsVarArg; 175 176 /// The return type of the function. 177 Type *ReturnType; 178 179 /// The argument types of the function. 180 SmallVector<Type *, 8> ArgumentTypes; 181 182 /// The declaration if available. 183 Function *Declaration = nullptr; 184 185 /// Uses of this runtime function per function containing the use. 186 using UseVector = SmallVector<Use *, 16>; 187 188 /// Clear UsesMap for runtime function. 189 void clearUsesMap() { UsesMap.clear(); } 190 191 /// Boolean conversion that is true if the runtime function was found. 192 operator bool() const { return Declaration; } 193 194 /// Return the vector of uses in function \p F. 195 UseVector &getOrCreateUseVector(Function *F) { 196 std::shared_ptr<UseVector> &UV = UsesMap[F]; 197 if (!UV) 198 UV = std::make_shared<UseVector>(); 199 return *UV; 200 } 201 202 /// Return the vector of uses in function \p F or `nullptr` if there are 203 /// none. 204 const UseVector *getUseVector(Function &F) const { 205 auto I = UsesMap.find(&F); 206 if (I != UsesMap.end()) 207 return I->second.get(); 208 return nullptr; 209 } 210 211 /// Return how many functions contain uses of this runtime function. 212 size_t getNumFunctionsWithUses() const { return UsesMap.size(); } 213 214 /// Return the number of arguments (or the minimal number for variadic 215 /// functions). 216 size_t getNumArgs() const { return ArgumentTypes.size(); } 217 218 /// Run the callback \p CB on each use and forget the use if the result is 219 /// true. The callback will be fed the function in which the use was 220 /// encountered as second argument. 221 void foreachUse(SmallVectorImpl<Function *> &SCC, 222 function_ref<bool(Use &, Function &)> CB) { 223 for (Function *F : SCC) 224 foreachUse(CB, F); 225 } 226 227 /// Run the callback \p CB on each use within the function \p F and forget 228 /// the use if the result is true. 229 void foreachUse(function_ref<bool(Use &, Function &)> CB, Function *F) { 230 SmallVector<unsigned, 8> ToBeDeleted; 231 ToBeDeleted.clear(); 232 233 unsigned Idx = 0; 234 UseVector &UV = getOrCreateUseVector(F); 235 236 for (Use *U : UV) { 237 if (CB(*U, *F)) 238 ToBeDeleted.push_back(Idx); 239 ++Idx; 240 } 241 242 // Remove the to-be-deleted indices in reverse order as prior 243 // modifications will not modify the smaller indices. 244 while (!ToBeDeleted.empty()) { 245 unsigned Idx = ToBeDeleted.pop_back_val(); 246 UV[Idx] = UV.back(); 247 UV.pop_back(); 248 } 249 } 250 251 private: 252 /// Map from functions to all uses of this runtime function contained in 253 /// them. 254 DenseMap<Function *, std::shared_ptr<UseVector>> UsesMap; 255 256 public: 257 /// Iterators for the uses of this runtime function. 258 decltype(UsesMap)::iterator begin() { return UsesMap.begin(); } 259 decltype(UsesMap)::iterator end() { return UsesMap.end(); } 260 }; 261 262 /// An OpenMP-IR-Builder instance 263 OpenMPIRBuilder OMPBuilder; 264 265 /// Map from runtime function kind to the runtime function description. 266 EnumeratedArray<RuntimeFunctionInfo, RuntimeFunction, 267 RuntimeFunction::OMPRTL___last> 268 RFIs; 269 270 /// Map from function declarations/definitions to their runtime enum type. 271 DenseMap<Function *, RuntimeFunction> RuntimeFunctionIDMap; 272 273 /// Map from ICV kind to the ICV description. 274 EnumeratedArray<InternalControlVarInfo, InternalControlVar, 275 InternalControlVar::ICV___last> 276 ICVs; 277 278 /// Helper to initialize all internal control variable information for those 279 /// defined in OMPKinds.def. 280 void initializeInternalControlVars() { 281 #define ICV_RT_SET(_Name, RTL) \ 282 { \ 283 auto &ICV = ICVs[_Name]; \ 284 ICV.Setter = RTL; \ 285 } 286 #define ICV_RT_GET(Name, RTL) \ 287 { \ 288 auto &ICV = ICVs[Name]; \ 289 ICV.Getter = RTL; \ 290 } 291 #define ICV_DATA_ENV(Enum, _Name, _EnvVarName, Init) \ 292 { \ 293 auto &ICV = ICVs[Enum]; \ 294 ICV.Name = _Name; \ 295 ICV.Kind = Enum; \ 296 ICV.InitKind = Init; \ 297 ICV.EnvVarName = _EnvVarName; \ 298 switch (ICV.InitKind) { \ 299 case ICV_IMPLEMENTATION_DEFINED: \ 300 ICV.InitValue = nullptr; \ 301 break; \ 302 case ICV_ZERO: \ 303 ICV.InitValue = ConstantInt::get( \ 304 Type::getInt32Ty(OMPBuilder.Int32->getContext()), 0); \ 305 break; \ 306 case ICV_FALSE: \ 307 ICV.InitValue = ConstantInt::getFalse(OMPBuilder.Int1->getContext()); \ 308 break; \ 309 case ICV_LAST: \ 310 break; \ 311 } \ 312 } 313 #include "llvm/Frontend/OpenMP/OMPKinds.def" 314 } 315 316 /// Returns true if the function declaration \p F matches the runtime 317 /// function types, that is, return type \p RTFRetType, and argument types 318 /// \p RTFArgTypes. 319 static bool declMatchesRTFTypes(Function *F, Type *RTFRetType, 320 SmallVector<Type *, 8> &RTFArgTypes) { 321 // TODO: We should output information to the user (under debug output 322 // and via remarks). 323 324 if (!F) 325 return false; 326 if (F->getReturnType() != RTFRetType) 327 return false; 328 if (F->arg_size() != RTFArgTypes.size()) 329 return false; 330 331 auto RTFTyIt = RTFArgTypes.begin(); 332 for (Argument &Arg : F->args()) { 333 if (Arg.getType() != *RTFTyIt) 334 return false; 335 336 ++RTFTyIt; 337 } 338 339 return true; 340 } 341 342 // Helper to collect all uses of the declaration in the UsesMap. 343 unsigned collectUses(RuntimeFunctionInfo &RFI, bool CollectStats = true) { 344 unsigned NumUses = 0; 345 if (!RFI.Declaration) 346 return NumUses; 347 OMPBuilder.addAttributes(RFI.Kind, *RFI.Declaration); 348 349 if (CollectStats) { 350 NumOpenMPRuntimeFunctionsIdentified += 1; 351 NumOpenMPRuntimeFunctionUsesIdentified += RFI.Declaration->getNumUses(); 352 } 353 354 // TODO: We directly convert uses into proper calls and unknown uses. 355 for (Use &U : RFI.Declaration->uses()) { 356 if (Instruction *UserI = dyn_cast<Instruction>(U.getUser())) { 357 if (ModuleSlice.count(UserI->getFunction())) { 358 RFI.getOrCreateUseVector(UserI->getFunction()).push_back(&U); 359 ++NumUses; 360 } 361 } else { 362 RFI.getOrCreateUseVector(nullptr).push_back(&U); 363 ++NumUses; 364 } 365 } 366 return NumUses; 367 } 368 369 // Helper function to recollect uses of a runtime function. 370 void recollectUsesForFunction(RuntimeFunction RTF) { 371 auto &RFI = RFIs[RTF]; 372 RFI.clearUsesMap(); 373 collectUses(RFI, /*CollectStats*/ false); 374 } 375 376 // Helper function to recollect uses of all runtime functions. 377 void recollectUses() { 378 for (int Idx = 0; Idx < RFIs.size(); ++Idx) 379 recollectUsesForFunction(static_cast<RuntimeFunction>(Idx)); 380 } 381 382 /// Helper to initialize all runtime function information for those defined 383 /// in OpenMPKinds.def. 384 void initializeRuntimeFunctions() { 385 Module &M = *((*ModuleSlice.begin())->getParent()); 386 387 // Helper macros for handling __VA_ARGS__ in OMP_RTL 388 #define OMP_TYPE(VarName, ...) \ 389 Type *VarName = OMPBuilder.VarName; \ 390 (void)VarName; 391 392 #define OMP_ARRAY_TYPE(VarName, ...) \ 393 ArrayType *VarName##Ty = OMPBuilder.VarName##Ty; \ 394 (void)VarName##Ty; \ 395 PointerType *VarName##PtrTy = OMPBuilder.VarName##PtrTy; \ 396 (void)VarName##PtrTy; 397 398 #define OMP_FUNCTION_TYPE(VarName, ...) \ 399 FunctionType *VarName = OMPBuilder.VarName; \ 400 (void)VarName; \ 401 PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr; \ 402 (void)VarName##Ptr; 403 404 #define OMP_STRUCT_TYPE(VarName, ...) \ 405 StructType *VarName = OMPBuilder.VarName; \ 406 (void)VarName; \ 407 PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr; \ 408 (void)VarName##Ptr; 409 410 #define OMP_RTL(_Enum, _Name, _IsVarArg, _ReturnType, ...) \ 411 { \ 412 SmallVector<Type *, 8> ArgsTypes({__VA_ARGS__}); \ 413 Function *F = M.getFunction(_Name); \ 414 RTLFunctions.insert(F); \ 415 if (declMatchesRTFTypes(F, OMPBuilder._ReturnType, ArgsTypes)) { \ 416 RuntimeFunctionIDMap[F] = _Enum; \ 417 F->removeFnAttr(Attribute::NoInline); \ 418 auto &RFI = RFIs[_Enum]; \ 419 RFI.Kind = _Enum; \ 420 RFI.Name = _Name; \ 421 RFI.IsVarArg = _IsVarArg; \ 422 RFI.ReturnType = OMPBuilder._ReturnType; \ 423 RFI.ArgumentTypes = std::move(ArgsTypes); \ 424 RFI.Declaration = F; \ 425 unsigned NumUses = collectUses(RFI); \ 426 (void)NumUses; \ 427 LLVM_DEBUG({ \ 428 dbgs() << TAG << RFI.Name << (RFI.Declaration ? "" : " not") \ 429 << " found\n"; \ 430 if (RFI.Declaration) \ 431 dbgs() << TAG << "-> got " << NumUses << " uses in " \ 432 << RFI.getNumFunctionsWithUses() \ 433 << " different functions.\n"; \ 434 }); \ 435 } \ 436 } 437 #include "llvm/Frontend/OpenMP/OMPKinds.def" 438 439 // TODO: We should attach the attributes defined in OMPKinds.def. 440 } 441 442 /// Collection of known kernels (\see Kernel) in the module. 443 SmallPtrSetImpl<Kernel> &Kernels; 444 445 /// Collection of known OpenMP runtime functions.. 446 DenseSet<const Function *> RTLFunctions; 447 }; 448 449 template <typename Ty, bool InsertInvalidates = true> 450 struct BooleanStateWithPtrSetVector : public BooleanState { 451 452 bool contains(Ty *Elem) const { return Set.contains(Elem); } 453 bool insert(Ty *Elem) { 454 if (InsertInvalidates) 455 BooleanState::indicatePessimisticFixpoint(); 456 return Set.insert(Elem); 457 } 458 459 Ty *operator[](int Idx) const { return Set[Idx]; } 460 bool operator==(const BooleanStateWithPtrSetVector &RHS) const { 461 return BooleanState::operator==(RHS) && Set == RHS.Set; 462 } 463 bool operator!=(const BooleanStateWithPtrSetVector &RHS) const { 464 return !(*this == RHS); 465 } 466 467 bool empty() const { return Set.empty(); } 468 size_t size() const { return Set.size(); } 469 470 /// "Clamp" this state with \p RHS. 471 BooleanStateWithPtrSetVector & 472 operator^=(const BooleanStateWithPtrSetVector &RHS) { 473 BooleanState::operator^=(RHS); 474 Set.insert(RHS.Set.begin(), RHS.Set.end()); 475 return *this; 476 } 477 478 private: 479 /// A set to keep track of elements. 480 SetVector<Ty *> Set; 481 482 public: 483 typename decltype(Set)::iterator begin() { return Set.begin(); } 484 typename decltype(Set)::iterator end() { return Set.end(); } 485 typename decltype(Set)::const_iterator begin() const { return Set.begin(); } 486 typename decltype(Set)::const_iterator end() const { return Set.end(); } 487 }; 488 489 struct KernelInfoState : AbstractState { 490 /// Flag to track if we reached a fixpoint. 491 bool IsAtFixpoint = false; 492 493 /// The parallel regions (identified by the outlined parallel functions) that 494 /// can be reached from the associated function. 495 BooleanStateWithPtrSetVector<Function, /* InsertInvalidates */ false> 496 ReachedKnownParallelRegions; 497 498 /// State to track what parallel region we might reach. 499 BooleanStateWithPtrSetVector<CallBase> ReachedUnknownParallelRegions; 500 501 /// State to track if we are in SPMD-mode, assumed or know, and why we decided 502 /// we cannot be. If it is assumed, then RequiresFullRuntime should also be 503 /// false. 504 BooleanStateWithPtrSetVector<Instruction> SPMDCompatibilityTracker; 505 506 /// The __kmpc_target_init call in this kernel, if any. If we find more than 507 /// one we abort as the kernel is malformed. 508 CallBase *KernelInitCB = nullptr; 509 510 /// The __kmpc_target_deinit call in this kernel, if any. If we find more than 511 /// one we abort as the kernel is malformed. 512 CallBase *KernelDeinitCB = nullptr; 513 514 /// Flag to indicate if the associated function is a kernel entry. 515 bool IsKernelEntry = false; 516 517 /// State to track what kernel entries can reach the associated function. 518 BooleanStateWithPtrSetVector<Function, false> ReachingKernelEntries; 519 520 /// Abstract State interface 521 ///{ 522 523 KernelInfoState() {} 524 KernelInfoState(bool BestState) { 525 if (!BestState) 526 indicatePessimisticFixpoint(); 527 } 528 529 /// See AbstractState::isValidState(...) 530 bool isValidState() const override { return true; } 531 532 /// See AbstractState::isAtFixpoint(...) 533 bool isAtFixpoint() const override { return IsAtFixpoint; } 534 535 /// See AbstractState::indicatePessimisticFixpoint(...) 536 ChangeStatus indicatePessimisticFixpoint() override { 537 IsAtFixpoint = true; 538 SPMDCompatibilityTracker.indicatePessimisticFixpoint(); 539 ReachedUnknownParallelRegions.indicatePessimisticFixpoint(); 540 return ChangeStatus::CHANGED; 541 } 542 543 /// See AbstractState::indicateOptimisticFixpoint(...) 544 ChangeStatus indicateOptimisticFixpoint() override { 545 IsAtFixpoint = true; 546 return ChangeStatus::UNCHANGED; 547 } 548 549 /// Return the assumed state 550 KernelInfoState &getAssumed() { return *this; } 551 const KernelInfoState &getAssumed() const { return *this; } 552 553 bool operator==(const KernelInfoState &RHS) const { 554 if (SPMDCompatibilityTracker != RHS.SPMDCompatibilityTracker) 555 return false; 556 if (ReachedKnownParallelRegions != RHS.ReachedKnownParallelRegions) 557 return false; 558 if (ReachedUnknownParallelRegions != RHS.ReachedUnknownParallelRegions) 559 return false; 560 if (ReachingKernelEntries != RHS.ReachingKernelEntries) 561 return false; 562 return true; 563 } 564 565 /// Return empty set as the best state of potential values. 566 static KernelInfoState getBestState() { return KernelInfoState(true); } 567 568 static KernelInfoState getBestState(KernelInfoState &KIS) { 569 return getBestState(); 570 } 571 572 /// Return full set as the worst state of potential values. 573 static KernelInfoState getWorstState() { return KernelInfoState(false); } 574 575 /// "Clamp" this state with \p KIS. 576 KernelInfoState operator^=(const KernelInfoState &KIS) { 577 // Do not merge two different _init and _deinit call sites. 578 if (KIS.KernelInitCB) { 579 if (KernelInitCB && KernelInitCB != KIS.KernelInitCB) 580 indicatePessimisticFixpoint(); 581 KernelInitCB = KIS.KernelInitCB; 582 } 583 if (KIS.KernelDeinitCB) { 584 if (KernelDeinitCB && KernelDeinitCB != KIS.KernelDeinitCB) 585 indicatePessimisticFixpoint(); 586 KernelDeinitCB = KIS.KernelDeinitCB; 587 } 588 SPMDCompatibilityTracker ^= KIS.SPMDCompatibilityTracker; 589 ReachedKnownParallelRegions ^= KIS.ReachedKnownParallelRegions; 590 ReachedUnknownParallelRegions ^= KIS.ReachedUnknownParallelRegions; 591 return *this; 592 } 593 594 KernelInfoState operator&=(const KernelInfoState &KIS) { 595 return (*this ^= KIS); 596 } 597 598 ///} 599 }; 600 601 /// Used to map the values physically (in the IR) stored in an offload 602 /// array, to a vector in memory. 603 struct OffloadArray { 604 /// Physical array (in the IR). 605 AllocaInst *Array = nullptr; 606 /// Mapped values. 607 SmallVector<Value *, 8> StoredValues; 608 /// Last stores made in the offload array. 609 SmallVector<StoreInst *, 8> LastAccesses; 610 611 OffloadArray() = default; 612 613 /// Initializes the OffloadArray with the values stored in \p Array before 614 /// instruction \p Before is reached. Returns false if the initialization 615 /// fails. 616 /// This MUST be used immediately after the construction of the object. 617 bool initialize(AllocaInst &Array, Instruction &Before) { 618 if (!Array.getAllocatedType()->isArrayTy()) 619 return false; 620 621 if (!getValues(Array, Before)) 622 return false; 623 624 this->Array = &Array; 625 return true; 626 } 627 628 static const unsigned DeviceIDArgNum = 1; 629 static const unsigned BasePtrsArgNum = 3; 630 static const unsigned PtrsArgNum = 4; 631 static const unsigned SizesArgNum = 5; 632 633 private: 634 /// Traverses the BasicBlock where \p Array is, collecting the stores made to 635 /// \p Array, leaving StoredValues with the values stored before the 636 /// instruction \p Before is reached. 637 bool getValues(AllocaInst &Array, Instruction &Before) { 638 // Initialize container. 639 const uint64_t NumValues = Array.getAllocatedType()->getArrayNumElements(); 640 StoredValues.assign(NumValues, nullptr); 641 LastAccesses.assign(NumValues, nullptr); 642 643 // TODO: This assumes the instruction \p Before is in the same 644 // BasicBlock as Array. Make it general, for any control flow graph. 645 BasicBlock *BB = Array.getParent(); 646 if (BB != Before.getParent()) 647 return false; 648 649 const DataLayout &DL = Array.getModule()->getDataLayout(); 650 const unsigned int PointerSize = DL.getPointerSize(); 651 652 for (Instruction &I : *BB) { 653 if (&I == &Before) 654 break; 655 656 if (!isa<StoreInst>(&I)) 657 continue; 658 659 auto *S = cast<StoreInst>(&I); 660 int64_t Offset = -1; 661 auto *Dst = 662 GetPointerBaseWithConstantOffset(S->getPointerOperand(), Offset, DL); 663 if (Dst == &Array) { 664 int64_t Idx = Offset / PointerSize; 665 StoredValues[Idx] = getUnderlyingObject(S->getValueOperand()); 666 LastAccesses[Idx] = S; 667 } 668 } 669 670 return isFilled(); 671 } 672 673 /// Returns true if all values in StoredValues and 674 /// LastAccesses are not nullptrs. 675 bool isFilled() { 676 const unsigned NumValues = StoredValues.size(); 677 for (unsigned I = 0; I < NumValues; ++I) { 678 if (!StoredValues[I] || !LastAccesses[I]) 679 return false; 680 } 681 682 return true; 683 } 684 }; 685 686 struct OpenMPOpt { 687 688 using OptimizationRemarkGetter = 689 function_ref<OptimizationRemarkEmitter &(Function *)>; 690 691 OpenMPOpt(SmallVectorImpl<Function *> &SCC, CallGraphUpdater &CGUpdater, 692 OptimizationRemarkGetter OREGetter, 693 OMPInformationCache &OMPInfoCache, Attributor &A) 694 : M(*(*SCC.begin())->getParent()), SCC(SCC), CGUpdater(CGUpdater), 695 OREGetter(OREGetter), OMPInfoCache(OMPInfoCache), A(A) {} 696 697 /// Check if any remarks are enabled for openmp-opt 698 bool remarksEnabled() { 699 auto &Ctx = M.getContext(); 700 return Ctx.getDiagHandlerPtr()->isAnyRemarkEnabled(DEBUG_TYPE); 701 } 702 703 /// Run all OpenMP optimizations on the underlying SCC/ModuleSlice. 704 bool run(bool IsModulePass) { 705 if (SCC.empty()) 706 return false; 707 708 bool Changed = false; 709 710 LLVM_DEBUG(dbgs() << TAG << "Run on SCC with " << SCC.size() 711 << " functions in a slice with " 712 << OMPInfoCache.ModuleSlice.size() << " functions\n"); 713 714 if (IsModulePass) { 715 Changed |= runAttributor(IsModulePass); 716 717 // Recollect uses, in case Attributor deleted any. 718 OMPInfoCache.recollectUses(); 719 720 if (remarksEnabled()) 721 analysisGlobalization(); 722 } else { 723 if (PrintICVValues) 724 printICVs(); 725 if (PrintOpenMPKernels) 726 printKernels(); 727 728 Changed |= runAttributor(IsModulePass); 729 730 // Recollect uses, in case Attributor deleted any. 731 OMPInfoCache.recollectUses(); 732 733 Changed |= deleteParallelRegions(); 734 Changed |= rewriteDeviceCodeStateMachine(); 735 736 if (HideMemoryTransferLatency) 737 Changed |= hideMemTransfersLatency(); 738 Changed |= deduplicateRuntimeCalls(); 739 if (EnableParallelRegionMerging) { 740 if (mergeParallelRegions()) { 741 deduplicateRuntimeCalls(); 742 Changed = true; 743 } 744 } 745 } 746 747 return Changed; 748 } 749 750 /// Print initial ICV values for testing. 751 /// FIXME: This should be done from the Attributor once it is added. 752 void printICVs() const { 753 InternalControlVar ICVs[] = {ICV_nthreads, ICV_active_levels, ICV_cancel, 754 ICV_proc_bind}; 755 756 for (Function *F : OMPInfoCache.ModuleSlice) { 757 for (auto ICV : ICVs) { 758 auto ICVInfo = OMPInfoCache.ICVs[ICV]; 759 auto Remark = [&](OptimizationRemarkAnalysis ORA) { 760 return ORA << "OpenMP ICV " << ore::NV("OpenMPICV", ICVInfo.Name) 761 << " Value: " 762 << (ICVInfo.InitValue 763 ? toString(ICVInfo.InitValue->getValue(), 10, true) 764 : "IMPLEMENTATION_DEFINED"); 765 }; 766 767 emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPICVTracker", Remark); 768 } 769 } 770 } 771 772 /// Print OpenMP GPU kernels for testing. 773 void printKernels() const { 774 for (Function *F : SCC) { 775 if (!OMPInfoCache.Kernels.count(F)) 776 continue; 777 778 auto Remark = [&](OptimizationRemarkAnalysis ORA) { 779 return ORA << "OpenMP GPU kernel " 780 << ore::NV("OpenMPGPUKernel", F->getName()) << "\n"; 781 }; 782 783 emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPGPU", Remark); 784 } 785 } 786 787 /// Return the call if \p U is a callee use in a regular call. If \p RFI is 788 /// given it has to be the callee or a nullptr is returned. 789 static CallInst *getCallIfRegularCall( 790 Use &U, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) { 791 CallInst *CI = dyn_cast<CallInst>(U.getUser()); 792 if (CI && CI->isCallee(&U) && !CI->hasOperandBundles() && 793 (!RFI || 794 (RFI->Declaration && CI->getCalledFunction() == RFI->Declaration))) 795 return CI; 796 return nullptr; 797 } 798 799 /// Return the call if \p V is a regular call. If \p RFI is given it has to be 800 /// the callee or a nullptr is returned. 801 static CallInst *getCallIfRegularCall( 802 Value &V, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) { 803 CallInst *CI = dyn_cast<CallInst>(&V); 804 if (CI && !CI->hasOperandBundles() && 805 (!RFI || 806 (RFI->Declaration && CI->getCalledFunction() == RFI->Declaration))) 807 return CI; 808 return nullptr; 809 } 810 811 private: 812 /// Merge parallel regions when it is safe. 813 bool mergeParallelRegions() { 814 const unsigned CallbackCalleeOperand = 2; 815 const unsigned CallbackFirstArgOperand = 3; 816 using InsertPointTy = OpenMPIRBuilder::InsertPointTy; 817 818 // Check if there are any __kmpc_fork_call calls to merge. 819 OMPInformationCache::RuntimeFunctionInfo &RFI = 820 OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call]; 821 822 if (!RFI.Declaration) 823 return false; 824 825 // Unmergable calls that prevent merging a parallel region. 826 OMPInformationCache::RuntimeFunctionInfo UnmergableCallsInfo[] = { 827 OMPInfoCache.RFIs[OMPRTL___kmpc_push_proc_bind], 828 OMPInfoCache.RFIs[OMPRTL___kmpc_push_num_threads], 829 }; 830 831 bool Changed = false; 832 LoopInfo *LI = nullptr; 833 DominatorTree *DT = nullptr; 834 835 SmallDenseMap<BasicBlock *, SmallPtrSet<Instruction *, 4>> BB2PRMap; 836 837 BasicBlock *StartBB = nullptr, *EndBB = nullptr; 838 auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, 839 BasicBlock &ContinuationIP) { 840 BasicBlock *CGStartBB = CodeGenIP.getBlock(); 841 BasicBlock *CGEndBB = 842 SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI); 843 assert(StartBB != nullptr && "StartBB should not be null"); 844 CGStartBB->getTerminator()->setSuccessor(0, StartBB); 845 assert(EndBB != nullptr && "EndBB should not be null"); 846 EndBB->getTerminator()->setSuccessor(0, CGEndBB); 847 }; 848 849 auto PrivCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, Value &, 850 Value &Inner, Value *&ReplacementValue) -> InsertPointTy { 851 ReplacementValue = &Inner; 852 return CodeGenIP; 853 }; 854 855 auto FiniCB = [&](InsertPointTy CodeGenIP) {}; 856 857 /// Create a sequential execution region within a merged parallel region, 858 /// encapsulated in a master construct with a barrier for synchronization. 859 auto CreateSequentialRegion = [&](Function *OuterFn, 860 BasicBlock *OuterPredBB, 861 Instruction *SeqStartI, 862 Instruction *SeqEndI) { 863 // Isolate the instructions of the sequential region to a separate 864 // block. 865 BasicBlock *ParentBB = SeqStartI->getParent(); 866 BasicBlock *SeqEndBB = 867 SplitBlock(ParentBB, SeqEndI->getNextNode(), DT, LI); 868 BasicBlock *SeqAfterBB = 869 SplitBlock(SeqEndBB, &*SeqEndBB->getFirstInsertionPt(), DT, LI); 870 BasicBlock *SeqStartBB = 871 SplitBlock(ParentBB, SeqStartI, DT, LI, nullptr, "seq.par.merged"); 872 873 assert(ParentBB->getUniqueSuccessor() == SeqStartBB && 874 "Expected a different CFG"); 875 const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc(); 876 ParentBB->getTerminator()->eraseFromParent(); 877 878 auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, 879 BasicBlock &ContinuationIP) { 880 BasicBlock *CGStartBB = CodeGenIP.getBlock(); 881 BasicBlock *CGEndBB = 882 SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI); 883 assert(SeqStartBB != nullptr && "SeqStartBB should not be null"); 884 CGStartBB->getTerminator()->setSuccessor(0, SeqStartBB); 885 assert(SeqEndBB != nullptr && "SeqEndBB should not be null"); 886 SeqEndBB->getTerminator()->setSuccessor(0, CGEndBB); 887 }; 888 auto FiniCB = [&](InsertPointTy CodeGenIP) {}; 889 890 // Find outputs from the sequential region to outside users and 891 // broadcast their values to them. 892 for (Instruction &I : *SeqStartBB) { 893 SmallPtrSet<Instruction *, 4> OutsideUsers; 894 for (User *Usr : I.users()) { 895 Instruction &UsrI = *cast<Instruction>(Usr); 896 // Ignore outputs to LT intrinsics, code extraction for the merged 897 // parallel region will fix them. 898 if (UsrI.isLifetimeStartOrEnd()) 899 continue; 900 901 if (UsrI.getParent() != SeqStartBB) 902 OutsideUsers.insert(&UsrI); 903 } 904 905 if (OutsideUsers.empty()) 906 continue; 907 908 // Emit an alloca in the outer region to store the broadcasted 909 // value. 910 const DataLayout &DL = M.getDataLayout(); 911 AllocaInst *AllocaI = new AllocaInst( 912 I.getType(), DL.getAllocaAddrSpace(), nullptr, 913 I.getName() + ".seq.output.alloc", &OuterFn->front().front()); 914 915 // Emit a store instruction in the sequential BB to update the 916 // value. 917 new StoreInst(&I, AllocaI, SeqStartBB->getTerminator()); 918 919 // Emit a load instruction and replace the use of the output value 920 // with it. 921 for (Instruction *UsrI : OutsideUsers) { 922 LoadInst *LoadI = new LoadInst( 923 I.getType(), AllocaI, I.getName() + ".seq.output.load", UsrI); 924 UsrI->replaceUsesOfWith(&I, LoadI); 925 } 926 } 927 928 OpenMPIRBuilder::LocationDescription Loc( 929 InsertPointTy(ParentBB, ParentBB->end()), DL); 930 InsertPointTy SeqAfterIP = 931 OMPInfoCache.OMPBuilder.createMaster(Loc, BodyGenCB, FiniCB); 932 933 OMPInfoCache.OMPBuilder.createBarrier(SeqAfterIP, OMPD_parallel); 934 935 BranchInst::Create(SeqAfterBB, SeqAfterIP.getBlock()); 936 937 LLVM_DEBUG(dbgs() << TAG << "After sequential inlining " << *OuterFn 938 << "\n"); 939 }; 940 941 // Helper to merge the __kmpc_fork_call calls in MergableCIs. They are all 942 // contained in BB and only separated by instructions that can be 943 // redundantly executed in parallel. The block BB is split before the first 944 // call (in MergableCIs) and after the last so the entire region we merge 945 // into a single parallel region is contained in a single basic block 946 // without any other instructions. We use the OpenMPIRBuilder to outline 947 // that block and call the resulting function via __kmpc_fork_call. 948 auto Merge = [&](SmallVectorImpl<CallInst *> &MergableCIs, BasicBlock *BB) { 949 // TODO: Change the interface to allow single CIs expanded, e.g, to 950 // include an outer loop. 951 assert(MergableCIs.size() > 1 && "Assumed multiple mergable CIs"); 952 953 auto Remark = [&](OptimizationRemark OR) { 954 OR << "Parallel region merged with parallel region" 955 << (MergableCIs.size() > 2 ? "s" : "") << " at "; 956 for (auto *CI : llvm::drop_begin(MergableCIs)) { 957 OR << ore::NV("OpenMPParallelMerge", CI->getDebugLoc()); 958 if (CI != MergableCIs.back()) 959 OR << ", "; 960 } 961 return OR << "."; 962 }; 963 964 emitRemark<OptimizationRemark>(MergableCIs.front(), "OMP150", Remark); 965 966 Function *OriginalFn = BB->getParent(); 967 LLVM_DEBUG(dbgs() << TAG << "Merge " << MergableCIs.size() 968 << " parallel regions in " << OriginalFn->getName() 969 << "\n"); 970 971 // Isolate the calls to merge in a separate block. 972 EndBB = SplitBlock(BB, MergableCIs.back()->getNextNode(), DT, LI); 973 BasicBlock *AfterBB = 974 SplitBlock(EndBB, &*EndBB->getFirstInsertionPt(), DT, LI); 975 StartBB = SplitBlock(BB, MergableCIs.front(), DT, LI, nullptr, 976 "omp.par.merged"); 977 978 assert(BB->getUniqueSuccessor() == StartBB && "Expected a different CFG"); 979 const DebugLoc DL = BB->getTerminator()->getDebugLoc(); 980 BB->getTerminator()->eraseFromParent(); 981 982 // Create sequential regions for sequential instructions that are 983 // in-between mergable parallel regions. 984 for (auto *It = MergableCIs.begin(), *End = MergableCIs.end() - 1; 985 It != End; ++It) { 986 Instruction *ForkCI = *It; 987 Instruction *NextForkCI = *(It + 1); 988 989 // Continue if there are not in-between instructions. 990 if (ForkCI->getNextNode() == NextForkCI) 991 continue; 992 993 CreateSequentialRegion(OriginalFn, BB, ForkCI->getNextNode(), 994 NextForkCI->getPrevNode()); 995 } 996 997 OpenMPIRBuilder::LocationDescription Loc(InsertPointTy(BB, BB->end()), 998 DL); 999 IRBuilder<>::InsertPoint AllocaIP( 1000 &OriginalFn->getEntryBlock(), 1001 OriginalFn->getEntryBlock().getFirstInsertionPt()); 1002 // Create the merged parallel region with default proc binding, to 1003 // avoid overriding binding settings, and without explicit cancellation. 1004 InsertPointTy AfterIP = OMPInfoCache.OMPBuilder.createParallel( 1005 Loc, AllocaIP, BodyGenCB, PrivCB, FiniCB, nullptr, nullptr, 1006 OMP_PROC_BIND_default, /* IsCancellable */ false); 1007 BranchInst::Create(AfterBB, AfterIP.getBlock()); 1008 1009 // Perform the actual outlining. 1010 OMPInfoCache.OMPBuilder.finalize(OriginalFn, 1011 /* AllowExtractorSinking */ true); 1012 1013 Function *OutlinedFn = MergableCIs.front()->getCaller(); 1014 1015 // Replace the __kmpc_fork_call calls with direct calls to the outlined 1016 // callbacks. 1017 SmallVector<Value *, 8> Args; 1018 for (auto *CI : MergableCIs) { 1019 Value *Callee = 1020 CI->getArgOperand(CallbackCalleeOperand)->stripPointerCasts(); 1021 FunctionType *FT = 1022 cast<FunctionType>(Callee->getType()->getPointerElementType()); 1023 Args.clear(); 1024 Args.push_back(OutlinedFn->getArg(0)); 1025 Args.push_back(OutlinedFn->getArg(1)); 1026 for (unsigned U = CallbackFirstArgOperand, E = CI->getNumArgOperands(); 1027 U < E; ++U) 1028 Args.push_back(CI->getArgOperand(U)); 1029 1030 CallInst *NewCI = CallInst::Create(FT, Callee, Args, "", CI); 1031 if (CI->getDebugLoc()) 1032 NewCI->setDebugLoc(CI->getDebugLoc()); 1033 1034 // Forward parameter attributes from the callback to the callee. 1035 for (unsigned U = CallbackFirstArgOperand, E = CI->getNumArgOperands(); 1036 U < E; ++U) 1037 for (const Attribute &A : CI->getAttributes().getParamAttributes(U)) 1038 NewCI->addParamAttr( 1039 U - (CallbackFirstArgOperand - CallbackCalleeOperand), A); 1040 1041 // Emit an explicit barrier to replace the implicit fork-join barrier. 1042 if (CI != MergableCIs.back()) { 1043 // TODO: Remove barrier if the merged parallel region includes the 1044 // 'nowait' clause. 1045 OMPInfoCache.OMPBuilder.createBarrier( 1046 InsertPointTy(NewCI->getParent(), 1047 NewCI->getNextNode()->getIterator()), 1048 OMPD_parallel); 1049 } 1050 1051 CI->eraseFromParent(); 1052 } 1053 1054 assert(OutlinedFn != OriginalFn && "Outlining failed"); 1055 CGUpdater.registerOutlinedFunction(*OriginalFn, *OutlinedFn); 1056 CGUpdater.reanalyzeFunction(*OriginalFn); 1057 1058 NumOpenMPParallelRegionsMerged += MergableCIs.size(); 1059 1060 return true; 1061 }; 1062 1063 // Helper function that identifes sequences of 1064 // __kmpc_fork_call uses in a basic block. 1065 auto DetectPRsCB = [&](Use &U, Function &F) { 1066 CallInst *CI = getCallIfRegularCall(U, &RFI); 1067 BB2PRMap[CI->getParent()].insert(CI); 1068 1069 return false; 1070 }; 1071 1072 BB2PRMap.clear(); 1073 RFI.foreachUse(SCC, DetectPRsCB); 1074 SmallVector<SmallVector<CallInst *, 4>, 4> MergableCIsVector; 1075 // Find mergable parallel regions within a basic block that are 1076 // safe to merge, that is any in-between instructions can safely 1077 // execute in parallel after merging. 1078 // TODO: support merging across basic-blocks. 1079 for (auto &It : BB2PRMap) { 1080 auto &CIs = It.getSecond(); 1081 if (CIs.size() < 2) 1082 continue; 1083 1084 BasicBlock *BB = It.getFirst(); 1085 SmallVector<CallInst *, 4> MergableCIs; 1086 1087 /// Returns true if the instruction is mergable, false otherwise. 1088 /// A terminator instruction is unmergable by definition since merging 1089 /// works within a BB. Instructions before the mergable region are 1090 /// mergable if they are not calls to OpenMP runtime functions that may 1091 /// set different execution parameters for subsequent parallel regions. 1092 /// Instructions in-between parallel regions are mergable if they are not 1093 /// calls to any non-intrinsic function since that may call a non-mergable 1094 /// OpenMP runtime function. 1095 auto IsMergable = [&](Instruction &I, bool IsBeforeMergableRegion) { 1096 // We do not merge across BBs, hence return false (unmergable) if the 1097 // instruction is a terminator. 1098 if (I.isTerminator()) 1099 return false; 1100 1101 if (!isa<CallInst>(&I)) 1102 return true; 1103 1104 CallInst *CI = cast<CallInst>(&I); 1105 if (IsBeforeMergableRegion) { 1106 Function *CalledFunction = CI->getCalledFunction(); 1107 if (!CalledFunction) 1108 return false; 1109 // Return false (unmergable) if the call before the parallel 1110 // region calls an explicit affinity (proc_bind) or number of 1111 // threads (num_threads) compiler-generated function. Those settings 1112 // may be incompatible with following parallel regions. 1113 // TODO: ICV tracking to detect compatibility. 1114 for (const auto &RFI : UnmergableCallsInfo) { 1115 if (CalledFunction == RFI.Declaration) 1116 return false; 1117 } 1118 } else { 1119 // Return false (unmergable) if there is a call instruction 1120 // in-between parallel regions when it is not an intrinsic. It 1121 // may call an unmergable OpenMP runtime function in its callpath. 1122 // TODO: Keep track of possible OpenMP calls in the callpath. 1123 if (!isa<IntrinsicInst>(CI)) 1124 return false; 1125 } 1126 1127 return true; 1128 }; 1129 // Find maximal number of parallel region CIs that are safe to merge. 1130 for (auto It = BB->begin(), End = BB->end(); It != End;) { 1131 Instruction &I = *It; 1132 ++It; 1133 1134 if (CIs.count(&I)) { 1135 MergableCIs.push_back(cast<CallInst>(&I)); 1136 continue; 1137 } 1138 1139 // Continue expanding if the instruction is mergable. 1140 if (IsMergable(I, MergableCIs.empty())) 1141 continue; 1142 1143 // Forward the instruction iterator to skip the next parallel region 1144 // since there is an unmergable instruction which can affect it. 1145 for (; It != End; ++It) { 1146 Instruction &SkipI = *It; 1147 if (CIs.count(&SkipI)) { 1148 LLVM_DEBUG(dbgs() << TAG << "Skip parallel region " << SkipI 1149 << " due to " << I << "\n"); 1150 ++It; 1151 break; 1152 } 1153 } 1154 1155 // Store mergable regions found. 1156 if (MergableCIs.size() > 1) { 1157 MergableCIsVector.push_back(MergableCIs); 1158 LLVM_DEBUG(dbgs() << TAG << "Found " << MergableCIs.size() 1159 << " parallel regions in block " << BB->getName() 1160 << " of function " << BB->getParent()->getName() 1161 << "\n";); 1162 } 1163 1164 MergableCIs.clear(); 1165 } 1166 1167 if (!MergableCIsVector.empty()) { 1168 Changed = true; 1169 1170 for (auto &MergableCIs : MergableCIsVector) 1171 Merge(MergableCIs, BB); 1172 MergableCIsVector.clear(); 1173 } 1174 } 1175 1176 if (Changed) { 1177 /// Re-collect use for fork calls, emitted barrier calls, and 1178 /// any emitted master/end_master calls. 1179 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_fork_call); 1180 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_barrier); 1181 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_master); 1182 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_end_master); 1183 } 1184 1185 return Changed; 1186 } 1187 1188 /// Try to delete parallel regions if possible. 1189 bool deleteParallelRegions() { 1190 const unsigned CallbackCalleeOperand = 2; 1191 1192 OMPInformationCache::RuntimeFunctionInfo &RFI = 1193 OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call]; 1194 1195 if (!RFI.Declaration) 1196 return false; 1197 1198 bool Changed = false; 1199 auto DeleteCallCB = [&](Use &U, Function &) { 1200 CallInst *CI = getCallIfRegularCall(U); 1201 if (!CI) 1202 return false; 1203 auto *Fn = dyn_cast<Function>( 1204 CI->getArgOperand(CallbackCalleeOperand)->stripPointerCasts()); 1205 if (!Fn) 1206 return false; 1207 if (!Fn->onlyReadsMemory()) 1208 return false; 1209 if (!Fn->hasFnAttribute(Attribute::WillReturn)) 1210 return false; 1211 1212 LLVM_DEBUG(dbgs() << TAG << "Delete read-only parallel region in " 1213 << CI->getCaller()->getName() << "\n"); 1214 1215 auto Remark = [&](OptimizationRemark OR) { 1216 return OR << "Removing parallel region with no side-effects."; 1217 }; 1218 emitRemark<OptimizationRemark>(CI, "OMP160", Remark); 1219 1220 CGUpdater.removeCallSite(*CI); 1221 CI->eraseFromParent(); 1222 Changed = true; 1223 ++NumOpenMPParallelRegionsDeleted; 1224 return true; 1225 }; 1226 1227 RFI.foreachUse(SCC, DeleteCallCB); 1228 1229 return Changed; 1230 } 1231 1232 /// Try to eliminate runtime calls by reusing existing ones. 1233 bool deduplicateRuntimeCalls() { 1234 bool Changed = false; 1235 1236 RuntimeFunction DeduplicableRuntimeCallIDs[] = { 1237 OMPRTL_omp_get_num_threads, 1238 OMPRTL_omp_in_parallel, 1239 OMPRTL_omp_get_cancellation, 1240 OMPRTL_omp_get_thread_limit, 1241 OMPRTL_omp_get_supported_active_levels, 1242 OMPRTL_omp_get_level, 1243 OMPRTL_omp_get_ancestor_thread_num, 1244 OMPRTL_omp_get_team_size, 1245 OMPRTL_omp_get_active_level, 1246 OMPRTL_omp_in_final, 1247 OMPRTL_omp_get_proc_bind, 1248 OMPRTL_omp_get_num_places, 1249 OMPRTL_omp_get_num_procs, 1250 OMPRTL_omp_get_place_num, 1251 OMPRTL_omp_get_partition_num_places, 1252 OMPRTL_omp_get_partition_place_nums}; 1253 1254 // Global-tid is handled separately. 1255 SmallSetVector<Value *, 16> GTIdArgs; 1256 collectGlobalThreadIdArguments(GTIdArgs); 1257 LLVM_DEBUG(dbgs() << TAG << "Found " << GTIdArgs.size() 1258 << " global thread ID arguments\n"); 1259 1260 for (Function *F : SCC) { 1261 for (auto DeduplicableRuntimeCallID : DeduplicableRuntimeCallIDs) 1262 Changed |= deduplicateRuntimeCalls( 1263 *F, OMPInfoCache.RFIs[DeduplicableRuntimeCallID]); 1264 1265 // __kmpc_global_thread_num is special as we can replace it with an 1266 // argument in enough cases to make it worth trying. 1267 Value *GTIdArg = nullptr; 1268 for (Argument &Arg : F->args()) 1269 if (GTIdArgs.count(&Arg)) { 1270 GTIdArg = &Arg; 1271 break; 1272 } 1273 Changed |= deduplicateRuntimeCalls( 1274 *F, OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num], GTIdArg); 1275 } 1276 1277 return Changed; 1278 } 1279 1280 /// Tries to hide the latency of runtime calls that involve host to 1281 /// device memory transfers by splitting them into their "issue" and "wait" 1282 /// versions. The "issue" is moved upwards as much as possible. The "wait" is 1283 /// moved downards as much as possible. The "issue" issues the memory transfer 1284 /// asynchronously, returning a handle. The "wait" waits in the returned 1285 /// handle for the memory transfer to finish. 1286 bool hideMemTransfersLatency() { 1287 auto &RFI = OMPInfoCache.RFIs[OMPRTL___tgt_target_data_begin_mapper]; 1288 bool Changed = false; 1289 auto SplitMemTransfers = [&](Use &U, Function &Decl) { 1290 auto *RTCall = getCallIfRegularCall(U, &RFI); 1291 if (!RTCall) 1292 return false; 1293 1294 OffloadArray OffloadArrays[3]; 1295 if (!getValuesInOffloadArrays(*RTCall, OffloadArrays)) 1296 return false; 1297 1298 LLVM_DEBUG(dumpValuesInOffloadArrays(OffloadArrays)); 1299 1300 // TODO: Check if can be moved upwards. 1301 bool WasSplit = false; 1302 Instruction *WaitMovementPoint = canBeMovedDownwards(*RTCall); 1303 if (WaitMovementPoint) 1304 WasSplit = splitTargetDataBeginRTC(*RTCall, *WaitMovementPoint); 1305 1306 Changed |= WasSplit; 1307 return WasSplit; 1308 }; 1309 RFI.foreachUse(SCC, SplitMemTransfers); 1310 1311 return Changed; 1312 } 1313 1314 void analysisGlobalization() { 1315 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared]; 1316 1317 auto CheckGlobalization = [&](Use &U, Function &Decl) { 1318 if (CallInst *CI = getCallIfRegularCall(U, &RFI)) { 1319 auto Remark = [&](OptimizationRemarkMissed ORM) { 1320 return ORM 1321 << "Found thread data sharing on the GPU. " 1322 << "Expect degraded performance due to data globalization."; 1323 }; 1324 emitRemark<OptimizationRemarkMissed>(CI, "OMP112", Remark); 1325 } 1326 1327 return false; 1328 }; 1329 1330 RFI.foreachUse(SCC, CheckGlobalization); 1331 } 1332 1333 /// Maps the values stored in the offload arrays passed as arguments to 1334 /// \p RuntimeCall into the offload arrays in \p OAs. 1335 bool getValuesInOffloadArrays(CallInst &RuntimeCall, 1336 MutableArrayRef<OffloadArray> OAs) { 1337 assert(OAs.size() == 3 && "Need space for three offload arrays!"); 1338 1339 // A runtime call that involves memory offloading looks something like: 1340 // call void @__tgt_target_data_begin_mapper(arg0, arg1, 1341 // i8** %offload_baseptrs, i8** %offload_ptrs, i64* %offload_sizes, 1342 // ...) 1343 // So, the idea is to access the allocas that allocate space for these 1344 // offload arrays, offload_baseptrs, offload_ptrs, offload_sizes. 1345 // Therefore: 1346 // i8** %offload_baseptrs. 1347 Value *BasePtrsArg = 1348 RuntimeCall.getArgOperand(OffloadArray::BasePtrsArgNum); 1349 // i8** %offload_ptrs. 1350 Value *PtrsArg = RuntimeCall.getArgOperand(OffloadArray::PtrsArgNum); 1351 // i8** %offload_sizes. 1352 Value *SizesArg = RuntimeCall.getArgOperand(OffloadArray::SizesArgNum); 1353 1354 // Get values stored in **offload_baseptrs. 1355 auto *V = getUnderlyingObject(BasePtrsArg); 1356 if (!isa<AllocaInst>(V)) 1357 return false; 1358 auto *BasePtrsArray = cast<AllocaInst>(V); 1359 if (!OAs[0].initialize(*BasePtrsArray, RuntimeCall)) 1360 return false; 1361 1362 // Get values stored in **offload_baseptrs. 1363 V = getUnderlyingObject(PtrsArg); 1364 if (!isa<AllocaInst>(V)) 1365 return false; 1366 auto *PtrsArray = cast<AllocaInst>(V); 1367 if (!OAs[1].initialize(*PtrsArray, RuntimeCall)) 1368 return false; 1369 1370 // Get values stored in **offload_sizes. 1371 V = getUnderlyingObject(SizesArg); 1372 // If it's a [constant] global array don't analyze it. 1373 if (isa<GlobalValue>(V)) 1374 return isa<Constant>(V); 1375 if (!isa<AllocaInst>(V)) 1376 return false; 1377 1378 auto *SizesArray = cast<AllocaInst>(V); 1379 if (!OAs[2].initialize(*SizesArray, RuntimeCall)) 1380 return false; 1381 1382 return true; 1383 } 1384 1385 /// Prints the values in the OffloadArrays \p OAs using LLVM_DEBUG. 1386 /// For now this is a way to test that the function getValuesInOffloadArrays 1387 /// is working properly. 1388 /// TODO: Move this to a unittest when unittests are available for OpenMPOpt. 1389 void dumpValuesInOffloadArrays(ArrayRef<OffloadArray> OAs) { 1390 assert(OAs.size() == 3 && "There are three offload arrays to debug!"); 1391 1392 LLVM_DEBUG(dbgs() << TAG << " Successfully got offload values:\n"); 1393 std::string ValuesStr; 1394 raw_string_ostream Printer(ValuesStr); 1395 std::string Separator = " --- "; 1396 1397 for (auto *BP : OAs[0].StoredValues) { 1398 BP->print(Printer); 1399 Printer << Separator; 1400 } 1401 LLVM_DEBUG(dbgs() << "\t\toffload_baseptrs: " << Printer.str() << "\n"); 1402 ValuesStr.clear(); 1403 1404 for (auto *P : OAs[1].StoredValues) { 1405 P->print(Printer); 1406 Printer << Separator; 1407 } 1408 LLVM_DEBUG(dbgs() << "\t\toffload_ptrs: " << Printer.str() << "\n"); 1409 ValuesStr.clear(); 1410 1411 for (auto *S : OAs[2].StoredValues) { 1412 S->print(Printer); 1413 Printer << Separator; 1414 } 1415 LLVM_DEBUG(dbgs() << "\t\toffload_sizes: " << Printer.str() << "\n"); 1416 } 1417 1418 /// Returns the instruction where the "wait" counterpart \p RuntimeCall can be 1419 /// moved. Returns nullptr if the movement is not possible, or not worth it. 1420 Instruction *canBeMovedDownwards(CallInst &RuntimeCall) { 1421 // FIXME: This traverses only the BasicBlock where RuntimeCall is. 1422 // Make it traverse the CFG. 1423 1424 Instruction *CurrentI = &RuntimeCall; 1425 bool IsWorthIt = false; 1426 while ((CurrentI = CurrentI->getNextNode())) { 1427 1428 // TODO: Once we detect the regions to be offloaded we should use the 1429 // alias analysis manager to check if CurrentI may modify one of 1430 // the offloaded regions. 1431 if (CurrentI->mayHaveSideEffects() || CurrentI->mayReadFromMemory()) { 1432 if (IsWorthIt) 1433 return CurrentI; 1434 1435 return nullptr; 1436 } 1437 1438 // FIXME: For now if we move it over anything without side effect 1439 // is worth it. 1440 IsWorthIt = true; 1441 } 1442 1443 // Return end of BasicBlock. 1444 return RuntimeCall.getParent()->getTerminator(); 1445 } 1446 1447 /// Splits \p RuntimeCall into its "issue" and "wait" counterparts. 1448 bool splitTargetDataBeginRTC(CallInst &RuntimeCall, 1449 Instruction &WaitMovementPoint) { 1450 // Create stack allocated handle (__tgt_async_info) at the beginning of the 1451 // function. Used for storing information of the async transfer, allowing to 1452 // wait on it later. 1453 auto &IRBuilder = OMPInfoCache.OMPBuilder; 1454 auto *F = RuntimeCall.getCaller(); 1455 Instruction *FirstInst = &(F->getEntryBlock().front()); 1456 AllocaInst *Handle = new AllocaInst( 1457 IRBuilder.AsyncInfo, F->getAddressSpace(), "handle", FirstInst); 1458 1459 // Add "issue" runtime call declaration: 1460 // declare %struct.tgt_async_info @__tgt_target_data_begin_issue(i64, i32, 1461 // i8**, i8**, i64*, i64*) 1462 FunctionCallee IssueDecl = IRBuilder.getOrCreateRuntimeFunction( 1463 M, OMPRTL___tgt_target_data_begin_mapper_issue); 1464 1465 // Change RuntimeCall call site for its asynchronous version. 1466 SmallVector<Value *, 16> Args; 1467 for (auto &Arg : RuntimeCall.args()) 1468 Args.push_back(Arg.get()); 1469 Args.push_back(Handle); 1470 1471 CallInst *IssueCallsite = 1472 CallInst::Create(IssueDecl, Args, /*NameStr=*/"", &RuntimeCall); 1473 RuntimeCall.eraseFromParent(); 1474 1475 // Add "wait" runtime call declaration: 1476 // declare void @__tgt_target_data_begin_wait(i64, %struct.__tgt_async_info) 1477 FunctionCallee WaitDecl = IRBuilder.getOrCreateRuntimeFunction( 1478 M, OMPRTL___tgt_target_data_begin_mapper_wait); 1479 1480 Value *WaitParams[2] = { 1481 IssueCallsite->getArgOperand( 1482 OffloadArray::DeviceIDArgNum), // device_id. 1483 Handle // handle to wait on. 1484 }; 1485 CallInst::Create(WaitDecl, WaitParams, /*NameStr=*/"", &WaitMovementPoint); 1486 1487 return true; 1488 } 1489 1490 static Value *combinedIdentStruct(Value *CurrentIdent, Value *NextIdent, 1491 bool GlobalOnly, bool &SingleChoice) { 1492 if (CurrentIdent == NextIdent) 1493 return CurrentIdent; 1494 1495 // TODO: Figure out how to actually combine multiple debug locations. For 1496 // now we just keep an existing one if there is a single choice. 1497 if (!GlobalOnly || isa<GlobalValue>(NextIdent)) { 1498 SingleChoice = !CurrentIdent; 1499 return NextIdent; 1500 } 1501 return nullptr; 1502 } 1503 1504 /// Return an `struct ident_t*` value that represents the ones used in the 1505 /// calls of \p RFI inside of \p F. If \p GlobalOnly is true, we will not 1506 /// return a local `struct ident_t*`. For now, if we cannot find a suitable 1507 /// return value we create one from scratch. We also do not yet combine 1508 /// information, e.g., the source locations, see combinedIdentStruct. 1509 Value * 1510 getCombinedIdentFromCallUsesIn(OMPInformationCache::RuntimeFunctionInfo &RFI, 1511 Function &F, bool GlobalOnly) { 1512 bool SingleChoice = true; 1513 Value *Ident = nullptr; 1514 auto CombineIdentStruct = [&](Use &U, Function &Caller) { 1515 CallInst *CI = getCallIfRegularCall(U, &RFI); 1516 if (!CI || &F != &Caller) 1517 return false; 1518 Ident = combinedIdentStruct(Ident, CI->getArgOperand(0), 1519 /* GlobalOnly */ true, SingleChoice); 1520 return false; 1521 }; 1522 RFI.foreachUse(SCC, CombineIdentStruct); 1523 1524 if (!Ident || !SingleChoice) { 1525 // The IRBuilder uses the insertion block to get to the module, this is 1526 // unfortunate but we work around it for now. 1527 if (!OMPInfoCache.OMPBuilder.getInsertionPoint().getBlock()) 1528 OMPInfoCache.OMPBuilder.updateToLocation(OpenMPIRBuilder::InsertPointTy( 1529 &F.getEntryBlock(), F.getEntryBlock().begin())); 1530 // Create a fallback location if non was found. 1531 // TODO: Use the debug locations of the calls instead. 1532 Constant *Loc = OMPInfoCache.OMPBuilder.getOrCreateDefaultSrcLocStr(); 1533 Ident = OMPInfoCache.OMPBuilder.getOrCreateIdent(Loc); 1534 } 1535 return Ident; 1536 } 1537 1538 /// Try to eliminate calls of \p RFI in \p F by reusing an existing one or 1539 /// \p ReplVal if given. 1540 bool deduplicateRuntimeCalls(Function &F, 1541 OMPInformationCache::RuntimeFunctionInfo &RFI, 1542 Value *ReplVal = nullptr) { 1543 auto *UV = RFI.getUseVector(F); 1544 if (!UV || UV->size() + (ReplVal != nullptr) < 2) 1545 return false; 1546 1547 LLVM_DEBUG( 1548 dbgs() << TAG << "Deduplicate " << UV->size() << " uses of " << RFI.Name 1549 << (ReplVal ? " with an existing value\n" : "\n") << "\n"); 1550 1551 assert((!ReplVal || (isa<Argument>(ReplVal) && 1552 cast<Argument>(ReplVal)->getParent() == &F)) && 1553 "Unexpected replacement value!"); 1554 1555 // TODO: Use dominance to find a good position instead. 1556 auto CanBeMoved = [this](CallBase &CB) { 1557 unsigned NumArgs = CB.getNumArgOperands(); 1558 if (NumArgs == 0) 1559 return true; 1560 if (CB.getArgOperand(0)->getType() != OMPInfoCache.OMPBuilder.IdentPtr) 1561 return false; 1562 for (unsigned u = 1; u < NumArgs; ++u) 1563 if (isa<Instruction>(CB.getArgOperand(u))) 1564 return false; 1565 return true; 1566 }; 1567 1568 if (!ReplVal) { 1569 for (Use *U : *UV) 1570 if (CallInst *CI = getCallIfRegularCall(*U, &RFI)) { 1571 if (!CanBeMoved(*CI)) 1572 continue; 1573 1574 CI->moveBefore(&*F.getEntryBlock().getFirstInsertionPt()); 1575 ReplVal = CI; 1576 break; 1577 } 1578 if (!ReplVal) 1579 return false; 1580 } 1581 1582 // If we use a call as a replacement value we need to make sure the ident is 1583 // valid at the new location. For now we just pick a global one, either 1584 // existing and used by one of the calls, or created from scratch. 1585 if (CallBase *CI = dyn_cast<CallBase>(ReplVal)) { 1586 if (CI->getNumArgOperands() > 0 && 1587 CI->getArgOperand(0)->getType() == OMPInfoCache.OMPBuilder.IdentPtr) { 1588 Value *Ident = getCombinedIdentFromCallUsesIn(RFI, F, 1589 /* GlobalOnly */ true); 1590 CI->setArgOperand(0, Ident); 1591 } 1592 } 1593 1594 bool Changed = false; 1595 auto ReplaceAndDeleteCB = [&](Use &U, Function &Caller) { 1596 CallInst *CI = getCallIfRegularCall(U, &RFI); 1597 if (!CI || CI == ReplVal || &F != &Caller) 1598 return false; 1599 assert(CI->getCaller() == &F && "Unexpected call!"); 1600 1601 auto Remark = [&](OptimizationRemark OR) { 1602 return OR << "OpenMP runtime call " 1603 << ore::NV("OpenMPOptRuntime", RFI.Name) << " deduplicated."; 1604 }; 1605 if (CI->getDebugLoc()) 1606 emitRemark<OptimizationRemark>(CI, "OMP170", Remark); 1607 else 1608 emitRemark<OptimizationRemark>(&F, "OMP170", Remark); 1609 1610 CGUpdater.removeCallSite(*CI); 1611 CI->replaceAllUsesWith(ReplVal); 1612 CI->eraseFromParent(); 1613 ++NumOpenMPRuntimeCallsDeduplicated; 1614 Changed = true; 1615 return true; 1616 }; 1617 RFI.foreachUse(SCC, ReplaceAndDeleteCB); 1618 1619 return Changed; 1620 } 1621 1622 /// Collect arguments that represent the global thread id in \p GTIdArgs. 1623 void collectGlobalThreadIdArguments(SmallSetVector<Value *, 16> >IdArgs) { 1624 // TODO: Below we basically perform a fixpoint iteration with a pessimistic 1625 // initialization. We could define an AbstractAttribute instead and 1626 // run the Attributor here once it can be run as an SCC pass. 1627 1628 // Helper to check the argument \p ArgNo at all call sites of \p F for 1629 // a GTId. 1630 auto CallArgOpIsGTId = [&](Function &F, unsigned ArgNo, CallInst &RefCI) { 1631 if (!F.hasLocalLinkage()) 1632 return false; 1633 for (Use &U : F.uses()) { 1634 if (CallInst *CI = getCallIfRegularCall(U)) { 1635 Value *ArgOp = CI->getArgOperand(ArgNo); 1636 if (CI == &RefCI || GTIdArgs.count(ArgOp) || 1637 getCallIfRegularCall( 1638 *ArgOp, &OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num])) 1639 continue; 1640 } 1641 return false; 1642 } 1643 return true; 1644 }; 1645 1646 // Helper to identify uses of a GTId as GTId arguments. 1647 auto AddUserArgs = [&](Value >Id) { 1648 for (Use &U : GTId.uses()) 1649 if (CallInst *CI = dyn_cast<CallInst>(U.getUser())) 1650 if (CI->isArgOperand(&U)) 1651 if (Function *Callee = CI->getCalledFunction()) 1652 if (CallArgOpIsGTId(*Callee, U.getOperandNo(), *CI)) 1653 GTIdArgs.insert(Callee->getArg(U.getOperandNo())); 1654 }; 1655 1656 // The argument users of __kmpc_global_thread_num calls are GTIds. 1657 OMPInformationCache::RuntimeFunctionInfo &GlobThreadNumRFI = 1658 OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num]; 1659 1660 GlobThreadNumRFI.foreachUse(SCC, [&](Use &U, Function &F) { 1661 if (CallInst *CI = getCallIfRegularCall(U, &GlobThreadNumRFI)) 1662 AddUserArgs(*CI); 1663 return false; 1664 }); 1665 1666 // Transitively search for more arguments by looking at the users of the 1667 // ones we know already. During the search the GTIdArgs vector is extended 1668 // so we cannot cache the size nor can we use a range based for. 1669 for (unsigned u = 0; u < GTIdArgs.size(); ++u) 1670 AddUserArgs(*GTIdArgs[u]); 1671 } 1672 1673 /// Kernel (=GPU) optimizations and utility functions 1674 /// 1675 ///{{ 1676 1677 /// Check if \p F is a kernel, hence entry point for target offloading. 1678 bool isKernel(Function &F) { return OMPInfoCache.Kernels.count(&F); } 1679 1680 /// Cache to remember the unique kernel for a function. 1681 DenseMap<Function *, Optional<Kernel>> UniqueKernelMap; 1682 1683 /// Find the unique kernel that will execute \p F, if any. 1684 Kernel getUniqueKernelFor(Function &F); 1685 1686 /// Find the unique kernel that will execute \p I, if any. 1687 Kernel getUniqueKernelFor(Instruction &I) { 1688 return getUniqueKernelFor(*I.getFunction()); 1689 } 1690 1691 /// Rewrite the device (=GPU) code state machine create in non-SPMD mode in 1692 /// the cases we can avoid taking the address of a function. 1693 bool rewriteDeviceCodeStateMachine(); 1694 1695 /// 1696 ///}} 1697 1698 /// Emit a remark generically 1699 /// 1700 /// This template function can be used to generically emit a remark. The 1701 /// RemarkKind should be one of the following: 1702 /// - OptimizationRemark to indicate a successful optimization attempt 1703 /// - OptimizationRemarkMissed to report a failed optimization attempt 1704 /// - OptimizationRemarkAnalysis to provide additional information about an 1705 /// optimization attempt 1706 /// 1707 /// The remark is built using a callback function provided by the caller that 1708 /// takes a RemarkKind as input and returns a RemarkKind. 1709 template <typename RemarkKind, typename RemarkCallBack> 1710 void emitRemark(Instruction *I, StringRef RemarkName, 1711 RemarkCallBack &&RemarkCB) const { 1712 Function *F = I->getParent()->getParent(); 1713 auto &ORE = OREGetter(F); 1714 1715 if (RemarkName.startswith("OMP")) 1716 ORE.emit([&]() { 1717 return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I)) 1718 << " [" << RemarkName << "]"; 1719 }); 1720 else 1721 ORE.emit( 1722 [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I)); }); 1723 } 1724 1725 /// Emit a remark on a function. 1726 template <typename RemarkKind, typename RemarkCallBack> 1727 void emitRemark(Function *F, StringRef RemarkName, 1728 RemarkCallBack &&RemarkCB) const { 1729 auto &ORE = OREGetter(F); 1730 1731 if (RemarkName.startswith("OMP")) 1732 ORE.emit([&]() { 1733 return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F)) 1734 << " [" << RemarkName << "]"; 1735 }); 1736 else 1737 ORE.emit( 1738 [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F)); }); 1739 } 1740 1741 /// The underlying module. 1742 Module &M; 1743 1744 /// The SCC we are operating on. 1745 SmallVectorImpl<Function *> &SCC; 1746 1747 /// Callback to update the call graph, the first argument is a removed call, 1748 /// the second an optional replacement call. 1749 CallGraphUpdater &CGUpdater; 1750 1751 /// Callback to get an OptimizationRemarkEmitter from a Function * 1752 OptimizationRemarkGetter OREGetter; 1753 1754 /// OpenMP-specific information cache. Also Used for Attributor runs. 1755 OMPInformationCache &OMPInfoCache; 1756 1757 /// Attributor instance. 1758 Attributor &A; 1759 1760 /// Helper function to run Attributor on SCC. 1761 bool runAttributor(bool IsModulePass) { 1762 if (SCC.empty()) 1763 return false; 1764 1765 registerAAs(IsModulePass); 1766 1767 ChangeStatus Changed = A.run(); 1768 1769 LLVM_DEBUG(dbgs() << "[Attributor] Done with " << SCC.size() 1770 << " functions, result: " << Changed << ".\n"); 1771 1772 return Changed == ChangeStatus::CHANGED; 1773 } 1774 1775 /// Populate the Attributor with abstract attribute opportunities in the 1776 /// function. 1777 void registerAAs(bool IsModulePass); 1778 }; 1779 1780 Kernel OpenMPOpt::getUniqueKernelFor(Function &F) { 1781 if (!OMPInfoCache.ModuleSlice.count(&F)) 1782 return nullptr; 1783 1784 // Use a scope to keep the lifetime of the CachedKernel short. 1785 { 1786 Optional<Kernel> &CachedKernel = UniqueKernelMap[&F]; 1787 if (CachedKernel) 1788 return *CachedKernel; 1789 1790 // TODO: We should use an AA to create an (optimistic and callback 1791 // call-aware) call graph. For now we stick to simple patterns that 1792 // are less powerful, basically the worst fixpoint. 1793 if (isKernel(F)) { 1794 CachedKernel = Kernel(&F); 1795 return *CachedKernel; 1796 } 1797 1798 CachedKernel = nullptr; 1799 if (!F.hasLocalLinkage()) { 1800 1801 // See https://openmp.llvm.org/remarks/OptimizationRemarks.html 1802 auto Remark = [&](OptimizationRemarkAnalysis ORA) { 1803 return ORA << "Potentially unknown OpenMP target region caller."; 1804 }; 1805 emitRemark<OptimizationRemarkAnalysis>(&F, "OMP100", Remark); 1806 1807 return nullptr; 1808 } 1809 } 1810 1811 auto GetUniqueKernelForUse = [&](const Use &U) -> Kernel { 1812 if (auto *Cmp = dyn_cast<ICmpInst>(U.getUser())) { 1813 // Allow use in equality comparisons. 1814 if (Cmp->isEquality()) 1815 return getUniqueKernelFor(*Cmp); 1816 return nullptr; 1817 } 1818 if (auto *CB = dyn_cast<CallBase>(U.getUser())) { 1819 // Allow direct calls. 1820 if (CB->isCallee(&U)) 1821 return getUniqueKernelFor(*CB); 1822 1823 OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI = 1824 OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51]; 1825 // Allow the use in __kmpc_parallel_51 calls. 1826 if (OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI)) 1827 return getUniqueKernelFor(*CB); 1828 return nullptr; 1829 } 1830 // Disallow every other use. 1831 return nullptr; 1832 }; 1833 1834 // TODO: In the future we want to track more than just a unique kernel. 1835 SmallPtrSet<Kernel, 2> PotentialKernels; 1836 OMPInformationCache::foreachUse(F, [&](const Use &U) { 1837 PotentialKernels.insert(GetUniqueKernelForUse(U)); 1838 }); 1839 1840 Kernel K = nullptr; 1841 if (PotentialKernels.size() == 1) 1842 K = *PotentialKernels.begin(); 1843 1844 // Cache the result. 1845 UniqueKernelMap[&F] = K; 1846 1847 return K; 1848 } 1849 1850 bool OpenMPOpt::rewriteDeviceCodeStateMachine() { 1851 OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI = 1852 OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51]; 1853 1854 bool Changed = false; 1855 if (!KernelParallelRFI) 1856 return Changed; 1857 1858 for (Function *F : SCC) { 1859 1860 // Check if the function is a use in a __kmpc_parallel_51 call at 1861 // all. 1862 bool UnknownUse = false; 1863 bool KernelParallelUse = false; 1864 unsigned NumDirectCalls = 0; 1865 1866 SmallVector<Use *, 2> ToBeReplacedStateMachineUses; 1867 OMPInformationCache::foreachUse(*F, [&](Use &U) { 1868 if (auto *CB = dyn_cast<CallBase>(U.getUser())) 1869 if (CB->isCallee(&U)) { 1870 ++NumDirectCalls; 1871 return; 1872 } 1873 1874 if (isa<ICmpInst>(U.getUser())) { 1875 ToBeReplacedStateMachineUses.push_back(&U); 1876 return; 1877 } 1878 1879 // Find wrapper functions that represent parallel kernels. 1880 CallInst *CI = 1881 OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI); 1882 const unsigned int WrapperFunctionArgNo = 6; 1883 if (!KernelParallelUse && CI && 1884 CI->getArgOperandNo(&U) == WrapperFunctionArgNo) { 1885 KernelParallelUse = true; 1886 ToBeReplacedStateMachineUses.push_back(&U); 1887 return; 1888 } 1889 UnknownUse = true; 1890 }); 1891 1892 // Do not emit a remark if we haven't seen a __kmpc_parallel_51 1893 // use. 1894 if (!KernelParallelUse) 1895 continue; 1896 1897 // If this ever hits, we should investigate. 1898 // TODO: Checking the number of uses is not a necessary restriction and 1899 // should be lifted. 1900 if (UnknownUse || NumDirectCalls != 1 || 1901 ToBeReplacedStateMachineUses.size() > 2) { 1902 auto Remark = [&](OptimizationRemarkAnalysis ORA) { 1903 return ORA << "Parallel region is used in " 1904 << (UnknownUse ? "unknown" : "unexpected") 1905 << " ways. Will not attempt to rewrite the state machine."; 1906 }; 1907 emitRemark<OptimizationRemarkAnalysis>(F, "OMP101", Remark); 1908 continue; 1909 } 1910 1911 // Even if we have __kmpc_parallel_51 calls, we (for now) give 1912 // up if the function is not called from a unique kernel. 1913 Kernel K = getUniqueKernelFor(*F); 1914 if (!K) { 1915 auto Remark = [&](OptimizationRemarkAnalysis ORA) { 1916 return ORA << "Parallel region is not called from a unique kernel. " 1917 "Will not attempt to rewrite the state machine."; 1918 }; 1919 emitRemark<OptimizationRemarkAnalysis>(F, "OMP102", Remark); 1920 continue; 1921 } 1922 1923 // We now know F is a parallel body function called only from the kernel K. 1924 // We also identified the state machine uses in which we replace the 1925 // function pointer by a new global symbol for identification purposes. This 1926 // ensures only direct calls to the function are left. 1927 1928 Module &M = *F->getParent(); 1929 Type *Int8Ty = Type::getInt8Ty(M.getContext()); 1930 1931 auto *ID = new GlobalVariable( 1932 M, Int8Ty, /* isConstant */ true, GlobalValue::PrivateLinkage, 1933 UndefValue::get(Int8Ty), F->getName() + ".ID"); 1934 1935 for (Use *U : ToBeReplacedStateMachineUses) 1936 U->set(ConstantExpr::getBitCast(ID, U->get()->getType())); 1937 1938 ++NumOpenMPParallelRegionsReplacedInGPUStateMachine; 1939 1940 Changed = true; 1941 } 1942 1943 return Changed; 1944 } 1945 1946 /// Abstract Attribute for tracking ICV values. 1947 struct AAICVTracker : public StateWrapper<BooleanState, AbstractAttribute> { 1948 using Base = StateWrapper<BooleanState, AbstractAttribute>; 1949 AAICVTracker(const IRPosition &IRP, Attributor &A) : Base(IRP) {} 1950 1951 void initialize(Attributor &A) override { 1952 Function *F = getAnchorScope(); 1953 if (!F || !A.isFunctionIPOAmendable(*F)) 1954 indicatePessimisticFixpoint(); 1955 } 1956 1957 /// Returns true if value is assumed to be tracked. 1958 bool isAssumedTracked() const { return getAssumed(); } 1959 1960 /// Returns true if value is known to be tracked. 1961 bool isKnownTracked() const { return getAssumed(); } 1962 1963 /// Create an abstract attribute biew for the position \p IRP. 1964 static AAICVTracker &createForPosition(const IRPosition &IRP, Attributor &A); 1965 1966 /// Return the value with which \p I can be replaced for specific \p ICV. 1967 virtual Optional<Value *> getReplacementValue(InternalControlVar ICV, 1968 const Instruction *I, 1969 Attributor &A) const { 1970 return None; 1971 } 1972 1973 /// Return an assumed unique ICV value if a single candidate is found. If 1974 /// there cannot be one, return a nullptr. If it is not clear yet, return the 1975 /// Optional::NoneType. 1976 virtual Optional<Value *> 1977 getUniqueReplacementValue(InternalControlVar ICV) const = 0; 1978 1979 // Currently only nthreads is being tracked. 1980 // this array will only grow with time. 1981 InternalControlVar TrackableICVs[1] = {ICV_nthreads}; 1982 1983 /// See AbstractAttribute::getName() 1984 const std::string getName() const override { return "AAICVTracker"; } 1985 1986 /// See AbstractAttribute::getIdAddr() 1987 const char *getIdAddr() const override { return &ID; } 1988 1989 /// This function should return true if the type of the \p AA is AAICVTracker 1990 static bool classof(const AbstractAttribute *AA) { 1991 return (AA->getIdAddr() == &ID); 1992 } 1993 1994 static const char ID; 1995 }; 1996 1997 struct AAICVTrackerFunction : public AAICVTracker { 1998 AAICVTrackerFunction(const IRPosition &IRP, Attributor &A) 1999 : AAICVTracker(IRP, A) {} 2000 2001 // FIXME: come up with better string. 2002 const std::string getAsStr() const override { return "ICVTrackerFunction"; } 2003 2004 // FIXME: come up with some stats. 2005 void trackStatistics() const override {} 2006 2007 /// We don't manifest anything for this AA. 2008 ChangeStatus manifest(Attributor &A) override { 2009 return ChangeStatus::UNCHANGED; 2010 } 2011 2012 // Map of ICV to their values at specific program point. 2013 EnumeratedArray<DenseMap<Instruction *, Value *>, InternalControlVar, 2014 InternalControlVar::ICV___last> 2015 ICVReplacementValuesMap; 2016 2017 ChangeStatus updateImpl(Attributor &A) override { 2018 ChangeStatus HasChanged = ChangeStatus::UNCHANGED; 2019 2020 Function *F = getAnchorScope(); 2021 2022 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); 2023 2024 for (InternalControlVar ICV : TrackableICVs) { 2025 auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter]; 2026 2027 auto &ValuesMap = ICVReplacementValuesMap[ICV]; 2028 auto TrackValues = [&](Use &U, Function &) { 2029 CallInst *CI = OpenMPOpt::getCallIfRegularCall(U); 2030 if (!CI) 2031 return false; 2032 2033 // FIXME: handle setters with more that 1 arguments. 2034 /// Track new value. 2035 if (ValuesMap.insert(std::make_pair(CI, CI->getArgOperand(0))).second) 2036 HasChanged = ChangeStatus::CHANGED; 2037 2038 return false; 2039 }; 2040 2041 auto CallCheck = [&](Instruction &I) { 2042 Optional<Value *> ReplVal = getValueForCall(A, &I, ICV); 2043 if (ReplVal.hasValue() && 2044 ValuesMap.insert(std::make_pair(&I, *ReplVal)).second) 2045 HasChanged = ChangeStatus::CHANGED; 2046 2047 return true; 2048 }; 2049 2050 // Track all changes of an ICV. 2051 SetterRFI.foreachUse(TrackValues, F); 2052 2053 bool UsedAssumedInformation = false; 2054 A.checkForAllInstructions(CallCheck, *this, {Instruction::Call}, 2055 UsedAssumedInformation, 2056 /* CheckBBLivenessOnly */ true); 2057 2058 /// TODO: Figure out a way to avoid adding entry in 2059 /// ICVReplacementValuesMap 2060 Instruction *Entry = &F->getEntryBlock().front(); 2061 if (HasChanged == ChangeStatus::CHANGED && !ValuesMap.count(Entry)) 2062 ValuesMap.insert(std::make_pair(Entry, nullptr)); 2063 } 2064 2065 return HasChanged; 2066 } 2067 2068 /// Hepler to check if \p I is a call and get the value for it if it is 2069 /// unique. 2070 Optional<Value *> getValueForCall(Attributor &A, const Instruction *I, 2071 InternalControlVar &ICV) const { 2072 2073 const auto *CB = dyn_cast<CallBase>(I); 2074 if (!CB || CB->hasFnAttr("no_openmp") || 2075 CB->hasFnAttr("no_openmp_routines")) 2076 return None; 2077 2078 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); 2079 auto &GetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Getter]; 2080 auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter]; 2081 Function *CalledFunction = CB->getCalledFunction(); 2082 2083 // Indirect call, assume ICV changes. 2084 if (CalledFunction == nullptr) 2085 return nullptr; 2086 if (CalledFunction == GetterRFI.Declaration) 2087 return None; 2088 if (CalledFunction == SetterRFI.Declaration) { 2089 if (ICVReplacementValuesMap[ICV].count(I)) 2090 return ICVReplacementValuesMap[ICV].lookup(I); 2091 2092 return nullptr; 2093 } 2094 2095 // Since we don't know, assume it changes the ICV. 2096 if (CalledFunction->isDeclaration()) 2097 return nullptr; 2098 2099 const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>( 2100 *this, IRPosition::callsite_returned(*CB), DepClassTy::REQUIRED); 2101 2102 if (ICVTrackingAA.isAssumedTracked()) 2103 return ICVTrackingAA.getUniqueReplacementValue(ICV); 2104 2105 // If we don't know, assume it changes. 2106 return nullptr; 2107 } 2108 2109 // We don't check unique value for a function, so return None. 2110 Optional<Value *> 2111 getUniqueReplacementValue(InternalControlVar ICV) const override { 2112 return None; 2113 } 2114 2115 /// Return the value with which \p I can be replaced for specific \p ICV. 2116 Optional<Value *> getReplacementValue(InternalControlVar ICV, 2117 const Instruction *I, 2118 Attributor &A) const override { 2119 const auto &ValuesMap = ICVReplacementValuesMap[ICV]; 2120 if (ValuesMap.count(I)) 2121 return ValuesMap.lookup(I); 2122 2123 SmallVector<const Instruction *, 16> Worklist; 2124 SmallPtrSet<const Instruction *, 16> Visited; 2125 Worklist.push_back(I); 2126 2127 Optional<Value *> ReplVal; 2128 2129 while (!Worklist.empty()) { 2130 const Instruction *CurrInst = Worklist.pop_back_val(); 2131 if (!Visited.insert(CurrInst).second) 2132 continue; 2133 2134 const BasicBlock *CurrBB = CurrInst->getParent(); 2135 2136 // Go up and look for all potential setters/calls that might change the 2137 // ICV. 2138 while ((CurrInst = CurrInst->getPrevNode())) { 2139 if (ValuesMap.count(CurrInst)) { 2140 Optional<Value *> NewReplVal = ValuesMap.lookup(CurrInst); 2141 // Unknown value, track new. 2142 if (!ReplVal.hasValue()) { 2143 ReplVal = NewReplVal; 2144 break; 2145 } 2146 2147 // If we found a new value, we can't know the icv value anymore. 2148 if (NewReplVal.hasValue()) 2149 if (ReplVal != NewReplVal) 2150 return nullptr; 2151 2152 break; 2153 } 2154 2155 Optional<Value *> NewReplVal = getValueForCall(A, CurrInst, ICV); 2156 if (!NewReplVal.hasValue()) 2157 continue; 2158 2159 // Unknown value, track new. 2160 if (!ReplVal.hasValue()) { 2161 ReplVal = NewReplVal; 2162 break; 2163 } 2164 2165 // if (NewReplVal.hasValue()) 2166 // We found a new value, we can't know the icv value anymore. 2167 if (ReplVal != NewReplVal) 2168 return nullptr; 2169 } 2170 2171 // If we are in the same BB and we have a value, we are done. 2172 if (CurrBB == I->getParent() && ReplVal.hasValue()) 2173 return ReplVal; 2174 2175 // Go through all predecessors and add terminators for analysis. 2176 for (const BasicBlock *Pred : predecessors(CurrBB)) 2177 if (const Instruction *Terminator = Pred->getTerminator()) 2178 Worklist.push_back(Terminator); 2179 } 2180 2181 return ReplVal; 2182 } 2183 }; 2184 2185 struct AAICVTrackerFunctionReturned : AAICVTracker { 2186 AAICVTrackerFunctionReturned(const IRPosition &IRP, Attributor &A) 2187 : AAICVTracker(IRP, A) {} 2188 2189 // FIXME: come up with better string. 2190 const std::string getAsStr() const override { 2191 return "ICVTrackerFunctionReturned"; 2192 } 2193 2194 // FIXME: come up with some stats. 2195 void trackStatistics() const override {} 2196 2197 /// We don't manifest anything for this AA. 2198 ChangeStatus manifest(Attributor &A) override { 2199 return ChangeStatus::UNCHANGED; 2200 } 2201 2202 // Map of ICV to their values at specific program point. 2203 EnumeratedArray<Optional<Value *>, InternalControlVar, 2204 InternalControlVar::ICV___last> 2205 ICVReplacementValuesMap; 2206 2207 /// Return the value with which \p I can be replaced for specific \p ICV. 2208 Optional<Value *> 2209 getUniqueReplacementValue(InternalControlVar ICV) const override { 2210 return ICVReplacementValuesMap[ICV]; 2211 } 2212 2213 ChangeStatus updateImpl(Attributor &A) override { 2214 ChangeStatus Changed = ChangeStatus::UNCHANGED; 2215 const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>( 2216 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED); 2217 2218 if (!ICVTrackingAA.isAssumedTracked()) 2219 return indicatePessimisticFixpoint(); 2220 2221 for (InternalControlVar ICV : TrackableICVs) { 2222 Optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV]; 2223 Optional<Value *> UniqueICVValue; 2224 2225 auto CheckReturnInst = [&](Instruction &I) { 2226 Optional<Value *> NewReplVal = 2227 ICVTrackingAA.getReplacementValue(ICV, &I, A); 2228 2229 // If we found a second ICV value there is no unique returned value. 2230 if (UniqueICVValue.hasValue() && UniqueICVValue != NewReplVal) 2231 return false; 2232 2233 UniqueICVValue = NewReplVal; 2234 2235 return true; 2236 }; 2237 2238 bool UsedAssumedInformation = false; 2239 if (!A.checkForAllInstructions(CheckReturnInst, *this, {Instruction::Ret}, 2240 UsedAssumedInformation, 2241 /* CheckBBLivenessOnly */ true)) 2242 UniqueICVValue = nullptr; 2243 2244 if (UniqueICVValue == ReplVal) 2245 continue; 2246 2247 ReplVal = UniqueICVValue; 2248 Changed = ChangeStatus::CHANGED; 2249 } 2250 2251 return Changed; 2252 } 2253 }; 2254 2255 struct AAICVTrackerCallSite : AAICVTracker { 2256 AAICVTrackerCallSite(const IRPosition &IRP, Attributor &A) 2257 : AAICVTracker(IRP, A) {} 2258 2259 void initialize(Attributor &A) override { 2260 Function *F = getAnchorScope(); 2261 if (!F || !A.isFunctionIPOAmendable(*F)) 2262 indicatePessimisticFixpoint(); 2263 2264 // We only initialize this AA for getters, so we need to know which ICV it 2265 // gets. 2266 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); 2267 for (InternalControlVar ICV : TrackableICVs) { 2268 auto ICVInfo = OMPInfoCache.ICVs[ICV]; 2269 auto &Getter = OMPInfoCache.RFIs[ICVInfo.Getter]; 2270 if (Getter.Declaration == getAssociatedFunction()) { 2271 AssociatedICV = ICVInfo.Kind; 2272 return; 2273 } 2274 } 2275 2276 /// Unknown ICV. 2277 indicatePessimisticFixpoint(); 2278 } 2279 2280 ChangeStatus manifest(Attributor &A) override { 2281 if (!ReplVal.hasValue() || !ReplVal.getValue()) 2282 return ChangeStatus::UNCHANGED; 2283 2284 A.changeValueAfterManifest(*getCtxI(), **ReplVal); 2285 A.deleteAfterManifest(*getCtxI()); 2286 2287 return ChangeStatus::CHANGED; 2288 } 2289 2290 // FIXME: come up with better string. 2291 const std::string getAsStr() const override { return "ICVTrackerCallSite"; } 2292 2293 // FIXME: come up with some stats. 2294 void trackStatistics() const override {} 2295 2296 InternalControlVar AssociatedICV; 2297 Optional<Value *> ReplVal; 2298 2299 ChangeStatus updateImpl(Attributor &A) override { 2300 const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>( 2301 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED); 2302 2303 // We don't have any information, so we assume it changes the ICV. 2304 if (!ICVTrackingAA.isAssumedTracked()) 2305 return indicatePessimisticFixpoint(); 2306 2307 Optional<Value *> NewReplVal = 2308 ICVTrackingAA.getReplacementValue(AssociatedICV, getCtxI(), A); 2309 2310 if (ReplVal == NewReplVal) 2311 return ChangeStatus::UNCHANGED; 2312 2313 ReplVal = NewReplVal; 2314 return ChangeStatus::CHANGED; 2315 } 2316 2317 // Return the value with which associated value can be replaced for specific 2318 // \p ICV. 2319 Optional<Value *> 2320 getUniqueReplacementValue(InternalControlVar ICV) const override { 2321 return ReplVal; 2322 } 2323 }; 2324 2325 struct AAICVTrackerCallSiteReturned : AAICVTracker { 2326 AAICVTrackerCallSiteReturned(const IRPosition &IRP, Attributor &A) 2327 : AAICVTracker(IRP, A) {} 2328 2329 // FIXME: come up with better string. 2330 const std::string getAsStr() const override { 2331 return "ICVTrackerCallSiteReturned"; 2332 } 2333 2334 // FIXME: come up with some stats. 2335 void trackStatistics() const override {} 2336 2337 /// We don't manifest anything for this AA. 2338 ChangeStatus manifest(Attributor &A) override { 2339 return ChangeStatus::UNCHANGED; 2340 } 2341 2342 // Map of ICV to their values at specific program point. 2343 EnumeratedArray<Optional<Value *>, InternalControlVar, 2344 InternalControlVar::ICV___last> 2345 ICVReplacementValuesMap; 2346 2347 /// Return the value with which associated value can be replaced for specific 2348 /// \p ICV. 2349 Optional<Value *> 2350 getUniqueReplacementValue(InternalControlVar ICV) const override { 2351 return ICVReplacementValuesMap[ICV]; 2352 } 2353 2354 ChangeStatus updateImpl(Attributor &A) override { 2355 ChangeStatus Changed = ChangeStatus::UNCHANGED; 2356 const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>( 2357 *this, IRPosition::returned(*getAssociatedFunction()), 2358 DepClassTy::REQUIRED); 2359 2360 // We don't have any information, so we assume it changes the ICV. 2361 if (!ICVTrackingAA.isAssumedTracked()) 2362 return indicatePessimisticFixpoint(); 2363 2364 for (InternalControlVar ICV : TrackableICVs) { 2365 Optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV]; 2366 Optional<Value *> NewReplVal = 2367 ICVTrackingAA.getUniqueReplacementValue(ICV); 2368 2369 if (ReplVal == NewReplVal) 2370 continue; 2371 2372 ReplVal = NewReplVal; 2373 Changed = ChangeStatus::CHANGED; 2374 } 2375 return Changed; 2376 } 2377 }; 2378 2379 struct AAExecutionDomainFunction : public AAExecutionDomain { 2380 AAExecutionDomainFunction(const IRPosition &IRP, Attributor &A) 2381 : AAExecutionDomain(IRP, A) {} 2382 2383 const std::string getAsStr() const override { 2384 return "[AAExecutionDomain] " + std::to_string(SingleThreadedBBs.size()) + 2385 "/" + std::to_string(NumBBs) + " BBs thread 0 only."; 2386 } 2387 2388 /// See AbstractAttribute::trackStatistics(). 2389 void trackStatistics() const override {} 2390 2391 void initialize(Attributor &A) override { 2392 Function *F = getAnchorScope(); 2393 for (const auto &BB : *F) 2394 SingleThreadedBBs.insert(&BB); 2395 NumBBs = SingleThreadedBBs.size(); 2396 } 2397 2398 ChangeStatus manifest(Attributor &A) override { 2399 LLVM_DEBUG({ 2400 for (const BasicBlock *BB : SingleThreadedBBs) 2401 dbgs() << TAG << " Basic block @" << getAnchorScope()->getName() << " " 2402 << BB->getName() << " is executed by a single thread.\n"; 2403 }); 2404 return ChangeStatus::UNCHANGED; 2405 } 2406 2407 ChangeStatus updateImpl(Attributor &A) override; 2408 2409 /// Check if an instruction is executed by a single thread. 2410 bool isExecutedByInitialThreadOnly(const Instruction &I) const override { 2411 return isExecutedByInitialThreadOnly(*I.getParent()); 2412 } 2413 2414 bool isExecutedByInitialThreadOnly(const BasicBlock &BB) const override { 2415 return isValidState() && SingleThreadedBBs.contains(&BB); 2416 } 2417 2418 /// Set of basic blocks that are executed by a single thread. 2419 DenseSet<const BasicBlock *> SingleThreadedBBs; 2420 2421 /// Total number of basic blocks in this function. 2422 long unsigned NumBBs; 2423 }; 2424 2425 ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) { 2426 Function *F = getAnchorScope(); 2427 ReversePostOrderTraversal<Function *> RPOT(F); 2428 auto NumSingleThreadedBBs = SingleThreadedBBs.size(); 2429 2430 bool AllCallSitesKnown; 2431 auto PredForCallSite = [&](AbstractCallSite ACS) { 2432 const auto &ExecutionDomainAA = A.getAAFor<AAExecutionDomain>( 2433 *this, IRPosition::function(*ACS.getInstruction()->getFunction()), 2434 DepClassTy::REQUIRED); 2435 return ACS.isDirectCall() && 2436 ExecutionDomainAA.isExecutedByInitialThreadOnly( 2437 *ACS.getInstruction()); 2438 }; 2439 2440 if (!A.checkForAllCallSites(PredForCallSite, *this, 2441 /* RequiresAllCallSites */ true, 2442 AllCallSitesKnown)) 2443 SingleThreadedBBs.erase(&F->getEntryBlock()); 2444 2445 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); 2446 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init]; 2447 2448 // Check if the edge into the successor block compares the __kmpc_target_init 2449 // result with -1. If we are in non-SPMD-mode that signals only the main 2450 // thread will execute the edge. 2451 auto IsInitialThreadOnly = [&](BranchInst *Edge, BasicBlock *SuccessorBB) { 2452 if (!Edge || !Edge->isConditional()) 2453 return false; 2454 if (Edge->getSuccessor(0) != SuccessorBB) 2455 return false; 2456 2457 auto *Cmp = dyn_cast<CmpInst>(Edge->getCondition()); 2458 if (!Cmp || !Cmp->isTrueWhenEqual() || !Cmp->isEquality()) 2459 return false; 2460 2461 ConstantInt *C = dyn_cast<ConstantInt>(Cmp->getOperand(1)); 2462 if (!C) 2463 return false; 2464 2465 // Match: -1 == __kmpc_target_init (for non-SPMD kernels only!) 2466 if (C->isAllOnesValue()) { 2467 auto *CB = dyn_cast<CallBase>(Cmp->getOperand(0)); 2468 CB = CB ? OpenMPOpt::getCallIfRegularCall(*CB, &RFI) : nullptr; 2469 if (!CB) 2470 return false; 2471 const int InitIsSPMDArgNo = 1; 2472 auto *IsSPMDModeCI = 2473 dyn_cast<ConstantInt>(CB->getOperand(InitIsSPMDArgNo)); 2474 return IsSPMDModeCI && IsSPMDModeCI->isZero(); 2475 } 2476 2477 return false; 2478 }; 2479 2480 // Merge all the predecessor states into the current basic block. A basic 2481 // block is executed by a single thread if all of its predecessors are. 2482 auto MergePredecessorStates = [&](BasicBlock *BB) { 2483 if (pred_begin(BB) == pred_end(BB)) 2484 return SingleThreadedBBs.contains(BB); 2485 2486 bool IsInitialThread = true; 2487 for (auto PredBB = pred_begin(BB), PredEndBB = pred_end(BB); 2488 PredBB != PredEndBB; ++PredBB) { 2489 if (!IsInitialThreadOnly(dyn_cast<BranchInst>((*PredBB)->getTerminator()), 2490 BB)) 2491 IsInitialThread &= SingleThreadedBBs.contains(*PredBB); 2492 } 2493 2494 return IsInitialThread; 2495 }; 2496 2497 for (auto *BB : RPOT) { 2498 if (!MergePredecessorStates(BB)) 2499 SingleThreadedBBs.erase(BB); 2500 } 2501 2502 return (NumSingleThreadedBBs == SingleThreadedBBs.size()) 2503 ? ChangeStatus::UNCHANGED 2504 : ChangeStatus::CHANGED; 2505 } 2506 2507 /// Try to replace memory allocation calls called by a single thread with a 2508 /// static buffer of shared memory. 2509 struct AAHeapToShared : public StateWrapper<BooleanState, AbstractAttribute> { 2510 using Base = StateWrapper<BooleanState, AbstractAttribute>; 2511 AAHeapToShared(const IRPosition &IRP, Attributor &A) : Base(IRP) {} 2512 2513 /// Create an abstract attribute view for the position \p IRP. 2514 static AAHeapToShared &createForPosition(const IRPosition &IRP, 2515 Attributor &A); 2516 2517 /// See AbstractAttribute::getName(). 2518 const std::string getName() const override { return "AAHeapToShared"; } 2519 2520 /// See AbstractAttribute::getIdAddr(). 2521 const char *getIdAddr() const override { return &ID; } 2522 2523 /// This function should return true if the type of the \p AA is 2524 /// AAHeapToShared. 2525 static bool classof(const AbstractAttribute *AA) { 2526 return (AA->getIdAddr() == &ID); 2527 } 2528 2529 /// Unique ID (due to the unique address) 2530 static const char ID; 2531 }; 2532 2533 struct AAHeapToSharedFunction : public AAHeapToShared { 2534 AAHeapToSharedFunction(const IRPosition &IRP, Attributor &A) 2535 : AAHeapToShared(IRP, A) {} 2536 2537 const std::string getAsStr() const override { 2538 return "[AAHeapToShared] " + std::to_string(MallocCalls.size()) + 2539 " malloc calls eligible."; 2540 } 2541 2542 /// See AbstractAttribute::trackStatistics(). 2543 void trackStatistics() const override {} 2544 2545 void initialize(Attributor &A) override { 2546 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); 2547 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared]; 2548 2549 for (User *U : RFI.Declaration->users()) 2550 if (CallBase *CB = dyn_cast<CallBase>(U)) 2551 MallocCalls.insert(CB); 2552 } 2553 2554 ChangeStatus manifest(Attributor &A) override { 2555 if (MallocCalls.empty()) 2556 return ChangeStatus::UNCHANGED; 2557 2558 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); 2559 auto &FreeCall = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared]; 2560 2561 Function *F = getAnchorScope(); 2562 auto *HS = A.lookupAAFor<AAHeapToStack>(IRPosition::function(*F), this, 2563 DepClassTy::OPTIONAL); 2564 2565 ChangeStatus Changed = ChangeStatus::UNCHANGED; 2566 for (CallBase *CB : MallocCalls) { 2567 // Skip replacing this if HeapToStack has already claimed it. 2568 if (HS && HS->isAssumedHeapToStack(*CB)) 2569 continue; 2570 2571 // Find the unique free call to remove it. 2572 SmallVector<CallBase *, 4> FreeCalls; 2573 for (auto *U : CB->users()) { 2574 CallBase *C = dyn_cast<CallBase>(U); 2575 if (C && C->getCalledFunction() == FreeCall.Declaration) 2576 FreeCalls.push_back(C); 2577 } 2578 if (FreeCalls.size() != 1) 2579 continue; 2580 2581 ConstantInt *AllocSize = dyn_cast<ConstantInt>(CB->getArgOperand(0)); 2582 2583 LLVM_DEBUG(dbgs() << TAG << "Replace globalization call in " 2584 << CB->getCaller()->getName() << " with " 2585 << AllocSize->getZExtValue() 2586 << " bytes of shared memory\n"); 2587 2588 // Create a new shared memory buffer of the same size as the allocation 2589 // and replace all the uses of the original allocation with it. 2590 Module *M = CB->getModule(); 2591 Type *Int8Ty = Type::getInt8Ty(M->getContext()); 2592 Type *Int8ArrTy = ArrayType::get(Int8Ty, AllocSize->getZExtValue()); 2593 auto *SharedMem = new GlobalVariable( 2594 *M, Int8ArrTy, /* IsConstant */ false, GlobalValue::InternalLinkage, 2595 UndefValue::get(Int8ArrTy), CB->getName(), nullptr, 2596 GlobalValue::NotThreadLocal, 2597 static_cast<unsigned>(AddressSpace::Shared)); 2598 auto *NewBuffer = 2599 ConstantExpr::getPointerCast(SharedMem, Int8Ty->getPointerTo()); 2600 2601 auto Remark = [&](OptimizationRemark OR) { 2602 return OR << "Replaced globalized variable with " 2603 << ore::NV("SharedMemory", AllocSize->getZExtValue()) 2604 << ((AllocSize->getZExtValue() != 1) ? " bytes " : " byte ") 2605 << "of shared memory."; 2606 }; 2607 A.emitRemark<OptimizationRemark>(CB, "OMP111", Remark); 2608 2609 SharedMem->setAlignment(MaybeAlign(32)); 2610 2611 A.changeValueAfterManifest(*CB, *NewBuffer); 2612 A.deleteAfterManifest(*CB); 2613 A.deleteAfterManifest(*FreeCalls.front()); 2614 2615 NumBytesMovedToSharedMemory += AllocSize->getZExtValue(); 2616 Changed = ChangeStatus::CHANGED; 2617 } 2618 2619 return Changed; 2620 } 2621 2622 ChangeStatus updateImpl(Attributor &A) override { 2623 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); 2624 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared]; 2625 Function *F = getAnchorScope(); 2626 2627 auto NumMallocCalls = MallocCalls.size(); 2628 2629 // Only consider malloc calls executed by a single thread with a constant. 2630 for (User *U : RFI.Declaration->users()) { 2631 const auto &ED = A.getAAFor<AAExecutionDomain>( 2632 *this, IRPosition::function(*F), DepClassTy::REQUIRED); 2633 if (CallBase *CB = dyn_cast<CallBase>(U)) 2634 if (!dyn_cast<ConstantInt>(CB->getArgOperand(0)) || 2635 !ED.isExecutedByInitialThreadOnly(*CB)) 2636 MallocCalls.erase(CB); 2637 } 2638 2639 if (NumMallocCalls != MallocCalls.size()) 2640 return ChangeStatus::CHANGED; 2641 2642 return ChangeStatus::UNCHANGED; 2643 } 2644 2645 /// Collection of all malloc calls in a function. 2646 SmallPtrSet<CallBase *, 4> MallocCalls; 2647 }; 2648 2649 struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> { 2650 using Base = StateWrapper<KernelInfoState, AbstractAttribute>; 2651 AAKernelInfo(const IRPosition &IRP, Attributor &A) : Base(IRP) {} 2652 2653 /// Statistics are tracked as part of manifest for now. 2654 void trackStatistics() const override {} 2655 2656 /// See AbstractAttribute::getAsStr() 2657 const std::string getAsStr() const override { 2658 if (!isValidState()) 2659 return "<invalid>"; 2660 return std::string(SPMDCompatibilityTracker.isAssumed() ? "SPMD" 2661 : "generic") + 2662 std::string(SPMDCompatibilityTracker.isAtFixpoint() ? " [FIX]" 2663 : "") + 2664 std::string(" #PRs: ") + 2665 std::to_string(ReachedKnownParallelRegions.size()) + 2666 ", #Unknown PRs: " + 2667 std::to_string(ReachedUnknownParallelRegions.size()); 2668 } 2669 2670 /// Create an abstract attribute biew for the position \p IRP. 2671 static AAKernelInfo &createForPosition(const IRPosition &IRP, Attributor &A); 2672 2673 /// See AbstractAttribute::getName() 2674 const std::string getName() const override { return "AAKernelInfo"; } 2675 2676 /// See AbstractAttribute::getIdAddr() 2677 const char *getIdAddr() const override { return &ID; } 2678 2679 /// This function should return true if the type of the \p AA is AAKernelInfo 2680 static bool classof(const AbstractAttribute *AA) { 2681 return (AA->getIdAddr() == &ID); 2682 } 2683 2684 static const char ID; 2685 }; 2686 2687 /// The function kernel info abstract attribute, basically, what can we say 2688 /// about a function with regards to the KernelInfoState. 2689 struct AAKernelInfoFunction : AAKernelInfo { 2690 AAKernelInfoFunction(const IRPosition &IRP, Attributor &A) 2691 : AAKernelInfo(IRP, A) {} 2692 2693 /// See AbstractAttribute::initialize(...). 2694 void initialize(Attributor &A) override { 2695 // This is a high-level transform that might change the constant arguments 2696 // of the init and dinit calls. We need to tell the Attributor about this 2697 // to avoid other parts using the current constant value for simpliication. 2698 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); 2699 2700 Function *Fn = getAnchorScope(); 2701 if (!OMPInfoCache.Kernels.count(Fn)) 2702 return; 2703 2704 // Add itself to the reaching kernel and set IsKernelEntry. 2705 ReachingKernelEntries.insert(Fn); 2706 IsKernelEntry = true; 2707 2708 OMPInformationCache::RuntimeFunctionInfo &InitRFI = 2709 OMPInfoCache.RFIs[OMPRTL___kmpc_target_init]; 2710 OMPInformationCache::RuntimeFunctionInfo &DeinitRFI = 2711 OMPInfoCache.RFIs[OMPRTL___kmpc_target_deinit]; 2712 2713 // For kernels we perform more initialization work, first we find the init 2714 // and deinit calls. 2715 auto StoreCallBase = [](Use &U, 2716 OMPInformationCache::RuntimeFunctionInfo &RFI, 2717 CallBase *&Storage) { 2718 CallBase *CB = OpenMPOpt::getCallIfRegularCall(U, &RFI); 2719 assert(CB && 2720 "Unexpected use of __kmpc_target_init or __kmpc_target_deinit!"); 2721 assert(!Storage && 2722 "Multiple uses of __kmpc_target_init or __kmpc_target_deinit!"); 2723 Storage = CB; 2724 return false; 2725 }; 2726 InitRFI.foreachUse( 2727 [&](Use &U, Function &) { 2728 StoreCallBase(U, InitRFI, KernelInitCB); 2729 return false; 2730 }, 2731 Fn); 2732 DeinitRFI.foreachUse( 2733 [&](Use &U, Function &) { 2734 StoreCallBase(U, DeinitRFI, KernelDeinitCB); 2735 return false; 2736 }, 2737 Fn); 2738 2739 assert((KernelInitCB && KernelDeinitCB) && 2740 "Kernel without __kmpc_target_init or __kmpc_target_deinit!"); 2741 2742 // For kernels we might need to initialize/finalize the IsSPMD state and 2743 // we need to register a simplification callback so that the Attributor 2744 // knows the constant arguments to __kmpc_target_init and 2745 // __kmpc_target_deinit might actually change. 2746 2747 Attributor::SimplifictionCallbackTy StateMachineSimplifyCB = 2748 [&](const IRPosition &IRP, const AbstractAttribute *AA, 2749 bool &UsedAssumedInformation) -> Optional<Value *> { 2750 // IRP represents the "use generic state machine" argument of an 2751 // __kmpc_target_init call. We will answer this one with the internal 2752 // state. As long as we are not in an invalid state, we will create a 2753 // custom state machine so the value should be a `i1 false`. If we are 2754 // in an invalid state, we won't change the value that is in the IR. 2755 if (!isValidState()) 2756 return nullptr; 2757 if (AA) 2758 A.recordDependence(*this, *AA, DepClassTy::OPTIONAL); 2759 UsedAssumedInformation = !isAtFixpoint(); 2760 auto *FalseVal = 2761 ConstantInt::getBool(IRP.getAnchorValue().getContext(), 0); 2762 return FalseVal; 2763 }; 2764 2765 Attributor::SimplifictionCallbackTy IsSPMDModeSimplifyCB = 2766 [&](const IRPosition &IRP, const AbstractAttribute *AA, 2767 bool &UsedAssumedInformation) -> Optional<Value *> { 2768 // IRP represents the "SPMDCompatibilityTracker" argument of an 2769 // __kmpc_target_init or 2770 // __kmpc_target_deinit call. We will answer this one with the internal 2771 // state. 2772 if (!SPMDCompatibilityTracker.isValidState()) 2773 return nullptr; 2774 if (!SPMDCompatibilityTracker.isAtFixpoint()) { 2775 if (AA) 2776 A.recordDependence(*this, *AA, DepClassTy::OPTIONAL); 2777 UsedAssumedInformation = true; 2778 } else { 2779 UsedAssumedInformation = false; 2780 } 2781 auto *Val = ConstantInt::getBool(IRP.getAnchorValue().getContext(), 2782 SPMDCompatibilityTracker.isAssumed()); 2783 return Val; 2784 }; 2785 2786 Attributor::SimplifictionCallbackTy IsGenericModeSimplifyCB = 2787 [&](const IRPosition &IRP, const AbstractAttribute *AA, 2788 bool &UsedAssumedInformation) -> Optional<Value *> { 2789 // IRP represents the "RequiresFullRuntime" argument of an 2790 // __kmpc_target_init or __kmpc_target_deinit call. We will answer this 2791 // one with the internal state of the SPMDCompatibilityTracker, so if 2792 // generic then true, if SPMD then false. 2793 if (!SPMDCompatibilityTracker.isValidState()) 2794 return nullptr; 2795 if (!SPMDCompatibilityTracker.isAtFixpoint()) { 2796 if (AA) 2797 A.recordDependence(*this, *AA, DepClassTy::OPTIONAL); 2798 UsedAssumedInformation = true; 2799 } else { 2800 UsedAssumedInformation = false; 2801 } 2802 auto *Val = ConstantInt::getBool(IRP.getAnchorValue().getContext(), 2803 !SPMDCompatibilityTracker.isAssumed()); 2804 return Val; 2805 }; 2806 2807 constexpr const int InitIsSPMDArgNo = 1; 2808 constexpr const int DeinitIsSPMDArgNo = 1; 2809 constexpr const int InitUseStateMachineArgNo = 2; 2810 constexpr const int InitRequiresFullRuntimeArgNo = 3; 2811 constexpr const int DeinitRequiresFullRuntimeArgNo = 2; 2812 A.registerSimplificationCallback( 2813 IRPosition::callsite_argument(*KernelInitCB, InitUseStateMachineArgNo), 2814 StateMachineSimplifyCB); 2815 A.registerSimplificationCallback( 2816 IRPosition::callsite_argument(*KernelInitCB, InitIsSPMDArgNo), 2817 IsSPMDModeSimplifyCB); 2818 A.registerSimplificationCallback( 2819 IRPosition::callsite_argument(*KernelDeinitCB, DeinitIsSPMDArgNo), 2820 IsSPMDModeSimplifyCB); 2821 A.registerSimplificationCallback( 2822 IRPosition::callsite_argument(*KernelInitCB, 2823 InitRequiresFullRuntimeArgNo), 2824 IsGenericModeSimplifyCB); 2825 A.registerSimplificationCallback( 2826 IRPosition::callsite_argument(*KernelDeinitCB, 2827 DeinitRequiresFullRuntimeArgNo), 2828 IsGenericModeSimplifyCB); 2829 2830 // Check if we know we are in SPMD-mode already. 2831 ConstantInt *IsSPMDArg = 2832 dyn_cast<ConstantInt>(KernelInitCB->getArgOperand(InitIsSPMDArgNo)); 2833 if (IsSPMDArg && !IsSPMDArg->isZero()) 2834 SPMDCompatibilityTracker.indicateOptimisticFixpoint(); 2835 } 2836 2837 /// Modify the IR based on the KernelInfoState as the fixpoint iteration is 2838 /// finished now. 2839 ChangeStatus manifest(Attributor &A) override { 2840 // If we are not looking at a kernel with __kmpc_target_init and 2841 // __kmpc_target_deinit call we cannot actually manifest the information. 2842 if (!KernelInitCB || !KernelDeinitCB) 2843 return ChangeStatus::UNCHANGED; 2844 2845 // Known SPMD-mode kernels need no manifest changes. 2846 if (SPMDCompatibilityTracker.isKnown()) 2847 return ChangeStatus::UNCHANGED; 2848 2849 // If we can we change the execution mode to SPMD-mode otherwise we build a 2850 // custom state machine. 2851 if (!changeToSPMDMode(A)) 2852 buildCustomStateMachine(A); 2853 2854 return ChangeStatus::CHANGED; 2855 } 2856 2857 bool changeToSPMDMode(Attributor &A) { 2858 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); 2859 2860 if (!SPMDCompatibilityTracker.isAssumed()) { 2861 for (Instruction *NonCompatibleI : SPMDCompatibilityTracker) { 2862 if (!NonCompatibleI) 2863 continue; 2864 2865 // Skip diagnostics on calls to known OpenMP runtime functions for now. 2866 if (auto *CB = dyn_cast<CallBase>(NonCompatibleI)) 2867 if (OMPInfoCache.RTLFunctions.contains(CB->getCalledFunction())) 2868 continue; 2869 2870 auto Remark = [&](OptimizationRemarkAnalysis ORA) { 2871 ORA << "Value has potential side effects preventing SPMD-mode " 2872 "execution"; 2873 if (isa<CallBase>(NonCompatibleI)) { 2874 ORA << ". Add `__attribute__((assume(\"ompx_spmd_amenable\")))` to " 2875 "the called function to override"; 2876 } 2877 return ORA << "."; 2878 }; 2879 A.emitRemark<OptimizationRemarkAnalysis>(NonCompatibleI, "OMP121", 2880 Remark); 2881 2882 LLVM_DEBUG(dbgs() << TAG << "SPMD-incompatible side-effect: " 2883 << *NonCompatibleI << "\n"); 2884 } 2885 2886 return false; 2887 } 2888 2889 // Adjust the global exec mode flag that tells the runtime what mode this 2890 // kernel is executed in. 2891 Function *Kernel = getAnchorScope(); 2892 GlobalVariable *ExecMode = Kernel->getParent()->getGlobalVariable( 2893 (Kernel->getName() + "_exec_mode").str()); 2894 assert(ExecMode && "Kernel without exec mode?"); 2895 assert(ExecMode->getInitializer() && 2896 ExecMode->getInitializer()->isOneValue() && 2897 "Initially non-SPMD kernel has SPMD exec mode!"); 2898 2899 // Set the global exec mode flag to indicate SPMD-Generic mode. 2900 constexpr int SPMDGeneric = 2; 2901 if (!ExecMode->getInitializer()->isZeroValue()) 2902 ExecMode->setInitializer( 2903 ConstantInt::get(ExecMode->getInitializer()->getType(), SPMDGeneric)); 2904 2905 // Next rewrite the init and deinit calls to indicate we use SPMD-mode now. 2906 const int InitIsSPMDArgNo = 1; 2907 const int DeinitIsSPMDArgNo = 1; 2908 const int InitUseStateMachineArgNo = 2; 2909 const int InitRequiresFullRuntimeArgNo = 3; 2910 const int DeinitRequiresFullRuntimeArgNo = 2; 2911 2912 auto &Ctx = getAnchorValue().getContext(); 2913 A.changeUseAfterManifest(KernelInitCB->getArgOperandUse(InitIsSPMDArgNo), 2914 *ConstantInt::getBool(Ctx, 1)); 2915 A.changeUseAfterManifest( 2916 KernelInitCB->getArgOperandUse(InitUseStateMachineArgNo), 2917 *ConstantInt::getBool(Ctx, 0)); 2918 A.changeUseAfterManifest( 2919 KernelDeinitCB->getArgOperandUse(DeinitIsSPMDArgNo), 2920 *ConstantInt::getBool(Ctx, 1)); 2921 A.changeUseAfterManifest( 2922 KernelInitCB->getArgOperandUse(InitRequiresFullRuntimeArgNo), 2923 *ConstantInt::getBool(Ctx, 0)); 2924 A.changeUseAfterManifest( 2925 KernelDeinitCB->getArgOperandUse(DeinitRequiresFullRuntimeArgNo), 2926 *ConstantInt::getBool(Ctx, 0)); 2927 2928 ++NumOpenMPTargetRegionKernelsSPMD; 2929 2930 auto Remark = [&](OptimizationRemark OR) { 2931 return OR << "Transformed generic-mode kernel to SPMD-mode."; 2932 }; 2933 A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP120", Remark); 2934 return true; 2935 }; 2936 2937 ChangeStatus buildCustomStateMachine(Attributor &A) { 2938 assert(ReachedKnownParallelRegions.isValidState() && 2939 "Custom state machine with invalid parallel region states?"); 2940 2941 const int InitIsSPMDArgNo = 1; 2942 const int InitUseStateMachineArgNo = 2; 2943 2944 // Check if the current configuration is non-SPMD and generic state machine. 2945 // If we already have SPMD mode or a custom state machine we do not need to 2946 // go any further. If it is anything but a constant something is weird and 2947 // we give up. 2948 ConstantInt *UseStateMachine = dyn_cast<ConstantInt>( 2949 KernelInitCB->getArgOperand(InitUseStateMachineArgNo)); 2950 ConstantInt *IsSPMD = 2951 dyn_cast<ConstantInt>(KernelInitCB->getArgOperand(InitIsSPMDArgNo)); 2952 2953 // If we are stuck with generic mode, try to create a custom device (=GPU) 2954 // state machine which is specialized for the parallel regions that are 2955 // reachable by the kernel. 2956 if (!UseStateMachine || UseStateMachine->isZero() || !IsSPMD || 2957 !IsSPMD->isZero()) 2958 return ChangeStatus::UNCHANGED; 2959 2960 // If not SPMD mode, indicate we use a custom state machine now. 2961 auto &Ctx = getAnchorValue().getContext(); 2962 auto *FalseVal = ConstantInt::getBool(Ctx, 0); 2963 A.changeUseAfterManifest( 2964 KernelInitCB->getArgOperandUse(InitUseStateMachineArgNo), *FalseVal); 2965 2966 // If we don't actually need a state machine we are done here. This can 2967 // happen if there simply are no parallel regions. In the resulting kernel 2968 // all worker threads will simply exit right away, leaving the main thread 2969 // to do the work alone. 2970 if (ReachedKnownParallelRegions.empty() && 2971 ReachedUnknownParallelRegions.empty()) { 2972 ++NumOpenMPTargetRegionKernelsWithoutStateMachine; 2973 2974 auto Remark = [&](OptimizationRemark OR) { 2975 return OR << "Removing unused state machine from generic-mode kernel."; 2976 }; 2977 A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP130", Remark); 2978 2979 return ChangeStatus::CHANGED; 2980 } 2981 2982 // Keep track in the statistics of our new shiny custom state machine. 2983 if (ReachedUnknownParallelRegions.empty()) { 2984 ++NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback; 2985 2986 auto Remark = [&](OptimizationRemark OR) { 2987 return OR << "Rewriting generic-mode kernel with a customized state " 2988 "machine."; 2989 }; 2990 A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP131", Remark); 2991 } else { 2992 ++NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback; 2993 2994 auto Remark = [&](OptimizationRemarkAnalysis OR) { 2995 return OR << "Generic-mode kernel is executed with a customized state " 2996 "machine that requires a fallback."; 2997 }; 2998 A.emitRemark<OptimizationRemarkAnalysis>(KernelInitCB, "OMP132", Remark); 2999 3000 // Tell the user why we ended up with a fallback. 3001 for (CallBase *UnknownParallelRegionCB : ReachedUnknownParallelRegions) { 3002 if (!UnknownParallelRegionCB) 3003 continue; 3004 auto Remark = [&](OptimizationRemarkAnalysis ORA) { 3005 return ORA << "Call may contain unknown parallel regions. Use " 3006 << "`__attribute__((assume(\"omp_no_parallelism\")))` to " 3007 "override."; 3008 }; 3009 A.emitRemark<OptimizationRemarkAnalysis>(UnknownParallelRegionCB, 3010 "OMP133", Remark); 3011 } 3012 } 3013 3014 // Create all the blocks: 3015 // 3016 // InitCB = __kmpc_target_init(...) 3017 // bool IsWorker = InitCB >= 0; 3018 // if (IsWorker) { 3019 // SMBeginBB: __kmpc_barrier_simple_spmd(...); 3020 // void *WorkFn; 3021 // bool Active = __kmpc_kernel_parallel(&WorkFn); 3022 // if (!WorkFn) return; 3023 // SMIsActiveCheckBB: if (Active) { 3024 // SMIfCascadeCurrentBB: if (WorkFn == <ParFn0>) 3025 // ParFn0(...); 3026 // SMIfCascadeCurrentBB: else if (WorkFn == <ParFn1>) 3027 // ParFn1(...); 3028 // ... 3029 // SMIfCascadeCurrentBB: else 3030 // ((WorkFnTy*)WorkFn)(...); 3031 // SMEndParallelBB: __kmpc_kernel_end_parallel(...); 3032 // } 3033 // SMDoneBB: __kmpc_barrier_simple_spmd(...); 3034 // goto SMBeginBB; 3035 // } 3036 // UserCodeEntryBB: // user code 3037 // __kmpc_target_deinit(...) 3038 // 3039 Function *Kernel = getAssociatedFunction(); 3040 assert(Kernel && "Expected an associated function!"); 3041 3042 BasicBlock *InitBB = KernelInitCB->getParent(); 3043 BasicBlock *UserCodeEntryBB = InitBB->splitBasicBlock( 3044 KernelInitCB->getNextNode(), "thread.user_code.check"); 3045 BasicBlock *StateMachineBeginBB = BasicBlock::Create( 3046 Ctx, "worker_state_machine.begin", Kernel, UserCodeEntryBB); 3047 BasicBlock *StateMachineFinishedBB = BasicBlock::Create( 3048 Ctx, "worker_state_machine.finished", Kernel, UserCodeEntryBB); 3049 BasicBlock *StateMachineIsActiveCheckBB = BasicBlock::Create( 3050 Ctx, "worker_state_machine.is_active.check", Kernel, UserCodeEntryBB); 3051 BasicBlock *StateMachineIfCascadeCurrentBB = 3052 BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.check", 3053 Kernel, UserCodeEntryBB); 3054 BasicBlock *StateMachineEndParallelBB = 3055 BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.end", 3056 Kernel, UserCodeEntryBB); 3057 BasicBlock *StateMachineDoneBarrierBB = BasicBlock::Create( 3058 Ctx, "worker_state_machine.done.barrier", Kernel, UserCodeEntryBB); 3059 A.registerManifestAddedBasicBlock(*InitBB); 3060 A.registerManifestAddedBasicBlock(*UserCodeEntryBB); 3061 A.registerManifestAddedBasicBlock(*StateMachineBeginBB); 3062 A.registerManifestAddedBasicBlock(*StateMachineFinishedBB); 3063 A.registerManifestAddedBasicBlock(*StateMachineIsActiveCheckBB); 3064 A.registerManifestAddedBasicBlock(*StateMachineIfCascadeCurrentBB); 3065 A.registerManifestAddedBasicBlock(*StateMachineEndParallelBB); 3066 A.registerManifestAddedBasicBlock(*StateMachineDoneBarrierBB); 3067 3068 const DebugLoc &DLoc = KernelInitCB->getDebugLoc(); 3069 ReturnInst::Create(Ctx, StateMachineFinishedBB)->setDebugLoc(DLoc); 3070 3071 InitBB->getTerminator()->eraseFromParent(); 3072 Instruction *IsWorker = 3073 ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_NE, KernelInitCB, 3074 ConstantInt::get(KernelInitCB->getType(), -1), 3075 "thread.is_worker", InitBB); 3076 IsWorker->setDebugLoc(DLoc); 3077 BranchInst::Create(StateMachineBeginBB, UserCodeEntryBB, IsWorker, InitBB); 3078 3079 // Create local storage for the work function pointer. 3080 Type *VoidPtrTy = Type::getInt8PtrTy(Ctx); 3081 AllocaInst *WorkFnAI = new AllocaInst(VoidPtrTy, 0, "worker.work_fn.addr", 3082 &Kernel->getEntryBlock().front()); 3083 WorkFnAI->setDebugLoc(DLoc); 3084 3085 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); 3086 OMPInfoCache.OMPBuilder.updateToLocation( 3087 OpenMPIRBuilder::LocationDescription( 3088 IRBuilder<>::InsertPoint(StateMachineBeginBB, 3089 StateMachineBeginBB->end()), 3090 DLoc)); 3091 3092 Value *Ident = KernelInitCB->getArgOperand(0); 3093 Value *GTid = KernelInitCB; 3094 3095 Module &M = *Kernel->getParent(); 3096 FunctionCallee BarrierFn = 3097 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction( 3098 M, OMPRTL___kmpc_barrier_simple_spmd); 3099 CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineBeginBB) 3100 ->setDebugLoc(DLoc); 3101 3102 FunctionCallee KernelParallelFn = 3103 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction( 3104 M, OMPRTL___kmpc_kernel_parallel); 3105 Instruction *IsActiveWorker = CallInst::Create( 3106 KernelParallelFn, {WorkFnAI}, "worker.is_active", StateMachineBeginBB); 3107 IsActiveWorker->setDebugLoc(DLoc); 3108 Instruction *WorkFn = new LoadInst(VoidPtrTy, WorkFnAI, "worker.work_fn", 3109 StateMachineBeginBB); 3110 WorkFn->setDebugLoc(DLoc); 3111 3112 FunctionType *ParallelRegionFnTy = FunctionType::get( 3113 Type::getVoidTy(Ctx), {Type::getInt16Ty(Ctx), Type::getInt32Ty(Ctx)}, 3114 false); 3115 Value *WorkFnCast = BitCastInst::CreatePointerBitCastOrAddrSpaceCast( 3116 WorkFn, ParallelRegionFnTy->getPointerTo(), "worker.work_fn.addr_cast", 3117 StateMachineBeginBB); 3118 3119 Instruction *IsDone = 3120 ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFn, 3121 Constant::getNullValue(VoidPtrTy), "worker.is_done", 3122 StateMachineBeginBB); 3123 IsDone->setDebugLoc(DLoc); 3124 BranchInst::Create(StateMachineFinishedBB, StateMachineIsActiveCheckBB, 3125 IsDone, StateMachineBeginBB) 3126 ->setDebugLoc(DLoc); 3127 3128 BranchInst::Create(StateMachineIfCascadeCurrentBB, 3129 StateMachineDoneBarrierBB, IsActiveWorker, 3130 StateMachineIsActiveCheckBB) 3131 ->setDebugLoc(DLoc); 3132 3133 Value *ZeroArg = 3134 Constant::getNullValue(ParallelRegionFnTy->getParamType(0)); 3135 3136 // Now that we have most of the CFG skeleton it is time for the if-cascade 3137 // that checks the function pointer we got from the runtime against the 3138 // parallel regions we expect, if there are any. 3139 for (int i = 0, e = ReachedKnownParallelRegions.size(); i < e; ++i) { 3140 auto *ParallelRegion = ReachedKnownParallelRegions[i]; 3141 BasicBlock *PRExecuteBB = BasicBlock::Create( 3142 Ctx, "worker_state_machine.parallel_region.execute", Kernel, 3143 StateMachineEndParallelBB); 3144 CallInst::Create(ParallelRegion, {ZeroArg, GTid}, "", PRExecuteBB) 3145 ->setDebugLoc(DLoc); 3146 BranchInst::Create(StateMachineEndParallelBB, PRExecuteBB) 3147 ->setDebugLoc(DLoc); 3148 3149 BasicBlock *PRNextBB = 3150 BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.check", 3151 Kernel, StateMachineEndParallelBB); 3152 3153 // Check if we need to compare the pointer at all or if we can just 3154 // call the parallel region function. 3155 Value *IsPR; 3156 if (i + 1 < e || !ReachedUnknownParallelRegions.empty()) { 3157 Instruction *CmpI = ICmpInst::Create( 3158 ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFnCast, ParallelRegion, 3159 "worker.check_parallel_region", StateMachineIfCascadeCurrentBB); 3160 CmpI->setDebugLoc(DLoc); 3161 IsPR = CmpI; 3162 } else { 3163 IsPR = ConstantInt::getTrue(Ctx); 3164 } 3165 3166 BranchInst::Create(PRExecuteBB, PRNextBB, IsPR, 3167 StateMachineIfCascadeCurrentBB) 3168 ->setDebugLoc(DLoc); 3169 StateMachineIfCascadeCurrentBB = PRNextBB; 3170 } 3171 3172 // At the end of the if-cascade we place the indirect function pointer call 3173 // in case we might need it, that is if there can be parallel regions we 3174 // have not handled in the if-cascade above. 3175 if (!ReachedUnknownParallelRegions.empty()) { 3176 StateMachineIfCascadeCurrentBB->setName( 3177 "worker_state_machine.parallel_region.fallback.execute"); 3178 CallInst::Create(ParallelRegionFnTy, WorkFnCast, {ZeroArg, GTid}, "", 3179 StateMachineIfCascadeCurrentBB) 3180 ->setDebugLoc(DLoc); 3181 } 3182 BranchInst::Create(StateMachineEndParallelBB, 3183 StateMachineIfCascadeCurrentBB) 3184 ->setDebugLoc(DLoc); 3185 3186 CallInst::Create(OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction( 3187 M, OMPRTL___kmpc_kernel_end_parallel), 3188 {}, "", StateMachineEndParallelBB) 3189 ->setDebugLoc(DLoc); 3190 BranchInst::Create(StateMachineDoneBarrierBB, StateMachineEndParallelBB) 3191 ->setDebugLoc(DLoc); 3192 3193 CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineDoneBarrierBB) 3194 ->setDebugLoc(DLoc); 3195 BranchInst::Create(StateMachineBeginBB, StateMachineDoneBarrierBB) 3196 ->setDebugLoc(DLoc); 3197 3198 return ChangeStatus::CHANGED; 3199 } 3200 3201 /// Fixpoint iteration update function. Will be called every time a dependence 3202 /// changed its state (and in the beginning). 3203 ChangeStatus updateImpl(Attributor &A) override { 3204 KernelInfoState StateBefore = getState(); 3205 3206 // Callback to check a read/write instruction. 3207 auto CheckRWInst = [&](Instruction &I) { 3208 // We handle calls later. 3209 if (isa<CallBase>(I)) 3210 return true; 3211 // We only care about write effects. 3212 if (!I.mayWriteToMemory()) 3213 return true; 3214 if (auto *SI = dyn_cast<StoreInst>(&I)) { 3215 SmallVector<const Value *> Objects; 3216 getUnderlyingObjects(SI->getPointerOperand(), Objects); 3217 if (llvm::all_of(Objects, 3218 [](const Value *Obj) { return isa<AllocaInst>(Obj); })) 3219 return true; 3220 } 3221 // For now we give up on everything but stores. 3222 SPMDCompatibilityTracker.insert(&I); 3223 return true; 3224 }; 3225 3226 bool UsedAssumedInformationInCheckRWInst = false; 3227 if (!SPMDCompatibilityTracker.isAtFixpoint()) 3228 if (!A.checkForAllReadWriteInstructions( 3229 CheckRWInst, *this, UsedAssumedInformationInCheckRWInst)) 3230 SPMDCompatibilityTracker.indicatePessimisticFixpoint(); 3231 3232 if (!IsKernelEntry) 3233 updateReachingKernelEntries(A); 3234 3235 // Callback to check a call instruction. 3236 bool AllSPMDStatesWereFixed = true; 3237 auto CheckCallInst = [&](Instruction &I) { 3238 auto &CB = cast<CallBase>(I); 3239 auto &CBAA = A.getAAFor<AAKernelInfo>( 3240 *this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL); 3241 getState() ^= CBAA.getState(); 3242 AllSPMDStatesWereFixed &= CBAA.SPMDCompatibilityTracker.isAtFixpoint(); 3243 return true; 3244 }; 3245 3246 bool UsedAssumedInformationInCheckCallInst = false; 3247 if (!A.checkForAllCallLikeInstructions( 3248 CheckCallInst, *this, UsedAssumedInformationInCheckCallInst)) 3249 return indicatePessimisticFixpoint(); 3250 3251 // If we haven't used any assumed information for the SPMD state we can fix 3252 // it. 3253 if (!UsedAssumedInformationInCheckRWInst && 3254 !UsedAssumedInformationInCheckCallInst && AllSPMDStatesWereFixed) 3255 SPMDCompatibilityTracker.indicateOptimisticFixpoint(); 3256 3257 return StateBefore == getState() ? ChangeStatus::UNCHANGED 3258 : ChangeStatus::CHANGED; 3259 } 3260 3261 private: 3262 /// Update info regarding reaching kernels. 3263 void updateReachingKernelEntries(Attributor &A) { 3264 auto PredCallSite = [&](AbstractCallSite ACS) { 3265 Function *Caller = ACS.getInstruction()->getFunction(); 3266 3267 assert(Caller && "Caller is nullptr"); 3268 3269 auto &CAA = A.getOrCreateAAFor<AAKernelInfo>( 3270 IRPosition::function(*Caller), this, DepClassTy::REQUIRED); 3271 if (CAA.ReachingKernelEntries.isValidState()) { 3272 ReachingKernelEntries ^= CAA.ReachingKernelEntries; 3273 return true; 3274 } 3275 3276 // We lost track of the caller of the associated function, any kernel 3277 // could reach now. 3278 ReachingKernelEntries.indicatePessimisticFixpoint(); 3279 3280 return true; 3281 }; 3282 3283 bool AllCallSitesKnown; 3284 if (!A.checkForAllCallSites(PredCallSite, *this, 3285 true /* RequireAllCallSites */, 3286 AllCallSitesKnown)) 3287 ReachingKernelEntries.indicatePessimisticFixpoint(); 3288 } 3289 }; 3290 3291 /// The call site kernel info abstract attribute, basically, what can we say 3292 /// about a call site with regards to the KernelInfoState. For now this simply 3293 /// forwards the information from the callee. 3294 struct AAKernelInfoCallSite : AAKernelInfo { 3295 AAKernelInfoCallSite(const IRPosition &IRP, Attributor &A) 3296 : AAKernelInfo(IRP, A) {} 3297 3298 /// See AbstractAttribute::initialize(...). 3299 void initialize(Attributor &A) override { 3300 AAKernelInfo::initialize(A); 3301 3302 CallBase &CB = cast<CallBase>(getAssociatedValue()); 3303 Function *Callee = getAssociatedFunction(); 3304 3305 // Helper to lookup an assumption string. 3306 auto HasAssumption = [](Function *Fn, StringRef AssumptionStr) { 3307 return Fn && hasAssumption(*Fn, AssumptionStr); 3308 }; 3309 3310 // Check for SPMD-mode assumptions. 3311 if (HasAssumption(Callee, "ompx_spmd_amenable")) 3312 SPMDCompatibilityTracker.indicateOptimisticFixpoint(); 3313 3314 // First weed out calls we do not care about, that is readonly/readnone 3315 // calls, intrinsics, and "no_openmp" calls. Neither of these can reach a 3316 // parallel region or anything else we are looking for. 3317 if (!CB.mayWriteToMemory() || isa<IntrinsicInst>(CB)) { 3318 indicateOptimisticFixpoint(); 3319 return; 3320 } 3321 3322 // Next we check if we know the callee. If it is a known OpenMP function 3323 // we will handle them explicitly in the switch below. If it is not, we 3324 // will use an AAKernelInfo object on the callee to gather information and 3325 // merge that into the current state. The latter happens in the updateImpl. 3326 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); 3327 const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee); 3328 if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) { 3329 // Unknown caller or declarations are not analyzable, we give up. 3330 if (!Callee || !A.isFunctionIPOAmendable(*Callee)) { 3331 3332 // Unknown callees might contain parallel regions, except if they have 3333 // an appropriate assumption attached. 3334 if (!(HasAssumption(Callee, "omp_no_openmp") || 3335 HasAssumption(Callee, "omp_no_parallelism"))) 3336 ReachedUnknownParallelRegions.insert(&CB); 3337 3338 // If SPMDCompatibilityTracker is not fixed, we need to give up on the 3339 // idea we can run something unknown in SPMD-mode. 3340 if (!SPMDCompatibilityTracker.isAtFixpoint()) 3341 SPMDCompatibilityTracker.insert(&CB); 3342 3343 // We have updated the state for this unknown call properly, there won't 3344 // be any change so we indicate a fixpoint. 3345 indicateOptimisticFixpoint(); 3346 } 3347 // If the callee is known and can be used in IPO, we will update the state 3348 // based on the callee state in updateImpl. 3349 return; 3350 } 3351 3352 const unsigned int WrapperFunctionArgNo = 6; 3353 RuntimeFunction RF = It->getSecond(); 3354 switch (RF) { 3355 // All the functions we know are compatible with SPMD mode. 3356 case OMPRTL___kmpc_is_spmd_exec_mode: 3357 case OMPRTL___kmpc_for_static_fini: 3358 case OMPRTL___kmpc_global_thread_num: 3359 case OMPRTL___kmpc_single: 3360 case OMPRTL___kmpc_end_single: 3361 case OMPRTL___kmpc_master: 3362 case OMPRTL___kmpc_end_master: 3363 case OMPRTL___kmpc_barrier: 3364 break; 3365 case OMPRTL___kmpc_for_static_init_4: 3366 case OMPRTL___kmpc_for_static_init_4u: 3367 case OMPRTL___kmpc_for_static_init_8: 3368 case OMPRTL___kmpc_for_static_init_8u: { 3369 // Check the schedule and allow static schedule in SPMD mode. 3370 unsigned ScheduleArgOpNo = 2; 3371 auto *ScheduleTypeCI = 3372 dyn_cast<ConstantInt>(CB.getArgOperand(ScheduleArgOpNo)); 3373 unsigned ScheduleTypeVal = 3374 ScheduleTypeCI ? ScheduleTypeCI->getZExtValue() : 0; 3375 switch (OMPScheduleType(ScheduleTypeVal)) { 3376 case OMPScheduleType::Static: 3377 case OMPScheduleType::StaticChunked: 3378 case OMPScheduleType::Distribute: 3379 case OMPScheduleType::DistributeChunked: 3380 break; 3381 default: 3382 SPMDCompatibilityTracker.insert(&CB); 3383 break; 3384 }; 3385 } break; 3386 case OMPRTL___kmpc_target_init: 3387 KernelInitCB = &CB; 3388 break; 3389 case OMPRTL___kmpc_target_deinit: 3390 KernelDeinitCB = &CB; 3391 break; 3392 case OMPRTL___kmpc_parallel_51: 3393 if (auto *ParallelRegion = dyn_cast<Function>( 3394 CB.getArgOperand(WrapperFunctionArgNo)->stripPointerCasts())) { 3395 ReachedKnownParallelRegions.insert(ParallelRegion); 3396 break; 3397 } 3398 // The condition above should usually get the parallel region function 3399 // pointer and record it. In the off chance it doesn't we assume the 3400 // worst. 3401 ReachedUnknownParallelRegions.insert(&CB); 3402 break; 3403 case OMPRTL___kmpc_omp_task: 3404 // We do not look into tasks right now, just give up. 3405 SPMDCompatibilityTracker.insert(&CB); 3406 ReachedUnknownParallelRegions.insert(&CB); 3407 break; 3408 default: 3409 // Unknown OpenMP runtime calls cannot be executed in SPMD-mode, 3410 // generally. 3411 SPMDCompatibilityTracker.insert(&CB); 3412 break; 3413 } 3414 // All other OpenMP runtime calls will not reach parallel regions so they 3415 // can be safely ignored for now. Since it is a known OpenMP runtime call we 3416 // have now modeled all effects and there is no need for any update. 3417 indicateOptimisticFixpoint(); 3418 } 3419 3420 ChangeStatus updateImpl(Attributor &A) override { 3421 // TODO: Once we have call site specific value information we can provide 3422 // call site specific liveness information and then it makes 3423 // sense to specialize attributes for call sites arguments instead of 3424 // redirecting requests to the callee argument. 3425 Function *F = getAssociatedFunction(); 3426 const IRPosition &FnPos = IRPosition::function(*F); 3427 auto &FnAA = A.getAAFor<AAKernelInfo>(*this, FnPos, DepClassTy::REQUIRED); 3428 if (getState() == FnAA.getState()) 3429 return ChangeStatus::UNCHANGED; 3430 getState() = FnAA.getState(); 3431 return ChangeStatus::CHANGED; 3432 } 3433 }; 3434 3435 struct AAFoldRuntimeCall 3436 : public StateWrapper<BooleanState, AbstractAttribute> { 3437 using Base = StateWrapper<BooleanState, AbstractAttribute>; 3438 3439 AAFoldRuntimeCall(const IRPosition &IRP, Attributor &A) : Base(IRP) {} 3440 3441 /// Statistics are tracked as part of manifest for now. 3442 void trackStatistics() const override {} 3443 3444 /// Create an abstract attribute biew for the position \p IRP. 3445 static AAFoldRuntimeCall &createForPosition(const IRPosition &IRP, 3446 Attributor &A); 3447 3448 /// See AbstractAttribute::getName() 3449 const std::string getName() const override { return "AAFoldRuntimeCall"; } 3450 3451 /// See AbstractAttribute::getIdAddr() 3452 const char *getIdAddr() const override { return &ID; } 3453 3454 /// This function should return true if the type of the \p AA is 3455 /// AAFoldRuntimeCall 3456 static bool classof(const AbstractAttribute *AA) { 3457 return (AA->getIdAddr() == &ID); 3458 } 3459 3460 static const char ID; 3461 }; 3462 3463 struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall { 3464 AAFoldRuntimeCallCallSiteReturned(const IRPosition &IRP, Attributor &A) 3465 : AAFoldRuntimeCall(IRP, A) {} 3466 3467 /// See AbstractAttribute::getAsStr() 3468 const std::string getAsStr() const override { 3469 if (!isValidState()) 3470 return "<invalid>"; 3471 3472 std::string Str("simplified value: "); 3473 3474 if (!SimplifiedValue.hasValue()) 3475 return Str + std::string("none"); 3476 3477 if (!SimplifiedValue.getValue()) 3478 return Str + std::string("nullptr"); 3479 3480 if (ConstantInt *CI = dyn_cast<ConstantInt>(SimplifiedValue.getValue())) 3481 return Str + std::to_string(CI->getSExtValue()); 3482 3483 return Str + std::string("unknown"); 3484 } 3485 3486 void initialize(Attributor &A) override { 3487 Function *Callee = getAssociatedFunction(); 3488 3489 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); 3490 const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee); 3491 assert(It != OMPInfoCache.RuntimeFunctionIDMap.end() && 3492 "Expected a known OpenMP runtime function"); 3493 3494 RFKind = It->getSecond(); 3495 3496 CallBase &CB = cast<CallBase>(getAssociatedValue()); 3497 A.registerSimplificationCallback( 3498 IRPosition::callsite_returned(CB), 3499 [&](const IRPosition &IRP, const AbstractAttribute *AA, 3500 bool &UsedAssumedInformation) -> Optional<Value *> { 3501 assert((isValidState() || (SimplifiedValue.hasValue() && 3502 SimplifiedValue.getValue() == nullptr)) && 3503 "Unexpected invalid state!"); 3504 3505 if (!isAtFixpoint()) { 3506 UsedAssumedInformation = true; 3507 if (AA) 3508 A.recordDependence(*this, *AA, DepClassTy::OPTIONAL); 3509 } 3510 return SimplifiedValue; 3511 }); 3512 } 3513 3514 ChangeStatus updateImpl(Attributor &A) override { 3515 ChangeStatus Changed = ChangeStatus::UNCHANGED; 3516 3517 switch (RFKind) { 3518 case OMPRTL___kmpc_is_spmd_exec_mode: 3519 Changed |= foldIsSPMDExecMode(A); 3520 break; 3521 case OMPRTL___kmpc_is_generic_main_thread_id: 3522 Changed |= foldIsGenericMainThread(A); 3523 break; 3524 default: 3525 llvm_unreachable("Unhandled OpenMP runtime function!"); 3526 } 3527 3528 return Changed; 3529 } 3530 3531 ChangeStatus manifest(Attributor &A) override { 3532 ChangeStatus Changed = ChangeStatus::UNCHANGED; 3533 3534 if (SimplifiedValue.hasValue() && SimplifiedValue.getValue()) { 3535 Instruction &CB = *getCtxI(); 3536 A.changeValueAfterManifest(CB, **SimplifiedValue); 3537 A.deleteAfterManifest(CB); 3538 3539 LLVM_DEBUG(dbgs() << TAG << "Folding runtime call: " << CB << " with " 3540 << **SimplifiedValue << "\n"); 3541 3542 Changed = ChangeStatus::CHANGED; 3543 } 3544 3545 return Changed; 3546 } 3547 3548 ChangeStatus indicatePessimisticFixpoint() override { 3549 SimplifiedValue = nullptr; 3550 return AAFoldRuntimeCall::indicatePessimisticFixpoint(); 3551 } 3552 3553 private: 3554 /// Fold __kmpc_is_spmd_exec_mode into a constant if possible. 3555 ChangeStatus foldIsSPMDExecMode(Attributor &A) { 3556 Optional<Value *> SimplifiedValueBefore = SimplifiedValue; 3557 3558 unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0; 3559 unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0; 3560 auto &CallerKernelInfoAA = A.getAAFor<AAKernelInfo>( 3561 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED); 3562 3563 if (!CallerKernelInfoAA.ReachingKernelEntries.isValidState()) 3564 return indicatePessimisticFixpoint(); 3565 3566 for (Kernel K : CallerKernelInfoAA.ReachingKernelEntries) { 3567 auto &AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K), 3568 DepClassTy::REQUIRED); 3569 3570 if (!AA.isValidState()) { 3571 SimplifiedValue = nullptr; 3572 return indicatePessimisticFixpoint(); 3573 } 3574 3575 if (AA.SPMDCompatibilityTracker.isAssumed()) { 3576 if (AA.SPMDCompatibilityTracker.isAtFixpoint()) 3577 ++KnownSPMDCount; 3578 else 3579 ++AssumedSPMDCount; 3580 } else { 3581 if (AA.SPMDCompatibilityTracker.isAtFixpoint()) 3582 ++KnownNonSPMDCount; 3583 else 3584 ++AssumedNonSPMDCount; 3585 } 3586 } 3587 3588 if (KnownSPMDCount && KnownNonSPMDCount) 3589 return indicatePessimisticFixpoint(); 3590 3591 if (AssumedSPMDCount && AssumedNonSPMDCount) 3592 return indicatePessimisticFixpoint(); 3593 3594 auto &Ctx = getAnchorValue().getContext(); 3595 if (KnownSPMDCount || AssumedSPMDCount) { 3596 assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 && 3597 "Expected only SPMD kernels!"); 3598 // All reaching kernels are in SPMD mode. Update all function calls to 3599 // __kmpc_is_spmd_exec_mode to 1. 3600 SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), true); 3601 } else if (KnownNonSPMDCount || AssumedNonSPMDCount) { 3602 assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 && 3603 "Expected only non-SPMD kernels!"); 3604 // All reaching kernels are in non-SPMD mode. Update all function 3605 // calls to __kmpc_is_spmd_exec_mode to 0. 3606 SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), false); 3607 } else { 3608 // We have empty reaching kernels, therefore we cannot tell if the 3609 // associated call site can be folded. At this moment, SimplifiedValue 3610 // must be none. 3611 assert(!SimplifiedValue.hasValue() && "SimplifiedValue should be none"); 3612 } 3613 3614 return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED 3615 : ChangeStatus::CHANGED; 3616 } 3617 3618 /// Fold __kmpc_is_generic_main_thread_id into a constant if possible. 3619 ChangeStatus foldIsGenericMainThread(Attributor &A) { 3620 Optional<Value *> SimplifiedValueBefore = SimplifiedValue; 3621 3622 CallBase &CB = cast<CallBase>(getAssociatedValue()); 3623 Function *F = CB.getFunction(); 3624 const auto &ExecutionDomainAA = A.getAAFor<AAExecutionDomain>( 3625 *this, IRPosition::function(*F), DepClassTy::REQUIRED); 3626 3627 if (!ExecutionDomainAA.isValidState()) 3628 return indicatePessimisticFixpoint(); 3629 3630 auto &Ctx = getAnchorValue().getContext(); 3631 if (ExecutionDomainAA.isExecutedByInitialThreadOnly(CB)) 3632 SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), true); 3633 else 3634 return indicatePessimisticFixpoint(); 3635 3636 return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED 3637 : ChangeStatus::CHANGED; 3638 } 3639 3640 /// An optional value the associated value is assumed to fold to. That is, we 3641 /// assume the associated value (which is a call) can be replaced by this 3642 /// simplified value. 3643 Optional<Value *> SimplifiedValue; 3644 3645 /// The runtime function kind of the callee of the associated call site. 3646 RuntimeFunction RFKind; 3647 }; 3648 3649 } // namespace 3650 3651 void OpenMPOpt::registerAAs(bool IsModulePass) { 3652 if (SCC.empty()) 3653 3654 return; 3655 if (IsModulePass) { 3656 // Ensure we create the AAKernelInfo AAs first and without triggering an 3657 // update. This will make sure we register all value simplification 3658 // callbacks before any other AA has the chance to create an AAValueSimplify 3659 // or similar. 3660 for (Function *Kernel : OMPInfoCache.Kernels) 3661 A.getOrCreateAAFor<AAKernelInfo>( 3662 IRPosition::function(*Kernel), /* QueryingAA */ nullptr, 3663 DepClassTy::NONE, /* ForceUpdate */ false, 3664 /* UpdateAfterInit */ false); 3665 3666 auto &IsMainRFI = 3667 OMPInfoCache.RFIs[OMPRTL___kmpc_is_generic_main_thread_id]; 3668 IsMainRFI.foreachUse(SCC, [&](Use &U, Function &F) { 3669 CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &IsMainRFI); 3670 if (!CI) 3671 return false; 3672 A.getOrCreateAAFor<AAFoldRuntimeCall>( 3673 IRPosition::callsite_returned(*CI), /* QueryingAA */ nullptr, 3674 DepClassTy::NONE, /* ForceUpdate */ false, 3675 /* UpdateAfterInit */ false); 3676 return false; 3677 }); 3678 3679 auto &IsSPMDRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_is_spmd_exec_mode]; 3680 IsSPMDRFI.foreachUse(SCC, [&](Use &U, Function &) { 3681 CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &IsSPMDRFI); 3682 if (!CI) 3683 return false; 3684 A.getOrCreateAAFor<AAFoldRuntimeCall>( 3685 IRPosition::callsite_returned(*CI), /* QueryingAA */ nullptr, 3686 DepClassTy::NONE, /* ForceUpdate */ false, 3687 /* UpdateAfterInit */ false); 3688 return false; 3689 }); 3690 } 3691 3692 // Create CallSite AA for all Getters. 3693 for (int Idx = 0; Idx < OMPInfoCache.ICVs.size() - 1; ++Idx) { 3694 auto ICVInfo = OMPInfoCache.ICVs[static_cast<InternalControlVar>(Idx)]; 3695 3696 auto &GetterRFI = OMPInfoCache.RFIs[ICVInfo.Getter]; 3697 3698 auto CreateAA = [&](Use &U, Function &Caller) { 3699 CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &GetterRFI); 3700 if (!CI) 3701 return false; 3702 3703 auto &CB = cast<CallBase>(*CI); 3704 3705 IRPosition CBPos = IRPosition::callsite_function(CB); 3706 A.getOrCreateAAFor<AAICVTracker>(CBPos); 3707 return false; 3708 }; 3709 3710 GetterRFI.foreachUse(SCC, CreateAA); 3711 } 3712 auto &GlobalizationRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared]; 3713 auto CreateAA = [&](Use &U, Function &F) { 3714 A.getOrCreateAAFor<AAHeapToShared>(IRPosition::function(F)); 3715 return false; 3716 }; 3717 GlobalizationRFI.foreachUse(SCC, CreateAA); 3718 3719 // Create an ExecutionDomain AA for every function and a HeapToStack AA for 3720 // every function if there is a device kernel. 3721 for (auto *F : SCC) { 3722 if (!F->isDeclaration()) 3723 A.getOrCreateAAFor<AAExecutionDomain>(IRPosition::function(*F)); 3724 if (isOpenMPDevice(M)) 3725 A.getOrCreateAAFor<AAHeapToStack>(IRPosition::function(*F)); 3726 } 3727 } 3728 3729 const char AAICVTracker::ID = 0; 3730 const char AAKernelInfo::ID = 0; 3731 const char AAExecutionDomain::ID = 0; 3732 const char AAHeapToShared::ID = 0; 3733 const char AAFoldRuntimeCall::ID = 0; 3734 3735 AAICVTracker &AAICVTracker::createForPosition(const IRPosition &IRP, 3736 Attributor &A) { 3737 AAICVTracker *AA = nullptr; 3738 switch (IRP.getPositionKind()) { 3739 case IRPosition::IRP_INVALID: 3740 case IRPosition::IRP_FLOAT: 3741 case IRPosition::IRP_ARGUMENT: 3742 case IRPosition::IRP_CALL_SITE_ARGUMENT: 3743 llvm_unreachable("ICVTracker can only be created for function position!"); 3744 case IRPosition::IRP_RETURNED: 3745 AA = new (A.Allocator) AAICVTrackerFunctionReturned(IRP, A); 3746 break; 3747 case IRPosition::IRP_CALL_SITE_RETURNED: 3748 AA = new (A.Allocator) AAICVTrackerCallSiteReturned(IRP, A); 3749 break; 3750 case IRPosition::IRP_CALL_SITE: 3751 AA = new (A.Allocator) AAICVTrackerCallSite(IRP, A); 3752 break; 3753 case IRPosition::IRP_FUNCTION: 3754 AA = new (A.Allocator) AAICVTrackerFunction(IRP, A); 3755 break; 3756 } 3757 3758 return *AA; 3759 } 3760 3761 AAExecutionDomain &AAExecutionDomain::createForPosition(const IRPosition &IRP, 3762 Attributor &A) { 3763 AAExecutionDomainFunction *AA = nullptr; 3764 switch (IRP.getPositionKind()) { 3765 case IRPosition::IRP_INVALID: 3766 case IRPosition::IRP_FLOAT: 3767 case IRPosition::IRP_ARGUMENT: 3768 case IRPosition::IRP_CALL_SITE_ARGUMENT: 3769 case IRPosition::IRP_RETURNED: 3770 case IRPosition::IRP_CALL_SITE_RETURNED: 3771 case IRPosition::IRP_CALL_SITE: 3772 llvm_unreachable( 3773 "AAExecutionDomain can only be created for function position!"); 3774 case IRPosition::IRP_FUNCTION: 3775 AA = new (A.Allocator) AAExecutionDomainFunction(IRP, A); 3776 break; 3777 } 3778 3779 return *AA; 3780 } 3781 3782 AAHeapToShared &AAHeapToShared::createForPosition(const IRPosition &IRP, 3783 Attributor &A) { 3784 AAHeapToSharedFunction *AA = nullptr; 3785 switch (IRP.getPositionKind()) { 3786 case IRPosition::IRP_INVALID: 3787 case IRPosition::IRP_FLOAT: 3788 case IRPosition::IRP_ARGUMENT: 3789 case IRPosition::IRP_CALL_SITE_ARGUMENT: 3790 case IRPosition::IRP_RETURNED: 3791 case IRPosition::IRP_CALL_SITE_RETURNED: 3792 case IRPosition::IRP_CALL_SITE: 3793 llvm_unreachable( 3794 "AAHeapToShared can only be created for function position!"); 3795 case IRPosition::IRP_FUNCTION: 3796 AA = new (A.Allocator) AAHeapToSharedFunction(IRP, A); 3797 break; 3798 } 3799 3800 return *AA; 3801 } 3802 3803 AAKernelInfo &AAKernelInfo::createForPosition(const IRPosition &IRP, 3804 Attributor &A) { 3805 AAKernelInfo *AA = nullptr; 3806 switch (IRP.getPositionKind()) { 3807 case IRPosition::IRP_INVALID: 3808 case IRPosition::IRP_FLOAT: 3809 case IRPosition::IRP_ARGUMENT: 3810 case IRPosition::IRP_RETURNED: 3811 case IRPosition::IRP_CALL_SITE_RETURNED: 3812 case IRPosition::IRP_CALL_SITE_ARGUMENT: 3813 llvm_unreachable("KernelInfo can only be created for function position!"); 3814 case IRPosition::IRP_CALL_SITE: 3815 AA = new (A.Allocator) AAKernelInfoCallSite(IRP, A); 3816 break; 3817 case IRPosition::IRP_FUNCTION: 3818 AA = new (A.Allocator) AAKernelInfoFunction(IRP, A); 3819 break; 3820 } 3821 3822 return *AA; 3823 } 3824 3825 AAFoldRuntimeCall &AAFoldRuntimeCall::createForPosition(const IRPosition &IRP, 3826 Attributor &A) { 3827 AAFoldRuntimeCall *AA = nullptr; 3828 switch (IRP.getPositionKind()) { 3829 case IRPosition::IRP_INVALID: 3830 case IRPosition::IRP_FLOAT: 3831 case IRPosition::IRP_ARGUMENT: 3832 case IRPosition::IRP_RETURNED: 3833 case IRPosition::IRP_FUNCTION: 3834 case IRPosition::IRP_CALL_SITE: 3835 case IRPosition::IRP_CALL_SITE_ARGUMENT: 3836 llvm_unreachable("KernelInfo can only be created for call site position!"); 3837 case IRPosition::IRP_CALL_SITE_RETURNED: 3838 AA = new (A.Allocator) AAFoldRuntimeCallCallSiteReturned(IRP, A); 3839 break; 3840 } 3841 3842 return *AA; 3843 } 3844 3845 PreservedAnalyses OpenMPOptPass::run(Module &M, ModuleAnalysisManager &AM) { 3846 if (!containsOpenMP(M)) 3847 return PreservedAnalyses::all(); 3848 if (DisableOpenMPOptimizations) 3849 return PreservedAnalyses::all(); 3850 3851 FunctionAnalysisManager &FAM = 3852 AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); 3853 KernelSet Kernels = getDeviceKernels(M); 3854 3855 auto IsCalled = [&](Function &F) { 3856 if (Kernels.contains(&F)) 3857 return true; 3858 for (const User *U : F.users()) 3859 if (!isa<BlockAddress>(U)) 3860 return true; 3861 return false; 3862 }; 3863 3864 auto EmitRemark = [&](Function &F) { 3865 auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(F); 3866 ORE.emit([&]() { 3867 OptimizationRemarkAnalysis ORA(DEBUG_TYPE, "OMP140", &F); 3868 return ORA << "Could not internalize function. " 3869 << "Some optimizations may not be possible."; 3870 }); 3871 }; 3872 3873 // Create internal copies of each function if this is a kernel Module. This 3874 // allows iterprocedural passes to see every call edge. 3875 DenseSet<const Function *> InternalizedFuncs; 3876 if (isOpenMPDevice(M)) 3877 for (Function &F : M) 3878 if (!F.isDeclaration() && !Kernels.contains(&F) && IsCalled(F) && 3879 !DisableInternalization) { 3880 if (Attributor::internalizeFunction(F, /* Force */ true)) { 3881 InternalizedFuncs.insert(&F); 3882 } else if (!F.hasLocalLinkage() && !F.hasFnAttribute(Attribute::Cold)) { 3883 EmitRemark(F); 3884 } 3885 } 3886 3887 // Look at every function in the Module unless it was internalized. 3888 SmallVector<Function *, 16> SCC; 3889 for (Function &F : M) 3890 if (!F.isDeclaration() && !InternalizedFuncs.contains(&F)) 3891 SCC.push_back(&F); 3892 3893 if (SCC.empty()) 3894 return PreservedAnalyses::all(); 3895 3896 AnalysisGetter AG(FAM); 3897 3898 auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & { 3899 return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F); 3900 }; 3901 3902 BumpPtrAllocator Allocator; 3903 CallGraphUpdater CGUpdater; 3904 3905 SetVector<Function *> Functions(SCC.begin(), SCC.end()); 3906 OMPInformationCache InfoCache(M, AG, Allocator, /*CGSCC*/ Functions, Kernels); 3907 3908 unsigned MaxFixpointIterations = (isOpenMPDevice(M)) ? 128 : 32; 3909 Attributor A(Functions, InfoCache, CGUpdater, nullptr, true, false, 3910 MaxFixpointIterations, OREGetter, DEBUG_TYPE); 3911 3912 OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A); 3913 bool Changed = OMPOpt.run(true); 3914 if (Changed) 3915 return PreservedAnalyses::none(); 3916 3917 return PreservedAnalyses::all(); 3918 } 3919 3920 PreservedAnalyses OpenMPOptCGSCCPass::run(LazyCallGraph::SCC &C, 3921 CGSCCAnalysisManager &AM, 3922 LazyCallGraph &CG, 3923 CGSCCUpdateResult &UR) { 3924 if (!containsOpenMP(*C.begin()->getFunction().getParent())) 3925 return PreservedAnalyses::all(); 3926 if (DisableOpenMPOptimizations) 3927 return PreservedAnalyses::all(); 3928 3929 SmallVector<Function *, 16> SCC; 3930 // If there are kernels in the module, we have to run on all SCC's. 3931 for (LazyCallGraph::Node &N : C) { 3932 Function *Fn = &N.getFunction(); 3933 SCC.push_back(Fn); 3934 } 3935 3936 if (SCC.empty()) 3937 return PreservedAnalyses::all(); 3938 3939 Module &M = *C.begin()->getFunction().getParent(); 3940 3941 KernelSet Kernels = getDeviceKernels(M); 3942 3943 FunctionAnalysisManager &FAM = 3944 AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager(); 3945 3946 AnalysisGetter AG(FAM); 3947 3948 auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & { 3949 return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F); 3950 }; 3951 3952 BumpPtrAllocator Allocator; 3953 CallGraphUpdater CGUpdater; 3954 CGUpdater.initialize(CG, C, AM, UR); 3955 3956 SetVector<Function *> Functions(SCC.begin(), SCC.end()); 3957 OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG, Allocator, 3958 /*CGSCC*/ Functions, Kernels); 3959 3960 unsigned MaxFixpointIterations = (isOpenMPDevice(M)) ? 128 : 32; 3961 Attributor A(Functions, InfoCache, CGUpdater, nullptr, false, true, 3962 MaxFixpointIterations, OREGetter, DEBUG_TYPE); 3963 3964 OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A); 3965 bool Changed = OMPOpt.run(false); 3966 if (Changed) 3967 return PreservedAnalyses::none(); 3968 3969 return PreservedAnalyses::all(); 3970 } 3971 3972 namespace { 3973 3974 struct OpenMPOptCGSCCLegacyPass : public CallGraphSCCPass { 3975 CallGraphUpdater CGUpdater; 3976 static char ID; 3977 3978 OpenMPOptCGSCCLegacyPass() : CallGraphSCCPass(ID) { 3979 initializeOpenMPOptCGSCCLegacyPassPass(*PassRegistry::getPassRegistry()); 3980 } 3981 3982 void getAnalysisUsage(AnalysisUsage &AU) const override { 3983 CallGraphSCCPass::getAnalysisUsage(AU); 3984 } 3985 3986 bool runOnSCC(CallGraphSCC &CGSCC) override { 3987 if (!containsOpenMP(CGSCC.getCallGraph().getModule())) 3988 return false; 3989 if (DisableOpenMPOptimizations || skipSCC(CGSCC)) 3990 return false; 3991 3992 SmallVector<Function *, 16> SCC; 3993 // If there are kernels in the module, we have to run on all SCC's. 3994 for (CallGraphNode *CGN : CGSCC) { 3995 Function *Fn = CGN->getFunction(); 3996 if (!Fn || Fn->isDeclaration()) 3997 continue; 3998 SCC.push_back(Fn); 3999 } 4000 4001 if (SCC.empty()) 4002 return false; 4003 4004 Module &M = CGSCC.getCallGraph().getModule(); 4005 KernelSet Kernels = getDeviceKernels(M); 4006 4007 CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph(); 4008 CGUpdater.initialize(CG, CGSCC); 4009 4010 // Maintain a map of functions to avoid rebuilding the ORE 4011 DenseMap<Function *, std::unique_ptr<OptimizationRemarkEmitter>> OREMap; 4012 auto OREGetter = [&OREMap](Function *F) -> OptimizationRemarkEmitter & { 4013 std::unique_ptr<OptimizationRemarkEmitter> &ORE = OREMap[F]; 4014 if (!ORE) 4015 ORE = std::make_unique<OptimizationRemarkEmitter>(F); 4016 return *ORE; 4017 }; 4018 4019 AnalysisGetter AG; 4020 SetVector<Function *> Functions(SCC.begin(), SCC.end()); 4021 BumpPtrAllocator Allocator; 4022 OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG, 4023 Allocator, 4024 /*CGSCC*/ Functions, Kernels); 4025 4026 unsigned MaxFixpointIterations = (isOpenMPDevice(M)) ? 128 : 32; 4027 Attributor A(Functions, InfoCache, CGUpdater, nullptr, false, true, 4028 MaxFixpointIterations, OREGetter, DEBUG_TYPE); 4029 4030 OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A); 4031 return OMPOpt.run(false); 4032 } 4033 4034 bool doFinalization(CallGraph &CG) override { return CGUpdater.finalize(); } 4035 }; 4036 4037 } // end anonymous namespace 4038 4039 KernelSet llvm::omp::getDeviceKernels(Module &M) { 4040 // TODO: Create a more cross-platform way of determining device kernels. 4041 NamedMDNode *MD = M.getOrInsertNamedMetadata("nvvm.annotations"); 4042 KernelSet Kernels; 4043 4044 if (!MD) 4045 return Kernels; 4046 4047 for (auto *Op : MD->operands()) { 4048 if (Op->getNumOperands() < 2) 4049 continue; 4050 MDString *KindID = dyn_cast<MDString>(Op->getOperand(1)); 4051 if (!KindID || KindID->getString() != "kernel") 4052 continue; 4053 4054 Function *KernelFn = 4055 mdconst::dyn_extract_or_null<Function>(Op->getOperand(0)); 4056 if (!KernelFn) 4057 continue; 4058 4059 ++NumOpenMPTargetRegionKernels; 4060 4061 Kernels.insert(KernelFn); 4062 } 4063 4064 return Kernels; 4065 } 4066 4067 bool llvm::omp::containsOpenMP(Module &M) { 4068 Metadata *MD = M.getModuleFlag("openmp"); 4069 if (!MD) 4070 return false; 4071 4072 return true; 4073 } 4074 4075 bool llvm::omp::isOpenMPDevice(Module &M) { 4076 Metadata *MD = M.getModuleFlag("openmp-device"); 4077 if (!MD) 4078 return false; 4079 4080 return true; 4081 } 4082 4083 char OpenMPOptCGSCCLegacyPass::ID = 0; 4084 4085 INITIALIZE_PASS_BEGIN(OpenMPOptCGSCCLegacyPass, "openmp-opt-cgscc", 4086 "OpenMP specific optimizations", false, false) 4087 INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass) 4088 INITIALIZE_PASS_END(OpenMPOptCGSCCLegacyPass, "openmp-opt-cgscc", 4089 "OpenMP specific optimizations", false, false) 4090 4091 Pass *llvm::createOpenMPOptCGSCCLegacyPass() { 4092 return new OpenMPOptCGSCCLegacyPass(); 4093 } 4094