1 //===-- IPO/OpenMPOpt.cpp - Collection of OpenMP specific optimizations ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // OpenMP specific optimizations: 10 // 11 // - Deduplication of runtime calls, e.g., omp_get_thread_num. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "llvm/Transforms/IPO/OpenMPOpt.h" 16 17 #include "llvm/ADT/EnumeratedArray.h" 18 #include "llvm/ADT/PostOrderIterator.h" 19 #include "llvm/ADT/Statistic.h" 20 #include "llvm/Analysis/CallGraph.h" 21 #include "llvm/Analysis/CallGraphSCCPass.h" 22 #include "llvm/Analysis/OptimizationRemarkEmitter.h" 23 #include "llvm/Analysis/ValueTracking.h" 24 #include "llvm/Frontend/OpenMP/OMPConstants.h" 25 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h" 26 #include "llvm/IR/IntrinsicInst.h" 27 #include "llvm/IR/IntrinsicsAMDGPU.h" 28 #include "llvm/IR/IntrinsicsNVPTX.h" 29 #include "llvm/InitializePasses.h" 30 #include "llvm/Support/CommandLine.h" 31 #include "llvm/Transforms/IPO.h" 32 #include "llvm/Transforms/IPO/Attributor.h" 33 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 34 #include "llvm/Transforms/Utils/CallGraphUpdater.h" 35 #include "llvm/Transforms/Utils/CodeExtractor.h" 36 37 using namespace llvm; 38 using namespace omp; 39 40 #define DEBUG_TYPE "openmp-opt" 41 42 static cl::opt<bool> DisableOpenMPOptimizations( 43 "openmp-opt-disable", cl::ZeroOrMore, 44 cl::desc("Disable OpenMP specific optimizations."), cl::Hidden, 45 cl::init(false)); 46 47 static cl::opt<bool> EnableParallelRegionMerging( 48 "openmp-opt-enable-merging", cl::ZeroOrMore, 49 cl::desc("Enable the OpenMP region merging optimization."), cl::Hidden, 50 cl::init(false)); 51 52 static cl::opt<bool> PrintICVValues("openmp-print-icv-values", cl::init(false), 53 cl::Hidden); 54 static cl::opt<bool> PrintOpenMPKernels("openmp-print-gpu-kernels", 55 cl::init(false), cl::Hidden); 56 57 static cl::opt<bool> HideMemoryTransferLatency( 58 "openmp-hide-memory-transfer-latency", 59 cl::desc("[WIP] Tries to hide the latency of host to device memory" 60 " transfers"), 61 cl::Hidden, cl::init(false)); 62 63 STATISTIC(NumOpenMPRuntimeCallsDeduplicated, 64 "Number of OpenMP runtime calls deduplicated"); 65 STATISTIC(NumOpenMPParallelRegionsDeleted, 66 "Number of OpenMP parallel regions deleted"); 67 STATISTIC(NumOpenMPRuntimeFunctionsIdentified, 68 "Number of OpenMP runtime functions identified"); 69 STATISTIC(NumOpenMPRuntimeFunctionUsesIdentified, 70 "Number of OpenMP runtime function uses identified"); 71 STATISTIC(NumOpenMPTargetRegionKernels, 72 "Number of OpenMP target region entry points (=kernels) identified"); 73 STATISTIC( 74 NumOpenMPParallelRegionsReplacedInGPUStateMachine, 75 "Number of OpenMP parallel regions replaced with ID in GPU state machines"); 76 STATISTIC(NumOpenMPParallelRegionsMerged, 77 "Number of OpenMP parallel regions merged"); 78 79 #if !defined(NDEBUG) 80 static constexpr auto TAG = "[" DEBUG_TYPE "]"; 81 #endif 82 83 namespace { 84 85 struct AAICVTracker; 86 87 /// OpenMP specific information. For now, stores RFIs and ICVs also needed for 88 /// Attributor runs. 89 struct OMPInformationCache : public InformationCache { 90 OMPInformationCache(Module &M, AnalysisGetter &AG, 91 BumpPtrAllocator &Allocator, SetVector<Function *> &CGSCC, 92 SmallPtrSetImpl<Kernel> &Kernels) 93 : InformationCache(M, AG, Allocator, &CGSCC), OMPBuilder(M), 94 Kernels(Kernels) { 95 96 OMPBuilder.initialize(); 97 initializeRuntimeFunctions(); 98 initializeInternalControlVars(); 99 } 100 101 /// Generic information that describes an internal control variable. 102 struct InternalControlVarInfo { 103 /// The kind, as described by InternalControlVar enum. 104 InternalControlVar Kind; 105 106 /// The name of the ICV. 107 StringRef Name; 108 109 /// Environment variable associated with this ICV. 110 StringRef EnvVarName; 111 112 /// Initial value kind. 113 ICVInitValue InitKind; 114 115 /// Initial value. 116 ConstantInt *InitValue; 117 118 /// Setter RTL function associated with this ICV. 119 RuntimeFunction Setter; 120 121 /// Getter RTL function associated with this ICV. 122 RuntimeFunction Getter; 123 124 /// RTL Function corresponding to the override clause of this ICV 125 RuntimeFunction Clause; 126 }; 127 128 /// Generic information that describes a runtime function 129 struct RuntimeFunctionInfo { 130 131 /// The kind, as described by the RuntimeFunction enum. 132 RuntimeFunction Kind; 133 134 /// The name of the function. 135 StringRef Name; 136 137 /// Flag to indicate a variadic function. 138 bool IsVarArg; 139 140 /// The return type of the function. 141 Type *ReturnType; 142 143 /// The argument types of the function. 144 SmallVector<Type *, 8> ArgumentTypes; 145 146 /// The declaration if available. 147 Function *Declaration = nullptr; 148 149 /// Uses of this runtime function per function containing the use. 150 using UseVector = SmallVector<Use *, 16>; 151 152 /// Clear UsesMap for runtime function. 153 void clearUsesMap() { UsesMap.clear(); } 154 155 /// Boolean conversion that is true if the runtime function was found. 156 operator bool() const { return Declaration; } 157 158 /// Return the vector of uses in function \p F. 159 UseVector &getOrCreateUseVector(Function *F) { 160 std::shared_ptr<UseVector> &UV = UsesMap[F]; 161 if (!UV) 162 UV = std::make_shared<UseVector>(); 163 return *UV; 164 } 165 166 /// Return the vector of uses in function \p F or `nullptr` if there are 167 /// none. 168 const UseVector *getUseVector(Function &F) const { 169 auto I = UsesMap.find(&F); 170 if (I != UsesMap.end()) 171 return I->second.get(); 172 return nullptr; 173 } 174 175 /// Return how many functions contain uses of this runtime function. 176 size_t getNumFunctionsWithUses() const { return UsesMap.size(); } 177 178 /// Return the number of arguments (or the minimal number for variadic 179 /// functions). 180 size_t getNumArgs() const { return ArgumentTypes.size(); } 181 182 /// Run the callback \p CB on each use and forget the use if the result is 183 /// true. The callback will be fed the function in which the use was 184 /// encountered as second argument. 185 void foreachUse(SmallVectorImpl<Function *> &SCC, 186 function_ref<bool(Use &, Function &)> CB) { 187 for (Function *F : SCC) 188 foreachUse(CB, F); 189 } 190 191 /// Run the callback \p CB on each use within the function \p F and forget 192 /// the use if the result is true. 193 void foreachUse(function_ref<bool(Use &, Function &)> CB, Function *F) { 194 SmallVector<unsigned, 8> ToBeDeleted; 195 ToBeDeleted.clear(); 196 197 unsigned Idx = 0; 198 UseVector &UV = getOrCreateUseVector(F); 199 200 for (Use *U : UV) { 201 if (CB(*U, *F)) 202 ToBeDeleted.push_back(Idx); 203 ++Idx; 204 } 205 206 // Remove the to-be-deleted indices in reverse order as prior 207 // modifications will not modify the smaller indices. 208 while (!ToBeDeleted.empty()) { 209 unsigned Idx = ToBeDeleted.pop_back_val(); 210 UV[Idx] = UV.back(); 211 UV.pop_back(); 212 } 213 } 214 215 private: 216 /// Map from functions to all uses of this runtime function contained in 217 /// them. 218 DenseMap<Function *, std::shared_ptr<UseVector>> UsesMap; 219 }; 220 221 /// An OpenMP-IR-Builder instance 222 OpenMPIRBuilder OMPBuilder; 223 224 /// Map from runtime function kind to the runtime function description. 225 EnumeratedArray<RuntimeFunctionInfo, RuntimeFunction, 226 RuntimeFunction::OMPRTL___last> 227 RFIs; 228 229 /// Map from ICV kind to the ICV description. 230 EnumeratedArray<InternalControlVarInfo, InternalControlVar, 231 InternalControlVar::ICV___last> 232 ICVs; 233 234 /// Helper to initialize all internal control variable information for those 235 /// defined in OMPKinds.def. 236 void initializeInternalControlVars() { 237 #define ICV_RT_SET(_Name, RTL) \ 238 { \ 239 auto &ICV = ICVs[_Name]; \ 240 ICV.Setter = RTL; \ 241 } 242 #define ICV_RT_GET(Name, RTL) \ 243 { \ 244 auto &ICV = ICVs[Name]; \ 245 ICV.Getter = RTL; \ 246 } 247 #define ICV_DATA_ENV(Enum, _Name, _EnvVarName, Init) \ 248 { \ 249 auto &ICV = ICVs[Enum]; \ 250 ICV.Name = _Name; \ 251 ICV.Kind = Enum; \ 252 ICV.InitKind = Init; \ 253 ICV.EnvVarName = _EnvVarName; \ 254 switch (ICV.InitKind) { \ 255 case ICV_IMPLEMENTATION_DEFINED: \ 256 ICV.InitValue = nullptr; \ 257 break; \ 258 case ICV_ZERO: \ 259 ICV.InitValue = ConstantInt::get( \ 260 Type::getInt32Ty(OMPBuilder.Int32->getContext()), 0); \ 261 break; \ 262 case ICV_FALSE: \ 263 ICV.InitValue = ConstantInt::getFalse(OMPBuilder.Int1->getContext()); \ 264 break; \ 265 case ICV_LAST: \ 266 break; \ 267 } \ 268 } 269 #include "llvm/Frontend/OpenMP/OMPKinds.def" 270 } 271 272 /// Returns true if the function declaration \p F matches the runtime 273 /// function types, that is, return type \p RTFRetType, and argument types 274 /// \p RTFArgTypes. 275 static bool declMatchesRTFTypes(Function *F, Type *RTFRetType, 276 SmallVector<Type *, 8> &RTFArgTypes) { 277 // TODO: We should output information to the user (under debug output 278 // and via remarks). 279 280 if (!F) 281 return false; 282 if (F->getReturnType() != RTFRetType) 283 return false; 284 if (F->arg_size() != RTFArgTypes.size()) 285 return false; 286 287 auto RTFTyIt = RTFArgTypes.begin(); 288 for (Argument &Arg : F->args()) { 289 if (Arg.getType() != *RTFTyIt) 290 return false; 291 292 ++RTFTyIt; 293 } 294 295 return true; 296 } 297 298 // Helper to collect all uses of the declaration in the UsesMap. 299 unsigned collectUses(RuntimeFunctionInfo &RFI, bool CollectStats = true) { 300 unsigned NumUses = 0; 301 if (!RFI.Declaration) 302 return NumUses; 303 OMPBuilder.addAttributes(RFI.Kind, *RFI.Declaration); 304 305 if (CollectStats) { 306 NumOpenMPRuntimeFunctionsIdentified += 1; 307 NumOpenMPRuntimeFunctionUsesIdentified += RFI.Declaration->getNumUses(); 308 } 309 310 // TODO: We directly convert uses into proper calls and unknown uses. 311 for (Use &U : RFI.Declaration->uses()) { 312 if (Instruction *UserI = dyn_cast<Instruction>(U.getUser())) { 313 if (ModuleSlice.count(UserI->getFunction())) { 314 RFI.getOrCreateUseVector(UserI->getFunction()).push_back(&U); 315 ++NumUses; 316 } 317 } else { 318 RFI.getOrCreateUseVector(nullptr).push_back(&U); 319 ++NumUses; 320 } 321 } 322 return NumUses; 323 } 324 325 // Helper function to recollect uses of a runtime function. 326 void recollectUsesForFunction(RuntimeFunction RTF) { 327 auto &RFI = RFIs[RTF]; 328 RFI.clearUsesMap(); 329 collectUses(RFI, /*CollectStats*/ false); 330 } 331 332 // Helper function to recollect uses of all runtime functions. 333 void recollectUses() { 334 for (int Idx = 0; Idx < RFIs.size(); ++Idx) 335 recollectUsesForFunction(static_cast<RuntimeFunction>(Idx)); 336 } 337 338 /// Helper to initialize all runtime function information for those defined 339 /// in OpenMPKinds.def. 340 void initializeRuntimeFunctions() { 341 Module &M = *((*ModuleSlice.begin())->getParent()); 342 343 // Helper macros for handling __VA_ARGS__ in OMP_RTL 344 #define OMP_TYPE(VarName, ...) \ 345 Type *VarName = OMPBuilder.VarName; \ 346 (void)VarName; 347 348 #define OMP_ARRAY_TYPE(VarName, ...) \ 349 ArrayType *VarName##Ty = OMPBuilder.VarName##Ty; \ 350 (void)VarName##Ty; \ 351 PointerType *VarName##PtrTy = OMPBuilder.VarName##PtrTy; \ 352 (void)VarName##PtrTy; 353 354 #define OMP_FUNCTION_TYPE(VarName, ...) \ 355 FunctionType *VarName = OMPBuilder.VarName; \ 356 (void)VarName; \ 357 PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr; \ 358 (void)VarName##Ptr; 359 360 #define OMP_STRUCT_TYPE(VarName, ...) \ 361 StructType *VarName = OMPBuilder.VarName; \ 362 (void)VarName; \ 363 PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr; \ 364 (void)VarName##Ptr; 365 366 #define OMP_RTL(_Enum, _Name, _IsVarArg, _ReturnType, ...) \ 367 { \ 368 SmallVector<Type *, 8> ArgsTypes({__VA_ARGS__}); \ 369 Function *F = M.getFunction(_Name); \ 370 if (declMatchesRTFTypes(F, OMPBuilder._ReturnType, ArgsTypes)) { \ 371 auto &RFI = RFIs[_Enum]; \ 372 RFI.Kind = _Enum; \ 373 RFI.Name = _Name; \ 374 RFI.IsVarArg = _IsVarArg; \ 375 RFI.ReturnType = OMPBuilder._ReturnType; \ 376 RFI.ArgumentTypes = std::move(ArgsTypes); \ 377 RFI.Declaration = F; \ 378 unsigned NumUses = collectUses(RFI); \ 379 (void)NumUses; \ 380 LLVM_DEBUG({ \ 381 dbgs() << TAG << RFI.Name << (RFI.Declaration ? "" : " not") \ 382 << " found\n"; \ 383 if (RFI.Declaration) \ 384 dbgs() << TAG << "-> got " << NumUses << " uses in " \ 385 << RFI.getNumFunctionsWithUses() \ 386 << " different functions.\n"; \ 387 }); \ 388 } \ 389 } 390 #include "llvm/Frontend/OpenMP/OMPKinds.def" 391 392 // TODO: We should attach the attributes defined in OMPKinds.def. 393 } 394 395 /// Collection of known kernels (\see Kernel) in the module. 396 SmallPtrSetImpl<Kernel> &Kernels; 397 }; 398 399 /// Used to map the values physically (in the IR) stored in an offload 400 /// array, to a vector in memory. 401 struct OffloadArray { 402 /// Physical array (in the IR). 403 AllocaInst *Array = nullptr; 404 /// Mapped values. 405 SmallVector<Value *, 8> StoredValues; 406 /// Last stores made in the offload array. 407 SmallVector<StoreInst *, 8> LastAccesses; 408 409 OffloadArray() = default; 410 411 /// Initializes the OffloadArray with the values stored in \p Array before 412 /// instruction \p Before is reached. Returns false if the initialization 413 /// fails. 414 /// This MUST be used immediately after the construction of the object. 415 bool initialize(AllocaInst &Array, Instruction &Before) { 416 if (!Array.getAllocatedType()->isArrayTy()) 417 return false; 418 419 if (!getValues(Array, Before)) 420 return false; 421 422 this->Array = &Array; 423 return true; 424 } 425 426 static const unsigned DeviceIDArgNum = 1; 427 static const unsigned BasePtrsArgNum = 3; 428 static const unsigned PtrsArgNum = 4; 429 static const unsigned SizesArgNum = 5; 430 431 private: 432 /// Traverses the BasicBlock where \p Array is, collecting the stores made to 433 /// \p Array, leaving StoredValues with the values stored before the 434 /// instruction \p Before is reached. 435 bool getValues(AllocaInst &Array, Instruction &Before) { 436 // Initialize container. 437 const uint64_t NumValues = Array.getAllocatedType()->getArrayNumElements(); 438 StoredValues.assign(NumValues, nullptr); 439 LastAccesses.assign(NumValues, nullptr); 440 441 // TODO: This assumes the instruction \p Before is in the same 442 // BasicBlock as Array. Make it general, for any control flow graph. 443 BasicBlock *BB = Array.getParent(); 444 if (BB != Before.getParent()) 445 return false; 446 447 const DataLayout &DL = Array.getModule()->getDataLayout(); 448 const unsigned int PointerSize = DL.getPointerSize(); 449 450 for (Instruction &I : *BB) { 451 if (&I == &Before) 452 break; 453 454 if (!isa<StoreInst>(&I)) 455 continue; 456 457 auto *S = cast<StoreInst>(&I); 458 int64_t Offset = -1; 459 auto *Dst = 460 GetPointerBaseWithConstantOffset(S->getPointerOperand(), Offset, DL); 461 if (Dst == &Array) { 462 int64_t Idx = Offset / PointerSize; 463 StoredValues[Idx] = getUnderlyingObject(S->getValueOperand()); 464 LastAccesses[Idx] = S; 465 } 466 } 467 468 return isFilled(); 469 } 470 471 /// Returns true if all values in StoredValues and 472 /// LastAccesses are not nullptrs. 473 bool isFilled() { 474 const unsigned NumValues = StoredValues.size(); 475 for (unsigned I = 0; I < NumValues; ++I) { 476 if (!StoredValues[I] || !LastAccesses[I]) 477 return false; 478 } 479 480 return true; 481 } 482 }; 483 484 struct OpenMPOpt { 485 486 using OptimizationRemarkGetter = 487 function_ref<OptimizationRemarkEmitter &(Function *)>; 488 489 OpenMPOpt(SmallVectorImpl<Function *> &SCC, CallGraphUpdater &CGUpdater, 490 OptimizationRemarkGetter OREGetter, 491 OMPInformationCache &OMPInfoCache, Attributor &A) 492 : M(*(*SCC.begin())->getParent()), SCC(SCC), CGUpdater(CGUpdater), 493 OREGetter(OREGetter), OMPInfoCache(OMPInfoCache), A(A) {} 494 495 /// Check if any remarks are enabled for openmp-opt 496 bool remarksEnabled() { 497 auto &Ctx = M.getContext(); 498 return Ctx.getDiagHandlerPtr()->isAnyRemarkEnabled(DEBUG_TYPE); 499 } 500 501 /// Run all OpenMP optimizations on the underlying SCC/ModuleSlice. 502 bool run(bool IsModulePass) { 503 if (SCC.empty()) 504 return false; 505 506 bool Changed = false; 507 508 LLVM_DEBUG(dbgs() << TAG << "Run on SCC with " << SCC.size() 509 << " functions in a slice with " 510 << OMPInfoCache.ModuleSlice.size() << " functions\n"); 511 512 if (IsModulePass) { 513 Changed |= runAttributor(); 514 515 if (remarksEnabled()) 516 analysisGlobalization(); 517 } else { 518 if (PrintICVValues) 519 printICVs(); 520 if (PrintOpenMPKernels) 521 printKernels(); 522 523 Changed |= rewriteDeviceCodeStateMachine(); 524 525 Changed |= runAttributor(); 526 527 // Recollect uses, in case Attributor deleted any. 528 OMPInfoCache.recollectUses(); 529 530 Changed |= deleteParallelRegions(); 531 if (HideMemoryTransferLatency) 532 Changed |= hideMemTransfersLatency(); 533 Changed |= deduplicateRuntimeCalls(); 534 if (EnableParallelRegionMerging) { 535 if (mergeParallelRegions()) { 536 deduplicateRuntimeCalls(); 537 Changed = true; 538 } 539 } 540 } 541 542 return Changed; 543 } 544 545 /// Print initial ICV values for testing. 546 /// FIXME: This should be done from the Attributor once it is added. 547 void printICVs() const { 548 InternalControlVar ICVs[] = {ICV_nthreads, ICV_active_levels, ICV_cancel, 549 ICV_proc_bind}; 550 551 for (Function *F : OMPInfoCache.ModuleSlice) { 552 for (auto ICV : ICVs) { 553 auto ICVInfo = OMPInfoCache.ICVs[ICV]; 554 auto Remark = [&](OptimizationRemarkAnalysis ORA) { 555 return ORA << "OpenMP ICV " << ore::NV("OpenMPICV", ICVInfo.Name) 556 << " Value: " 557 << (ICVInfo.InitValue 558 ? toString(ICVInfo.InitValue->getValue(), 10, true) 559 : "IMPLEMENTATION_DEFINED"); 560 }; 561 562 emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPICVTracker", Remark); 563 } 564 } 565 } 566 567 /// Print OpenMP GPU kernels for testing. 568 void printKernels() const { 569 for (Function *F : SCC) { 570 if (!OMPInfoCache.Kernels.count(F)) 571 continue; 572 573 auto Remark = [&](OptimizationRemarkAnalysis ORA) { 574 return ORA << "OpenMP GPU kernel " 575 << ore::NV("OpenMPGPUKernel", F->getName()) << "\n"; 576 }; 577 578 emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPGPU", Remark); 579 } 580 } 581 582 /// Return the call if \p U is a callee use in a regular call. If \p RFI is 583 /// given it has to be the callee or a nullptr is returned. 584 static CallInst *getCallIfRegularCall( 585 Use &U, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) { 586 CallInst *CI = dyn_cast<CallInst>(U.getUser()); 587 if (CI && CI->isCallee(&U) && !CI->hasOperandBundles() && 588 (!RFI || CI->getCalledFunction() == RFI->Declaration)) 589 return CI; 590 return nullptr; 591 } 592 593 /// Return the call if \p V is a regular call. If \p RFI is given it has to be 594 /// the callee or a nullptr is returned. 595 static CallInst *getCallIfRegularCall( 596 Value &V, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) { 597 CallInst *CI = dyn_cast<CallInst>(&V); 598 if (CI && !CI->hasOperandBundles() && 599 (!RFI || CI->getCalledFunction() == RFI->Declaration)) 600 return CI; 601 return nullptr; 602 } 603 604 private: 605 /// Merge parallel regions when it is safe. 606 bool mergeParallelRegions() { 607 const unsigned CallbackCalleeOperand = 2; 608 const unsigned CallbackFirstArgOperand = 3; 609 using InsertPointTy = OpenMPIRBuilder::InsertPointTy; 610 611 // Check if there are any __kmpc_fork_call calls to merge. 612 OMPInformationCache::RuntimeFunctionInfo &RFI = 613 OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call]; 614 615 if (!RFI.Declaration) 616 return false; 617 618 // Unmergable calls that prevent merging a parallel region. 619 OMPInformationCache::RuntimeFunctionInfo UnmergableCallsInfo[] = { 620 OMPInfoCache.RFIs[OMPRTL___kmpc_push_proc_bind], 621 OMPInfoCache.RFIs[OMPRTL___kmpc_push_num_threads], 622 }; 623 624 bool Changed = false; 625 LoopInfo *LI = nullptr; 626 DominatorTree *DT = nullptr; 627 628 SmallDenseMap<BasicBlock *, SmallPtrSet<Instruction *, 4>> BB2PRMap; 629 630 BasicBlock *StartBB = nullptr, *EndBB = nullptr; 631 auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, 632 BasicBlock &ContinuationIP) { 633 BasicBlock *CGStartBB = CodeGenIP.getBlock(); 634 BasicBlock *CGEndBB = 635 SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI); 636 assert(StartBB != nullptr && "StartBB should not be null"); 637 CGStartBB->getTerminator()->setSuccessor(0, StartBB); 638 assert(EndBB != nullptr && "EndBB should not be null"); 639 EndBB->getTerminator()->setSuccessor(0, CGEndBB); 640 }; 641 642 auto PrivCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, Value &, 643 Value &Inner, Value *&ReplacementValue) -> InsertPointTy { 644 ReplacementValue = &Inner; 645 return CodeGenIP; 646 }; 647 648 auto FiniCB = [&](InsertPointTy CodeGenIP) {}; 649 650 /// Create a sequential execution region within a merged parallel region, 651 /// encapsulated in a master construct with a barrier for synchronization. 652 auto CreateSequentialRegion = [&](Function *OuterFn, 653 BasicBlock *OuterPredBB, 654 Instruction *SeqStartI, 655 Instruction *SeqEndI) { 656 // Isolate the instructions of the sequential region to a separate 657 // block. 658 BasicBlock *ParentBB = SeqStartI->getParent(); 659 BasicBlock *SeqEndBB = 660 SplitBlock(ParentBB, SeqEndI->getNextNode(), DT, LI); 661 BasicBlock *SeqAfterBB = 662 SplitBlock(SeqEndBB, &*SeqEndBB->getFirstInsertionPt(), DT, LI); 663 BasicBlock *SeqStartBB = 664 SplitBlock(ParentBB, SeqStartI, DT, LI, nullptr, "seq.par.merged"); 665 666 assert(ParentBB->getUniqueSuccessor() == SeqStartBB && 667 "Expected a different CFG"); 668 const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc(); 669 ParentBB->getTerminator()->eraseFromParent(); 670 671 auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, 672 BasicBlock &ContinuationIP) { 673 BasicBlock *CGStartBB = CodeGenIP.getBlock(); 674 BasicBlock *CGEndBB = 675 SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI); 676 assert(SeqStartBB != nullptr && "SeqStartBB should not be null"); 677 CGStartBB->getTerminator()->setSuccessor(0, SeqStartBB); 678 assert(SeqEndBB != nullptr && "SeqEndBB should not be null"); 679 SeqEndBB->getTerminator()->setSuccessor(0, CGEndBB); 680 }; 681 auto FiniCB = [&](InsertPointTy CodeGenIP) {}; 682 683 // Find outputs from the sequential region to outside users and 684 // broadcast their values to them. 685 for (Instruction &I : *SeqStartBB) { 686 SmallPtrSet<Instruction *, 4> OutsideUsers; 687 for (User *Usr : I.users()) { 688 Instruction &UsrI = *cast<Instruction>(Usr); 689 // Ignore outputs to LT intrinsics, code extraction for the merged 690 // parallel region will fix them. 691 if (UsrI.isLifetimeStartOrEnd()) 692 continue; 693 694 if (UsrI.getParent() != SeqStartBB) 695 OutsideUsers.insert(&UsrI); 696 } 697 698 if (OutsideUsers.empty()) 699 continue; 700 701 // Emit an alloca in the outer region to store the broadcasted 702 // value. 703 const DataLayout &DL = M.getDataLayout(); 704 AllocaInst *AllocaI = new AllocaInst( 705 I.getType(), DL.getAllocaAddrSpace(), nullptr, 706 I.getName() + ".seq.output.alloc", &OuterFn->front().front()); 707 708 // Emit a store instruction in the sequential BB to update the 709 // value. 710 new StoreInst(&I, AllocaI, SeqStartBB->getTerminator()); 711 712 // Emit a load instruction and replace the use of the output value 713 // with it. 714 for (Instruction *UsrI : OutsideUsers) { 715 LoadInst *LoadI = new LoadInst( 716 I.getType(), AllocaI, I.getName() + ".seq.output.load", UsrI); 717 UsrI->replaceUsesOfWith(&I, LoadI); 718 } 719 } 720 721 OpenMPIRBuilder::LocationDescription Loc( 722 InsertPointTy(ParentBB, ParentBB->end()), DL); 723 InsertPointTy SeqAfterIP = 724 OMPInfoCache.OMPBuilder.createMaster(Loc, BodyGenCB, FiniCB); 725 726 OMPInfoCache.OMPBuilder.createBarrier(SeqAfterIP, OMPD_parallel); 727 728 BranchInst::Create(SeqAfterBB, SeqAfterIP.getBlock()); 729 730 LLVM_DEBUG(dbgs() << TAG << "After sequential inlining " << *OuterFn 731 << "\n"); 732 }; 733 734 // Helper to merge the __kmpc_fork_call calls in MergableCIs. They are all 735 // contained in BB and only separated by instructions that can be 736 // redundantly executed in parallel. The block BB is split before the first 737 // call (in MergableCIs) and after the last so the entire region we merge 738 // into a single parallel region is contained in a single basic block 739 // without any other instructions. We use the OpenMPIRBuilder to outline 740 // that block and call the resulting function via __kmpc_fork_call. 741 auto Merge = [&](SmallVectorImpl<CallInst *> &MergableCIs, BasicBlock *BB) { 742 // TODO: Change the interface to allow single CIs expanded, e.g, to 743 // include an outer loop. 744 assert(MergableCIs.size() > 1 && "Assumed multiple mergable CIs"); 745 746 auto Remark = [&](OptimizationRemark OR) { 747 OR << "Parallel region at " 748 << ore::NV("OpenMPParallelMergeFront", 749 MergableCIs.front()->getDebugLoc()) 750 << " merged with parallel regions at "; 751 for (auto *CI : llvm::drop_begin(MergableCIs)) { 752 OR << ore::NV("OpenMPParallelMerge", CI->getDebugLoc()); 753 if (CI != MergableCIs.back()) 754 OR << ", "; 755 } 756 return OR; 757 }; 758 759 emitRemark<OptimizationRemark>(MergableCIs.front(), 760 "OpenMPParallelRegionMerging", Remark); 761 762 Function *OriginalFn = BB->getParent(); 763 LLVM_DEBUG(dbgs() << TAG << "Merge " << MergableCIs.size() 764 << " parallel regions in " << OriginalFn->getName() 765 << "\n"); 766 767 // Isolate the calls to merge in a separate block. 768 EndBB = SplitBlock(BB, MergableCIs.back()->getNextNode(), DT, LI); 769 BasicBlock *AfterBB = 770 SplitBlock(EndBB, &*EndBB->getFirstInsertionPt(), DT, LI); 771 StartBB = SplitBlock(BB, MergableCIs.front(), DT, LI, nullptr, 772 "omp.par.merged"); 773 774 assert(BB->getUniqueSuccessor() == StartBB && "Expected a different CFG"); 775 const DebugLoc DL = BB->getTerminator()->getDebugLoc(); 776 BB->getTerminator()->eraseFromParent(); 777 778 // Create sequential regions for sequential instructions that are 779 // in-between mergable parallel regions. 780 for (auto *It = MergableCIs.begin(), *End = MergableCIs.end() - 1; 781 It != End; ++It) { 782 Instruction *ForkCI = *It; 783 Instruction *NextForkCI = *(It + 1); 784 785 // Continue if there are not in-between instructions. 786 if (ForkCI->getNextNode() == NextForkCI) 787 continue; 788 789 CreateSequentialRegion(OriginalFn, BB, ForkCI->getNextNode(), 790 NextForkCI->getPrevNode()); 791 } 792 793 OpenMPIRBuilder::LocationDescription Loc(InsertPointTy(BB, BB->end()), 794 DL); 795 IRBuilder<>::InsertPoint AllocaIP( 796 &OriginalFn->getEntryBlock(), 797 OriginalFn->getEntryBlock().getFirstInsertionPt()); 798 // Create the merged parallel region with default proc binding, to 799 // avoid overriding binding settings, and without explicit cancellation. 800 InsertPointTy AfterIP = OMPInfoCache.OMPBuilder.createParallel( 801 Loc, AllocaIP, BodyGenCB, PrivCB, FiniCB, nullptr, nullptr, 802 OMP_PROC_BIND_default, /* IsCancellable */ false); 803 BranchInst::Create(AfterBB, AfterIP.getBlock()); 804 805 // Perform the actual outlining. 806 OMPInfoCache.OMPBuilder.finalize(OriginalFn, 807 /* AllowExtractorSinking */ true); 808 809 Function *OutlinedFn = MergableCIs.front()->getCaller(); 810 811 // Replace the __kmpc_fork_call calls with direct calls to the outlined 812 // callbacks. 813 SmallVector<Value *, 8> Args; 814 for (auto *CI : MergableCIs) { 815 Value *Callee = 816 CI->getArgOperand(CallbackCalleeOperand)->stripPointerCasts(); 817 FunctionType *FT = 818 cast<FunctionType>(Callee->getType()->getPointerElementType()); 819 Args.clear(); 820 Args.push_back(OutlinedFn->getArg(0)); 821 Args.push_back(OutlinedFn->getArg(1)); 822 for (unsigned U = CallbackFirstArgOperand, E = CI->getNumArgOperands(); 823 U < E; ++U) 824 Args.push_back(CI->getArgOperand(U)); 825 826 CallInst *NewCI = CallInst::Create(FT, Callee, Args, "", CI); 827 if (CI->getDebugLoc()) 828 NewCI->setDebugLoc(CI->getDebugLoc()); 829 830 // Forward parameter attributes from the callback to the callee. 831 for (unsigned U = CallbackFirstArgOperand, E = CI->getNumArgOperands(); 832 U < E; ++U) 833 for (const Attribute &A : CI->getAttributes().getParamAttributes(U)) 834 NewCI->addParamAttr( 835 U - (CallbackFirstArgOperand - CallbackCalleeOperand), A); 836 837 // Emit an explicit barrier to replace the implicit fork-join barrier. 838 if (CI != MergableCIs.back()) { 839 // TODO: Remove barrier if the merged parallel region includes the 840 // 'nowait' clause. 841 OMPInfoCache.OMPBuilder.createBarrier( 842 InsertPointTy(NewCI->getParent(), 843 NewCI->getNextNode()->getIterator()), 844 OMPD_parallel); 845 } 846 847 auto Remark = [&](OptimizationRemark OR) { 848 return OR << "Parallel region at " 849 << ore::NV("OpenMPParallelMerge", CI->getDebugLoc()) 850 << " merged with " 851 << ore::NV("OpenMPParallelMergeFront", 852 MergableCIs.front()->getDebugLoc()); 853 }; 854 if (CI != MergableCIs.front()) 855 emitRemark<OptimizationRemark>(CI, "OpenMPParallelRegionMerging", 856 Remark); 857 858 CI->eraseFromParent(); 859 } 860 861 assert(OutlinedFn != OriginalFn && "Outlining failed"); 862 CGUpdater.registerOutlinedFunction(*OriginalFn, *OutlinedFn); 863 CGUpdater.reanalyzeFunction(*OriginalFn); 864 865 NumOpenMPParallelRegionsMerged += MergableCIs.size(); 866 867 return true; 868 }; 869 870 // Helper function that identifes sequences of 871 // __kmpc_fork_call uses in a basic block. 872 auto DetectPRsCB = [&](Use &U, Function &F) { 873 CallInst *CI = getCallIfRegularCall(U, &RFI); 874 BB2PRMap[CI->getParent()].insert(CI); 875 876 return false; 877 }; 878 879 BB2PRMap.clear(); 880 RFI.foreachUse(SCC, DetectPRsCB); 881 SmallVector<SmallVector<CallInst *, 4>, 4> MergableCIsVector; 882 // Find mergable parallel regions within a basic block that are 883 // safe to merge, that is any in-between instructions can safely 884 // execute in parallel after merging. 885 // TODO: support merging across basic-blocks. 886 for (auto &It : BB2PRMap) { 887 auto &CIs = It.getSecond(); 888 if (CIs.size() < 2) 889 continue; 890 891 BasicBlock *BB = It.getFirst(); 892 SmallVector<CallInst *, 4> MergableCIs; 893 894 /// Returns true if the instruction is mergable, false otherwise. 895 /// A terminator instruction is unmergable by definition since merging 896 /// works within a BB. Instructions before the mergable region are 897 /// mergable if they are not calls to OpenMP runtime functions that may 898 /// set different execution parameters for subsequent parallel regions. 899 /// Instructions in-between parallel regions are mergable if they are not 900 /// calls to any non-intrinsic function since that may call a non-mergable 901 /// OpenMP runtime function. 902 auto IsMergable = [&](Instruction &I, bool IsBeforeMergableRegion) { 903 // We do not merge across BBs, hence return false (unmergable) if the 904 // instruction is a terminator. 905 if (I.isTerminator()) 906 return false; 907 908 if (!isa<CallInst>(&I)) 909 return true; 910 911 CallInst *CI = cast<CallInst>(&I); 912 if (IsBeforeMergableRegion) { 913 Function *CalledFunction = CI->getCalledFunction(); 914 if (!CalledFunction) 915 return false; 916 // Return false (unmergable) if the call before the parallel 917 // region calls an explicit affinity (proc_bind) or number of 918 // threads (num_threads) compiler-generated function. Those settings 919 // may be incompatible with following parallel regions. 920 // TODO: ICV tracking to detect compatibility. 921 for (const auto &RFI : UnmergableCallsInfo) { 922 if (CalledFunction == RFI.Declaration) 923 return false; 924 } 925 } else { 926 // Return false (unmergable) if there is a call instruction 927 // in-between parallel regions when it is not an intrinsic. It 928 // may call an unmergable OpenMP runtime function in its callpath. 929 // TODO: Keep track of possible OpenMP calls in the callpath. 930 if (!isa<IntrinsicInst>(CI)) 931 return false; 932 } 933 934 return true; 935 }; 936 // Find maximal number of parallel region CIs that are safe to merge. 937 for (auto It = BB->begin(), End = BB->end(); It != End;) { 938 Instruction &I = *It; 939 ++It; 940 941 if (CIs.count(&I)) { 942 MergableCIs.push_back(cast<CallInst>(&I)); 943 continue; 944 } 945 946 // Continue expanding if the instruction is mergable. 947 if (IsMergable(I, MergableCIs.empty())) 948 continue; 949 950 // Forward the instruction iterator to skip the next parallel region 951 // since there is an unmergable instruction which can affect it. 952 for (; It != End; ++It) { 953 Instruction &SkipI = *It; 954 if (CIs.count(&SkipI)) { 955 LLVM_DEBUG(dbgs() << TAG << "Skip parallel region " << SkipI 956 << " due to " << I << "\n"); 957 ++It; 958 break; 959 } 960 } 961 962 // Store mergable regions found. 963 if (MergableCIs.size() > 1) { 964 MergableCIsVector.push_back(MergableCIs); 965 LLVM_DEBUG(dbgs() << TAG << "Found " << MergableCIs.size() 966 << " parallel regions in block " << BB->getName() 967 << " of function " << BB->getParent()->getName() 968 << "\n";); 969 } 970 971 MergableCIs.clear(); 972 } 973 974 if (!MergableCIsVector.empty()) { 975 Changed = true; 976 977 for (auto &MergableCIs : MergableCIsVector) 978 Merge(MergableCIs, BB); 979 MergableCIsVector.clear(); 980 } 981 } 982 983 if (Changed) { 984 /// Re-collect use for fork calls, emitted barrier calls, and 985 /// any emitted master/end_master calls. 986 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_fork_call); 987 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_barrier); 988 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_master); 989 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_end_master); 990 } 991 992 return Changed; 993 } 994 995 /// Try to delete parallel regions if possible. 996 bool deleteParallelRegions() { 997 const unsigned CallbackCalleeOperand = 2; 998 999 OMPInformationCache::RuntimeFunctionInfo &RFI = 1000 OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call]; 1001 1002 if (!RFI.Declaration) 1003 return false; 1004 1005 bool Changed = false; 1006 auto DeleteCallCB = [&](Use &U, Function &) { 1007 CallInst *CI = getCallIfRegularCall(U); 1008 if (!CI) 1009 return false; 1010 auto *Fn = dyn_cast<Function>( 1011 CI->getArgOperand(CallbackCalleeOperand)->stripPointerCasts()); 1012 if (!Fn) 1013 return false; 1014 if (!Fn->onlyReadsMemory()) 1015 return false; 1016 if (!Fn->hasFnAttribute(Attribute::WillReturn)) 1017 return false; 1018 1019 LLVM_DEBUG(dbgs() << TAG << "Delete read-only parallel region in " 1020 << CI->getCaller()->getName() << "\n"); 1021 1022 auto Remark = [&](OptimizationRemark OR) { 1023 return OR << "Parallel region in " 1024 << ore::NV("OpenMPParallelDelete", CI->getCaller()->getName()) 1025 << " deleted"; 1026 }; 1027 emitRemark<OptimizationRemark>(CI, "OpenMPParallelRegionDeletion", 1028 Remark); 1029 1030 CGUpdater.removeCallSite(*CI); 1031 CI->eraseFromParent(); 1032 Changed = true; 1033 ++NumOpenMPParallelRegionsDeleted; 1034 return true; 1035 }; 1036 1037 RFI.foreachUse(SCC, DeleteCallCB); 1038 1039 return Changed; 1040 } 1041 1042 /// Try to eliminate runtime calls by reusing existing ones. 1043 bool deduplicateRuntimeCalls() { 1044 bool Changed = false; 1045 1046 RuntimeFunction DeduplicableRuntimeCallIDs[] = { 1047 OMPRTL_omp_get_num_threads, 1048 OMPRTL_omp_in_parallel, 1049 OMPRTL_omp_get_cancellation, 1050 OMPRTL_omp_get_thread_limit, 1051 OMPRTL_omp_get_supported_active_levels, 1052 OMPRTL_omp_get_level, 1053 OMPRTL_omp_get_ancestor_thread_num, 1054 OMPRTL_omp_get_team_size, 1055 OMPRTL_omp_get_active_level, 1056 OMPRTL_omp_in_final, 1057 OMPRTL_omp_get_proc_bind, 1058 OMPRTL_omp_get_num_places, 1059 OMPRTL_omp_get_num_procs, 1060 OMPRTL_omp_get_place_num, 1061 OMPRTL_omp_get_partition_num_places, 1062 OMPRTL_omp_get_partition_place_nums}; 1063 1064 // Global-tid is handled separately. 1065 SmallSetVector<Value *, 16> GTIdArgs; 1066 collectGlobalThreadIdArguments(GTIdArgs); 1067 LLVM_DEBUG(dbgs() << TAG << "Found " << GTIdArgs.size() 1068 << " global thread ID arguments\n"); 1069 1070 for (Function *F : SCC) { 1071 for (auto DeduplicableRuntimeCallID : DeduplicableRuntimeCallIDs) 1072 Changed |= deduplicateRuntimeCalls( 1073 *F, OMPInfoCache.RFIs[DeduplicableRuntimeCallID]); 1074 1075 // __kmpc_global_thread_num is special as we can replace it with an 1076 // argument in enough cases to make it worth trying. 1077 Value *GTIdArg = nullptr; 1078 for (Argument &Arg : F->args()) 1079 if (GTIdArgs.count(&Arg)) { 1080 GTIdArg = &Arg; 1081 break; 1082 } 1083 Changed |= deduplicateRuntimeCalls( 1084 *F, OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num], GTIdArg); 1085 } 1086 1087 return Changed; 1088 } 1089 1090 /// Tries to hide the latency of runtime calls that involve host to 1091 /// device memory transfers by splitting them into their "issue" and "wait" 1092 /// versions. The "issue" is moved upwards as much as possible. The "wait" is 1093 /// moved downards as much as possible. The "issue" issues the memory transfer 1094 /// asynchronously, returning a handle. The "wait" waits in the returned 1095 /// handle for the memory transfer to finish. 1096 bool hideMemTransfersLatency() { 1097 auto &RFI = OMPInfoCache.RFIs[OMPRTL___tgt_target_data_begin_mapper]; 1098 bool Changed = false; 1099 auto SplitMemTransfers = [&](Use &U, Function &Decl) { 1100 auto *RTCall = getCallIfRegularCall(U, &RFI); 1101 if (!RTCall) 1102 return false; 1103 1104 OffloadArray OffloadArrays[3]; 1105 if (!getValuesInOffloadArrays(*RTCall, OffloadArrays)) 1106 return false; 1107 1108 LLVM_DEBUG(dumpValuesInOffloadArrays(OffloadArrays)); 1109 1110 // TODO: Check if can be moved upwards. 1111 bool WasSplit = false; 1112 Instruction *WaitMovementPoint = canBeMovedDownwards(*RTCall); 1113 if (WaitMovementPoint) 1114 WasSplit = splitTargetDataBeginRTC(*RTCall, *WaitMovementPoint); 1115 1116 Changed |= WasSplit; 1117 return WasSplit; 1118 }; 1119 RFI.foreachUse(SCC, SplitMemTransfers); 1120 1121 return Changed; 1122 } 1123 1124 void analysisGlobalization() { 1125 RuntimeFunction GlobalizationRuntimeIDs[] = { 1126 OMPRTL___kmpc_data_sharing_coalesced_push_stack, 1127 OMPRTL___kmpc_data_sharing_push_stack}; 1128 1129 for (const auto GlobalizationCallID : GlobalizationRuntimeIDs) { 1130 auto &RFI = OMPInfoCache.RFIs[GlobalizationCallID]; 1131 1132 auto CheckGlobalization = [&](Use &U, Function &Decl) { 1133 if (CallInst *CI = getCallIfRegularCall(U, &RFI)) { 1134 auto Remark = [&](OptimizationRemarkAnalysis ORA) { 1135 return ORA 1136 << "Found thread data sharing on the GPU. " 1137 << "Expect degraded performance due to data globalization."; 1138 }; 1139 emitRemark<OptimizationRemarkAnalysis>(CI, "OpenMPGlobalization", 1140 Remark); 1141 } 1142 1143 return false; 1144 }; 1145 1146 RFI.foreachUse(SCC, CheckGlobalization); 1147 } 1148 } 1149 1150 /// Maps the values stored in the offload arrays passed as arguments to 1151 /// \p RuntimeCall into the offload arrays in \p OAs. 1152 bool getValuesInOffloadArrays(CallInst &RuntimeCall, 1153 MutableArrayRef<OffloadArray> OAs) { 1154 assert(OAs.size() == 3 && "Need space for three offload arrays!"); 1155 1156 // A runtime call that involves memory offloading looks something like: 1157 // call void @__tgt_target_data_begin_mapper(arg0, arg1, 1158 // i8** %offload_baseptrs, i8** %offload_ptrs, i64* %offload_sizes, 1159 // ...) 1160 // So, the idea is to access the allocas that allocate space for these 1161 // offload arrays, offload_baseptrs, offload_ptrs, offload_sizes. 1162 // Therefore: 1163 // i8** %offload_baseptrs. 1164 Value *BasePtrsArg = 1165 RuntimeCall.getArgOperand(OffloadArray::BasePtrsArgNum); 1166 // i8** %offload_ptrs. 1167 Value *PtrsArg = RuntimeCall.getArgOperand(OffloadArray::PtrsArgNum); 1168 // i8** %offload_sizes. 1169 Value *SizesArg = RuntimeCall.getArgOperand(OffloadArray::SizesArgNum); 1170 1171 // Get values stored in **offload_baseptrs. 1172 auto *V = getUnderlyingObject(BasePtrsArg); 1173 if (!isa<AllocaInst>(V)) 1174 return false; 1175 auto *BasePtrsArray = cast<AllocaInst>(V); 1176 if (!OAs[0].initialize(*BasePtrsArray, RuntimeCall)) 1177 return false; 1178 1179 // Get values stored in **offload_baseptrs. 1180 V = getUnderlyingObject(PtrsArg); 1181 if (!isa<AllocaInst>(V)) 1182 return false; 1183 auto *PtrsArray = cast<AllocaInst>(V); 1184 if (!OAs[1].initialize(*PtrsArray, RuntimeCall)) 1185 return false; 1186 1187 // Get values stored in **offload_sizes. 1188 V = getUnderlyingObject(SizesArg); 1189 // If it's a [constant] global array don't analyze it. 1190 if (isa<GlobalValue>(V)) 1191 return isa<Constant>(V); 1192 if (!isa<AllocaInst>(V)) 1193 return false; 1194 1195 auto *SizesArray = cast<AllocaInst>(V); 1196 if (!OAs[2].initialize(*SizesArray, RuntimeCall)) 1197 return false; 1198 1199 return true; 1200 } 1201 1202 /// Prints the values in the OffloadArrays \p OAs using LLVM_DEBUG. 1203 /// For now this is a way to test that the function getValuesInOffloadArrays 1204 /// is working properly. 1205 /// TODO: Move this to a unittest when unittests are available for OpenMPOpt. 1206 void dumpValuesInOffloadArrays(ArrayRef<OffloadArray> OAs) { 1207 assert(OAs.size() == 3 && "There are three offload arrays to debug!"); 1208 1209 LLVM_DEBUG(dbgs() << TAG << " Successfully got offload values:\n"); 1210 std::string ValuesStr; 1211 raw_string_ostream Printer(ValuesStr); 1212 std::string Separator = " --- "; 1213 1214 for (auto *BP : OAs[0].StoredValues) { 1215 BP->print(Printer); 1216 Printer << Separator; 1217 } 1218 LLVM_DEBUG(dbgs() << "\t\toffload_baseptrs: " << Printer.str() << "\n"); 1219 ValuesStr.clear(); 1220 1221 for (auto *P : OAs[1].StoredValues) { 1222 P->print(Printer); 1223 Printer << Separator; 1224 } 1225 LLVM_DEBUG(dbgs() << "\t\toffload_ptrs: " << Printer.str() << "\n"); 1226 ValuesStr.clear(); 1227 1228 for (auto *S : OAs[2].StoredValues) { 1229 S->print(Printer); 1230 Printer << Separator; 1231 } 1232 LLVM_DEBUG(dbgs() << "\t\toffload_sizes: " << Printer.str() << "\n"); 1233 } 1234 1235 /// Returns the instruction where the "wait" counterpart \p RuntimeCall can be 1236 /// moved. Returns nullptr if the movement is not possible, or not worth it. 1237 Instruction *canBeMovedDownwards(CallInst &RuntimeCall) { 1238 // FIXME: This traverses only the BasicBlock where RuntimeCall is. 1239 // Make it traverse the CFG. 1240 1241 Instruction *CurrentI = &RuntimeCall; 1242 bool IsWorthIt = false; 1243 while ((CurrentI = CurrentI->getNextNode())) { 1244 1245 // TODO: Once we detect the regions to be offloaded we should use the 1246 // alias analysis manager to check if CurrentI may modify one of 1247 // the offloaded regions. 1248 if (CurrentI->mayHaveSideEffects() || CurrentI->mayReadFromMemory()) { 1249 if (IsWorthIt) 1250 return CurrentI; 1251 1252 return nullptr; 1253 } 1254 1255 // FIXME: For now if we move it over anything without side effect 1256 // is worth it. 1257 IsWorthIt = true; 1258 } 1259 1260 // Return end of BasicBlock. 1261 return RuntimeCall.getParent()->getTerminator(); 1262 } 1263 1264 /// Splits \p RuntimeCall into its "issue" and "wait" counterparts. 1265 bool splitTargetDataBeginRTC(CallInst &RuntimeCall, 1266 Instruction &WaitMovementPoint) { 1267 // Create stack allocated handle (__tgt_async_info) at the beginning of the 1268 // function. Used for storing information of the async transfer, allowing to 1269 // wait on it later. 1270 auto &IRBuilder = OMPInfoCache.OMPBuilder; 1271 auto *F = RuntimeCall.getCaller(); 1272 Instruction *FirstInst = &(F->getEntryBlock().front()); 1273 AllocaInst *Handle = new AllocaInst( 1274 IRBuilder.AsyncInfo, F->getAddressSpace(), "handle", FirstInst); 1275 1276 // Add "issue" runtime call declaration: 1277 // declare %struct.tgt_async_info @__tgt_target_data_begin_issue(i64, i32, 1278 // i8**, i8**, i64*, i64*) 1279 FunctionCallee IssueDecl = IRBuilder.getOrCreateRuntimeFunction( 1280 M, OMPRTL___tgt_target_data_begin_mapper_issue); 1281 1282 // Change RuntimeCall call site for its asynchronous version. 1283 SmallVector<Value *, 16> Args; 1284 for (auto &Arg : RuntimeCall.args()) 1285 Args.push_back(Arg.get()); 1286 Args.push_back(Handle); 1287 1288 CallInst *IssueCallsite = 1289 CallInst::Create(IssueDecl, Args, /*NameStr=*/"", &RuntimeCall); 1290 RuntimeCall.eraseFromParent(); 1291 1292 // Add "wait" runtime call declaration: 1293 // declare void @__tgt_target_data_begin_wait(i64, %struct.__tgt_async_info) 1294 FunctionCallee WaitDecl = IRBuilder.getOrCreateRuntimeFunction( 1295 M, OMPRTL___tgt_target_data_begin_mapper_wait); 1296 1297 Value *WaitParams[2] = { 1298 IssueCallsite->getArgOperand( 1299 OffloadArray::DeviceIDArgNum), // device_id. 1300 Handle // handle to wait on. 1301 }; 1302 CallInst::Create(WaitDecl, WaitParams, /*NameStr=*/"", &WaitMovementPoint); 1303 1304 return true; 1305 } 1306 1307 static Value *combinedIdentStruct(Value *CurrentIdent, Value *NextIdent, 1308 bool GlobalOnly, bool &SingleChoice) { 1309 if (CurrentIdent == NextIdent) 1310 return CurrentIdent; 1311 1312 // TODO: Figure out how to actually combine multiple debug locations. For 1313 // now we just keep an existing one if there is a single choice. 1314 if (!GlobalOnly || isa<GlobalValue>(NextIdent)) { 1315 SingleChoice = !CurrentIdent; 1316 return NextIdent; 1317 } 1318 return nullptr; 1319 } 1320 1321 /// Return an `struct ident_t*` value that represents the ones used in the 1322 /// calls of \p RFI inside of \p F. If \p GlobalOnly is true, we will not 1323 /// return a local `struct ident_t*`. For now, if we cannot find a suitable 1324 /// return value we create one from scratch. We also do not yet combine 1325 /// information, e.g., the source locations, see combinedIdentStruct. 1326 Value * 1327 getCombinedIdentFromCallUsesIn(OMPInformationCache::RuntimeFunctionInfo &RFI, 1328 Function &F, bool GlobalOnly) { 1329 bool SingleChoice = true; 1330 Value *Ident = nullptr; 1331 auto CombineIdentStruct = [&](Use &U, Function &Caller) { 1332 CallInst *CI = getCallIfRegularCall(U, &RFI); 1333 if (!CI || &F != &Caller) 1334 return false; 1335 Ident = combinedIdentStruct(Ident, CI->getArgOperand(0), 1336 /* GlobalOnly */ true, SingleChoice); 1337 return false; 1338 }; 1339 RFI.foreachUse(SCC, CombineIdentStruct); 1340 1341 if (!Ident || !SingleChoice) { 1342 // The IRBuilder uses the insertion block to get to the module, this is 1343 // unfortunate but we work around it for now. 1344 if (!OMPInfoCache.OMPBuilder.getInsertionPoint().getBlock()) 1345 OMPInfoCache.OMPBuilder.updateToLocation(OpenMPIRBuilder::InsertPointTy( 1346 &F.getEntryBlock(), F.getEntryBlock().begin())); 1347 // Create a fallback location if non was found. 1348 // TODO: Use the debug locations of the calls instead. 1349 Constant *Loc = OMPInfoCache.OMPBuilder.getOrCreateDefaultSrcLocStr(); 1350 Ident = OMPInfoCache.OMPBuilder.getOrCreateIdent(Loc); 1351 } 1352 return Ident; 1353 } 1354 1355 /// Try to eliminate calls of \p RFI in \p F by reusing an existing one or 1356 /// \p ReplVal if given. 1357 bool deduplicateRuntimeCalls(Function &F, 1358 OMPInformationCache::RuntimeFunctionInfo &RFI, 1359 Value *ReplVal = nullptr) { 1360 auto *UV = RFI.getUseVector(F); 1361 if (!UV || UV->size() + (ReplVal != nullptr) < 2) 1362 return false; 1363 1364 LLVM_DEBUG( 1365 dbgs() << TAG << "Deduplicate " << UV->size() << " uses of " << RFI.Name 1366 << (ReplVal ? " with an existing value\n" : "\n") << "\n"); 1367 1368 assert((!ReplVal || (isa<Argument>(ReplVal) && 1369 cast<Argument>(ReplVal)->getParent() == &F)) && 1370 "Unexpected replacement value!"); 1371 1372 // TODO: Use dominance to find a good position instead. 1373 auto CanBeMoved = [this](CallBase &CB) { 1374 unsigned NumArgs = CB.getNumArgOperands(); 1375 if (NumArgs == 0) 1376 return true; 1377 if (CB.getArgOperand(0)->getType() != OMPInfoCache.OMPBuilder.IdentPtr) 1378 return false; 1379 for (unsigned u = 1; u < NumArgs; ++u) 1380 if (isa<Instruction>(CB.getArgOperand(u))) 1381 return false; 1382 return true; 1383 }; 1384 1385 if (!ReplVal) { 1386 for (Use *U : *UV) 1387 if (CallInst *CI = getCallIfRegularCall(*U, &RFI)) { 1388 if (!CanBeMoved(*CI)) 1389 continue; 1390 1391 auto Remark = [&](OptimizationRemark OR) { 1392 return OR << "OpenMP runtime call " 1393 << ore::NV("OpenMPOptRuntime", RFI.Name) 1394 << " moved to beginning of OpenMP region"; 1395 }; 1396 emitRemark<OptimizationRemark>(&F, "OpenMPRuntimeCodeMotion", Remark); 1397 1398 CI->moveBefore(&*F.getEntryBlock().getFirstInsertionPt()); 1399 ReplVal = CI; 1400 break; 1401 } 1402 if (!ReplVal) 1403 return false; 1404 } 1405 1406 // If we use a call as a replacement value we need to make sure the ident is 1407 // valid at the new location. For now we just pick a global one, either 1408 // existing and used by one of the calls, or created from scratch. 1409 if (CallBase *CI = dyn_cast<CallBase>(ReplVal)) { 1410 if (CI->getNumArgOperands() > 0 && 1411 CI->getArgOperand(0)->getType() == OMPInfoCache.OMPBuilder.IdentPtr) { 1412 Value *Ident = getCombinedIdentFromCallUsesIn(RFI, F, 1413 /* GlobalOnly */ true); 1414 CI->setArgOperand(0, Ident); 1415 } 1416 } 1417 1418 bool Changed = false; 1419 auto ReplaceAndDeleteCB = [&](Use &U, Function &Caller) { 1420 CallInst *CI = getCallIfRegularCall(U, &RFI); 1421 if (!CI || CI == ReplVal || &F != &Caller) 1422 return false; 1423 assert(CI->getCaller() == &F && "Unexpected call!"); 1424 1425 auto Remark = [&](OptimizationRemark OR) { 1426 return OR << "OpenMP runtime call " 1427 << ore::NV("OpenMPOptRuntime", RFI.Name) << " deduplicated"; 1428 }; 1429 emitRemark<OptimizationRemark>(&F, "OpenMPRuntimeDeduplicated", Remark); 1430 1431 CGUpdater.removeCallSite(*CI); 1432 CI->replaceAllUsesWith(ReplVal); 1433 CI->eraseFromParent(); 1434 ++NumOpenMPRuntimeCallsDeduplicated; 1435 Changed = true; 1436 return true; 1437 }; 1438 RFI.foreachUse(SCC, ReplaceAndDeleteCB); 1439 1440 return Changed; 1441 } 1442 1443 /// Collect arguments that represent the global thread id in \p GTIdArgs. 1444 void collectGlobalThreadIdArguments(SmallSetVector<Value *, 16> >IdArgs) { 1445 // TODO: Below we basically perform a fixpoint iteration with a pessimistic 1446 // initialization. We could define an AbstractAttribute instead and 1447 // run the Attributor here once it can be run as an SCC pass. 1448 1449 // Helper to check the argument \p ArgNo at all call sites of \p F for 1450 // a GTId. 1451 auto CallArgOpIsGTId = [&](Function &F, unsigned ArgNo, CallInst &RefCI) { 1452 if (!F.hasLocalLinkage()) 1453 return false; 1454 for (Use &U : F.uses()) { 1455 if (CallInst *CI = getCallIfRegularCall(U)) { 1456 Value *ArgOp = CI->getArgOperand(ArgNo); 1457 if (CI == &RefCI || GTIdArgs.count(ArgOp) || 1458 getCallIfRegularCall( 1459 *ArgOp, &OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num])) 1460 continue; 1461 } 1462 return false; 1463 } 1464 return true; 1465 }; 1466 1467 // Helper to identify uses of a GTId as GTId arguments. 1468 auto AddUserArgs = [&](Value >Id) { 1469 for (Use &U : GTId.uses()) 1470 if (CallInst *CI = dyn_cast<CallInst>(U.getUser())) 1471 if (CI->isArgOperand(&U)) 1472 if (Function *Callee = CI->getCalledFunction()) 1473 if (CallArgOpIsGTId(*Callee, U.getOperandNo(), *CI)) 1474 GTIdArgs.insert(Callee->getArg(U.getOperandNo())); 1475 }; 1476 1477 // The argument users of __kmpc_global_thread_num calls are GTIds. 1478 OMPInformationCache::RuntimeFunctionInfo &GlobThreadNumRFI = 1479 OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num]; 1480 1481 GlobThreadNumRFI.foreachUse(SCC, [&](Use &U, Function &F) { 1482 if (CallInst *CI = getCallIfRegularCall(U, &GlobThreadNumRFI)) 1483 AddUserArgs(*CI); 1484 return false; 1485 }); 1486 1487 // Transitively search for more arguments by looking at the users of the 1488 // ones we know already. During the search the GTIdArgs vector is extended 1489 // so we cannot cache the size nor can we use a range based for. 1490 for (unsigned u = 0; u < GTIdArgs.size(); ++u) 1491 AddUserArgs(*GTIdArgs[u]); 1492 } 1493 1494 /// Kernel (=GPU) optimizations and utility functions 1495 /// 1496 ///{{ 1497 1498 /// Check if \p F is a kernel, hence entry point for target offloading. 1499 bool isKernel(Function &F) { return OMPInfoCache.Kernels.count(&F); } 1500 1501 /// Cache to remember the unique kernel for a function. 1502 DenseMap<Function *, Optional<Kernel>> UniqueKernelMap; 1503 1504 /// Find the unique kernel that will execute \p F, if any. 1505 Kernel getUniqueKernelFor(Function &F); 1506 1507 /// Find the unique kernel that will execute \p I, if any. 1508 Kernel getUniqueKernelFor(Instruction &I) { 1509 return getUniqueKernelFor(*I.getFunction()); 1510 } 1511 1512 /// Rewrite the device (=GPU) code state machine create in non-SPMD mode in 1513 /// the cases we can avoid taking the address of a function. 1514 bool rewriteDeviceCodeStateMachine(); 1515 1516 /// 1517 ///}} 1518 1519 /// Emit a remark generically 1520 /// 1521 /// This template function can be used to generically emit a remark. The 1522 /// RemarkKind should be one of the following: 1523 /// - OptimizationRemark to indicate a successful optimization attempt 1524 /// - OptimizationRemarkMissed to report a failed optimization attempt 1525 /// - OptimizationRemarkAnalysis to provide additional information about an 1526 /// optimization attempt 1527 /// 1528 /// The remark is built using a callback function provided by the caller that 1529 /// takes a RemarkKind as input and returns a RemarkKind. 1530 template <typename RemarkKind, typename RemarkCallBack> 1531 void emitRemark(Instruction *I, StringRef RemarkName, 1532 RemarkCallBack &&RemarkCB) const { 1533 Function *F = I->getParent()->getParent(); 1534 auto &ORE = OREGetter(F); 1535 1536 ORE.emit([&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I)); }); 1537 } 1538 1539 /// Emit a remark on a function. 1540 template <typename RemarkKind, typename RemarkCallBack> 1541 void emitRemark(Function *F, StringRef RemarkName, 1542 RemarkCallBack &&RemarkCB) const { 1543 auto &ORE = OREGetter(F); 1544 1545 ORE.emit([&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F)); }); 1546 } 1547 1548 /// The underlying module. 1549 Module &M; 1550 1551 /// The SCC we are operating on. 1552 SmallVectorImpl<Function *> &SCC; 1553 1554 /// Callback to update the call graph, the first argument is a removed call, 1555 /// the second an optional replacement call. 1556 CallGraphUpdater &CGUpdater; 1557 1558 /// Callback to get an OptimizationRemarkEmitter from a Function * 1559 OptimizationRemarkGetter OREGetter; 1560 1561 /// OpenMP-specific information cache. Also Used for Attributor runs. 1562 OMPInformationCache &OMPInfoCache; 1563 1564 /// Attributor instance. 1565 Attributor &A; 1566 1567 /// Helper function to run Attributor on SCC. 1568 bool runAttributor() { 1569 if (SCC.empty()) 1570 return false; 1571 1572 registerAAs(); 1573 1574 ChangeStatus Changed = A.run(); 1575 1576 LLVM_DEBUG(dbgs() << "[Attributor] Done with " << SCC.size() 1577 << " functions, result: " << Changed << ".\n"); 1578 1579 return Changed == ChangeStatus::CHANGED; 1580 } 1581 1582 /// Populate the Attributor with abstract attribute opportunities in the 1583 /// function. 1584 void registerAAs() { 1585 if (SCC.empty()) 1586 return; 1587 1588 // Create CallSite AA for all Getters. 1589 for (int Idx = 0; Idx < OMPInfoCache.ICVs.size() - 1; ++Idx) { 1590 auto ICVInfo = OMPInfoCache.ICVs[static_cast<InternalControlVar>(Idx)]; 1591 1592 auto &GetterRFI = OMPInfoCache.RFIs[ICVInfo.Getter]; 1593 1594 auto CreateAA = [&](Use &U, Function &Caller) { 1595 CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &GetterRFI); 1596 if (!CI) 1597 return false; 1598 1599 auto &CB = cast<CallBase>(*CI); 1600 1601 IRPosition CBPos = IRPosition::callsite_function(CB); 1602 A.getOrCreateAAFor<AAICVTracker>(CBPos); 1603 return false; 1604 }; 1605 1606 GetterRFI.foreachUse(SCC, CreateAA); 1607 } 1608 1609 for (auto &F : M) { 1610 if (!F.isDeclaration()) 1611 A.getOrCreateAAFor<AAExecutionDomain>(IRPosition::function(F)); 1612 } 1613 } 1614 }; 1615 1616 Kernel OpenMPOpt::getUniqueKernelFor(Function &F) { 1617 if (!OMPInfoCache.ModuleSlice.count(&F)) 1618 return nullptr; 1619 1620 // Use a scope to keep the lifetime of the CachedKernel short. 1621 { 1622 Optional<Kernel> &CachedKernel = UniqueKernelMap[&F]; 1623 if (CachedKernel) 1624 return *CachedKernel; 1625 1626 // TODO: We should use an AA to create an (optimistic and callback 1627 // call-aware) call graph. For now we stick to simple patterns that 1628 // are less powerful, basically the worst fixpoint. 1629 if (isKernel(F)) { 1630 CachedKernel = Kernel(&F); 1631 return *CachedKernel; 1632 } 1633 1634 CachedKernel = nullptr; 1635 if (!F.hasLocalLinkage()) { 1636 1637 // See https://openmp.llvm.org/remarks/OptimizationRemarks.html 1638 auto Remark = [&](OptimizationRemarkAnalysis ORA) { 1639 return ORA 1640 << "[OMP100] Potentially unknown OpenMP target region caller"; 1641 }; 1642 emitRemark<OptimizationRemarkAnalysis>(&F, "OMP100", Remark); 1643 1644 return nullptr; 1645 } 1646 } 1647 1648 auto GetUniqueKernelForUse = [&](const Use &U) -> Kernel { 1649 if (auto *Cmp = dyn_cast<ICmpInst>(U.getUser())) { 1650 // Allow use in equality comparisons. 1651 if (Cmp->isEquality()) 1652 return getUniqueKernelFor(*Cmp); 1653 return nullptr; 1654 } 1655 if (auto *CB = dyn_cast<CallBase>(U.getUser())) { 1656 // Allow direct calls. 1657 if (CB->isCallee(&U)) 1658 return getUniqueKernelFor(*CB); 1659 1660 OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI = 1661 OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51]; 1662 // Allow the use in __kmpc_parallel_51 calls. 1663 if (OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI)) 1664 return getUniqueKernelFor(*CB); 1665 return nullptr; 1666 } 1667 // Disallow every other use. 1668 return nullptr; 1669 }; 1670 1671 // TODO: In the future we want to track more than just a unique kernel. 1672 SmallPtrSet<Kernel, 2> PotentialKernels; 1673 OMPInformationCache::foreachUse(F, [&](const Use &U) { 1674 PotentialKernels.insert(GetUniqueKernelForUse(U)); 1675 }); 1676 1677 Kernel K = nullptr; 1678 if (PotentialKernels.size() == 1) 1679 K = *PotentialKernels.begin(); 1680 1681 // Cache the result. 1682 UniqueKernelMap[&F] = K; 1683 1684 return K; 1685 } 1686 1687 bool OpenMPOpt::rewriteDeviceCodeStateMachine() { 1688 OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI = 1689 OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51]; 1690 1691 bool Changed = false; 1692 if (!KernelParallelRFI) 1693 return Changed; 1694 1695 for (Function *F : SCC) { 1696 1697 // Check if the function is a use in a __kmpc_parallel_51 call at 1698 // all. 1699 bool UnknownUse = false; 1700 bool KernelParallelUse = false; 1701 unsigned NumDirectCalls = 0; 1702 1703 SmallVector<Use *, 2> ToBeReplacedStateMachineUses; 1704 OMPInformationCache::foreachUse(*F, [&](Use &U) { 1705 if (auto *CB = dyn_cast<CallBase>(U.getUser())) 1706 if (CB->isCallee(&U)) { 1707 ++NumDirectCalls; 1708 return; 1709 } 1710 1711 if (isa<ICmpInst>(U.getUser())) { 1712 ToBeReplacedStateMachineUses.push_back(&U); 1713 return; 1714 } 1715 1716 // Find wrapper functions that represent parallel kernels. 1717 CallInst *CI = 1718 OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI); 1719 const unsigned int WrapperFunctionArgNo = 6; 1720 if (!KernelParallelUse && CI && 1721 CI->getArgOperandNo(&U) == WrapperFunctionArgNo) { 1722 KernelParallelUse = true; 1723 ToBeReplacedStateMachineUses.push_back(&U); 1724 return; 1725 } 1726 UnknownUse = true; 1727 }); 1728 1729 // Do not emit a remark if we haven't seen a __kmpc_parallel_51 1730 // use. 1731 if (!KernelParallelUse) 1732 continue; 1733 1734 { 1735 auto Remark = [&](OptimizationRemarkAnalysis ORA) { 1736 return ORA << "Found a parallel region that is called in a target " 1737 "region but not part of a combined target construct nor " 1738 "nested inside a target construct without intermediate " 1739 "code. This can lead to excessive register usage for " 1740 "unrelated target regions in the same translation unit " 1741 "due to spurious call edges assumed by ptxas."; 1742 }; 1743 emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPParallelRegionInNonSPMD", 1744 Remark); 1745 } 1746 1747 // If this ever hits, we should investigate. 1748 // TODO: Checking the number of uses is not a necessary restriction and 1749 // should be lifted. 1750 if (UnknownUse || NumDirectCalls != 1 || 1751 ToBeReplacedStateMachineUses.size() != 2) { 1752 { 1753 auto Remark = [&](OptimizationRemarkAnalysis ORA) { 1754 return ORA << "Parallel region is used in " 1755 << (UnknownUse ? "unknown" : "unexpected") 1756 << " ways; will not attempt to rewrite the state machine."; 1757 }; 1758 emitRemark<OptimizationRemarkAnalysis>( 1759 F, "OpenMPParallelRegionInNonSPMD", Remark); 1760 } 1761 continue; 1762 } 1763 1764 // Even if we have __kmpc_parallel_51 calls, we (for now) give 1765 // up if the function is not called from a unique kernel. 1766 Kernel K = getUniqueKernelFor(*F); 1767 if (!K) { 1768 { 1769 auto Remark = [&](OptimizationRemarkAnalysis ORA) { 1770 return ORA << "Parallel region is not known to be called from a " 1771 "unique single target region, maybe the surrounding " 1772 "function has external linkage?; will not attempt to " 1773 "rewrite the state machine use."; 1774 }; 1775 emitRemark<OptimizationRemarkAnalysis>( 1776 F, "OpenMPParallelRegionInMultipleKernesl", Remark); 1777 } 1778 continue; 1779 } 1780 1781 // We now know F is a parallel body function called only from the kernel K. 1782 // We also identified the state machine uses in which we replace the 1783 // function pointer by a new global symbol for identification purposes. This 1784 // ensures only direct calls to the function are left. 1785 1786 { 1787 auto RemarkParalleRegion = [&](OptimizationRemarkAnalysis ORA) { 1788 return ORA << "Specialize parallel region that is only reached from a " 1789 "single target region to avoid spurious call edges and " 1790 "excessive register usage in other target regions. " 1791 "(parallel region ID: " 1792 << ore::NV("OpenMPParallelRegion", F->getName()) 1793 << ", kernel ID: " 1794 << ore::NV("OpenMPTargetRegion", K->getName()) << ")"; 1795 }; 1796 emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPParallelRegionInNonSPMD", 1797 RemarkParalleRegion); 1798 auto RemarkKernel = [&](OptimizationRemarkAnalysis ORA) { 1799 return ORA << "Target region containing the parallel region that is " 1800 "specialized. (parallel region ID: " 1801 << ore::NV("OpenMPParallelRegion", F->getName()) 1802 << ", kernel ID: " 1803 << ore::NV("OpenMPTargetRegion", K->getName()) << ")"; 1804 }; 1805 emitRemark<OptimizationRemarkAnalysis>(K, "OpenMPParallelRegionInNonSPMD", 1806 RemarkKernel); 1807 } 1808 1809 Module &M = *F->getParent(); 1810 Type *Int8Ty = Type::getInt8Ty(M.getContext()); 1811 1812 auto *ID = new GlobalVariable( 1813 M, Int8Ty, /* isConstant */ true, GlobalValue::PrivateLinkage, 1814 UndefValue::get(Int8Ty), F->getName() + ".ID"); 1815 1816 for (Use *U : ToBeReplacedStateMachineUses) 1817 U->set(ConstantExpr::getBitCast(ID, U->get()->getType())); 1818 1819 ++NumOpenMPParallelRegionsReplacedInGPUStateMachine; 1820 1821 Changed = true; 1822 } 1823 1824 return Changed; 1825 } 1826 1827 /// Abstract Attribute for tracking ICV values. 1828 struct AAICVTracker : public StateWrapper<BooleanState, AbstractAttribute> { 1829 using Base = StateWrapper<BooleanState, AbstractAttribute>; 1830 AAICVTracker(const IRPosition &IRP, Attributor &A) : Base(IRP) {} 1831 1832 void initialize(Attributor &A) override { 1833 Function *F = getAnchorScope(); 1834 if (!F || !A.isFunctionIPOAmendable(*F)) 1835 indicatePessimisticFixpoint(); 1836 } 1837 1838 /// Returns true if value is assumed to be tracked. 1839 bool isAssumedTracked() const { return getAssumed(); } 1840 1841 /// Returns true if value is known to be tracked. 1842 bool isKnownTracked() const { return getAssumed(); } 1843 1844 /// Create an abstract attribute biew for the position \p IRP. 1845 static AAICVTracker &createForPosition(const IRPosition &IRP, Attributor &A); 1846 1847 /// Return the value with which \p I can be replaced for specific \p ICV. 1848 virtual Optional<Value *> getReplacementValue(InternalControlVar ICV, 1849 const Instruction *I, 1850 Attributor &A) const { 1851 return None; 1852 } 1853 1854 /// Return an assumed unique ICV value if a single candidate is found. If 1855 /// there cannot be one, return a nullptr. If it is not clear yet, return the 1856 /// Optional::NoneType. 1857 virtual Optional<Value *> 1858 getUniqueReplacementValue(InternalControlVar ICV) const = 0; 1859 1860 // Currently only nthreads is being tracked. 1861 // this array will only grow with time. 1862 InternalControlVar TrackableICVs[1] = {ICV_nthreads}; 1863 1864 /// See AbstractAttribute::getName() 1865 const std::string getName() const override { return "AAICVTracker"; } 1866 1867 /// See AbstractAttribute::getIdAddr() 1868 const char *getIdAddr() const override { return &ID; } 1869 1870 /// This function should return true if the type of the \p AA is AAICVTracker 1871 static bool classof(const AbstractAttribute *AA) { 1872 return (AA->getIdAddr() == &ID); 1873 } 1874 1875 static const char ID; 1876 }; 1877 1878 struct AAICVTrackerFunction : public AAICVTracker { 1879 AAICVTrackerFunction(const IRPosition &IRP, Attributor &A) 1880 : AAICVTracker(IRP, A) {} 1881 1882 // FIXME: come up with better string. 1883 const std::string getAsStr() const override { return "ICVTrackerFunction"; } 1884 1885 // FIXME: come up with some stats. 1886 void trackStatistics() const override {} 1887 1888 /// We don't manifest anything for this AA. 1889 ChangeStatus manifest(Attributor &A) override { 1890 return ChangeStatus::UNCHANGED; 1891 } 1892 1893 // Map of ICV to their values at specific program point. 1894 EnumeratedArray<DenseMap<Instruction *, Value *>, InternalControlVar, 1895 InternalControlVar::ICV___last> 1896 ICVReplacementValuesMap; 1897 1898 ChangeStatus updateImpl(Attributor &A) override { 1899 ChangeStatus HasChanged = ChangeStatus::UNCHANGED; 1900 1901 Function *F = getAnchorScope(); 1902 1903 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); 1904 1905 for (InternalControlVar ICV : TrackableICVs) { 1906 auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter]; 1907 1908 auto &ValuesMap = ICVReplacementValuesMap[ICV]; 1909 auto TrackValues = [&](Use &U, Function &) { 1910 CallInst *CI = OpenMPOpt::getCallIfRegularCall(U); 1911 if (!CI) 1912 return false; 1913 1914 // FIXME: handle setters with more that 1 arguments. 1915 /// Track new value. 1916 if (ValuesMap.insert(std::make_pair(CI, CI->getArgOperand(0))).second) 1917 HasChanged = ChangeStatus::CHANGED; 1918 1919 return false; 1920 }; 1921 1922 auto CallCheck = [&](Instruction &I) { 1923 Optional<Value *> ReplVal = getValueForCall(A, &I, ICV); 1924 if (ReplVal.hasValue() && 1925 ValuesMap.insert(std::make_pair(&I, *ReplVal)).second) 1926 HasChanged = ChangeStatus::CHANGED; 1927 1928 return true; 1929 }; 1930 1931 // Track all changes of an ICV. 1932 SetterRFI.foreachUse(TrackValues, F); 1933 1934 A.checkForAllInstructions(CallCheck, *this, {Instruction::Call}, 1935 /* CheckBBLivenessOnly */ true); 1936 1937 /// TODO: Figure out a way to avoid adding entry in 1938 /// ICVReplacementValuesMap 1939 Instruction *Entry = &F->getEntryBlock().front(); 1940 if (HasChanged == ChangeStatus::CHANGED && !ValuesMap.count(Entry)) 1941 ValuesMap.insert(std::make_pair(Entry, nullptr)); 1942 } 1943 1944 return HasChanged; 1945 } 1946 1947 /// Hepler to check if \p I is a call and get the value for it if it is 1948 /// unique. 1949 Optional<Value *> getValueForCall(Attributor &A, const Instruction *I, 1950 InternalControlVar &ICV) const { 1951 1952 const auto *CB = dyn_cast<CallBase>(I); 1953 if (!CB || CB->hasFnAttr("no_openmp") || 1954 CB->hasFnAttr("no_openmp_routines")) 1955 return None; 1956 1957 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); 1958 auto &GetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Getter]; 1959 auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter]; 1960 Function *CalledFunction = CB->getCalledFunction(); 1961 1962 // Indirect call, assume ICV changes. 1963 if (CalledFunction == nullptr) 1964 return nullptr; 1965 if (CalledFunction == GetterRFI.Declaration) 1966 return None; 1967 if (CalledFunction == SetterRFI.Declaration) { 1968 if (ICVReplacementValuesMap[ICV].count(I)) 1969 return ICVReplacementValuesMap[ICV].lookup(I); 1970 1971 return nullptr; 1972 } 1973 1974 // Since we don't know, assume it changes the ICV. 1975 if (CalledFunction->isDeclaration()) 1976 return nullptr; 1977 1978 const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>( 1979 *this, IRPosition::callsite_returned(*CB), DepClassTy::REQUIRED); 1980 1981 if (ICVTrackingAA.isAssumedTracked()) 1982 return ICVTrackingAA.getUniqueReplacementValue(ICV); 1983 1984 // If we don't know, assume it changes. 1985 return nullptr; 1986 } 1987 1988 // We don't check unique value for a function, so return None. 1989 Optional<Value *> 1990 getUniqueReplacementValue(InternalControlVar ICV) const override { 1991 return None; 1992 } 1993 1994 /// Return the value with which \p I can be replaced for specific \p ICV. 1995 Optional<Value *> getReplacementValue(InternalControlVar ICV, 1996 const Instruction *I, 1997 Attributor &A) const override { 1998 const auto &ValuesMap = ICVReplacementValuesMap[ICV]; 1999 if (ValuesMap.count(I)) 2000 return ValuesMap.lookup(I); 2001 2002 SmallVector<const Instruction *, 16> Worklist; 2003 SmallPtrSet<const Instruction *, 16> Visited; 2004 Worklist.push_back(I); 2005 2006 Optional<Value *> ReplVal; 2007 2008 while (!Worklist.empty()) { 2009 const Instruction *CurrInst = Worklist.pop_back_val(); 2010 if (!Visited.insert(CurrInst).second) 2011 continue; 2012 2013 const BasicBlock *CurrBB = CurrInst->getParent(); 2014 2015 // Go up and look for all potential setters/calls that might change the 2016 // ICV. 2017 while ((CurrInst = CurrInst->getPrevNode())) { 2018 if (ValuesMap.count(CurrInst)) { 2019 Optional<Value *> NewReplVal = ValuesMap.lookup(CurrInst); 2020 // Unknown value, track new. 2021 if (!ReplVal.hasValue()) { 2022 ReplVal = NewReplVal; 2023 break; 2024 } 2025 2026 // If we found a new value, we can't know the icv value anymore. 2027 if (NewReplVal.hasValue()) 2028 if (ReplVal != NewReplVal) 2029 return nullptr; 2030 2031 break; 2032 } 2033 2034 Optional<Value *> NewReplVal = getValueForCall(A, CurrInst, ICV); 2035 if (!NewReplVal.hasValue()) 2036 continue; 2037 2038 // Unknown value, track new. 2039 if (!ReplVal.hasValue()) { 2040 ReplVal = NewReplVal; 2041 break; 2042 } 2043 2044 // if (NewReplVal.hasValue()) 2045 // We found a new value, we can't know the icv value anymore. 2046 if (ReplVal != NewReplVal) 2047 return nullptr; 2048 } 2049 2050 // If we are in the same BB and we have a value, we are done. 2051 if (CurrBB == I->getParent() && ReplVal.hasValue()) 2052 return ReplVal; 2053 2054 // Go through all predecessors and add terminators for analysis. 2055 for (const BasicBlock *Pred : predecessors(CurrBB)) 2056 if (const Instruction *Terminator = Pred->getTerminator()) 2057 Worklist.push_back(Terminator); 2058 } 2059 2060 return ReplVal; 2061 } 2062 }; 2063 2064 struct AAICVTrackerFunctionReturned : AAICVTracker { 2065 AAICVTrackerFunctionReturned(const IRPosition &IRP, Attributor &A) 2066 : AAICVTracker(IRP, A) {} 2067 2068 // FIXME: come up with better string. 2069 const std::string getAsStr() const override { 2070 return "ICVTrackerFunctionReturned"; 2071 } 2072 2073 // FIXME: come up with some stats. 2074 void trackStatistics() const override {} 2075 2076 /// We don't manifest anything for this AA. 2077 ChangeStatus manifest(Attributor &A) override { 2078 return ChangeStatus::UNCHANGED; 2079 } 2080 2081 // Map of ICV to their values at specific program point. 2082 EnumeratedArray<Optional<Value *>, InternalControlVar, 2083 InternalControlVar::ICV___last> 2084 ICVReplacementValuesMap; 2085 2086 /// Return the value with which \p I can be replaced for specific \p ICV. 2087 Optional<Value *> 2088 getUniqueReplacementValue(InternalControlVar ICV) const override { 2089 return ICVReplacementValuesMap[ICV]; 2090 } 2091 2092 ChangeStatus updateImpl(Attributor &A) override { 2093 ChangeStatus Changed = ChangeStatus::UNCHANGED; 2094 const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>( 2095 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED); 2096 2097 if (!ICVTrackingAA.isAssumedTracked()) 2098 return indicatePessimisticFixpoint(); 2099 2100 for (InternalControlVar ICV : TrackableICVs) { 2101 Optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV]; 2102 Optional<Value *> UniqueICVValue; 2103 2104 auto CheckReturnInst = [&](Instruction &I) { 2105 Optional<Value *> NewReplVal = 2106 ICVTrackingAA.getReplacementValue(ICV, &I, A); 2107 2108 // If we found a second ICV value there is no unique returned value. 2109 if (UniqueICVValue.hasValue() && UniqueICVValue != NewReplVal) 2110 return false; 2111 2112 UniqueICVValue = NewReplVal; 2113 2114 return true; 2115 }; 2116 2117 if (!A.checkForAllInstructions(CheckReturnInst, *this, {Instruction::Ret}, 2118 /* CheckBBLivenessOnly */ true)) 2119 UniqueICVValue = nullptr; 2120 2121 if (UniqueICVValue == ReplVal) 2122 continue; 2123 2124 ReplVal = UniqueICVValue; 2125 Changed = ChangeStatus::CHANGED; 2126 } 2127 2128 return Changed; 2129 } 2130 }; 2131 2132 struct AAICVTrackerCallSite : AAICVTracker { 2133 AAICVTrackerCallSite(const IRPosition &IRP, Attributor &A) 2134 : AAICVTracker(IRP, A) {} 2135 2136 void initialize(Attributor &A) override { 2137 Function *F = getAnchorScope(); 2138 if (!F || !A.isFunctionIPOAmendable(*F)) 2139 indicatePessimisticFixpoint(); 2140 2141 // We only initialize this AA for getters, so we need to know which ICV it 2142 // gets. 2143 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); 2144 for (InternalControlVar ICV : TrackableICVs) { 2145 auto ICVInfo = OMPInfoCache.ICVs[ICV]; 2146 auto &Getter = OMPInfoCache.RFIs[ICVInfo.Getter]; 2147 if (Getter.Declaration == getAssociatedFunction()) { 2148 AssociatedICV = ICVInfo.Kind; 2149 return; 2150 } 2151 } 2152 2153 /// Unknown ICV. 2154 indicatePessimisticFixpoint(); 2155 } 2156 2157 ChangeStatus manifest(Attributor &A) override { 2158 if (!ReplVal.hasValue() || !ReplVal.getValue()) 2159 return ChangeStatus::UNCHANGED; 2160 2161 A.changeValueAfterManifest(*getCtxI(), **ReplVal); 2162 A.deleteAfterManifest(*getCtxI()); 2163 2164 return ChangeStatus::CHANGED; 2165 } 2166 2167 // FIXME: come up with better string. 2168 const std::string getAsStr() const override { return "ICVTrackerCallSite"; } 2169 2170 // FIXME: come up with some stats. 2171 void trackStatistics() const override {} 2172 2173 InternalControlVar AssociatedICV; 2174 Optional<Value *> ReplVal; 2175 2176 ChangeStatus updateImpl(Attributor &A) override { 2177 const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>( 2178 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED); 2179 2180 // We don't have any information, so we assume it changes the ICV. 2181 if (!ICVTrackingAA.isAssumedTracked()) 2182 return indicatePessimisticFixpoint(); 2183 2184 Optional<Value *> NewReplVal = 2185 ICVTrackingAA.getReplacementValue(AssociatedICV, getCtxI(), A); 2186 2187 if (ReplVal == NewReplVal) 2188 return ChangeStatus::UNCHANGED; 2189 2190 ReplVal = NewReplVal; 2191 return ChangeStatus::CHANGED; 2192 } 2193 2194 // Return the value with which associated value can be replaced for specific 2195 // \p ICV. 2196 Optional<Value *> 2197 getUniqueReplacementValue(InternalControlVar ICV) const override { 2198 return ReplVal; 2199 } 2200 }; 2201 2202 struct AAICVTrackerCallSiteReturned : AAICVTracker { 2203 AAICVTrackerCallSiteReturned(const IRPosition &IRP, Attributor &A) 2204 : AAICVTracker(IRP, A) {} 2205 2206 // FIXME: come up with better string. 2207 const std::string getAsStr() const override { 2208 return "ICVTrackerCallSiteReturned"; 2209 } 2210 2211 // FIXME: come up with some stats. 2212 void trackStatistics() const override {} 2213 2214 /// We don't manifest anything for this AA. 2215 ChangeStatus manifest(Attributor &A) override { 2216 return ChangeStatus::UNCHANGED; 2217 } 2218 2219 // Map of ICV to their values at specific program point. 2220 EnumeratedArray<Optional<Value *>, InternalControlVar, 2221 InternalControlVar::ICV___last> 2222 ICVReplacementValuesMap; 2223 2224 /// Return the value with which associated value can be replaced for specific 2225 /// \p ICV. 2226 Optional<Value *> 2227 getUniqueReplacementValue(InternalControlVar ICV) const override { 2228 return ICVReplacementValuesMap[ICV]; 2229 } 2230 2231 ChangeStatus updateImpl(Attributor &A) override { 2232 ChangeStatus Changed = ChangeStatus::UNCHANGED; 2233 const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>( 2234 *this, IRPosition::returned(*getAssociatedFunction()), 2235 DepClassTy::REQUIRED); 2236 2237 // We don't have any information, so we assume it changes the ICV. 2238 if (!ICVTrackingAA.isAssumedTracked()) 2239 return indicatePessimisticFixpoint(); 2240 2241 for (InternalControlVar ICV : TrackableICVs) { 2242 Optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV]; 2243 Optional<Value *> NewReplVal = 2244 ICVTrackingAA.getUniqueReplacementValue(ICV); 2245 2246 if (ReplVal == NewReplVal) 2247 continue; 2248 2249 ReplVal = NewReplVal; 2250 Changed = ChangeStatus::CHANGED; 2251 } 2252 return Changed; 2253 } 2254 }; 2255 2256 struct AAExecutionDomainFunction : public AAExecutionDomain { 2257 AAExecutionDomainFunction(const IRPosition &IRP, Attributor &A) 2258 : AAExecutionDomain(IRP, A) {} 2259 2260 const std::string getAsStr() const override { 2261 return "[AAExecutionDomain] " + std::to_string(SingleThreadedBBs.size()) + 2262 "/" + std::to_string(NumBBs) + " BBs thread 0 only."; 2263 } 2264 2265 /// See AbstractAttribute::trackStatistics(). 2266 void trackStatistics() const override {} 2267 2268 void initialize(Attributor &A) override { 2269 Function *F = getAnchorScope(); 2270 for (const auto &BB : *F) 2271 SingleThreadedBBs.insert(&BB); 2272 NumBBs = SingleThreadedBBs.size(); 2273 } 2274 2275 ChangeStatus manifest(Attributor &A) override { 2276 LLVM_DEBUG({ 2277 for (const BasicBlock *BB : SingleThreadedBBs) 2278 dbgs() << TAG << " Basic block @" << getAnchorScope()->getName() << " " 2279 << BB->getName() << " is executed by a single thread.\n"; 2280 }); 2281 return ChangeStatus::UNCHANGED; 2282 } 2283 2284 ChangeStatus updateImpl(Attributor &A) override; 2285 2286 /// Check if an instruction is executed by a single thread. 2287 bool isExecutedByInitialThreadOnly(const Instruction &I) const override { 2288 return isExecutedByInitialThreadOnly(*I.getParent()); 2289 } 2290 2291 bool isExecutedByInitialThreadOnly(const BasicBlock &BB) const override { 2292 return SingleThreadedBBs.contains(&BB); 2293 } 2294 2295 /// Set of basic blocks that are executed by a single thread. 2296 DenseSet<const BasicBlock *> SingleThreadedBBs; 2297 2298 /// Total number of basic blocks in this function. 2299 long unsigned NumBBs; 2300 }; 2301 2302 ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) { 2303 Function *F = getAnchorScope(); 2304 ReversePostOrderTraversal<Function *> RPOT(F); 2305 auto NumSingleThreadedBBs = SingleThreadedBBs.size(); 2306 2307 bool AllCallSitesKnown; 2308 auto PredForCallSite = [&](AbstractCallSite ACS) { 2309 const auto &ExecutionDomainAA = A.getAAFor<AAExecutionDomain>( 2310 *this, IRPosition::function(*ACS.getInstruction()->getFunction()), 2311 DepClassTy::REQUIRED); 2312 return ExecutionDomainAA.isExecutedByInitialThreadOnly( 2313 *ACS.getInstruction()); 2314 }; 2315 2316 if (!A.checkForAllCallSites(PredForCallSite, *this, 2317 /* RequiresAllCallSites */ true, 2318 AllCallSitesKnown)) 2319 SingleThreadedBBs.erase(&F->getEntryBlock()); 2320 2321 // Check if the edge into the successor block compares a thread-id function to 2322 // a constant zero. 2323 // TODO: Use AAValueSimplify to simplify and propogate constants. 2324 // TODO: Check more than a single use for thread ID's. 2325 auto IsSingleThreadOnly = [&](BranchInst *Edge, BasicBlock *SuccessorBB) { 2326 if (!Edge || !Edge->isConditional()) 2327 return false; 2328 if (Edge->getSuccessor(0) != SuccessorBB) 2329 return false; 2330 2331 auto *Cmp = dyn_cast<CmpInst>(Edge->getCondition()); 2332 if (!Cmp || !Cmp->isTrueWhenEqual() || !Cmp->isEquality()) 2333 return false; 2334 2335 ConstantInt *C = dyn_cast<ConstantInt>(Cmp->getOperand(1)); 2336 if (!C || !C->isZero()) 2337 return false; 2338 2339 if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0))) 2340 if (II->getIntrinsicID() == Intrinsic::nvvm_read_ptx_sreg_tid_x) 2341 return true; 2342 if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0))) 2343 if (II->getIntrinsicID() == Intrinsic::amdgcn_workitem_id_x) 2344 return true; 2345 2346 return false; 2347 }; 2348 2349 // Merge all the predecessor states into the current basic block. A basic 2350 // block is executed by a single thread if all of its predecessors are. 2351 auto MergePredecessorStates = [&](BasicBlock *BB) { 2352 if (pred_begin(BB) == pred_end(BB)) 2353 return SingleThreadedBBs.contains(BB); 2354 2355 bool IsSingleThreaded = true; 2356 for (auto PredBB = pred_begin(BB), PredEndBB = pred_end(BB); 2357 PredBB != PredEndBB; ++PredBB) { 2358 if (!IsSingleThreadOnly(dyn_cast<BranchInst>((*PredBB)->getTerminator()), 2359 BB)) 2360 IsSingleThreaded &= SingleThreadedBBs.contains(*PredBB); 2361 } 2362 2363 return IsSingleThreaded; 2364 }; 2365 2366 for (auto *BB : RPOT) { 2367 if (!MergePredecessorStates(BB)) 2368 SingleThreadedBBs.erase(BB); 2369 } 2370 2371 return (NumSingleThreadedBBs == SingleThreadedBBs.size()) 2372 ? ChangeStatus::UNCHANGED 2373 : ChangeStatus::CHANGED; 2374 } 2375 2376 } // namespace 2377 2378 const char AAICVTracker::ID = 0; 2379 const char AAExecutionDomain::ID = 0; 2380 2381 AAICVTracker &AAICVTracker::createForPosition(const IRPosition &IRP, 2382 Attributor &A) { 2383 AAICVTracker *AA = nullptr; 2384 switch (IRP.getPositionKind()) { 2385 case IRPosition::IRP_INVALID: 2386 case IRPosition::IRP_FLOAT: 2387 case IRPosition::IRP_ARGUMENT: 2388 case IRPosition::IRP_CALL_SITE_ARGUMENT: 2389 llvm_unreachable("ICVTracker can only be created for function position!"); 2390 case IRPosition::IRP_RETURNED: 2391 AA = new (A.Allocator) AAICVTrackerFunctionReturned(IRP, A); 2392 break; 2393 case IRPosition::IRP_CALL_SITE_RETURNED: 2394 AA = new (A.Allocator) AAICVTrackerCallSiteReturned(IRP, A); 2395 break; 2396 case IRPosition::IRP_CALL_SITE: 2397 AA = new (A.Allocator) AAICVTrackerCallSite(IRP, A); 2398 break; 2399 case IRPosition::IRP_FUNCTION: 2400 AA = new (A.Allocator) AAICVTrackerFunction(IRP, A); 2401 break; 2402 } 2403 2404 return *AA; 2405 } 2406 2407 AAExecutionDomain &AAExecutionDomain::createForPosition(const IRPosition &IRP, 2408 Attributor &A) { 2409 AAExecutionDomainFunction *AA = nullptr; 2410 switch (IRP.getPositionKind()) { 2411 case IRPosition::IRP_INVALID: 2412 case IRPosition::IRP_FLOAT: 2413 case IRPosition::IRP_ARGUMENT: 2414 case IRPosition::IRP_CALL_SITE_ARGUMENT: 2415 case IRPosition::IRP_RETURNED: 2416 case IRPosition::IRP_CALL_SITE_RETURNED: 2417 case IRPosition::IRP_CALL_SITE: 2418 llvm_unreachable( 2419 "AAExecutionDomain can only be created for function position!"); 2420 case IRPosition::IRP_FUNCTION: 2421 AA = new (A.Allocator) AAExecutionDomainFunction(IRP, A); 2422 break; 2423 } 2424 2425 return *AA; 2426 } 2427 2428 PreservedAnalyses OpenMPOptPass::run(Module &M, ModuleAnalysisManager &AM) { 2429 if (!containsOpenMP(M, OMPInModule)) 2430 return PreservedAnalyses::all(); 2431 2432 if (DisableOpenMPOptimizations) 2433 return PreservedAnalyses::all(); 2434 2435 // Look at every function definition in the Module. 2436 SmallVector<Function *, 16> SCC; 2437 for (Function &Fn : M) 2438 if (!Fn.isDeclaration()) 2439 SCC.push_back(&Fn); 2440 2441 if (SCC.empty()) 2442 return PreservedAnalyses::all(); 2443 2444 FunctionAnalysisManager &FAM = 2445 AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); 2446 2447 AnalysisGetter AG(FAM); 2448 2449 auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & { 2450 return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F); 2451 }; 2452 2453 BumpPtrAllocator Allocator; 2454 CallGraphUpdater CGUpdater; 2455 2456 SetVector<Function *> Functions(SCC.begin(), SCC.end()); 2457 OMPInformationCache InfoCache(M, AG, Allocator, /*CGSCC*/ Functions, 2458 OMPInModule.getKernels()); 2459 2460 Attributor A(Functions, InfoCache, CGUpdater); 2461 2462 OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A); 2463 bool Changed = OMPOpt.run(true); 2464 if (Changed) 2465 return PreservedAnalyses::none(); 2466 2467 return PreservedAnalyses::all(); 2468 } 2469 2470 PreservedAnalyses OpenMPOptCGSCCPass::run(LazyCallGraph::SCC &C, 2471 CGSCCAnalysisManager &AM, 2472 LazyCallGraph &CG, 2473 CGSCCUpdateResult &UR) { 2474 if (!containsOpenMP(*C.begin()->getFunction().getParent(), OMPInModule)) 2475 return PreservedAnalyses::all(); 2476 2477 if (DisableOpenMPOptimizations) 2478 return PreservedAnalyses::all(); 2479 2480 SmallVector<Function *, 16> SCC; 2481 // If there are kernels in the module, we have to run on all SCC's. 2482 bool SCCIsInteresting = !OMPInModule.getKernels().empty(); 2483 for (LazyCallGraph::Node &N : C) { 2484 Function *Fn = &N.getFunction(); 2485 SCC.push_back(Fn); 2486 2487 // Do we already know that the SCC contains kernels, 2488 // or that OpenMP functions are called from this SCC? 2489 if (SCCIsInteresting) 2490 continue; 2491 // If not, let's check that. 2492 SCCIsInteresting |= OMPInModule.containsOMPRuntimeCalls(Fn); 2493 } 2494 2495 if (!SCCIsInteresting || SCC.empty()) 2496 return PreservedAnalyses::all(); 2497 2498 FunctionAnalysisManager &FAM = 2499 AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager(); 2500 2501 AnalysisGetter AG(FAM); 2502 2503 auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & { 2504 return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F); 2505 }; 2506 2507 BumpPtrAllocator Allocator; 2508 CallGraphUpdater CGUpdater; 2509 CGUpdater.initialize(CG, C, AM, UR); 2510 2511 SetVector<Function *> Functions(SCC.begin(), SCC.end()); 2512 OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG, Allocator, 2513 /*CGSCC*/ Functions, OMPInModule.getKernels()); 2514 2515 Attributor A(Functions, InfoCache, CGUpdater, nullptr, false); 2516 2517 OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A); 2518 bool Changed = OMPOpt.run(false); 2519 if (Changed) 2520 return PreservedAnalyses::none(); 2521 2522 return PreservedAnalyses::all(); 2523 } 2524 2525 namespace { 2526 2527 struct OpenMPOptCGSCCLegacyPass : public CallGraphSCCPass { 2528 CallGraphUpdater CGUpdater; 2529 OpenMPInModule OMPInModule; 2530 static char ID; 2531 2532 OpenMPOptCGSCCLegacyPass() : CallGraphSCCPass(ID) { 2533 initializeOpenMPOptCGSCCLegacyPassPass(*PassRegistry::getPassRegistry()); 2534 } 2535 2536 void getAnalysisUsage(AnalysisUsage &AU) const override { 2537 CallGraphSCCPass::getAnalysisUsage(AU); 2538 } 2539 2540 bool doInitialization(CallGraph &CG) override { 2541 // Disable the pass if there is no OpenMP (runtime call) in the module. 2542 containsOpenMP(CG.getModule(), OMPInModule); 2543 return false; 2544 } 2545 2546 bool runOnSCC(CallGraphSCC &CGSCC) override { 2547 if (!containsOpenMP(CGSCC.getCallGraph().getModule(), OMPInModule)) 2548 return false; 2549 if (DisableOpenMPOptimizations || skipSCC(CGSCC)) 2550 return false; 2551 2552 SmallVector<Function *, 16> SCC; 2553 // If there are kernels in the module, we have to run on all SCC's. 2554 bool SCCIsInteresting = !OMPInModule.getKernels().empty(); 2555 for (CallGraphNode *CGN : CGSCC) { 2556 Function *Fn = CGN->getFunction(); 2557 if (!Fn || Fn->isDeclaration()) 2558 continue; 2559 SCC.push_back(Fn); 2560 2561 // Do we already know that the SCC contains kernels, 2562 // or that OpenMP functions are called from this SCC? 2563 if (SCCIsInteresting) 2564 continue; 2565 // If not, let's check that. 2566 SCCIsInteresting |= OMPInModule.containsOMPRuntimeCalls(Fn); 2567 } 2568 2569 if (!SCCIsInteresting || SCC.empty()) 2570 return false; 2571 2572 CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph(); 2573 CGUpdater.initialize(CG, CGSCC); 2574 2575 // Maintain a map of functions to avoid rebuilding the ORE 2576 DenseMap<Function *, std::unique_ptr<OptimizationRemarkEmitter>> OREMap; 2577 auto OREGetter = [&OREMap](Function *F) -> OptimizationRemarkEmitter & { 2578 std::unique_ptr<OptimizationRemarkEmitter> &ORE = OREMap[F]; 2579 if (!ORE) 2580 ORE = std::make_unique<OptimizationRemarkEmitter>(F); 2581 return *ORE; 2582 }; 2583 2584 AnalysisGetter AG; 2585 SetVector<Function *> Functions(SCC.begin(), SCC.end()); 2586 BumpPtrAllocator Allocator; 2587 OMPInformationCache InfoCache( 2588 *(Functions.back()->getParent()), AG, Allocator, 2589 /*CGSCC*/ Functions, OMPInModule.getKernels()); 2590 2591 Attributor A(Functions, InfoCache, CGUpdater, nullptr, false); 2592 2593 OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A); 2594 return OMPOpt.run(false); 2595 } 2596 2597 bool doFinalization(CallGraph &CG) override { return CGUpdater.finalize(); } 2598 }; 2599 2600 } // end anonymous namespace 2601 2602 void OpenMPInModule::identifyKernels(Module &M) { 2603 2604 NamedMDNode *MD = M.getOrInsertNamedMetadata("nvvm.annotations"); 2605 if (!MD) 2606 return; 2607 2608 for (auto *Op : MD->operands()) { 2609 if (Op->getNumOperands() < 2) 2610 continue; 2611 MDString *KindID = dyn_cast<MDString>(Op->getOperand(1)); 2612 if (!KindID || KindID->getString() != "kernel") 2613 continue; 2614 2615 Function *KernelFn = 2616 mdconst::dyn_extract_or_null<Function>(Op->getOperand(0)); 2617 if (!KernelFn) 2618 continue; 2619 2620 ++NumOpenMPTargetRegionKernels; 2621 2622 Kernels.insert(KernelFn); 2623 } 2624 } 2625 2626 bool llvm::omp::containsOpenMP(Module &M, OpenMPInModule &OMPInModule) { 2627 if (OMPInModule.isKnown()) 2628 return OMPInModule; 2629 2630 auto RecordFunctionsContainingUsesOf = [&](Function *F) { 2631 for (User *U : F->users()) 2632 if (auto *I = dyn_cast<Instruction>(U)) 2633 OMPInModule.FuncsWithOMPRuntimeCalls.insert(I->getFunction()); 2634 }; 2635 2636 // MSVC doesn't like long if-else chains for some reason and instead just 2637 // issues an error. Work around it.. 2638 do { 2639 #define OMP_RTL(_Enum, _Name, ...) \ 2640 if (Function *F = M.getFunction(_Name)) { \ 2641 RecordFunctionsContainingUsesOf(F); \ 2642 OMPInModule = true; \ 2643 } 2644 #include "llvm/Frontend/OpenMP/OMPKinds.def" 2645 } while (false); 2646 2647 // Identify kernels once. TODO: We should split the OMPInformationCache into a 2648 // module and an SCC part. The kernel information, among other things, could 2649 // go into the module part. 2650 if (OMPInModule.isKnown() && OMPInModule) { 2651 OMPInModule.identifyKernels(M); 2652 return true; 2653 } 2654 2655 return OMPInModule = false; 2656 } 2657 2658 char OpenMPOptCGSCCLegacyPass::ID = 0; 2659 2660 INITIALIZE_PASS_BEGIN(OpenMPOptCGSCCLegacyPass, "openmp-opt-cgscc", 2661 "OpenMP specific optimizations", false, false) 2662 INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass) 2663 INITIALIZE_PASS_END(OpenMPOptCGSCCLegacyPass, "openmp-opt-cgscc", 2664 "OpenMP specific optimizations", false, false) 2665 2666 Pass *llvm::createOpenMPOptCGSCCLegacyPass() { 2667 return new OpenMPOptCGSCCLegacyPass(); 2668 } 2669