1 //===-- IPO/OpenMPOpt.cpp - Collection of OpenMP specific optimizations ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // OpenMP specific optimizations: 10 // 11 // - Deduplication of runtime calls, e.g., omp_get_thread_num. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "llvm/Transforms/IPO/OpenMPOpt.h" 16 17 #include "llvm/ADT/EnumeratedArray.h" 18 #include "llvm/ADT/PostOrderIterator.h" 19 #include "llvm/ADT/Statistic.h" 20 #include "llvm/Analysis/CallGraph.h" 21 #include "llvm/Analysis/CallGraphSCCPass.h" 22 #include "llvm/Analysis/OptimizationRemarkEmitter.h" 23 #include "llvm/Analysis/ValueTracking.h" 24 #include "llvm/Frontend/OpenMP/OMPConstants.h" 25 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h" 26 #include "llvm/IR/IntrinsicInst.h" 27 #include "llvm/IR/IntrinsicsAMDGPU.h" 28 #include "llvm/IR/IntrinsicsNVPTX.h" 29 #include "llvm/IR/PatternMatch.h" 30 #include "llvm/InitializePasses.h" 31 #include "llvm/Support/CommandLine.h" 32 #include "llvm/Transforms/IPO.h" 33 #include "llvm/Transforms/IPO/Attributor.h" 34 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 35 #include "llvm/Transforms/Utils/CallGraphUpdater.h" 36 #include "llvm/Transforms/Utils/CodeExtractor.h" 37 38 using namespace llvm::PatternMatch; 39 using namespace llvm; 40 using namespace omp; 41 42 #define DEBUG_TYPE "openmp-opt" 43 44 static cl::opt<bool> DisableOpenMPOptimizations( 45 "openmp-opt-disable", cl::ZeroOrMore, 46 cl::desc("Disable OpenMP specific optimizations."), cl::Hidden, 47 cl::init(false)); 48 49 static cl::opt<bool> EnableParallelRegionMerging( 50 "openmp-opt-enable-merging", cl::ZeroOrMore, 51 cl::desc("Enable the OpenMP region merging optimization."), cl::Hidden, 52 cl::init(false)); 53 54 static cl::opt<bool> PrintICVValues("openmp-print-icv-values", cl::init(false), 55 cl::Hidden); 56 static cl::opt<bool> PrintOpenMPKernels("openmp-print-gpu-kernels", 57 cl::init(false), cl::Hidden); 58 59 static cl::opt<bool> HideMemoryTransferLatency( 60 "openmp-hide-memory-transfer-latency", 61 cl::desc("[WIP] Tries to hide the latency of host to device memory" 62 " transfers"), 63 cl::Hidden, cl::init(false)); 64 65 STATISTIC(NumOpenMPRuntimeCallsDeduplicated, 66 "Number of OpenMP runtime calls deduplicated"); 67 STATISTIC(NumOpenMPParallelRegionsDeleted, 68 "Number of OpenMP parallel regions deleted"); 69 STATISTIC(NumOpenMPRuntimeFunctionsIdentified, 70 "Number of OpenMP runtime functions identified"); 71 STATISTIC(NumOpenMPRuntimeFunctionUsesIdentified, 72 "Number of OpenMP runtime function uses identified"); 73 STATISTIC(NumOpenMPTargetRegionKernels, 74 "Number of OpenMP target region entry points (=kernels) identified"); 75 STATISTIC( 76 NumOpenMPParallelRegionsReplacedInGPUStateMachine, 77 "Number of OpenMP parallel regions replaced with ID in GPU state machines"); 78 STATISTIC(NumOpenMPParallelRegionsMerged, 79 "Number of OpenMP parallel regions merged"); 80 STATISTIC(NumBytesMovedToSharedMemory, 81 "Amount of memory pushed to shared memory"); 82 83 #if !defined(NDEBUG) 84 static constexpr auto TAG = "[" DEBUG_TYPE "]"; 85 #endif 86 87 namespace { 88 89 enum class AddressSpace : unsigned { 90 Generic = 0, 91 Global = 1, 92 Shared = 3, 93 Constant = 4, 94 Local = 5, 95 }; 96 97 struct AAHeapToShared; 98 99 struct AAICVTracker; 100 101 /// OpenMP specific information. For now, stores RFIs and ICVs also needed for 102 /// Attributor runs. 103 struct OMPInformationCache : public InformationCache { 104 OMPInformationCache(Module &M, AnalysisGetter &AG, 105 BumpPtrAllocator &Allocator, SetVector<Function *> &CGSCC, 106 SmallPtrSetImpl<Kernel> &Kernels) 107 : InformationCache(M, AG, Allocator, &CGSCC), OMPBuilder(M), 108 Kernels(Kernels) { 109 110 OMPBuilder.initialize(); 111 initializeRuntimeFunctions(); 112 initializeInternalControlVars(); 113 } 114 115 /// Generic information that describes an internal control variable. 116 struct InternalControlVarInfo { 117 /// The kind, as described by InternalControlVar enum. 118 InternalControlVar Kind; 119 120 /// The name of the ICV. 121 StringRef Name; 122 123 /// Environment variable associated with this ICV. 124 StringRef EnvVarName; 125 126 /// Initial value kind. 127 ICVInitValue InitKind; 128 129 /// Initial value. 130 ConstantInt *InitValue; 131 132 /// Setter RTL function associated with this ICV. 133 RuntimeFunction Setter; 134 135 /// Getter RTL function associated with this ICV. 136 RuntimeFunction Getter; 137 138 /// RTL Function corresponding to the override clause of this ICV 139 RuntimeFunction Clause; 140 }; 141 142 /// Generic information that describes a runtime function 143 struct RuntimeFunctionInfo { 144 145 /// The kind, as described by the RuntimeFunction enum. 146 RuntimeFunction Kind; 147 148 /// The name of the function. 149 StringRef Name; 150 151 /// Flag to indicate a variadic function. 152 bool IsVarArg; 153 154 /// The return type of the function. 155 Type *ReturnType; 156 157 /// The argument types of the function. 158 SmallVector<Type *, 8> ArgumentTypes; 159 160 /// The declaration if available. 161 Function *Declaration = nullptr; 162 163 /// Uses of this runtime function per function containing the use. 164 using UseVector = SmallVector<Use *, 16>; 165 166 /// Clear UsesMap for runtime function. 167 void clearUsesMap() { UsesMap.clear(); } 168 169 /// Boolean conversion that is true if the runtime function was found. 170 operator bool() const { return Declaration; } 171 172 /// Return the vector of uses in function \p F. 173 UseVector &getOrCreateUseVector(Function *F) { 174 std::shared_ptr<UseVector> &UV = UsesMap[F]; 175 if (!UV) 176 UV = std::make_shared<UseVector>(); 177 return *UV; 178 } 179 180 /// Return the vector of uses in function \p F or `nullptr` if there are 181 /// none. 182 const UseVector *getUseVector(Function &F) const { 183 auto I = UsesMap.find(&F); 184 if (I != UsesMap.end()) 185 return I->second.get(); 186 return nullptr; 187 } 188 189 /// Return how many functions contain uses of this runtime function. 190 size_t getNumFunctionsWithUses() const { return UsesMap.size(); } 191 192 /// Return the number of arguments (or the minimal number for variadic 193 /// functions). 194 size_t getNumArgs() const { return ArgumentTypes.size(); } 195 196 /// Run the callback \p CB on each use and forget the use if the result is 197 /// true. The callback will be fed the function in which the use was 198 /// encountered as second argument. 199 void foreachUse(SmallVectorImpl<Function *> &SCC, 200 function_ref<bool(Use &, Function &)> CB) { 201 for (Function *F : SCC) 202 foreachUse(CB, F); 203 } 204 205 /// Run the callback \p CB on each use within the function \p F and forget 206 /// the use if the result is true. 207 void foreachUse(function_ref<bool(Use &, Function &)> CB, Function *F) { 208 SmallVector<unsigned, 8> ToBeDeleted; 209 ToBeDeleted.clear(); 210 211 unsigned Idx = 0; 212 UseVector &UV = getOrCreateUseVector(F); 213 214 for (Use *U : UV) { 215 if (CB(*U, *F)) 216 ToBeDeleted.push_back(Idx); 217 ++Idx; 218 } 219 220 // Remove the to-be-deleted indices in reverse order as prior 221 // modifications will not modify the smaller indices. 222 while (!ToBeDeleted.empty()) { 223 unsigned Idx = ToBeDeleted.pop_back_val(); 224 UV[Idx] = UV.back(); 225 UV.pop_back(); 226 } 227 } 228 229 private: 230 /// Map from functions to all uses of this runtime function contained in 231 /// them. 232 DenseMap<Function *, std::shared_ptr<UseVector>> UsesMap; 233 }; 234 235 /// An OpenMP-IR-Builder instance 236 OpenMPIRBuilder OMPBuilder; 237 238 /// Map from runtime function kind to the runtime function description. 239 EnumeratedArray<RuntimeFunctionInfo, RuntimeFunction, 240 RuntimeFunction::OMPRTL___last> 241 RFIs; 242 243 /// Map from ICV kind to the ICV description. 244 EnumeratedArray<InternalControlVarInfo, InternalControlVar, 245 InternalControlVar::ICV___last> 246 ICVs; 247 248 /// Helper to initialize all internal control variable information for those 249 /// defined in OMPKinds.def. 250 void initializeInternalControlVars() { 251 #define ICV_RT_SET(_Name, RTL) \ 252 { \ 253 auto &ICV = ICVs[_Name]; \ 254 ICV.Setter = RTL; \ 255 } 256 #define ICV_RT_GET(Name, RTL) \ 257 { \ 258 auto &ICV = ICVs[Name]; \ 259 ICV.Getter = RTL; \ 260 } 261 #define ICV_DATA_ENV(Enum, _Name, _EnvVarName, Init) \ 262 { \ 263 auto &ICV = ICVs[Enum]; \ 264 ICV.Name = _Name; \ 265 ICV.Kind = Enum; \ 266 ICV.InitKind = Init; \ 267 ICV.EnvVarName = _EnvVarName; \ 268 switch (ICV.InitKind) { \ 269 case ICV_IMPLEMENTATION_DEFINED: \ 270 ICV.InitValue = nullptr; \ 271 break; \ 272 case ICV_ZERO: \ 273 ICV.InitValue = ConstantInt::get( \ 274 Type::getInt32Ty(OMPBuilder.Int32->getContext()), 0); \ 275 break; \ 276 case ICV_FALSE: \ 277 ICV.InitValue = ConstantInt::getFalse(OMPBuilder.Int1->getContext()); \ 278 break; \ 279 case ICV_LAST: \ 280 break; \ 281 } \ 282 } 283 #include "llvm/Frontend/OpenMP/OMPKinds.def" 284 } 285 286 /// Returns true if the function declaration \p F matches the runtime 287 /// function types, that is, return type \p RTFRetType, and argument types 288 /// \p RTFArgTypes. 289 static bool declMatchesRTFTypes(Function *F, Type *RTFRetType, 290 SmallVector<Type *, 8> &RTFArgTypes) { 291 // TODO: We should output information to the user (under debug output 292 // and via remarks). 293 294 if (!F) 295 return false; 296 if (F->getReturnType() != RTFRetType) 297 return false; 298 if (F->arg_size() != RTFArgTypes.size()) 299 return false; 300 301 auto RTFTyIt = RTFArgTypes.begin(); 302 for (Argument &Arg : F->args()) { 303 if (Arg.getType() != *RTFTyIt) 304 return false; 305 306 ++RTFTyIt; 307 } 308 309 return true; 310 } 311 312 // Helper to collect all uses of the declaration in the UsesMap. 313 unsigned collectUses(RuntimeFunctionInfo &RFI, bool CollectStats = true) { 314 unsigned NumUses = 0; 315 if (!RFI.Declaration) 316 return NumUses; 317 OMPBuilder.addAttributes(RFI.Kind, *RFI.Declaration); 318 319 if (CollectStats) { 320 NumOpenMPRuntimeFunctionsIdentified += 1; 321 NumOpenMPRuntimeFunctionUsesIdentified += RFI.Declaration->getNumUses(); 322 } 323 324 // TODO: We directly convert uses into proper calls and unknown uses. 325 for (Use &U : RFI.Declaration->uses()) { 326 if (Instruction *UserI = dyn_cast<Instruction>(U.getUser())) { 327 if (ModuleSlice.count(UserI->getFunction())) { 328 RFI.getOrCreateUseVector(UserI->getFunction()).push_back(&U); 329 ++NumUses; 330 } 331 } else { 332 RFI.getOrCreateUseVector(nullptr).push_back(&U); 333 ++NumUses; 334 } 335 } 336 return NumUses; 337 } 338 339 // Helper function to recollect uses of a runtime function. 340 void recollectUsesForFunction(RuntimeFunction RTF) { 341 auto &RFI = RFIs[RTF]; 342 RFI.clearUsesMap(); 343 collectUses(RFI, /*CollectStats*/ false); 344 } 345 346 // Helper function to recollect uses of all runtime functions. 347 void recollectUses() { 348 for (int Idx = 0; Idx < RFIs.size(); ++Idx) 349 recollectUsesForFunction(static_cast<RuntimeFunction>(Idx)); 350 } 351 352 /// Helper to initialize all runtime function information for those defined 353 /// in OpenMPKinds.def. 354 void initializeRuntimeFunctions() { 355 Module &M = *((*ModuleSlice.begin())->getParent()); 356 357 // Helper macros for handling __VA_ARGS__ in OMP_RTL 358 #define OMP_TYPE(VarName, ...) \ 359 Type *VarName = OMPBuilder.VarName; \ 360 (void)VarName; 361 362 #define OMP_ARRAY_TYPE(VarName, ...) \ 363 ArrayType *VarName##Ty = OMPBuilder.VarName##Ty; \ 364 (void)VarName##Ty; \ 365 PointerType *VarName##PtrTy = OMPBuilder.VarName##PtrTy; \ 366 (void)VarName##PtrTy; 367 368 #define OMP_FUNCTION_TYPE(VarName, ...) \ 369 FunctionType *VarName = OMPBuilder.VarName; \ 370 (void)VarName; \ 371 PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr; \ 372 (void)VarName##Ptr; 373 374 #define OMP_STRUCT_TYPE(VarName, ...) \ 375 StructType *VarName = OMPBuilder.VarName; \ 376 (void)VarName; \ 377 PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr; \ 378 (void)VarName##Ptr; 379 380 #define OMP_RTL(_Enum, _Name, _IsVarArg, _ReturnType, ...) \ 381 { \ 382 SmallVector<Type *, 8> ArgsTypes({__VA_ARGS__}); \ 383 Function *F = M.getFunction(_Name); \ 384 if (declMatchesRTFTypes(F, OMPBuilder._ReturnType, ArgsTypes)) { \ 385 auto &RFI = RFIs[_Enum]; \ 386 RFI.Kind = _Enum; \ 387 RFI.Name = _Name; \ 388 RFI.IsVarArg = _IsVarArg; \ 389 RFI.ReturnType = OMPBuilder._ReturnType; \ 390 RFI.ArgumentTypes = std::move(ArgsTypes); \ 391 RFI.Declaration = F; \ 392 unsigned NumUses = collectUses(RFI); \ 393 (void)NumUses; \ 394 LLVM_DEBUG({ \ 395 dbgs() << TAG << RFI.Name << (RFI.Declaration ? "" : " not") \ 396 << " found\n"; \ 397 if (RFI.Declaration) \ 398 dbgs() << TAG << "-> got " << NumUses << " uses in " \ 399 << RFI.getNumFunctionsWithUses() \ 400 << " different functions.\n"; \ 401 }); \ 402 } \ 403 } 404 #include "llvm/Frontend/OpenMP/OMPKinds.def" 405 406 // TODO: We should attach the attributes defined in OMPKinds.def. 407 } 408 409 /// Collection of known kernels (\see Kernel) in the module. 410 SmallPtrSetImpl<Kernel> &Kernels; 411 }; 412 413 /// Used to map the values physically (in the IR) stored in an offload 414 /// array, to a vector in memory. 415 struct OffloadArray { 416 /// Physical array (in the IR). 417 AllocaInst *Array = nullptr; 418 /// Mapped values. 419 SmallVector<Value *, 8> StoredValues; 420 /// Last stores made in the offload array. 421 SmallVector<StoreInst *, 8> LastAccesses; 422 423 OffloadArray() = default; 424 425 /// Initializes the OffloadArray with the values stored in \p Array before 426 /// instruction \p Before is reached. Returns false if the initialization 427 /// fails. 428 /// This MUST be used immediately after the construction of the object. 429 bool initialize(AllocaInst &Array, Instruction &Before) { 430 if (!Array.getAllocatedType()->isArrayTy()) 431 return false; 432 433 if (!getValues(Array, Before)) 434 return false; 435 436 this->Array = &Array; 437 return true; 438 } 439 440 static const unsigned DeviceIDArgNum = 1; 441 static const unsigned BasePtrsArgNum = 3; 442 static const unsigned PtrsArgNum = 4; 443 static const unsigned SizesArgNum = 5; 444 445 private: 446 /// Traverses the BasicBlock where \p Array is, collecting the stores made to 447 /// \p Array, leaving StoredValues with the values stored before the 448 /// instruction \p Before is reached. 449 bool getValues(AllocaInst &Array, Instruction &Before) { 450 // Initialize container. 451 const uint64_t NumValues = Array.getAllocatedType()->getArrayNumElements(); 452 StoredValues.assign(NumValues, nullptr); 453 LastAccesses.assign(NumValues, nullptr); 454 455 // TODO: This assumes the instruction \p Before is in the same 456 // BasicBlock as Array. Make it general, for any control flow graph. 457 BasicBlock *BB = Array.getParent(); 458 if (BB != Before.getParent()) 459 return false; 460 461 const DataLayout &DL = Array.getModule()->getDataLayout(); 462 const unsigned int PointerSize = DL.getPointerSize(); 463 464 for (Instruction &I : *BB) { 465 if (&I == &Before) 466 break; 467 468 if (!isa<StoreInst>(&I)) 469 continue; 470 471 auto *S = cast<StoreInst>(&I); 472 int64_t Offset = -1; 473 auto *Dst = 474 GetPointerBaseWithConstantOffset(S->getPointerOperand(), Offset, DL); 475 if (Dst == &Array) { 476 int64_t Idx = Offset / PointerSize; 477 StoredValues[Idx] = getUnderlyingObject(S->getValueOperand()); 478 LastAccesses[Idx] = S; 479 } 480 } 481 482 return isFilled(); 483 } 484 485 /// Returns true if all values in StoredValues and 486 /// LastAccesses are not nullptrs. 487 bool isFilled() { 488 const unsigned NumValues = StoredValues.size(); 489 for (unsigned I = 0; I < NumValues; ++I) { 490 if (!StoredValues[I] || !LastAccesses[I]) 491 return false; 492 } 493 494 return true; 495 } 496 }; 497 498 struct OpenMPOpt { 499 500 using OptimizationRemarkGetter = 501 function_ref<OptimizationRemarkEmitter &(Function *)>; 502 503 OpenMPOpt(SmallVectorImpl<Function *> &SCC, CallGraphUpdater &CGUpdater, 504 OptimizationRemarkGetter OREGetter, 505 OMPInformationCache &OMPInfoCache, Attributor &A) 506 : M(*(*SCC.begin())->getParent()), SCC(SCC), CGUpdater(CGUpdater), 507 OREGetter(OREGetter), OMPInfoCache(OMPInfoCache), A(A) {} 508 509 /// Check if any remarks are enabled for openmp-opt 510 bool remarksEnabled() { 511 auto &Ctx = M.getContext(); 512 return Ctx.getDiagHandlerPtr()->isAnyRemarkEnabled(DEBUG_TYPE); 513 } 514 515 /// Run all OpenMP optimizations on the underlying SCC/ModuleSlice. 516 bool run(bool IsModulePass) { 517 if (SCC.empty()) 518 return false; 519 520 bool Changed = false; 521 522 LLVM_DEBUG(dbgs() << TAG << "Run on SCC with " << SCC.size() 523 << " functions in a slice with " 524 << OMPInfoCache.ModuleSlice.size() << " functions\n"); 525 526 if (IsModulePass) { 527 Changed |= runAttributor(); 528 529 // Recollect uses, in case Attributor deleted any. 530 OMPInfoCache.recollectUses(); 531 532 if (remarksEnabled()) 533 analysisGlobalization(); 534 } else { 535 if (PrintICVValues) 536 printICVs(); 537 if (PrintOpenMPKernels) 538 printKernels(); 539 540 Changed |= rewriteDeviceCodeStateMachine(); 541 542 Changed |= runAttributor(); 543 544 // Recollect uses, in case Attributor deleted any. 545 OMPInfoCache.recollectUses(); 546 547 Changed |= deleteParallelRegions(); 548 if (HideMemoryTransferLatency) 549 Changed |= hideMemTransfersLatency(); 550 Changed |= deduplicateRuntimeCalls(); 551 if (EnableParallelRegionMerging) { 552 if (mergeParallelRegions()) { 553 deduplicateRuntimeCalls(); 554 Changed = true; 555 } 556 } 557 } 558 559 return Changed; 560 } 561 562 /// Print initial ICV values for testing. 563 /// FIXME: This should be done from the Attributor once it is added. 564 void printICVs() const { 565 InternalControlVar ICVs[] = {ICV_nthreads, ICV_active_levels, ICV_cancel, 566 ICV_proc_bind}; 567 568 for (Function *F : OMPInfoCache.ModuleSlice) { 569 for (auto ICV : ICVs) { 570 auto ICVInfo = OMPInfoCache.ICVs[ICV]; 571 auto Remark = [&](OptimizationRemarkAnalysis ORA) { 572 return ORA << "OpenMP ICV " << ore::NV("OpenMPICV", ICVInfo.Name) 573 << " Value: " 574 << (ICVInfo.InitValue 575 ? toString(ICVInfo.InitValue->getValue(), 10, true) 576 : "IMPLEMENTATION_DEFINED"); 577 }; 578 579 emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPICVTracker", Remark); 580 } 581 } 582 } 583 584 /// Print OpenMP GPU kernels for testing. 585 void printKernels() const { 586 for (Function *F : SCC) { 587 if (!OMPInfoCache.Kernels.count(F)) 588 continue; 589 590 auto Remark = [&](OptimizationRemarkAnalysis ORA) { 591 return ORA << "OpenMP GPU kernel " 592 << ore::NV("OpenMPGPUKernel", F->getName()) << "\n"; 593 }; 594 595 emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPGPU", Remark); 596 } 597 } 598 599 /// Return the call if \p U is a callee use in a regular call. If \p RFI is 600 /// given it has to be the callee or a nullptr is returned. 601 static CallInst *getCallIfRegularCall( 602 Use &U, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) { 603 CallInst *CI = dyn_cast<CallInst>(U.getUser()); 604 if (CI && CI->isCallee(&U) && !CI->hasOperandBundles() && 605 (!RFI || CI->getCalledFunction() == RFI->Declaration)) 606 return CI; 607 return nullptr; 608 } 609 610 /// Return the call if \p V is a regular call. If \p RFI is given it has to be 611 /// the callee or a nullptr is returned. 612 static CallInst *getCallIfRegularCall( 613 Value &V, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) { 614 CallInst *CI = dyn_cast<CallInst>(&V); 615 if (CI && !CI->hasOperandBundles() && 616 (!RFI || CI->getCalledFunction() == RFI->Declaration)) 617 return CI; 618 return nullptr; 619 } 620 621 private: 622 /// Merge parallel regions when it is safe. 623 bool mergeParallelRegions() { 624 const unsigned CallbackCalleeOperand = 2; 625 const unsigned CallbackFirstArgOperand = 3; 626 using InsertPointTy = OpenMPIRBuilder::InsertPointTy; 627 628 // Check if there are any __kmpc_fork_call calls to merge. 629 OMPInformationCache::RuntimeFunctionInfo &RFI = 630 OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call]; 631 632 if (!RFI.Declaration) 633 return false; 634 635 // Unmergable calls that prevent merging a parallel region. 636 OMPInformationCache::RuntimeFunctionInfo UnmergableCallsInfo[] = { 637 OMPInfoCache.RFIs[OMPRTL___kmpc_push_proc_bind], 638 OMPInfoCache.RFIs[OMPRTL___kmpc_push_num_threads], 639 }; 640 641 bool Changed = false; 642 LoopInfo *LI = nullptr; 643 DominatorTree *DT = nullptr; 644 645 SmallDenseMap<BasicBlock *, SmallPtrSet<Instruction *, 4>> BB2PRMap; 646 647 BasicBlock *StartBB = nullptr, *EndBB = nullptr; 648 auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, 649 BasicBlock &ContinuationIP) { 650 BasicBlock *CGStartBB = CodeGenIP.getBlock(); 651 BasicBlock *CGEndBB = 652 SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI); 653 assert(StartBB != nullptr && "StartBB should not be null"); 654 CGStartBB->getTerminator()->setSuccessor(0, StartBB); 655 assert(EndBB != nullptr && "EndBB should not be null"); 656 EndBB->getTerminator()->setSuccessor(0, CGEndBB); 657 }; 658 659 auto PrivCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, Value &, 660 Value &Inner, Value *&ReplacementValue) -> InsertPointTy { 661 ReplacementValue = &Inner; 662 return CodeGenIP; 663 }; 664 665 auto FiniCB = [&](InsertPointTy CodeGenIP) {}; 666 667 /// Create a sequential execution region within a merged parallel region, 668 /// encapsulated in a master construct with a barrier for synchronization. 669 auto CreateSequentialRegion = [&](Function *OuterFn, 670 BasicBlock *OuterPredBB, 671 Instruction *SeqStartI, 672 Instruction *SeqEndI) { 673 // Isolate the instructions of the sequential region to a separate 674 // block. 675 BasicBlock *ParentBB = SeqStartI->getParent(); 676 BasicBlock *SeqEndBB = 677 SplitBlock(ParentBB, SeqEndI->getNextNode(), DT, LI); 678 BasicBlock *SeqAfterBB = 679 SplitBlock(SeqEndBB, &*SeqEndBB->getFirstInsertionPt(), DT, LI); 680 BasicBlock *SeqStartBB = 681 SplitBlock(ParentBB, SeqStartI, DT, LI, nullptr, "seq.par.merged"); 682 683 assert(ParentBB->getUniqueSuccessor() == SeqStartBB && 684 "Expected a different CFG"); 685 const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc(); 686 ParentBB->getTerminator()->eraseFromParent(); 687 688 auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, 689 BasicBlock &ContinuationIP) { 690 BasicBlock *CGStartBB = CodeGenIP.getBlock(); 691 BasicBlock *CGEndBB = 692 SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI); 693 assert(SeqStartBB != nullptr && "SeqStartBB should not be null"); 694 CGStartBB->getTerminator()->setSuccessor(0, SeqStartBB); 695 assert(SeqEndBB != nullptr && "SeqEndBB should not be null"); 696 SeqEndBB->getTerminator()->setSuccessor(0, CGEndBB); 697 }; 698 auto FiniCB = [&](InsertPointTy CodeGenIP) {}; 699 700 // Find outputs from the sequential region to outside users and 701 // broadcast their values to them. 702 for (Instruction &I : *SeqStartBB) { 703 SmallPtrSet<Instruction *, 4> OutsideUsers; 704 for (User *Usr : I.users()) { 705 Instruction &UsrI = *cast<Instruction>(Usr); 706 // Ignore outputs to LT intrinsics, code extraction for the merged 707 // parallel region will fix them. 708 if (UsrI.isLifetimeStartOrEnd()) 709 continue; 710 711 if (UsrI.getParent() != SeqStartBB) 712 OutsideUsers.insert(&UsrI); 713 } 714 715 if (OutsideUsers.empty()) 716 continue; 717 718 // Emit an alloca in the outer region to store the broadcasted 719 // value. 720 const DataLayout &DL = M.getDataLayout(); 721 AllocaInst *AllocaI = new AllocaInst( 722 I.getType(), DL.getAllocaAddrSpace(), nullptr, 723 I.getName() + ".seq.output.alloc", &OuterFn->front().front()); 724 725 // Emit a store instruction in the sequential BB to update the 726 // value. 727 new StoreInst(&I, AllocaI, SeqStartBB->getTerminator()); 728 729 // Emit a load instruction and replace the use of the output value 730 // with it. 731 for (Instruction *UsrI : OutsideUsers) { 732 LoadInst *LoadI = new LoadInst( 733 I.getType(), AllocaI, I.getName() + ".seq.output.load", UsrI); 734 UsrI->replaceUsesOfWith(&I, LoadI); 735 } 736 } 737 738 OpenMPIRBuilder::LocationDescription Loc( 739 InsertPointTy(ParentBB, ParentBB->end()), DL); 740 InsertPointTy SeqAfterIP = 741 OMPInfoCache.OMPBuilder.createMaster(Loc, BodyGenCB, FiniCB); 742 743 OMPInfoCache.OMPBuilder.createBarrier(SeqAfterIP, OMPD_parallel); 744 745 BranchInst::Create(SeqAfterBB, SeqAfterIP.getBlock()); 746 747 LLVM_DEBUG(dbgs() << TAG << "After sequential inlining " << *OuterFn 748 << "\n"); 749 }; 750 751 // Helper to merge the __kmpc_fork_call calls in MergableCIs. They are all 752 // contained in BB and only separated by instructions that can be 753 // redundantly executed in parallel. The block BB is split before the first 754 // call (in MergableCIs) and after the last so the entire region we merge 755 // into a single parallel region is contained in a single basic block 756 // without any other instructions. We use the OpenMPIRBuilder to outline 757 // that block and call the resulting function via __kmpc_fork_call. 758 auto Merge = [&](SmallVectorImpl<CallInst *> &MergableCIs, BasicBlock *BB) { 759 // TODO: Change the interface to allow single CIs expanded, e.g, to 760 // include an outer loop. 761 assert(MergableCIs.size() > 1 && "Assumed multiple mergable CIs"); 762 763 auto Remark = [&](OptimizationRemark OR) { 764 OR << "Parallel region at " 765 << ore::NV("OpenMPParallelMergeFront", 766 MergableCIs.front()->getDebugLoc()) 767 << " merged with parallel regions at "; 768 for (auto *CI : llvm::drop_begin(MergableCIs)) { 769 OR << ore::NV("OpenMPParallelMerge", CI->getDebugLoc()); 770 if (CI != MergableCIs.back()) 771 OR << ", "; 772 } 773 return OR; 774 }; 775 776 emitRemark<OptimizationRemark>(MergableCIs.front(), 777 "OpenMPParallelRegionMerging", Remark); 778 779 Function *OriginalFn = BB->getParent(); 780 LLVM_DEBUG(dbgs() << TAG << "Merge " << MergableCIs.size() 781 << " parallel regions in " << OriginalFn->getName() 782 << "\n"); 783 784 // Isolate the calls to merge in a separate block. 785 EndBB = SplitBlock(BB, MergableCIs.back()->getNextNode(), DT, LI); 786 BasicBlock *AfterBB = 787 SplitBlock(EndBB, &*EndBB->getFirstInsertionPt(), DT, LI); 788 StartBB = SplitBlock(BB, MergableCIs.front(), DT, LI, nullptr, 789 "omp.par.merged"); 790 791 assert(BB->getUniqueSuccessor() == StartBB && "Expected a different CFG"); 792 const DebugLoc DL = BB->getTerminator()->getDebugLoc(); 793 BB->getTerminator()->eraseFromParent(); 794 795 // Create sequential regions for sequential instructions that are 796 // in-between mergable parallel regions. 797 for (auto *It = MergableCIs.begin(), *End = MergableCIs.end() - 1; 798 It != End; ++It) { 799 Instruction *ForkCI = *It; 800 Instruction *NextForkCI = *(It + 1); 801 802 // Continue if there are not in-between instructions. 803 if (ForkCI->getNextNode() == NextForkCI) 804 continue; 805 806 CreateSequentialRegion(OriginalFn, BB, ForkCI->getNextNode(), 807 NextForkCI->getPrevNode()); 808 } 809 810 OpenMPIRBuilder::LocationDescription Loc(InsertPointTy(BB, BB->end()), 811 DL); 812 IRBuilder<>::InsertPoint AllocaIP( 813 &OriginalFn->getEntryBlock(), 814 OriginalFn->getEntryBlock().getFirstInsertionPt()); 815 // Create the merged parallel region with default proc binding, to 816 // avoid overriding binding settings, and without explicit cancellation. 817 InsertPointTy AfterIP = OMPInfoCache.OMPBuilder.createParallel( 818 Loc, AllocaIP, BodyGenCB, PrivCB, FiniCB, nullptr, nullptr, 819 OMP_PROC_BIND_default, /* IsCancellable */ false); 820 BranchInst::Create(AfterBB, AfterIP.getBlock()); 821 822 // Perform the actual outlining. 823 OMPInfoCache.OMPBuilder.finalize(OriginalFn, 824 /* AllowExtractorSinking */ true); 825 826 Function *OutlinedFn = MergableCIs.front()->getCaller(); 827 828 // Replace the __kmpc_fork_call calls with direct calls to the outlined 829 // callbacks. 830 SmallVector<Value *, 8> Args; 831 for (auto *CI : MergableCIs) { 832 Value *Callee = 833 CI->getArgOperand(CallbackCalleeOperand)->stripPointerCasts(); 834 FunctionType *FT = 835 cast<FunctionType>(Callee->getType()->getPointerElementType()); 836 Args.clear(); 837 Args.push_back(OutlinedFn->getArg(0)); 838 Args.push_back(OutlinedFn->getArg(1)); 839 for (unsigned U = CallbackFirstArgOperand, E = CI->getNumArgOperands(); 840 U < E; ++U) 841 Args.push_back(CI->getArgOperand(U)); 842 843 CallInst *NewCI = CallInst::Create(FT, Callee, Args, "", CI); 844 if (CI->getDebugLoc()) 845 NewCI->setDebugLoc(CI->getDebugLoc()); 846 847 // Forward parameter attributes from the callback to the callee. 848 for (unsigned U = CallbackFirstArgOperand, E = CI->getNumArgOperands(); 849 U < E; ++U) 850 for (const Attribute &A : CI->getAttributes().getParamAttributes(U)) 851 NewCI->addParamAttr( 852 U - (CallbackFirstArgOperand - CallbackCalleeOperand), A); 853 854 // Emit an explicit barrier to replace the implicit fork-join barrier. 855 if (CI != MergableCIs.back()) { 856 // TODO: Remove barrier if the merged parallel region includes the 857 // 'nowait' clause. 858 OMPInfoCache.OMPBuilder.createBarrier( 859 InsertPointTy(NewCI->getParent(), 860 NewCI->getNextNode()->getIterator()), 861 OMPD_parallel); 862 } 863 864 auto Remark = [&](OptimizationRemark OR) { 865 return OR << "Parallel region at " 866 << ore::NV("OpenMPParallelMerge", CI->getDebugLoc()) 867 << " merged with " 868 << ore::NV("OpenMPParallelMergeFront", 869 MergableCIs.front()->getDebugLoc()); 870 }; 871 if (CI != MergableCIs.front()) 872 emitRemark<OptimizationRemark>(CI, "OpenMPParallelRegionMerging", 873 Remark); 874 875 CI->eraseFromParent(); 876 } 877 878 assert(OutlinedFn != OriginalFn && "Outlining failed"); 879 CGUpdater.registerOutlinedFunction(*OriginalFn, *OutlinedFn); 880 CGUpdater.reanalyzeFunction(*OriginalFn); 881 882 NumOpenMPParallelRegionsMerged += MergableCIs.size(); 883 884 return true; 885 }; 886 887 // Helper function that identifes sequences of 888 // __kmpc_fork_call uses in a basic block. 889 auto DetectPRsCB = [&](Use &U, Function &F) { 890 CallInst *CI = getCallIfRegularCall(U, &RFI); 891 BB2PRMap[CI->getParent()].insert(CI); 892 893 return false; 894 }; 895 896 BB2PRMap.clear(); 897 RFI.foreachUse(SCC, DetectPRsCB); 898 SmallVector<SmallVector<CallInst *, 4>, 4> MergableCIsVector; 899 // Find mergable parallel regions within a basic block that are 900 // safe to merge, that is any in-between instructions can safely 901 // execute in parallel after merging. 902 // TODO: support merging across basic-blocks. 903 for (auto &It : BB2PRMap) { 904 auto &CIs = It.getSecond(); 905 if (CIs.size() < 2) 906 continue; 907 908 BasicBlock *BB = It.getFirst(); 909 SmallVector<CallInst *, 4> MergableCIs; 910 911 /// Returns true if the instruction is mergable, false otherwise. 912 /// A terminator instruction is unmergable by definition since merging 913 /// works within a BB. Instructions before the mergable region are 914 /// mergable if they are not calls to OpenMP runtime functions that may 915 /// set different execution parameters for subsequent parallel regions. 916 /// Instructions in-between parallel regions are mergable if they are not 917 /// calls to any non-intrinsic function since that may call a non-mergable 918 /// OpenMP runtime function. 919 auto IsMergable = [&](Instruction &I, bool IsBeforeMergableRegion) { 920 // We do not merge across BBs, hence return false (unmergable) if the 921 // instruction is a terminator. 922 if (I.isTerminator()) 923 return false; 924 925 if (!isa<CallInst>(&I)) 926 return true; 927 928 CallInst *CI = cast<CallInst>(&I); 929 if (IsBeforeMergableRegion) { 930 Function *CalledFunction = CI->getCalledFunction(); 931 if (!CalledFunction) 932 return false; 933 // Return false (unmergable) if the call before the parallel 934 // region calls an explicit affinity (proc_bind) or number of 935 // threads (num_threads) compiler-generated function. Those settings 936 // may be incompatible with following parallel regions. 937 // TODO: ICV tracking to detect compatibility. 938 for (const auto &RFI : UnmergableCallsInfo) { 939 if (CalledFunction == RFI.Declaration) 940 return false; 941 } 942 } else { 943 // Return false (unmergable) if there is a call instruction 944 // in-between parallel regions when it is not an intrinsic. It 945 // may call an unmergable OpenMP runtime function in its callpath. 946 // TODO: Keep track of possible OpenMP calls in the callpath. 947 if (!isa<IntrinsicInst>(CI)) 948 return false; 949 } 950 951 return true; 952 }; 953 // Find maximal number of parallel region CIs that are safe to merge. 954 for (auto It = BB->begin(), End = BB->end(); It != End;) { 955 Instruction &I = *It; 956 ++It; 957 958 if (CIs.count(&I)) { 959 MergableCIs.push_back(cast<CallInst>(&I)); 960 continue; 961 } 962 963 // Continue expanding if the instruction is mergable. 964 if (IsMergable(I, MergableCIs.empty())) 965 continue; 966 967 // Forward the instruction iterator to skip the next parallel region 968 // since there is an unmergable instruction which can affect it. 969 for (; It != End; ++It) { 970 Instruction &SkipI = *It; 971 if (CIs.count(&SkipI)) { 972 LLVM_DEBUG(dbgs() << TAG << "Skip parallel region " << SkipI 973 << " due to " << I << "\n"); 974 ++It; 975 break; 976 } 977 } 978 979 // Store mergable regions found. 980 if (MergableCIs.size() > 1) { 981 MergableCIsVector.push_back(MergableCIs); 982 LLVM_DEBUG(dbgs() << TAG << "Found " << MergableCIs.size() 983 << " parallel regions in block " << BB->getName() 984 << " of function " << BB->getParent()->getName() 985 << "\n";); 986 } 987 988 MergableCIs.clear(); 989 } 990 991 if (!MergableCIsVector.empty()) { 992 Changed = true; 993 994 for (auto &MergableCIs : MergableCIsVector) 995 Merge(MergableCIs, BB); 996 MergableCIsVector.clear(); 997 } 998 } 999 1000 if (Changed) { 1001 /// Re-collect use for fork calls, emitted barrier calls, and 1002 /// any emitted master/end_master calls. 1003 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_fork_call); 1004 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_barrier); 1005 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_master); 1006 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_end_master); 1007 } 1008 1009 return Changed; 1010 } 1011 1012 /// Try to delete parallel regions if possible. 1013 bool deleteParallelRegions() { 1014 const unsigned CallbackCalleeOperand = 2; 1015 1016 OMPInformationCache::RuntimeFunctionInfo &RFI = 1017 OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call]; 1018 1019 if (!RFI.Declaration) 1020 return false; 1021 1022 bool Changed = false; 1023 auto DeleteCallCB = [&](Use &U, Function &) { 1024 CallInst *CI = getCallIfRegularCall(U); 1025 if (!CI) 1026 return false; 1027 auto *Fn = dyn_cast<Function>( 1028 CI->getArgOperand(CallbackCalleeOperand)->stripPointerCasts()); 1029 if (!Fn) 1030 return false; 1031 if (!Fn->onlyReadsMemory()) 1032 return false; 1033 if (!Fn->hasFnAttribute(Attribute::WillReturn)) 1034 return false; 1035 1036 LLVM_DEBUG(dbgs() << TAG << "Delete read-only parallel region in " 1037 << CI->getCaller()->getName() << "\n"); 1038 1039 auto Remark = [&](OptimizationRemark OR) { 1040 return OR << "Parallel region in " 1041 << ore::NV("OpenMPParallelDelete", CI->getCaller()->getName()) 1042 << " deleted"; 1043 }; 1044 emitRemark<OptimizationRemark>(CI, "OpenMPParallelRegionDeletion", 1045 Remark); 1046 1047 CGUpdater.removeCallSite(*CI); 1048 CI->eraseFromParent(); 1049 Changed = true; 1050 ++NumOpenMPParallelRegionsDeleted; 1051 return true; 1052 }; 1053 1054 RFI.foreachUse(SCC, DeleteCallCB); 1055 1056 return Changed; 1057 } 1058 1059 /// Try to eliminate runtime calls by reusing existing ones. 1060 bool deduplicateRuntimeCalls() { 1061 bool Changed = false; 1062 1063 RuntimeFunction DeduplicableRuntimeCallIDs[] = { 1064 OMPRTL_omp_get_num_threads, 1065 OMPRTL_omp_in_parallel, 1066 OMPRTL_omp_get_cancellation, 1067 OMPRTL_omp_get_thread_limit, 1068 OMPRTL_omp_get_supported_active_levels, 1069 OMPRTL_omp_get_level, 1070 OMPRTL_omp_get_ancestor_thread_num, 1071 OMPRTL_omp_get_team_size, 1072 OMPRTL_omp_get_active_level, 1073 OMPRTL_omp_in_final, 1074 OMPRTL_omp_get_proc_bind, 1075 OMPRTL_omp_get_num_places, 1076 OMPRTL_omp_get_num_procs, 1077 OMPRTL_omp_get_place_num, 1078 OMPRTL_omp_get_partition_num_places, 1079 OMPRTL_omp_get_partition_place_nums}; 1080 1081 // Global-tid is handled separately. 1082 SmallSetVector<Value *, 16> GTIdArgs; 1083 collectGlobalThreadIdArguments(GTIdArgs); 1084 LLVM_DEBUG(dbgs() << TAG << "Found " << GTIdArgs.size() 1085 << " global thread ID arguments\n"); 1086 1087 for (Function *F : SCC) { 1088 for (auto DeduplicableRuntimeCallID : DeduplicableRuntimeCallIDs) 1089 Changed |= deduplicateRuntimeCalls( 1090 *F, OMPInfoCache.RFIs[DeduplicableRuntimeCallID]); 1091 1092 // __kmpc_global_thread_num is special as we can replace it with an 1093 // argument in enough cases to make it worth trying. 1094 Value *GTIdArg = nullptr; 1095 for (Argument &Arg : F->args()) 1096 if (GTIdArgs.count(&Arg)) { 1097 GTIdArg = &Arg; 1098 break; 1099 } 1100 Changed |= deduplicateRuntimeCalls( 1101 *F, OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num], GTIdArg); 1102 } 1103 1104 return Changed; 1105 } 1106 1107 /// Tries to hide the latency of runtime calls that involve host to 1108 /// device memory transfers by splitting them into their "issue" and "wait" 1109 /// versions. The "issue" is moved upwards as much as possible. The "wait" is 1110 /// moved downards as much as possible. The "issue" issues the memory transfer 1111 /// asynchronously, returning a handle. The "wait" waits in the returned 1112 /// handle for the memory transfer to finish. 1113 bool hideMemTransfersLatency() { 1114 auto &RFI = OMPInfoCache.RFIs[OMPRTL___tgt_target_data_begin_mapper]; 1115 bool Changed = false; 1116 auto SplitMemTransfers = [&](Use &U, Function &Decl) { 1117 auto *RTCall = getCallIfRegularCall(U, &RFI); 1118 if (!RTCall) 1119 return false; 1120 1121 OffloadArray OffloadArrays[3]; 1122 if (!getValuesInOffloadArrays(*RTCall, OffloadArrays)) 1123 return false; 1124 1125 LLVM_DEBUG(dumpValuesInOffloadArrays(OffloadArrays)); 1126 1127 // TODO: Check if can be moved upwards. 1128 bool WasSplit = false; 1129 Instruction *WaitMovementPoint = canBeMovedDownwards(*RTCall); 1130 if (WaitMovementPoint) 1131 WasSplit = splitTargetDataBeginRTC(*RTCall, *WaitMovementPoint); 1132 1133 Changed |= WasSplit; 1134 return WasSplit; 1135 }; 1136 RFI.foreachUse(SCC, SplitMemTransfers); 1137 1138 return Changed; 1139 } 1140 1141 void analysisGlobalization() { 1142 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared]; 1143 1144 auto CheckGlobalization = [&](Use &U, Function &Decl) { 1145 if (CallInst *CI = getCallIfRegularCall(U, &RFI)) { 1146 auto Remark = [&](OptimizationRemarkAnalysis ORA) { 1147 return ORA 1148 << "Found thread data sharing on the GPU. " 1149 << "Expect degraded performance due to data globalization."; 1150 }; 1151 emitRemark<OptimizationRemarkAnalysis>(CI, "OpenMPGlobalization", 1152 Remark); 1153 } 1154 1155 return false; 1156 }; 1157 1158 RFI.foreachUse(SCC, CheckGlobalization); 1159 } 1160 1161 /// Maps the values stored in the offload arrays passed as arguments to 1162 /// \p RuntimeCall into the offload arrays in \p OAs. 1163 bool getValuesInOffloadArrays(CallInst &RuntimeCall, 1164 MutableArrayRef<OffloadArray> OAs) { 1165 assert(OAs.size() == 3 && "Need space for three offload arrays!"); 1166 1167 // A runtime call that involves memory offloading looks something like: 1168 // call void @__tgt_target_data_begin_mapper(arg0, arg1, 1169 // i8** %offload_baseptrs, i8** %offload_ptrs, i64* %offload_sizes, 1170 // ...) 1171 // So, the idea is to access the allocas that allocate space for these 1172 // offload arrays, offload_baseptrs, offload_ptrs, offload_sizes. 1173 // Therefore: 1174 // i8** %offload_baseptrs. 1175 Value *BasePtrsArg = 1176 RuntimeCall.getArgOperand(OffloadArray::BasePtrsArgNum); 1177 // i8** %offload_ptrs. 1178 Value *PtrsArg = RuntimeCall.getArgOperand(OffloadArray::PtrsArgNum); 1179 // i8** %offload_sizes. 1180 Value *SizesArg = RuntimeCall.getArgOperand(OffloadArray::SizesArgNum); 1181 1182 // Get values stored in **offload_baseptrs. 1183 auto *V = getUnderlyingObject(BasePtrsArg); 1184 if (!isa<AllocaInst>(V)) 1185 return false; 1186 auto *BasePtrsArray = cast<AllocaInst>(V); 1187 if (!OAs[0].initialize(*BasePtrsArray, RuntimeCall)) 1188 return false; 1189 1190 // Get values stored in **offload_baseptrs. 1191 V = getUnderlyingObject(PtrsArg); 1192 if (!isa<AllocaInst>(V)) 1193 return false; 1194 auto *PtrsArray = cast<AllocaInst>(V); 1195 if (!OAs[1].initialize(*PtrsArray, RuntimeCall)) 1196 return false; 1197 1198 // Get values stored in **offload_sizes. 1199 V = getUnderlyingObject(SizesArg); 1200 // If it's a [constant] global array don't analyze it. 1201 if (isa<GlobalValue>(V)) 1202 return isa<Constant>(V); 1203 if (!isa<AllocaInst>(V)) 1204 return false; 1205 1206 auto *SizesArray = cast<AllocaInst>(V); 1207 if (!OAs[2].initialize(*SizesArray, RuntimeCall)) 1208 return false; 1209 1210 return true; 1211 } 1212 1213 /// Prints the values in the OffloadArrays \p OAs using LLVM_DEBUG. 1214 /// For now this is a way to test that the function getValuesInOffloadArrays 1215 /// is working properly. 1216 /// TODO: Move this to a unittest when unittests are available for OpenMPOpt. 1217 void dumpValuesInOffloadArrays(ArrayRef<OffloadArray> OAs) { 1218 assert(OAs.size() == 3 && "There are three offload arrays to debug!"); 1219 1220 LLVM_DEBUG(dbgs() << TAG << " Successfully got offload values:\n"); 1221 std::string ValuesStr; 1222 raw_string_ostream Printer(ValuesStr); 1223 std::string Separator = " --- "; 1224 1225 for (auto *BP : OAs[0].StoredValues) { 1226 BP->print(Printer); 1227 Printer << Separator; 1228 } 1229 LLVM_DEBUG(dbgs() << "\t\toffload_baseptrs: " << Printer.str() << "\n"); 1230 ValuesStr.clear(); 1231 1232 for (auto *P : OAs[1].StoredValues) { 1233 P->print(Printer); 1234 Printer << Separator; 1235 } 1236 LLVM_DEBUG(dbgs() << "\t\toffload_ptrs: " << Printer.str() << "\n"); 1237 ValuesStr.clear(); 1238 1239 for (auto *S : OAs[2].StoredValues) { 1240 S->print(Printer); 1241 Printer << Separator; 1242 } 1243 LLVM_DEBUG(dbgs() << "\t\toffload_sizes: " << Printer.str() << "\n"); 1244 } 1245 1246 /// Returns the instruction where the "wait" counterpart \p RuntimeCall can be 1247 /// moved. Returns nullptr if the movement is not possible, or not worth it. 1248 Instruction *canBeMovedDownwards(CallInst &RuntimeCall) { 1249 // FIXME: This traverses only the BasicBlock where RuntimeCall is. 1250 // Make it traverse the CFG. 1251 1252 Instruction *CurrentI = &RuntimeCall; 1253 bool IsWorthIt = false; 1254 while ((CurrentI = CurrentI->getNextNode())) { 1255 1256 // TODO: Once we detect the regions to be offloaded we should use the 1257 // alias analysis manager to check if CurrentI may modify one of 1258 // the offloaded regions. 1259 if (CurrentI->mayHaveSideEffects() || CurrentI->mayReadFromMemory()) { 1260 if (IsWorthIt) 1261 return CurrentI; 1262 1263 return nullptr; 1264 } 1265 1266 // FIXME: For now if we move it over anything without side effect 1267 // is worth it. 1268 IsWorthIt = true; 1269 } 1270 1271 // Return end of BasicBlock. 1272 return RuntimeCall.getParent()->getTerminator(); 1273 } 1274 1275 /// Splits \p RuntimeCall into its "issue" and "wait" counterparts. 1276 bool splitTargetDataBeginRTC(CallInst &RuntimeCall, 1277 Instruction &WaitMovementPoint) { 1278 // Create stack allocated handle (__tgt_async_info) at the beginning of the 1279 // function. Used for storing information of the async transfer, allowing to 1280 // wait on it later. 1281 auto &IRBuilder = OMPInfoCache.OMPBuilder; 1282 auto *F = RuntimeCall.getCaller(); 1283 Instruction *FirstInst = &(F->getEntryBlock().front()); 1284 AllocaInst *Handle = new AllocaInst( 1285 IRBuilder.AsyncInfo, F->getAddressSpace(), "handle", FirstInst); 1286 1287 // Add "issue" runtime call declaration: 1288 // declare %struct.tgt_async_info @__tgt_target_data_begin_issue(i64, i32, 1289 // i8**, i8**, i64*, i64*) 1290 FunctionCallee IssueDecl = IRBuilder.getOrCreateRuntimeFunction( 1291 M, OMPRTL___tgt_target_data_begin_mapper_issue); 1292 1293 // Change RuntimeCall call site for its asynchronous version. 1294 SmallVector<Value *, 16> Args; 1295 for (auto &Arg : RuntimeCall.args()) 1296 Args.push_back(Arg.get()); 1297 Args.push_back(Handle); 1298 1299 CallInst *IssueCallsite = 1300 CallInst::Create(IssueDecl, Args, /*NameStr=*/"", &RuntimeCall); 1301 RuntimeCall.eraseFromParent(); 1302 1303 // Add "wait" runtime call declaration: 1304 // declare void @__tgt_target_data_begin_wait(i64, %struct.__tgt_async_info) 1305 FunctionCallee WaitDecl = IRBuilder.getOrCreateRuntimeFunction( 1306 M, OMPRTL___tgt_target_data_begin_mapper_wait); 1307 1308 Value *WaitParams[2] = { 1309 IssueCallsite->getArgOperand( 1310 OffloadArray::DeviceIDArgNum), // device_id. 1311 Handle // handle to wait on. 1312 }; 1313 CallInst::Create(WaitDecl, WaitParams, /*NameStr=*/"", &WaitMovementPoint); 1314 1315 return true; 1316 } 1317 1318 static Value *combinedIdentStruct(Value *CurrentIdent, Value *NextIdent, 1319 bool GlobalOnly, bool &SingleChoice) { 1320 if (CurrentIdent == NextIdent) 1321 return CurrentIdent; 1322 1323 // TODO: Figure out how to actually combine multiple debug locations. For 1324 // now we just keep an existing one if there is a single choice. 1325 if (!GlobalOnly || isa<GlobalValue>(NextIdent)) { 1326 SingleChoice = !CurrentIdent; 1327 return NextIdent; 1328 } 1329 return nullptr; 1330 } 1331 1332 /// Return an `struct ident_t*` value that represents the ones used in the 1333 /// calls of \p RFI inside of \p F. If \p GlobalOnly is true, we will not 1334 /// return a local `struct ident_t*`. For now, if we cannot find a suitable 1335 /// return value we create one from scratch. We also do not yet combine 1336 /// information, e.g., the source locations, see combinedIdentStruct. 1337 Value * 1338 getCombinedIdentFromCallUsesIn(OMPInformationCache::RuntimeFunctionInfo &RFI, 1339 Function &F, bool GlobalOnly) { 1340 bool SingleChoice = true; 1341 Value *Ident = nullptr; 1342 auto CombineIdentStruct = [&](Use &U, Function &Caller) { 1343 CallInst *CI = getCallIfRegularCall(U, &RFI); 1344 if (!CI || &F != &Caller) 1345 return false; 1346 Ident = combinedIdentStruct(Ident, CI->getArgOperand(0), 1347 /* GlobalOnly */ true, SingleChoice); 1348 return false; 1349 }; 1350 RFI.foreachUse(SCC, CombineIdentStruct); 1351 1352 if (!Ident || !SingleChoice) { 1353 // The IRBuilder uses the insertion block to get to the module, this is 1354 // unfortunate but we work around it for now. 1355 if (!OMPInfoCache.OMPBuilder.getInsertionPoint().getBlock()) 1356 OMPInfoCache.OMPBuilder.updateToLocation(OpenMPIRBuilder::InsertPointTy( 1357 &F.getEntryBlock(), F.getEntryBlock().begin())); 1358 // Create a fallback location if non was found. 1359 // TODO: Use the debug locations of the calls instead. 1360 Constant *Loc = OMPInfoCache.OMPBuilder.getOrCreateDefaultSrcLocStr(); 1361 Ident = OMPInfoCache.OMPBuilder.getOrCreateIdent(Loc); 1362 } 1363 return Ident; 1364 } 1365 1366 /// Try to eliminate calls of \p RFI in \p F by reusing an existing one or 1367 /// \p ReplVal if given. 1368 bool deduplicateRuntimeCalls(Function &F, 1369 OMPInformationCache::RuntimeFunctionInfo &RFI, 1370 Value *ReplVal = nullptr) { 1371 auto *UV = RFI.getUseVector(F); 1372 if (!UV || UV->size() + (ReplVal != nullptr) < 2) 1373 return false; 1374 1375 LLVM_DEBUG( 1376 dbgs() << TAG << "Deduplicate " << UV->size() << " uses of " << RFI.Name 1377 << (ReplVal ? " with an existing value\n" : "\n") << "\n"); 1378 1379 assert((!ReplVal || (isa<Argument>(ReplVal) && 1380 cast<Argument>(ReplVal)->getParent() == &F)) && 1381 "Unexpected replacement value!"); 1382 1383 // TODO: Use dominance to find a good position instead. 1384 auto CanBeMoved = [this](CallBase &CB) { 1385 unsigned NumArgs = CB.getNumArgOperands(); 1386 if (NumArgs == 0) 1387 return true; 1388 if (CB.getArgOperand(0)->getType() != OMPInfoCache.OMPBuilder.IdentPtr) 1389 return false; 1390 for (unsigned u = 1; u < NumArgs; ++u) 1391 if (isa<Instruction>(CB.getArgOperand(u))) 1392 return false; 1393 return true; 1394 }; 1395 1396 if (!ReplVal) { 1397 for (Use *U : *UV) 1398 if (CallInst *CI = getCallIfRegularCall(*U, &RFI)) { 1399 if (!CanBeMoved(*CI)) 1400 continue; 1401 1402 auto Remark = [&](OptimizationRemark OR) { 1403 return OR << "OpenMP runtime call " 1404 << ore::NV("OpenMPOptRuntime", RFI.Name) 1405 << " moved to beginning of OpenMP region"; 1406 }; 1407 emitRemark<OptimizationRemark>(&F, "OpenMPRuntimeCodeMotion", Remark); 1408 1409 CI->moveBefore(&*F.getEntryBlock().getFirstInsertionPt()); 1410 ReplVal = CI; 1411 break; 1412 } 1413 if (!ReplVal) 1414 return false; 1415 } 1416 1417 // If we use a call as a replacement value we need to make sure the ident is 1418 // valid at the new location. For now we just pick a global one, either 1419 // existing and used by one of the calls, or created from scratch. 1420 if (CallBase *CI = dyn_cast<CallBase>(ReplVal)) { 1421 if (CI->getNumArgOperands() > 0 && 1422 CI->getArgOperand(0)->getType() == OMPInfoCache.OMPBuilder.IdentPtr) { 1423 Value *Ident = getCombinedIdentFromCallUsesIn(RFI, F, 1424 /* GlobalOnly */ true); 1425 CI->setArgOperand(0, Ident); 1426 } 1427 } 1428 1429 bool Changed = false; 1430 auto ReplaceAndDeleteCB = [&](Use &U, Function &Caller) { 1431 CallInst *CI = getCallIfRegularCall(U, &RFI); 1432 if (!CI || CI == ReplVal || &F != &Caller) 1433 return false; 1434 assert(CI->getCaller() == &F && "Unexpected call!"); 1435 1436 auto Remark = [&](OptimizationRemark OR) { 1437 return OR << "OpenMP runtime call " 1438 << ore::NV("OpenMPOptRuntime", RFI.Name) << " deduplicated"; 1439 }; 1440 emitRemark<OptimizationRemark>(&F, "OpenMPRuntimeDeduplicated", Remark); 1441 1442 CGUpdater.removeCallSite(*CI); 1443 CI->replaceAllUsesWith(ReplVal); 1444 CI->eraseFromParent(); 1445 ++NumOpenMPRuntimeCallsDeduplicated; 1446 Changed = true; 1447 return true; 1448 }; 1449 RFI.foreachUse(SCC, ReplaceAndDeleteCB); 1450 1451 return Changed; 1452 } 1453 1454 /// Collect arguments that represent the global thread id in \p GTIdArgs. 1455 void collectGlobalThreadIdArguments(SmallSetVector<Value *, 16> >IdArgs) { 1456 // TODO: Below we basically perform a fixpoint iteration with a pessimistic 1457 // initialization. We could define an AbstractAttribute instead and 1458 // run the Attributor here once it can be run as an SCC pass. 1459 1460 // Helper to check the argument \p ArgNo at all call sites of \p F for 1461 // a GTId. 1462 auto CallArgOpIsGTId = [&](Function &F, unsigned ArgNo, CallInst &RefCI) { 1463 if (!F.hasLocalLinkage()) 1464 return false; 1465 for (Use &U : F.uses()) { 1466 if (CallInst *CI = getCallIfRegularCall(U)) { 1467 Value *ArgOp = CI->getArgOperand(ArgNo); 1468 if (CI == &RefCI || GTIdArgs.count(ArgOp) || 1469 getCallIfRegularCall( 1470 *ArgOp, &OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num])) 1471 continue; 1472 } 1473 return false; 1474 } 1475 return true; 1476 }; 1477 1478 // Helper to identify uses of a GTId as GTId arguments. 1479 auto AddUserArgs = [&](Value >Id) { 1480 for (Use &U : GTId.uses()) 1481 if (CallInst *CI = dyn_cast<CallInst>(U.getUser())) 1482 if (CI->isArgOperand(&U)) 1483 if (Function *Callee = CI->getCalledFunction()) 1484 if (CallArgOpIsGTId(*Callee, U.getOperandNo(), *CI)) 1485 GTIdArgs.insert(Callee->getArg(U.getOperandNo())); 1486 }; 1487 1488 // The argument users of __kmpc_global_thread_num calls are GTIds. 1489 OMPInformationCache::RuntimeFunctionInfo &GlobThreadNumRFI = 1490 OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num]; 1491 1492 GlobThreadNumRFI.foreachUse(SCC, [&](Use &U, Function &F) { 1493 if (CallInst *CI = getCallIfRegularCall(U, &GlobThreadNumRFI)) 1494 AddUserArgs(*CI); 1495 return false; 1496 }); 1497 1498 // Transitively search for more arguments by looking at the users of the 1499 // ones we know already. During the search the GTIdArgs vector is extended 1500 // so we cannot cache the size nor can we use a range based for. 1501 for (unsigned u = 0; u < GTIdArgs.size(); ++u) 1502 AddUserArgs(*GTIdArgs[u]); 1503 } 1504 1505 /// Kernel (=GPU) optimizations and utility functions 1506 /// 1507 ///{{ 1508 1509 /// Check if \p F is a kernel, hence entry point for target offloading. 1510 bool isKernel(Function &F) { return OMPInfoCache.Kernels.count(&F); } 1511 1512 /// Cache to remember the unique kernel for a function. 1513 DenseMap<Function *, Optional<Kernel>> UniqueKernelMap; 1514 1515 /// Find the unique kernel that will execute \p F, if any. 1516 Kernel getUniqueKernelFor(Function &F); 1517 1518 /// Find the unique kernel that will execute \p I, if any. 1519 Kernel getUniqueKernelFor(Instruction &I) { 1520 return getUniqueKernelFor(*I.getFunction()); 1521 } 1522 1523 /// Rewrite the device (=GPU) code state machine create in non-SPMD mode in 1524 /// the cases we can avoid taking the address of a function. 1525 bool rewriteDeviceCodeStateMachine(); 1526 1527 /// 1528 ///}} 1529 1530 /// Emit a remark generically 1531 /// 1532 /// This template function can be used to generically emit a remark. The 1533 /// RemarkKind should be one of the following: 1534 /// - OptimizationRemark to indicate a successful optimization attempt 1535 /// - OptimizationRemarkMissed to report a failed optimization attempt 1536 /// - OptimizationRemarkAnalysis to provide additional information about an 1537 /// optimization attempt 1538 /// 1539 /// The remark is built using a callback function provided by the caller that 1540 /// takes a RemarkKind as input and returns a RemarkKind. 1541 template <typename RemarkKind, typename RemarkCallBack> 1542 void emitRemark(Instruction *I, StringRef RemarkName, 1543 RemarkCallBack &&RemarkCB) const { 1544 Function *F = I->getParent()->getParent(); 1545 auto &ORE = OREGetter(F); 1546 1547 ORE.emit([&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I)); }); 1548 } 1549 1550 /// Emit a remark on a function. 1551 template <typename RemarkKind, typename RemarkCallBack> 1552 void emitRemark(Function *F, StringRef RemarkName, 1553 RemarkCallBack &&RemarkCB) const { 1554 auto &ORE = OREGetter(F); 1555 1556 ORE.emit([&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F)); }); 1557 } 1558 1559 /// The underlying module. 1560 Module &M; 1561 1562 /// The SCC we are operating on. 1563 SmallVectorImpl<Function *> &SCC; 1564 1565 /// Callback to update the call graph, the first argument is a removed call, 1566 /// the second an optional replacement call. 1567 CallGraphUpdater &CGUpdater; 1568 1569 /// Callback to get an OptimizationRemarkEmitter from a Function * 1570 OptimizationRemarkGetter OREGetter; 1571 1572 /// OpenMP-specific information cache. Also Used for Attributor runs. 1573 OMPInformationCache &OMPInfoCache; 1574 1575 /// Attributor instance. 1576 Attributor &A; 1577 1578 /// Helper function to run Attributor on SCC. 1579 bool runAttributor() { 1580 if (SCC.empty()) 1581 return false; 1582 1583 registerAAs(); 1584 1585 ChangeStatus Changed = A.run(); 1586 1587 LLVM_DEBUG(dbgs() << "[Attributor] Done with " << SCC.size() 1588 << " functions, result: " << Changed << ".\n"); 1589 1590 return Changed == ChangeStatus::CHANGED; 1591 } 1592 1593 /// Populate the Attributor with abstract attribute opportunities in the 1594 /// function. 1595 void registerAAs() { 1596 if (SCC.empty()) 1597 return; 1598 1599 // Create CallSite AA for all Getters. 1600 for (int Idx = 0; Idx < OMPInfoCache.ICVs.size() - 1; ++Idx) { 1601 auto ICVInfo = OMPInfoCache.ICVs[static_cast<InternalControlVar>(Idx)]; 1602 1603 auto &GetterRFI = OMPInfoCache.RFIs[ICVInfo.Getter]; 1604 1605 auto CreateAA = [&](Use &U, Function &Caller) { 1606 CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &GetterRFI); 1607 if (!CI) 1608 return false; 1609 1610 auto &CB = cast<CallBase>(*CI); 1611 1612 IRPosition CBPos = IRPosition::callsite_function(CB); 1613 A.getOrCreateAAFor<AAICVTracker>(CBPos); 1614 return false; 1615 }; 1616 1617 GetterRFI.foreachUse(SCC, CreateAA); 1618 } 1619 auto &GlobalizationRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared]; 1620 auto CreateAA = [&](Use &U, Function &F) { 1621 A.getOrCreateAAFor<AAHeapToShared>(IRPosition::function(F)); 1622 return false; 1623 }; 1624 GlobalizationRFI.foreachUse(SCC, CreateAA); 1625 1626 // Create an ExecutionDomain AA for every function and a HeapToStack AA for 1627 // every function if there is a device kernel. 1628 for (auto *F : SCC) { 1629 if (!F->isDeclaration()) 1630 A.getOrCreateAAFor<AAExecutionDomain>(IRPosition::function(*F)); 1631 if (!OMPInfoCache.Kernels.empty()) 1632 A.getOrCreateAAFor<AAHeapToStack>(IRPosition::function(*F)); 1633 } 1634 } 1635 }; 1636 1637 Kernel OpenMPOpt::getUniqueKernelFor(Function &F) { 1638 if (!OMPInfoCache.ModuleSlice.count(&F)) 1639 return nullptr; 1640 1641 // Use a scope to keep the lifetime of the CachedKernel short. 1642 { 1643 Optional<Kernel> &CachedKernel = UniqueKernelMap[&F]; 1644 if (CachedKernel) 1645 return *CachedKernel; 1646 1647 // TODO: We should use an AA to create an (optimistic and callback 1648 // call-aware) call graph. For now we stick to simple patterns that 1649 // are less powerful, basically the worst fixpoint. 1650 if (isKernel(F)) { 1651 CachedKernel = Kernel(&F); 1652 return *CachedKernel; 1653 } 1654 1655 CachedKernel = nullptr; 1656 if (!F.hasLocalLinkage()) { 1657 1658 // See https://openmp.llvm.org/remarks/OptimizationRemarks.html 1659 auto Remark = [&](OptimizationRemarkAnalysis ORA) { 1660 return ORA 1661 << "[OMP100] Potentially unknown OpenMP target region caller"; 1662 }; 1663 emitRemark<OptimizationRemarkAnalysis>(&F, "OMP100", Remark); 1664 1665 return nullptr; 1666 } 1667 } 1668 1669 auto GetUniqueKernelForUse = [&](const Use &U) -> Kernel { 1670 if (auto *Cmp = dyn_cast<ICmpInst>(U.getUser())) { 1671 // Allow use in equality comparisons. 1672 if (Cmp->isEquality()) 1673 return getUniqueKernelFor(*Cmp); 1674 return nullptr; 1675 } 1676 if (auto *CB = dyn_cast<CallBase>(U.getUser())) { 1677 // Allow direct calls. 1678 if (CB->isCallee(&U)) 1679 return getUniqueKernelFor(*CB); 1680 1681 OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI = 1682 OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51]; 1683 // Allow the use in __kmpc_parallel_51 calls. 1684 if (OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI)) 1685 return getUniqueKernelFor(*CB); 1686 return nullptr; 1687 } 1688 // Disallow every other use. 1689 return nullptr; 1690 }; 1691 1692 // TODO: In the future we want to track more than just a unique kernel. 1693 SmallPtrSet<Kernel, 2> PotentialKernels; 1694 OMPInformationCache::foreachUse(F, [&](const Use &U) { 1695 PotentialKernels.insert(GetUniqueKernelForUse(U)); 1696 }); 1697 1698 Kernel K = nullptr; 1699 if (PotentialKernels.size() == 1) 1700 K = *PotentialKernels.begin(); 1701 1702 // Cache the result. 1703 UniqueKernelMap[&F] = K; 1704 1705 return K; 1706 } 1707 1708 bool OpenMPOpt::rewriteDeviceCodeStateMachine() { 1709 OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI = 1710 OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51]; 1711 1712 bool Changed = false; 1713 if (!KernelParallelRFI) 1714 return Changed; 1715 1716 for (Function *F : SCC) { 1717 1718 // Check if the function is a use in a __kmpc_parallel_51 call at 1719 // all. 1720 bool UnknownUse = false; 1721 bool KernelParallelUse = false; 1722 unsigned NumDirectCalls = 0; 1723 1724 SmallVector<Use *, 2> ToBeReplacedStateMachineUses; 1725 OMPInformationCache::foreachUse(*F, [&](Use &U) { 1726 if (auto *CB = dyn_cast<CallBase>(U.getUser())) 1727 if (CB->isCallee(&U)) { 1728 ++NumDirectCalls; 1729 return; 1730 } 1731 1732 if (isa<ICmpInst>(U.getUser())) { 1733 ToBeReplacedStateMachineUses.push_back(&U); 1734 return; 1735 } 1736 1737 // Find wrapper functions that represent parallel kernels. 1738 CallInst *CI = 1739 OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI); 1740 const unsigned int WrapperFunctionArgNo = 6; 1741 if (!KernelParallelUse && CI && 1742 CI->getArgOperandNo(&U) == WrapperFunctionArgNo) { 1743 KernelParallelUse = true; 1744 ToBeReplacedStateMachineUses.push_back(&U); 1745 return; 1746 } 1747 UnknownUse = true; 1748 }); 1749 1750 // Do not emit a remark if we haven't seen a __kmpc_parallel_51 1751 // use. 1752 if (!KernelParallelUse) 1753 continue; 1754 1755 { 1756 auto Remark = [&](OptimizationRemarkAnalysis ORA) { 1757 return ORA << "Found a parallel region that is called in a target " 1758 "region but not part of a combined target construct nor " 1759 "nested inside a target construct without intermediate " 1760 "code. This can lead to excessive register usage for " 1761 "unrelated target regions in the same translation unit " 1762 "due to spurious call edges assumed by ptxas."; 1763 }; 1764 emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPParallelRegionInNonSPMD", 1765 Remark); 1766 } 1767 1768 // If this ever hits, we should investigate. 1769 // TODO: Checking the number of uses is not a necessary restriction and 1770 // should be lifted. 1771 if (UnknownUse || NumDirectCalls != 1 || 1772 ToBeReplacedStateMachineUses.size() != 2) { 1773 { 1774 auto Remark = [&](OptimizationRemarkAnalysis ORA) { 1775 return ORA << "Parallel region is used in " 1776 << (UnknownUse ? "unknown" : "unexpected") 1777 << " ways; will not attempt to rewrite the state machine."; 1778 }; 1779 emitRemark<OptimizationRemarkAnalysis>( 1780 F, "OpenMPParallelRegionInNonSPMD", Remark); 1781 } 1782 continue; 1783 } 1784 1785 // Even if we have __kmpc_parallel_51 calls, we (for now) give 1786 // up if the function is not called from a unique kernel. 1787 Kernel K = getUniqueKernelFor(*F); 1788 if (!K) { 1789 { 1790 auto Remark = [&](OptimizationRemarkAnalysis ORA) { 1791 return ORA << "Parallel region is not known to be called from a " 1792 "unique single target region, maybe the surrounding " 1793 "function has external linkage?; will not attempt to " 1794 "rewrite the state machine use."; 1795 }; 1796 emitRemark<OptimizationRemarkAnalysis>( 1797 F, "OpenMPParallelRegionInMultipleKernesl", Remark); 1798 } 1799 continue; 1800 } 1801 1802 // We now know F is a parallel body function called only from the kernel K. 1803 // We also identified the state machine uses in which we replace the 1804 // function pointer by a new global symbol for identification purposes. This 1805 // ensures only direct calls to the function are left. 1806 1807 { 1808 auto RemarkParalleRegion = [&](OptimizationRemarkAnalysis ORA) { 1809 return ORA << "Specialize parallel region that is only reached from a " 1810 "single target region to avoid spurious call edges and " 1811 "excessive register usage in other target regions. " 1812 "(parallel region ID: " 1813 << ore::NV("OpenMPParallelRegion", F->getName()) 1814 << ", kernel ID: " 1815 << ore::NV("OpenMPTargetRegion", K->getName()) << ")"; 1816 }; 1817 emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPParallelRegionInNonSPMD", 1818 RemarkParalleRegion); 1819 auto RemarkKernel = [&](OptimizationRemarkAnalysis ORA) { 1820 return ORA << "Target region containing the parallel region that is " 1821 "specialized. (parallel region ID: " 1822 << ore::NV("OpenMPParallelRegion", F->getName()) 1823 << ", kernel ID: " 1824 << ore::NV("OpenMPTargetRegion", K->getName()) << ")"; 1825 }; 1826 emitRemark<OptimizationRemarkAnalysis>(K, "OpenMPParallelRegionInNonSPMD", 1827 RemarkKernel); 1828 } 1829 1830 Module &M = *F->getParent(); 1831 Type *Int8Ty = Type::getInt8Ty(M.getContext()); 1832 1833 auto *ID = new GlobalVariable( 1834 M, Int8Ty, /* isConstant */ true, GlobalValue::PrivateLinkage, 1835 UndefValue::get(Int8Ty), F->getName() + ".ID"); 1836 1837 for (Use *U : ToBeReplacedStateMachineUses) 1838 U->set(ConstantExpr::getBitCast(ID, U->get()->getType())); 1839 1840 ++NumOpenMPParallelRegionsReplacedInGPUStateMachine; 1841 1842 Changed = true; 1843 } 1844 1845 return Changed; 1846 } 1847 1848 /// Abstract Attribute for tracking ICV values. 1849 struct AAICVTracker : public StateWrapper<BooleanState, AbstractAttribute> { 1850 using Base = StateWrapper<BooleanState, AbstractAttribute>; 1851 AAICVTracker(const IRPosition &IRP, Attributor &A) : Base(IRP) {} 1852 1853 void initialize(Attributor &A) override { 1854 Function *F = getAnchorScope(); 1855 if (!F || !A.isFunctionIPOAmendable(*F)) 1856 indicatePessimisticFixpoint(); 1857 } 1858 1859 /// Returns true if value is assumed to be tracked. 1860 bool isAssumedTracked() const { return getAssumed(); } 1861 1862 /// Returns true if value is known to be tracked. 1863 bool isKnownTracked() const { return getAssumed(); } 1864 1865 /// Create an abstract attribute biew for the position \p IRP. 1866 static AAICVTracker &createForPosition(const IRPosition &IRP, Attributor &A); 1867 1868 /// Return the value with which \p I can be replaced for specific \p ICV. 1869 virtual Optional<Value *> getReplacementValue(InternalControlVar ICV, 1870 const Instruction *I, 1871 Attributor &A) const { 1872 return None; 1873 } 1874 1875 /// Return an assumed unique ICV value if a single candidate is found. If 1876 /// there cannot be one, return a nullptr. If it is not clear yet, return the 1877 /// Optional::NoneType. 1878 virtual Optional<Value *> 1879 getUniqueReplacementValue(InternalControlVar ICV) const = 0; 1880 1881 // Currently only nthreads is being tracked. 1882 // this array will only grow with time. 1883 InternalControlVar TrackableICVs[1] = {ICV_nthreads}; 1884 1885 /// See AbstractAttribute::getName() 1886 const std::string getName() const override { return "AAICVTracker"; } 1887 1888 /// See AbstractAttribute::getIdAddr() 1889 const char *getIdAddr() const override { return &ID; } 1890 1891 /// This function should return true if the type of the \p AA is AAICVTracker 1892 static bool classof(const AbstractAttribute *AA) { 1893 return (AA->getIdAddr() == &ID); 1894 } 1895 1896 static const char ID; 1897 }; 1898 1899 struct AAICVTrackerFunction : public AAICVTracker { 1900 AAICVTrackerFunction(const IRPosition &IRP, Attributor &A) 1901 : AAICVTracker(IRP, A) {} 1902 1903 // FIXME: come up with better string. 1904 const std::string getAsStr() const override { return "ICVTrackerFunction"; } 1905 1906 // FIXME: come up with some stats. 1907 void trackStatistics() const override {} 1908 1909 /// We don't manifest anything for this AA. 1910 ChangeStatus manifest(Attributor &A) override { 1911 return ChangeStatus::UNCHANGED; 1912 } 1913 1914 // Map of ICV to their values at specific program point. 1915 EnumeratedArray<DenseMap<Instruction *, Value *>, InternalControlVar, 1916 InternalControlVar::ICV___last> 1917 ICVReplacementValuesMap; 1918 1919 ChangeStatus updateImpl(Attributor &A) override { 1920 ChangeStatus HasChanged = ChangeStatus::UNCHANGED; 1921 1922 Function *F = getAnchorScope(); 1923 1924 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); 1925 1926 for (InternalControlVar ICV : TrackableICVs) { 1927 auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter]; 1928 1929 auto &ValuesMap = ICVReplacementValuesMap[ICV]; 1930 auto TrackValues = [&](Use &U, Function &) { 1931 CallInst *CI = OpenMPOpt::getCallIfRegularCall(U); 1932 if (!CI) 1933 return false; 1934 1935 // FIXME: handle setters with more that 1 arguments. 1936 /// Track new value. 1937 if (ValuesMap.insert(std::make_pair(CI, CI->getArgOperand(0))).second) 1938 HasChanged = ChangeStatus::CHANGED; 1939 1940 return false; 1941 }; 1942 1943 auto CallCheck = [&](Instruction &I) { 1944 Optional<Value *> ReplVal = getValueForCall(A, &I, ICV); 1945 if (ReplVal.hasValue() && 1946 ValuesMap.insert(std::make_pair(&I, *ReplVal)).second) 1947 HasChanged = ChangeStatus::CHANGED; 1948 1949 return true; 1950 }; 1951 1952 // Track all changes of an ICV. 1953 SetterRFI.foreachUse(TrackValues, F); 1954 1955 A.checkForAllInstructions(CallCheck, *this, {Instruction::Call}, 1956 /* CheckBBLivenessOnly */ true); 1957 1958 /// TODO: Figure out a way to avoid adding entry in 1959 /// ICVReplacementValuesMap 1960 Instruction *Entry = &F->getEntryBlock().front(); 1961 if (HasChanged == ChangeStatus::CHANGED && !ValuesMap.count(Entry)) 1962 ValuesMap.insert(std::make_pair(Entry, nullptr)); 1963 } 1964 1965 return HasChanged; 1966 } 1967 1968 /// Hepler to check if \p I is a call and get the value for it if it is 1969 /// unique. 1970 Optional<Value *> getValueForCall(Attributor &A, const Instruction *I, 1971 InternalControlVar &ICV) const { 1972 1973 const auto *CB = dyn_cast<CallBase>(I); 1974 if (!CB || CB->hasFnAttr("no_openmp") || 1975 CB->hasFnAttr("no_openmp_routines")) 1976 return None; 1977 1978 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); 1979 auto &GetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Getter]; 1980 auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter]; 1981 Function *CalledFunction = CB->getCalledFunction(); 1982 1983 // Indirect call, assume ICV changes. 1984 if (CalledFunction == nullptr) 1985 return nullptr; 1986 if (CalledFunction == GetterRFI.Declaration) 1987 return None; 1988 if (CalledFunction == SetterRFI.Declaration) { 1989 if (ICVReplacementValuesMap[ICV].count(I)) 1990 return ICVReplacementValuesMap[ICV].lookup(I); 1991 1992 return nullptr; 1993 } 1994 1995 // Since we don't know, assume it changes the ICV. 1996 if (CalledFunction->isDeclaration()) 1997 return nullptr; 1998 1999 const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>( 2000 *this, IRPosition::callsite_returned(*CB), DepClassTy::REQUIRED); 2001 2002 if (ICVTrackingAA.isAssumedTracked()) 2003 return ICVTrackingAA.getUniqueReplacementValue(ICV); 2004 2005 // If we don't know, assume it changes. 2006 return nullptr; 2007 } 2008 2009 // We don't check unique value for a function, so return None. 2010 Optional<Value *> 2011 getUniqueReplacementValue(InternalControlVar ICV) const override { 2012 return None; 2013 } 2014 2015 /// Return the value with which \p I can be replaced for specific \p ICV. 2016 Optional<Value *> getReplacementValue(InternalControlVar ICV, 2017 const Instruction *I, 2018 Attributor &A) const override { 2019 const auto &ValuesMap = ICVReplacementValuesMap[ICV]; 2020 if (ValuesMap.count(I)) 2021 return ValuesMap.lookup(I); 2022 2023 SmallVector<const Instruction *, 16> Worklist; 2024 SmallPtrSet<const Instruction *, 16> Visited; 2025 Worklist.push_back(I); 2026 2027 Optional<Value *> ReplVal; 2028 2029 while (!Worklist.empty()) { 2030 const Instruction *CurrInst = Worklist.pop_back_val(); 2031 if (!Visited.insert(CurrInst).second) 2032 continue; 2033 2034 const BasicBlock *CurrBB = CurrInst->getParent(); 2035 2036 // Go up and look for all potential setters/calls that might change the 2037 // ICV. 2038 while ((CurrInst = CurrInst->getPrevNode())) { 2039 if (ValuesMap.count(CurrInst)) { 2040 Optional<Value *> NewReplVal = ValuesMap.lookup(CurrInst); 2041 // Unknown value, track new. 2042 if (!ReplVal.hasValue()) { 2043 ReplVal = NewReplVal; 2044 break; 2045 } 2046 2047 // If we found a new value, we can't know the icv value anymore. 2048 if (NewReplVal.hasValue()) 2049 if (ReplVal != NewReplVal) 2050 return nullptr; 2051 2052 break; 2053 } 2054 2055 Optional<Value *> NewReplVal = getValueForCall(A, CurrInst, ICV); 2056 if (!NewReplVal.hasValue()) 2057 continue; 2058 2059 // Unknown value, track new. 2060 if (!ReplVal.hasValue()) { 2061 ReplVal = NewReplVal; 2062 break; 2063 } 2064 2065 // if (NewReplVal.hasValue()) 2066 // We found a new value, we can't know the icv value anymore. 2067 if (ReplVal != NewReplVal) 2068 return nullptr; 2069 } 2070 2071 // If we are in the same BB and we have a value, we are done. 2072 if (CurrBB == I->getParent() && ReplVal.hasValue()) 2073 return ReplVal; 2074 2075 // Go through all predecessors and add terminators for analysis. 2076 for (const BasicBlock *Pred : predecessors(CurrBB)) 2077 if (const Instruction *Terminator = Pred->getTerminator()) 2078 Worklist.push_back(Terminator); 2079 } 2080 2081 return ReplVal; 2082 } 2083 }; 2084 2085 struct AAICVTrackerFunctionReturned : AAICVTracker { 2086 AAICVTrackerFunctionReturned(const IRPosition &IRP, Attributor &A) 2087 : AAICVTracker(IRP, A) {} 2088 2089 // FIXME: come up with better string. 2090 const std::string getAsStr() const override { 2091 return "ICVTrackerFunctionReturned"; 2092 } 2093 2094 // FIXME: come up with some stats. 2095 void trackStatistics() const override {} 2096 2097 /// We don't manifest anything for this AA. 2098 ChangeStatus manifest(Attributor &A) override { 2099 return ChangeStatus::UNCHANGED; 2100 } 2101 2102 // Map of ICV to their values at specific program point. 2103 EnumeratedArray<Optional<Value *>, InternalControlVar, 2104 InternalControlVar::ICV___last> 2105 ICVReplacementValuesMap; 2106 2107 /// Return the value with which \p I can be replaced for specific \p ICV. 2108 Optional<Value *> 2109 getUniqueReplacementValue(InternalControlVar ICV) const override { 2110 return ICVReplacementValuesMap[ICV]; 2111 } 2112 2113 ChangeStatus updateImpl(Attributor &A) override { 2114 ChangeStatus Changed = ChangeStatus::UNCHANGED; 2115 const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>( 2116 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED); 2117 2118 if (!ICVTrackingAA.isAssumedTracked()) 2119 return indicatePessimisticFixpoint(); 2120 2121 for (InternalControlVar ICV : TrackableICVs) { 2122 Optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV]; 2123 Optional<Value *> UniqueICVValue; 2124 2125 auto CheckReturnInst = [&](Instruction &I) { 2126 Optional<Value *> NewReplVal = 2127 ICVTrackingAA.getReplacementValue(ICV, &I, A); 2128 2129 // If we found a second ICV value there is no unique returned value. 2130 if (UniqueICVValue.hasValue() && UniqueICVValue != NewReplVal) 2131 return false; 2132 2133 UniqueICVValue = NewReplVal; 2134 2135 return true; 2136 }; 2137 2138 if (!A.checkForAllInstructions(CheckReturnInst, *this, {Instruction::Ret}, 2139 /* CheckBBLivenessOnly */ true)) 2140 UniqueICVValue = nullptr; 2141 2142 if (UniqueICVValue == ReplVal) 2143 continue; 2144 2145 ReplVal = UniqueICVValue; 2146 Changed = ChangeStatus::CHANGED; 2147 } 2148 2149 return Changed; 2150 } 2151 }; 2152 2153 struct AAICVTrackerCallSite : AAICVTracker { 2154 AAICVTrackerCallSite(const IRPosition &IRP, Attributor &A) 2155 : AAICVTracker(IRP, A) {} 2156 2157 void initialize(Attributor &A) override { 2158 Function *F = getAnchorScope(); 2159 if (!F || !A.isFunctionIPOAmendable(*F)) 2160 indicatePessimisticFixpoint(); 2161 2162 // We only initialize this AA for getters, so we need to know which ICV it 2163 // gets. 2164 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); 2165 for (InternalControlVar ICV : TrackableICVs) { 2166 auto ICVInfo = OMPInfoCache.ICVs[ICV]; 2167 auto &Getter = OMPInfoCache.RFIs[ICVInfo.Getter]; 2168 if (Getter.Declaration == getAssociatedFunction()) { 2169 AssociatedICV = ICVInfo.Kind; 2170 return; 2171 } 2172 } 2173 2174 /// Unknown ICV. 2175 indicatePessimisticFixpoint(); 2176 } 2177 2178 ChangeStatus manifest(Attributor &A) override { 2179 if (!ReplVal.hasValue() || !ReplVal.getValue()) 2180 return ChangeStatus::UNCHANGED; 2181 2182 A.changeValueAfterManifest(*getCtxI(), **ReplVal); 2183 A.deleteAfterManifest(*getCtxI()); 2184 2185 return ChangeStatus::CHANGED; 2186 } 2187 2188 // FIXME: come up with better string. 2189 const std::string getAsStr() const override { return "ICVTrackerCallSite"; } 2190 2191 // FIXME: come up with some stats. 2192 void trackStatistics() const override {} 2193 2194 InternalControlVar AssociatedICV; 2195 Optional<Value *> ReplVal; 2196 2197 ChangeStatus updateImpl(Attributor &A) override { 2198 const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>( 2199 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED); 2200 2201 // We don't have any information, so we assume it changes the ICV. 2202 if (!ICVTrackingAA.isAssumedTracked()) 2203 return indicatePessimisticFixpoint(); 2204 2205 Optional<Value *> NewReplVal = 2206 ICVTrackingAA.getReplacementValue(AssociatedICV, getCtxI(), A); 2207 2208 if (ReplVal == NewReplVal) 2209 return ChangeStatus::UNCHANGED; 2210 2211 ReplVal = NewReplVal; 2212 return ChangeStatus::CHANGED; 2213 } 2214 2215 // Return the value with which associated value can be replaced for specific 2216 // \p ICV. 2217 Optional<Value *> 2218 getUniqueReplacementValue(InternalControlVar ICV) const override { 2219 return ReplVal; 2220 } 2221 }; 2222 2223 struct AAICVTrackerCallSiteReturned : AAICVTracker { 2224 AAICVTrackerCallSiteReturned(const IRPosition &IRP, Attributor &A) 2225 : AAICVTracker(IRP, A) {} 2226 2227 // FIXME: come up with better string. 2228 const std::string getAsStr() const override { 2229 return "ICVTrackerCallSiteReturned"; 2230 } 2231 2232 // FIXME: come up with some stats. 2233 void trackStatistics() const override {} 2234 2235 /// We don't manifest anything for this AA. 2236 ChangeStatus manifest(Attributor &A) override { 2237 return ChangeStatus::UNCHANGED; 2238 } 2239 2240 // Map of ICV to their values at specific program point. 2241 EnumeratedArray<Optional<Value *>, InternalControlVar, 2242 InternalControlVar::ICV___last> 2243 ICVReplacementValuesMap; 2244 2245 /// Return the value with which associated value can be replaced for specific 2246 /// \p ICV. 2247 Optional<Value *> 2248 getUniqueReplacementValue(InternalControlVar ICV) const override { 2249 return ICVReplacementValuesMap[ICV]; 2250 } 2251 2252 ChangeStatus updateImpl(Attributor &A) override { 2253 ChangeStatus Changed = ChangeStatus::UNCHANGED; 2254 const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>( 2255 *this, IRPosition::returned(*getAssociatedFunction()), 2256 DepClassTy::REQUIRED); 2257 2258 // We don't have any information, so we assume it changes the ICV. 2259 if (!ICVTrackingAA.isAssumedTracked()) 2260 return indicatePessimisticFixpoint(); 2261 2262 for (InternalControlVar ICV : TrackableICVs) { 2263 Optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV]; 2264 Optional<Value *> NewReplVal = 2265 ICVTrackingAA.getUniqueReplacementValue(ICV); 2266 2267 if (ReplVal == NewReplVal) 2268 continue; 2269 2270 ReplVal = NewReplVal; 2271 Changed = ChangeStatus::CHANGED; 2272 } 2273 return Changed; 2274 } 2275 }; 2276 2277 struct AAExecutionDomainFunction : public AAExecutionDomain { 2278 AAExecutionDomainFunction(const IRPosition &IRP, Attributor &A) 2279 : AAExecutionDomain(IRP, A) {} 2280 2281 const std::string getAsStr() const override { 2282 return "[AAExecutionDomain] " + std::to_string(SingleThreadedBBs.size()) + 2283 "/" + std::to_string(NumBBs) + " BBs thread 0 only."; 2284 } 2285 2286 /// See AbstractAttribute::trackStatistics(). 2287 void trackStatistics() const override {} 2288 2289 void initialize(Attributor &A) override { 2290 Function *F = getAnchorScope(); 2291 for (const auto &BB : *F) 2292 SingleThreadedBBs.insert(&BB); 2293 NumBBs = SingleThreadedBBs.size(); 2294 } 2295 2296 ChangeStatus manifest(Attributor &A) override { 2297 LLVM_DEBUG({ 2298 for (const BasicBlock *BB : SingleThreadedBBs) 2299 dbgs() << TAG << " Basic block @" << getAnchorScope()->getName() << " " 2300 << BB->getName() << " is executed by a single thread.\n"; 2301 }); 2302 return ChangeStatus::UNCHANGED; 2303 } 2304 2305 ChangeStatus updateImpl(Attributor &A) override; 2306 2307 /// Check if an instruction is executed by a single thread. 2308 bool isExecutedByInitialThreadOnly(const Instruction &I) const override { 2309 return isExecutedByInitialThreadOnly(*I.getParent()); 2310 } 2311 2312 bool isExecutedByInitialThreadOnly(const BasicBlock &BB) const override { 2313 return SingleThreadedBBs.contains(&BB); 2314 } 2315 2316 /// Set of basic blocks that are executed by a single thread. 2317 DenseSet<const BasicBlock *> SingleThreadedBBs; 2318 2319 /// Total number of basic blocks in this function. 2320 long unsigned NumBBs; 2321 }; 2322 2323 ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) { 2324 Function *F = getAnchorScope(); 2325 ReversePostOrderTraversal<Function *> RPOT(F); 2326 auto NumSingleThreadedBBs = SingleThreadedBBs.size(); 2327 2328 bool AllCallSitesKnown; 2329 auto PredForCallSite = [&](AbstractCallSite ACS) { 2330 const auto &ExecutionDomainAA = A.getAAFor<AAExecutionDomain>( 2331 *this, IRPosition::function(*ACS.getInstruction()->getFunction()), 2332 DepClassTy::REQUIRED); 2333 return ExecutionDomainAA.isExecutedByInitialThreadOnly( 2334 *ACS.getInstruction()); 2335 }; 2336 2337 if (!A.checkForAllCallSites(PredForCallSite, *this, 2338 /* RequiresAllCallSites */ true, 2339 AllCallSitesKnown)) 2340 SingleThreadedBBs.erase(&F->getEntryBlock()); 2341 2342 // Check if the edge into the successor block compares a thread-id function to 2343 // a constant zero. 2344 // TODO: Use AAValueSimplify to simplify and propogate constants. 2345 // TODO: Check more than a single use for thread ID's. 2346 auto IsInitialThreadOnly = [&](BranchInst *Edge, BasicBlock *SuccessorBB) { 2347 if (!Edge || !Edge->isConditional()) 2348 return false; 2349 if (Edge->getSuccessor(0) != SuccessorBB) 2350 return false; 2351 2352 auto *Cmp = dyn_cast<CmpInst>(Edge->getCondition()); 2353 if (!Cmp || !Cmp->isTrueWhenEqual() || !Cmp->isEquality()) 2354 return false; 2355 2356 // Temporarily match the pattern generated by clang for teams regions. 2357 // TODO: Remove this once the new runtime is in place. 2358 ConstantInt *One, *NegOne; 2359 CmpInst::Predicate Pred; 2360 auto &&m_ThreadID = m_Intrinsic<Intrinsic::nvvm_read_ptx_sreg_tid_x>(); 2361 auto &&m_WarpSize = m_Intrinsic<Intrinsic::nvvm_read_ptx_sreg_warpsize>(); 2362 auto &&m_BlockSize = m_Intrinsic<Intrinsic::nvvm_read_ptx_sreg_ntid_x>(); 2363 if (match(Cmp, m_Cmp(Pred, m_ThreadID, 2364 m_And(m_Sub(m_BlockSize, m_ConstantInt(One)), 2365 m_Xor(m_Sub(m_WarpSize, m_ConstantInt(One)), 2366 m_ConstantInt(NegOne)))))) 2367 if (One->isOne() && NegOne->isMinusOne() && 2368 Pred == CmpInst::Predicate::ICMP_EQ) 2369 return true; 2370 2371 ConstantInt *C = dyn_cast<ConstantInt>(Cmp->getOperand(1)); 2372 if (!C || !C->isZero()) 2373 return false; 2374 2375 if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0))) 2376 if (II->getIntrinsicID() == Intrinsic::nvvm_read_ptx_sreg_tid_x) 2377 return true; 2378 if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0))) 2379 if (II->getIntrinsicID() == Intrinsic::amdgcn_workitem_id_x) 2380 return true; 2381 2382 return false; 2383 }; 2384 2385 // Merge all the predecessor states into the current basic block. A basic 2386 // block is executed by a single thread if all of its predecessors are. 2387 auto MergePredecessorStates = [&](BasicBlock *BB) { 2388 if (pred_begin(BB) == pred_end(BB)) 2389 return SingleThreadedBBs.contains(BB); 2390 2391 bool IsInitialThread = true; 2392 for (auto PredBB = pred_begin(BB), PredEndBB = pred_end(BB); 2393 PredBB != PredEndBB; ++PredBB) { 2394 if (!IsInitialThreadOnly(dyn_cast<BranchInst>((*PredBB)->getTerminator()), 2395 BB)) 2396 IsInitialThread &= SingleThreadedBBs.contains(*PredBB); 2397 } 2398 2399 return IsInitialThread; 2400 }; 2401 2402 for (auto *BB : RPOT) { 2403 if (!MergePredecessorStates(BB)) 2404 SingleThreadedBBs.erase(BB); 2405 } 2406 2407 return (NumSingleThreadedBBs == SingleThreadedBBs.size()) 2408 ? ChangeStatus::UNCHANGED 2409 : ChangeStatus::CHANGED; 2410 } 2411 2412 /// Try to replace memory allocation calls called by a single thread with a 2413 /// static buffer of shared memory. 2414 struct AAHeapToShared : public StateWrapper<BooleanState, AbstractAttribute> { 2415 using Base = StateWrapper<BooleanState, AbstractAttribute>; 2416 AAHeapToShared(const IRPosition &IRP, Attributor &A) : Base(IRP) {} 2417 2418 /// Create an abstract attribute view for the position \p IRP. 2419 static AAHeapToShared &createForPosition(const IRPosition &IRP, 2420 Attributor &A); 2421 2422 /// See AbstractAttribute::getName(). 2423 const std::string getName() const override { return "AAHeapToShared"; } 2424 2425 /// See AbstractAttribute::getIdAddr(). 2426 const char *getIdAddr() const override { return &ID; } 2427 2428 /// This function should return true if the type of the \p AA is 2429 /// AAHeapToShared. 2430 static bool classof(const AbstractAttribute *AA) { 2431 return (AA->getIdAddr() == &ID); 2432 } 2433 2434 /// Unique ID (due to the unique address) 2435 static const char ID; 2436 }; 2437 2438 struct AAHeapToSharedFunction : public AAHeapToShared { 2439 AAHeapToSharedFunction(const IRPosition &IRP, Attributor &A) 2440 : AAHeapToShared(IRP, A) {} 2441 2442 const std::string getAsStr() const override { 2443 return "[AAHeapToShared] " + std::to_string(MallocCalls.size()) + 2444 " malloc calls eligible."; 2445 } 2446 2447 /// See AbstractAttribute::trackStatistics(). 2448 void trackStatistics() const override {} 2449 2450 void initialize(Attributor &A) override { 2451 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); 2452 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared]; 2453 2454 for (User *U : RFI.Declaration->users()) 2455 if (CallBase *CB = dyn_cast<CallBase>(U)) 2456 MallocCalls.insert(CB); 2457 } 2458 2459 ChangeStatus manifest(Attributor &A) override { 2460 if (MallocCalls.empty()) 2461 return ChangeStatus::UNCHANGED; 2462 2463 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); 2464 auto &FreeCall = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared]; 2465 2466 Function *F = getAnchorScope(); 2467 auto *HS = A.lookupAAFor<AAHeapToStack>(IRPosition::function(*F), this, 2468 DepClassTy::OPTIONAL); 2469 2470 ChangeStatus Changed = ChangeStatus::UNCHANGED; 2471 for (CallBase *CB : MallocCalls) { 2472 // Skip replacing this if HeapToStack has already claimed it. 2473 if (HS && HS->isKnownHeapToStack(*CB)) 2474 continue; 2475 2476 // Find the unique free call to remove it. 2477 SmallVector<CallBase *, 4> FreeCalls; 2478 for (auto *U : CB->users()) { 2479 CallBase *C = dyn_cast<CallBase>(U); 2480 if (C && C->getCalledFunction() == FreeCall.Declaration) 2481 FreeCalls.push_back(C); 2482 } 2483 if (FreeCalls.size() != 1) 2484 continue; 2485 2486 ConstantInt *AllocSize = dyn_cast<ConstantInt>(CB->getArgOperand(0)); 2487 2488 LLVM_DEBUG(dbgs() << TAG << "Replace globalization call in " 2489 << CB->getCaller()->getName() << " with " 2490 << AllocSize->getZExtValue() 2491 << " bytes of shared memory\n"); 2492 2493 // Create a new shared memory buffer of the same size as the allocation 2494 // and replace all the uses of the original allocation with it. 2495 Module *M = CB->getModule(); 2496 Type *Int8Ty = Type::getInt8Ty(M->getContext()); 2497 Type *Int8ArrTy = ArrayType::get(Int8Ty, AllocSize->getZExtValue()); 2498 auto *SharedMem = new GlobalVariable( 2499 *M, Int8ArrTy, /* IsConstant */ false, GlobalValue::InternalLinkage, 2500 UndefValue::get(Int8ArrTy), CB->getName(), nullptr, 2501 GlobalValue::NotThreadLocal, 2502 static_cast<unsigned>(AddressSpace::Shared)); 2503 auto *NewBuffer = 2504 ConstantExpr::getPointerCast(SharedMem, Int8Ty->getPointerTo()); 2505 2506 SharedMem->setAlignment(MaybeAlign(32)); 2507 2508 A.changeValueAfterManifest(*CB, *NewBuffer); 2509 A.deleteAfterManifest(*CB); 2510 A.deleteAfterManifest(*FreeCalls.front()); 2511 2512 NumBytesMovedToSharedMemory += AllocSize->getZExtValue(); 2513 Changed = ChangeStatus::CHANGED; 2514 } 2515 2516 return Changed; 2517 } 2518 2519 ChangeStatus updateImpl(Attributor &A) override { 2520 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); 2521 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared]; 2522 Function *F = getAnchorScope(); 2523 2524 auto NumMallocCalls = MallocCalls.size(); 2525 2526 // Only consider malloc calls executed by a single thread with a constant. 2527 for (User *U : RFI.Declaration->users()) { 2528 const auto &ED = A.getAAFor<AAExecutionDomain>( 2529 *this, IRPosition::function(*F), DepClassTy::REQUIRED); 2530 if (CallBase *CB = dyn_cast<CallBase>(U)) 2531 if (!dyn_cast<ConstantInt>(CB->getArgOperand(0)) || 2532 !ED.isExecutedByInitialThreadOnly(*CB)) 2533 MallocCalls.erase(CB); 2534 } 2535 2536 if (NumMallocCalls != MallocCalls.size()) 2537 return ChangeStatus::CHANGED; 2538 2539 return ChangeStatus::UNCHANGED; 2540 } 2541 2542 /// Collection of all malloc calls in a function. 2543 SmallPtrSet<CallBase *, 4> MallocCalls; 2544 }; 2545 2546 } // namespace 2547 2548 const char AAICVTracker::ID = 0; 2549 const char AAExecutionDomain::ID = 0; 2550 const char AAHeapToShared::ID = 0; 2551 2552 AAICVTracker &AAICVTracker::createForPosition(const IRPosition &IRP, 2553 Attributor &A) { 2554 AAICVTracker *AA = nullptr; 2555 switch (IRP.getPositionKind()) { 2556 case IRPosition::IRP_INVALID: 2557 case IRPosition::IRP_FLOAT: 2558 case IRPosition::IRP_ARGUMENT: 2559 case IRPosition::IRP_CALL_SITE_ARGUMENT: 2560 llvm_unreachable("ICVTracker can only be created for function position!"); 2561 case IRPosition::IRP_RETURNED: 2562 AA = new (A.Allocator) AAICVTrackerFunctionReturned(IRP, A); 2563 break; 2564 case IRPosition::IRP_CALL_SITE_RETURNED: 2565 AA = new (A.Allocator) AAICVTrackerCallSiteReturned(IRP, A); 2566 break; 2567 case IRPosition::IRP_CALL_SITE: 2568 AA = new (A.Allocator) AAICVTrackerCallSite(IRP, A); 2569 break; 2570 case IRPosition::IRP_FUNCTION: 2571 AA = new (A.Allocator) AAICVTrackerFunction(IRP, A); 2572 break; 2573 } 2574 2575 return *AA; 2576 } 2577 2578 AAExecutionDomain &AAExecutionDomain::createForPosition(const IRPosition &IRP, 2579 Attributor &A) { 2580 AAExecutionDomainFunction *AA = nullptr; 2581 switch (IRP.getPositionKind()) { 2582 case IRPosition::IRP_INVALID: 2583 case IRPosition::IRP_FLOAT: 2584 case IRPosition::IRP_ARGUMENT: 2585 case IRPosition::IRP_CALL_SITE_ARGUMENT: 2586 case IRPosition::IRP_RETURNED: 2587 case IRPosition::IRP_CALL_SITE_RETURNED: 2588 case IRPosition::IRP_CALL_SITE: 2589 llvm_unreachable( 2590 "AAExecutionDomain can only be created for function position!"); 2591 case IRPosition::IRP_FUNCTION: 2592 AA = new (A.Allocator) AAExecutionDomainFunction(IRP, A); 2593 break; 2594 } 2595 2596 return *AA; 2597 } 2598 2599 AAHeapToShared &AAHeapToShared::createForPosition(const IRPosition &IRP, 2600 Attributor &A) { 2601 AAHeapToSharedFunction *AA = nullptr; 2602 switch (IRP.getPositionKind()) { 2603 case IRPosition::IRP_INVALID: 2604 case IRPosition::IRP_FLOAT: 2605 case IRPosition::IRP_ARGUMENT: 2606 case IRPosition::IRP_CALL_SITE_ARGUMENT: 2607 case IRPosition::IRP_RETURNED: 2608 case IRPosition::IRP_CALL_SITE_RETURNED: 2609 case IRPosition::IRP_CALL_SITE: 2610 llvm_unreachable( 2611 "AAHeapToShared can only be created for function position!"); 2612 case IRPosition::IRP_FUNCTION: 2613 AA = new (A.Allocator) AAHeapToSharedFunction(IRP, A); 2614 break; 2615 } 2616 2617 return *AA; 2618 } 2619 2620 PreservedAnalyses OpenMPOptPass::run(Module &M, ModuleAnalysisManager &AM) { 2621 if (!containsOpenMP(M, OMPInModule)) 2622 return PreservedAnalyses::all(); 2623 2624 if (DisableOpenMPOptimizations) 2625 return PreservedAnalyses::all(); 2626 2627 // Create internal copies of each function if this is a kernel Module. 2628 DenseSet<const Function *> InternalizedFuncs; 2629 if (!OMPInModule.getKernels().empty()) 2630 for (Function &F : M) 2631 if (!F.isDeclaration() && !OMPInModule.getKernels().contains(&F)) 2632 if (Attributor::internalizeFunction(F, /* Force */ true)) 2633 InternalizedFuncs.insert(&F); 2634 2635 // Look at every function definition in the Module that wasn't internalized. 2636 SmallVector<Function *, 16> SCC; 2637 for (Function &F : M) 2638 if (!F.isDeclaration() && !InternalizedFuncs.contains(&F)) 2639 SCC.push_back(&F); 2640 2641 if (SCC.empty()) 2642 return PreservedAnalyses::all(); 2643 2644 FunctionAnalysisManager &FAM = 2645 AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); 2646 2647 AnalysisGetter AG(FAM); 2648 2649 auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & { 2650 return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F); 2651 }; 2652 2653 BumpPtrAllocator Allocator; 2654 CallGraphUpdater CGUpdater; 2655 2656 SetVector<Function *> Functions(SCC.begin(), SCC.end()); 2657 OMPInformationCache InfoCache(M, AG, Allocator, /*CGSCC*/ Functions, 2658 OMPInModule.getKernels()); 2659 2660 Attributor A(Functions, InfoCache, CGUpdater, nullptr, true, false); 2661 2662 OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A); 2663 bool Changed = OMPOpt.run(true); 2664 if (Changed) 2665 return PreservedAnalyses::none(); 2666 2667 return PreservedAnalyses::all(); 2668 } 2669 2670 PreservedAnalyses OpenMPOptCGSCCPass::run(LazyCallGraph::SCC &C, 2671 CGSCCAnalysisManager &AM, 2672 LazyCallGraph &CG, 2673 CGSCCUpdateResult &UR) { 2674 if (!containsOpenMP(*C.begin()->getFunction().getParent(), OMPInModule)) 2675 return PreservedAnalyses::all(); 2676 2677 if (DisableOpenMPOptimizations) 2678 return PreservedAnalyses::all(); 2679 2680 SmallVector<Function *, 16> SCC; 2681 // If there are kernels in the module, we have to run on all SCC's. 2682 bool SCCIsInteresting = !OMPInModule.getKernels().empty(); 2683 for (LazyCallGraph::Node &N : C) { 2684 Function *Fn = &N.getFunction(); 2685 SCC.push_back(Fn); 2686 2687 // Do we already know that the SCC contains kernels, 2688 // or that OpenMP functions are called from this SCC? 2689 if (SCCIsInteresting) 2690 continue; 2691 // If not, let's check that. 2692 SCCIsInteresting |= OMPInModule.containsOMPRuntimeCalls(Fn); 2693 } 2694 2695 if (!SCCIsInteresting || SCC.empty()) 2696 return PreservedAnalyses::all(); 2697 2698 FunctionAnalysisManager &FAM = 2699 AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager(); 2700 2701 AnalysisGetter AG(FAM); 2702 2703 auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & { 2704 return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F); 2705 }; 2706 2707 BumpPtrAllocator Allocator; 2708 CallGraphUpdater CGUpdater; 2709 CGUpdater.initialize(CG, C, AM, UR); 2710 2711 SetVector<Function *> Functions(SCC.begin(), SCC.end()); 2712 OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG, Allocator, 2713 /*CGSCC*/ Functions, OMPInModule.getKernels()); 2714 2715 Attributor A(Functions, InfoCache, CGUpdater, nullptr, false); 2716 2717 OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A); 2718 bool Changed = OMPOpt.run(false); 2719 if (Changed) 2720 return PreservedAnalyses::none(); 2721 2722 return PreservedAnalyses::all(); 2723 } 2724 2725 namespace { 2726 2727 struct OpenMPOptCGSCCLegacyPass : public CallGraphSCCPass { 2728 CallGraphUpdater CGUpdater; 2729 OpenMPInModule OMPInModule; 2730 static char ID; 2731 2732 OpenMPOptCGSCCLegacyPass() : CallGraphSCCPass(ID) { 2733 initializeOpenMPOptCGSCCLegacyPassPass(*PassRegistry::getPassRegistry()); 2734 } 2735 2736 void getAnalysisUsage(AnalysisUsage &AU) const override { 2737 CallGraphSCCPass::getAnalysisUsage(AU); 2738 } 2739 2740 bool doInitialization(CallGraph &CG) override { 2741 // Disable the pass if there is no OpenMP (runtime call) in the module. 2742 containsOpenMP(CG.getModule(), OMPInModule); 2743 return false; 2744 } 2745 2746 bool runOnSCC(CallGraphSCC &CGSCC) override { 2747 if (!containsOpenMP(CGSCC.getCallGraph().getModule(), OMPInModule)) 2748 return false; 2749 if (DisableOpenMPOptimizations || skipSCC(CGSCC)) 2750 return false; 2751 2752 SmallVector<Function *, 16> SCC; 2753 // If there are kernels in the module, we have to run on all SCC's. 2754 bool SCCIsInteresting = !OMPInModule.getKernels().empty(); 2755 for (CallGraphNode *CGN : CGSCC) { 2756 Function *Fn = CGN->getFunction(); 2757 if (!Fn || Fn->isDeclaration()) 2758 continue; 2759 SCC.push_back(Fn); 2760 2761 // Do we already know that the SCC contains kernels, 2762 // or that OpenMP functions are called from this SCC? 2763 if (SCCIsInteresting) 2764 continue; 2765 // If not, let's check that. 2766 SCCIsInteresting |= OMPInModule.containsOMPRuntimeCalls(Fn); 2767 } 2768 2769 if (!SCCIsInteresting || SCC.empty()) 2770 return false; 2771 2772 CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph(); 2773 CGUpdater.initialize(CG, CGSCC); 2774 2775 // Maintain a map of functions to avoid rebuilding the ORE 2776 DenseMap<Function *, std::unique_ptr<OptimizationRemarkEmitter>> OREMap; 2777 auto OREGetter = [&OREMap](Function *F) -> OptimizationRemarkEmitter & { 2778 std::unique_ptr<OptimizationRemarkEmitter> &ORE = OREMap[F]; 2779 if (!ORE) 2780 ORE = std::make_unique<OptimizationRemarkEmitter>(F); 2781 return *ORE; 2782 }; 2783 2784 AnalysisGetter AG; 2785 SetVector<Function *> Functions(SCC.begin(), SCC.end()); 2786 BumpPtrAllocator Allocator; 2787 OMPInformationCache InfoCache( 2788 *(Functions.back()->getParent()), AG, Allocator, 2789 /*CGSCC*/ Functions, OMPInModule.getKernels()); 2790 2791 Attributor A(Functions, InfoCache, CGUpdater, nullptr, false); 2792 2793 OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A); 2794 return OMPOpt.run(false); 2795 } 2796 2797 bool doFinalization(CallGraph &CG) override { return CGUpdater.finalize(); } 2798 }; 2799 2800 } // end anonymous namespace 2801 2802 void OpenMPInModule::identifyKernels(Module &M) { 2803 2804 NamedMDNode *MD = M.getOrInsertNamedMetadata("nvvm.annotations"); 2805 if (!MD) 2806 return; 2807 2808 for (auto *Op : MD->operands()) { 2809 if (Op->getNumOperands() < 2) 2810 continue; 2811 MDString *KindID = dyn_cast<MDString>(Op->getOperand(1)); 2812 if (!KindID || KindID->getString() != "kernel") 2813 continue; 2814 2815 Function *KernelFn = 2816 mdconst::dyn_extract_or_null<Function>(Op->getOperand(0)); 2817 if (!KernelFn) 2818 continue; 2819 2820 ++NumOpenMPTargetRegionKernels; 2821 2822 Kernels.insert(KernelFn); 2823 } 2824 } 2825 2826 bool llvm::omp::containsOpenMP(Module &M, OpenMPInModule &OMPInModule) { 2827 if (OMPInModule.isKnown()) 2828 return OMPInModule; 2829 2830 auto RecordFunctionsContainingUsesOf = [&](Function *F) { 2831 for (User *U : F->users()) 2832 if (auto *I = dyn_cast<Instruction>(U)) 2833 OMPInModule.FuncsWithOMPRuntimeCalls.insert(I->getFunction()); 2834 }; 2835 2836 // MSVC doesn't like long if-else chains for some reason and instead just 2837 // issues an error. Work around it.. 2838 do { 2839 #define OMP_RTL(_Enum, _Name, ...) \ 2840 if (Function *F = M.getFunction(_Name)) { \ 2841 RecordFunctionsContainingUsesOf(F); \ 2842 OMPInModule = true; \ 2843 } 2844 #include "llvm/Frontend/OpenMP/OMPKinds.def" 2845 } while (false); 2846 2847 // Identify kernels once. TODO: We should split the OMPInformationCache into a 2848 // module and an SCC part. The kernel information, among other things, could 2849 // go into the module part. 2850 if (OMPInModule.isKnown() && OMPInModule) { 2851 OMPInModule.identifyKernels(M); 2852 return true; 2853 } 2854 2855 return OMPInModule = false; 2856 } 2857 2858 char OpenMPOptCGSCCLegacyPass::ID = 0; 2859 2860 INITIALIZE_PASS_BEGIN(OpenMPOptCGSCCLegacyPass, "openmp-opt-cgscc", 2861 "OpenMP specific optimizations", false, false) 2862 INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass) 2863 INITIALIZE_PASS_END(OpenMPOptCGSCCLegacyPass, "openmp-opt-cgscc", 2864 "OpenMP specific optimizations", false, false) 2865 2866 Pass *llvm::createOpenMPOptCGSCCLegacyPass() { 2867 return new OpenMPOptCGSCCLegacyPass(); 2868 } 2869