1 //===-- IPO/OpenMPOpt.cpp - Collection of OpenMP specific optimizations ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // OpenMP specific optimizations: 10 // 11 // - Deduplication of runtime calls, e.g., omp_get_thread_num. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "llvm/Transforms/IPO/OpenMPOpt.h" 16 17 #include "llvm/ADT/EnumeratedArray.h" 18 #include "llvm/ADT/PostOrderIterator.h" 19 #include "llvm/ADT/Statistic.h" 20 #include "llvm/Analysis/CallGraph.h" 21 #include "llvm/Analysis/CallGraphSCCPass.h" 22 #include "llvm/Analysis/OptimizationRemarkEmitter.h" 23 #include "llvm/Analysis/ValueTracking.h" 24 #include "llvm/Frontend/OpenMP/OMPConstants.h" 25 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h" 26 #include "llvm/IR/IntrinsicInst.h" 27 #include "llvm/IR/IntrinsicsAMDGPU.h" 28 #include "llvm/IR/IntrinsicsNVPTX.h" 29 #include "llvm/InitializePasses.h" 30 #include "llvm/Support/CommandLine.h" 31 #include "llvm/Transforms/IPO.h" 32 #include "llvm/Transforms/IPO/Attributor.h" 33 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 34 #include "llvm/Transforms/Utils/CallGraphUpdater.h" 35 #include "llvm/Transforms/Utils/CodeExtractor.h" 36 37 using namespace llvm; 38 using namespace omp; 39 40 #define DEBUG_TYPE "openmp-opt" 41 42 static cl::opt<bool> DisableOpenMPOptimizations( 43 "openmp-opt-disable", cl::ZeroOrMore, 44 cl::desc("Disable OpenMP specific optimizations."), cl::Hidden, 45 cl::init(false)); 46 47 static cl::opt<bool> EnableParallelRegionMerging( 48 "openmp-opt-enable-merging", cl::ZeroOrMore, 49 cl::desc("Enable the OpenMP region merging optimization."), cl::Hidden, 50 cl::init(false)); 51 52 static cl::opt<bool> PrintICVValues("openmp-print-icv-values", cl::init(false), 53 cl::Hidden); 54 static cl::opt<bool> PrintOpenMPKernels("openmp-print-gpu-kernels", 55 cl::init(false), cl::Hidden); 56 57 static cl::opt<bool> HideMemoryTransferLatency( 58 "openmp-hide-memory-transfer-latency", 59 cl::desc("[WIP] Tries to hide the latency of host to device memory" 60 " transfers"), 61 cl::Hidden, cl::init(false)); 62 63 STATISTIC(NumOpenMPRuntimeCallsDeduplicated, 64 "Number of OpenMP runtime calls deduplicated"); 65 STATISTIC(NumOpenMPParallelRegionsDeleted, 66 "Number of OpenMP parallel regions deleted"); 67 STATISTIC(NumOpenMPRuntimeFunctionsIdentified, 68 "Number of OpenMP runtime functions identified"); 69 STATISTIC(NumOpenMPRuntimeFunctionUsesIdentified, 70 "Number of OpenMP runtime function uses identified"); 71 STATISTIC(NumOpenMPTargetRegionKernels, 72 "Number of OpenMP target region entry points (=kernels) identified"); 73 STATISTIC( 74 NumOpenMPParallelRegionsReplacedInGPUStateMachine, 75 "Number of OpenMP parallel regions replaced with ID in GPU state machines"); 76 STATISTIC(NumOpenMPParallelRegionsMerged, 77 "Number of OpenMP parallel regions merged"); 78 79 #if !defined(NDEBUG) 80 static constexpr auto TAG = "[" DEBUG_TYPE "]"; 81 #endif 82 83 namespace { 84 85 struct AAExecutionDomain 86 : public StateWrapper<BooleanState, AbstractAttribute> { 87 using Base = StateWrapper<BooleanState, AbstractAttribute>; 88 AAExecutionDomain(const IRPosition &IRP, Attributor &A) : Base(IRP) {} 89 90 /// Create an abstract attribute view for the position \p IRP. 91 static AAExecutionDomain &createForPosition(const IRPosition &IRP, 92 Attributor &A); 93 94 /// See AbstractAttribute::getName(). 95 const std::string getName() const override { return "AAExecutionDomain"; } 96 97 /// See AbstractAttribute::getIdAddr(). 98 const char *getIdAddr() const override { return &ID; } 99 100 /// Check if an instruction is executed by a single thread. 101 virtual bool isSingleThreadExecution(const Instruction &) const = 0; 102 103 virtual bool isSingleThreadExecution(const BasicBlock &) const = 0; 104 105 /// This function should return true if the type of the \p AA is 106 /// AAExecutionDomain. 107 static bool classof(const AbstractAttribute *AA) { 108 return (AA->getIdAddr() == &ID); 109 } 110 111 /// Unique ID (due to the unique address) 112 static const char ID; 113 }; 114 115 struct AAICVTracker; 116 117 /// OpenMP specific information. For now, stores RFIs and ICVs also needed for 118 /// Attributor runs. 119 struct OMPInformationCache : public InformationCache { 120 OMPInformationCache(Module &M, AnalysisGetter &AG, 121 BumpPtrAllocator &Allocator, SetVector<Function *> &CGSCC, 122 SmallPtrSetImpl<Kernel> &Kernels) 123 : InformationCache(M, AG, Allocator, &CGSCC), OMPBuilder(M), 124 Kernels(Kernels) { 125 126 OMPBuilder.initialize(); 127 initializeRuntimeFunctions(); 128 initializeInternalControlVars(); 129 } 130 131 /// Generic information that describes an internal control variable. 132 struct InternalControlVarInfo { 133 /// The kind, as described by InternalControlVar enum. 134 InternalControlVar Kind; 135 136 /// The name of the ICV. 137 StringRef Name; 138 139 /// Environment variable associated with this ICV. 140 StringRef EnvVarName; 141 142 /// Initial value kind. 143 ICVInitValue InitKind; 144 145 /// Initial value. 146 ConstantInt *InitValue; 147 148 /// Setter RTL function associated with this ICV. 149 RuntimeFunction Setter; 150 151 /// Getter RTL function associated with this ICV. 152 RuntimeFunction Getter; 153 154 /// RTL Function corresponding to the override clause of this ICV 155 RuntimeFunction Clause; 156 }; 157 158 /// Generic information that describes a runtime function 159 struct RuntimeFunctionInfo { 160 161 /// The kind, as described by the RuntimeFunction enum. 162 RuntimeFunction Kind; 163 164 /// The name of the function. 165 StringRef Name; 166 167 /// Flag to indicate a variadic function. 168 bool IsVarArg; 169 170 /// The return type of the function. 171 Type *ReturnType; 172 173 /// The argument types of the function. 174 SmallVector<Type *, 8> ArgumentTypes; 175 176 /// The declaration if available. 177 Function *Declaration = nullptr; 178 179 /// Uses of this runtime function per function containing the use. 180 using UseVector = SmallVector<Use *, 16>; 181 182 /// Clear UsesMap for runtime function. 183 void clearUsesMap() { UsesMap.clear(); } 184 185 /// Boolean conversion that is true if the runtime function was found. 186 operator bool() const { return Declaration; } 187 188 /// Return the vector of uses in function \p F. 189 UseVector &getOrCreateUseVector(Function *F) { 190 std::shared_ptr<UseVector> &UV = UsesMap[F]; 191 if (!UV) 192 UV = std::make_shared<UseVector>(); 193 return *UV; 194 } 195 196 /// Return the vector of uses in function \p F or `nullptr` if there are 197 /// none. 198 const UseVector *getUseVector(Function &F) const { 199 auto I = UsesMap.find(&F); 200 if (I != UsesMap.end()) 201 return I->second.get(); 202 return nullptr; 203 } 204 205 /// Return how many functions contain uses of this runtime function. 206 size_t getNumFunctionsWithUses() const { return UsesMap.size(); } 207 208 /// Return the number of arguments (or the minimal number for variadic 209 /// functions). 210 size_t getNumArgs() const { return ArgumentTypes.size(); } 211 212 /// Run the callback \p CB on each use and forget the use if the result is 213 /// true. The callback will be fed the function in which the use was 214 /// encountered as second argument. 215 void foreachUse(SmallVectorImpl<Function *> &SCC, 216 function_ref<bool(Use &, Function &)> CB) { 217 for (Function *F : SCC) 218 foreachUse(CB, F); 219 } 220 221 /// Run the callback \p CB on each use within the function \p F and forget 222 /// the use if the result is true. 223 void foreachUse(function_ref<bool(Use &, Function &)> CB, Function *F) { 224 SmallVector<unsigned, 8> ToBeDeleted; 225 ToBeDeleted.clear(); 226 227 unsigned Idx = 0; 228 UseVector &UV = getOrCreateUseVector(F); 229 230 for (Use *U : UV) { 231 if (CB(*U, *F)) 232 ToBeDeleted.push_back(Idx); 233 ++Idx; 234 } 235 236 // Remove the to-be-deleted indices in reverse order as prior 237 // modifications will not modify the smaller indices. 238 while (!ToBeDeleted.empty()) { 239 unsigned Idx = ToBeDeleted.pop_back_val(); 240 UV[Idx] = UV.back(); 241 UV.pop_back(); 242 } 243 } 244 245 private: 246 /// Map from functions to all uses of this runtime function contained in 247 /// them. 248 DenseMap<Function *, std::shared_ptr<UseVector>> UsesMap; 249 }; 250 251 /// An OpenMP-IR-Builder instance 252 OpenMPIRBuilder OMPBuilder; 253 254 /// Map from runtime function kind to the runtime function description. 255 EnumeratedArray<RuntimeFunctionInfo, RuntimeFunction, 256 RuntimeFunction::OMPRTL___last> 257 RFIs; 258 259 /// Map from ICV kind to the ICV description. 260 EnumeratedArray<InternalControlVarInfo, InternalControlVar, 261 InternalControlVar::ICV___last> 262 ICVs; 263 264 /// Helper to initialize all internal control variable information for those 265 /// defined in OMPKinds.def. 266 void initializeInternalControlVars() { 267 #define ICV_RT_SET(_Name, RTL) \ 268 { \ 269 auto &ICV = ICVs[_Name]; \ 270 ICV.Setter = RTL; \ 271 } 272 #define ICV_RT_GET(Name, RTL) \ 273 { \ 274 auto &ICV = ICVs[Name]; \ 275 ICV.Getter = RTL; \ 276 } 277 #define ICV_DATA_ENV(Enum, _Name, _EnvVarName, Init) \ 278 { \ 279 auto &ICV = ICVs[Enum]; \ 280 ICV.Name = _Name; \ 281 ICV.Kind = Enum; \ 282 ICV.InitKind = Init; \ 283 ICV.EnvVarName = _EnvVarName; \ 284 switch (ICV.InitKind) { \ 285 case ICV_IMPLEMENTATION_DEFINED: \ 286 ICV.InitValue = nullptr; \ 287 break; \ 288 case ICV_ZERO: \ 289 ICV.InitValue = ConstantInt::get( \ 290 Type::getInt32Ty(OMPBuilder.Int32->getContext()), 0); \ 291 break; \ 292 case ICV_FALSE: \ 293 ICV.InitValue = ConstantInt::getFalse(OMPBuilder.Int1->getContext()); \ 294 break; \ 295 case ICV_LAST: \ 296 break; \ 297 } \ 298 } 299 #include "llvm/Frontend/OpenMP/OMPKinds.def" 300 } 301 302 /// Returns true if the function declaration \p F matches the runtime 303 /// function types, that is, return type \p RTFRetType, and argument types 304 /// \p RTFArgTypes. 305 static bool declMatchesRTFTypes(Function *F, Type *RTFRetType, 306 SmallVector<Type *, 8> &RTFArgTypes) { 307 // TODO: We should output information to the user (under debug output 308 // and via remarks). 309 310 if (!F) 311 return false; 312 if (F->getReturnType() != RTFRetType) 313 return false; 314 if (F->arg_size() != RTFArgTypes.size()) 315 return false; 316 317 auto RTFTyIt = RTFArgTypes.begin(); 318 for (Argument &Arg : F->args()) { 319 if (Arg.getType() != *RTFTyIt) 320 return false; 321 322 ++RTFTyIt; 323 } 324 325 return true; 326 } 327 328 // Helper to collect all uses of the declaration in the UsesMap. 329 unsigned collectUses(RuntimeFunctionInfo &RFI, bool CollectStats = true) { 330 unsigned NumUses = 0; 331 if (!RFI.Declaration) 332 return NumUses; 333 OMPBuilder.addAttributes(RFI.Kind, *RFI.Declaration); 334 335 if (CollectStats) { 336 NumOpenMPRuntimeFunctionsIdentified += 1; 337 NumOpenMPRuntimeFunctionUsesIdentified += RFI.Declaration->getNumUses(); 338 } 339 340 // TODO: We directly convert uses into proper calls and unknown uses. 341 for (Use &U : RFI.Declaration->uses()) { 342 if (Instruction *UserI = dyn_cast<Instruction>(U.getUser())) { 343 if (ModuleSlice.count(UserI->getFunction())) { 344 RFI.getOrCreateUseVector(UserI->getFunction()).push_back(&U); 345 ++NumUses; 346 } 347 } else { 348 RFI.getOrCreateUseVector(nullptr).push_back(&U); 349 ++NumUses; 350 } 351 } 352 return NumUses; 353 } 354 355 // Helper function to recollect uses of a runtime function. 356 void recollectUsesForFunction(RuntimeFunction RTF) { 357 auto &RFI = RFIs[RTF]; 358 RFI.clearUsesMap(); 359 collectUses(RFI, /*CollectStats*/ false); 360 } 361 362 // Helper function to recollect uses of all runtime functions. 363 void recollectUses() { 364 for (int Idx = 0; Idx < RFIs.size(); ++Idx) 365 recollectUsesForFunction(static_cast<RuntimeFunction>(Idx)); 366 } 367 368 /// Helper to initialize all runtime function information for those defined 369 /// in OpenMPKinds.def. 370 void initializeRuntimeFunctions() { 371 Module &M = *((*ModuleSlice.begin())->getParent()); 372 373 // Helper macros for handling __VA_ARGS__ in OMP_RTL 374 #define OMP_TYPE(VarName, ...) \ 375 Type *VarName = OMPBuilder.VarName; \ 376 (void)VarName; 377 378 #define OMP_ARRAY_TYPE(VarName, ...) \ 379 ArrayType *VarName##Ty = OMPBuilder.VarName##Ty; \ 380 (void)VarName##Ty; \ 381 PointerType *VarName##PtrTy = OMPBuilder.VarName##PtrTy; \ 382 (void)VarName##PtrTy; 383 384 #define OMP_FUNCTION_TYPE(VarName, ...) \ 385 FunctionType *VarName = OMPBuilder.VarName; \ 386 (void)VarName; \ 387 PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr; \ 388 (void)VarName##Ptr; 389 390 #define OMP_STRUCT_TYPE(VarName, ...) \ 391 StructType *VarName = OMPBuilder.VarName; \ 392 (void)VarName; \ 393 PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr; \ 394 (void)VarName##Ptr; 395 396 #define OMP_RTL(_Enum, _Name, _IsVarArg, _ReturnType, ...) \ 397 { \ 398 SmallVector<Type *, 8> ArgsTypes({__VA_ARGS__}); \ 399 Function *F = M.getFunction(_Name); \ 400 if (declMatchesRTFTypes(F, OMPBuilder._ReturnType, ArgsTypes)) { \ 401 auto &RFI = RFIs[_Enum]; \ 402 RFI.Kind = _Enum; \ 403 RFI.Name = _Name; \ 404 RFI.IsVarArg = _IsVarArg; \ 405 RFI.ReturnType = OMPBuilder._ReturnType; \ 406 RFI.ArgumentTypes = std::move(ArgsTypes); \ 407 RFI.Declaration = F; \ 408 unsigned NumUses = collectUses(RFI); \ 409 (void)NumUses; \ 410 LLVM_DEBUG({ \ 411 dbgs() << TAG << RFI.Name << (RFI.Declaration ? "" : " not") \ 412 << " found\n"; \ 413 if (RFI.Declaration) \ 414 dbgs() << TAG << "-> got " << NumUses << " uses in " \ 415 << RFI.getNumFunctionsWithUses() \ 416 << " different functions.\n"; \ 417 }); \ 418 } \ 419 } 420 #include "llvm/Frontend/OpenMP/OMPKinds.def" 421 422 // TODO: We should attach the attributes defined in OMPKinds.def. 423 } 424 425 /// Collection of known kernels (\see Kernel) in the module. 426 SmallPtrSetImpl<Kernel> &Kernels; 427 }; 428 429 /// Used to map the values physically (in the IR) stored in an offload 430 /// array, to a vector in memory. 431 struct OffloadArray { 432 /// Physical array (in the IR). 433 AllocaInst *Array = nullptr; 434 /// Mapped values. 435 SmallVector<Value *, 8> StoredValues; 436 /// Last stores made in the offload array. 437 SmallVector<StoreInst *, 8> LastAccesses; 438 439 OffloadArray() = default; 440 441 /// Initializes the OffloadArray with the values stored in \p Array before 442 /// instruction \p Before is reached. Returns false if the initialization 443 /// fails. 444 /// This MUST be used immediately after the construction of the object. 445 bool initialize(AllocaInst &Array, Instruction &Before) { 446 if (!Array.getAllocatedType()->isArrayTy()) 447 return false; 448 449 if (!getValues(Array, Before)) 450 return false; 451 452 this->Array = &Array; 453 return true; 454 } 455 456 static const unsigned DeviceIDArgNum = 1; 457 static const unsigned BasePtrsArgNum = 3; 458 static const unsigned PtrsArgNum = 4; 459 static const unsigned SizesArgNum = 5; 460 461 private: 462 /// Traverses the BasicBlock where \p Array is, collecting the stores made to 463 /// \p Array, leaving StoredValues with the values stored before the 464 /// instruction \p Before is reached. 465 bool getValues(AllocaInst &Array, Instruction &Before) { 466 // Initialize container. 467 const uint64_t NumValues = Array.getAllocatedType()->getArrayNumElements(); 468 StoredValues.assign(NumValues, nullptr); 469 LastAccesses.assign(NumValues, nullptr); 470 471 // TODO: This assumes the instruction \p Before is in the same 472 // BasicBlock as Array. Make it general, for any control flow graph. 473 BasicBlock *BB = Array.getParent(); 474 if (BB != Before.getParent()) 475 return false; 476 477 const DataLayout &DL = Array.getModule()->getDataLayout(); 478 const unsigned int PointerSize = DL.getPointerSize(); 479 480 for (Instruction &I : *BB) { 481 if (&I == &Before) 482 break; 483 484 if (!isa<StoreInst>(&I)) 485 continue; 486 487 auto *S = cast<StoreInst>(&I); 488 int64_t Offset = -1; 489 auto *Dst = 490 GetPointerBaseWithConstantOffset(S->getPointerOperand(), Offset, DL); 491 if (Dst == &Array) { 492 int64_t Idx = Offset / PointerSize; 493 StoredValues[Idx] = getUnderlyingObject(S->getValueOperand()); 494 LastAccesses[Idx] = S; 495 } 496 } 497 498 return isFilled(); 499 } 500 501 /// Returns true if all values in StoredValues and 502 /// LastAccesses are not nullptrs. 503 bool isFilled() { 504 const unsigned NumValues = StoredValues.size(); 505 for (unsigned I = 0; I < NumValues; ++I) { 506 if (!StoredValues[I] || !LastAccesses[I]) 507 return false; 508 } 509 510 return true; 511 } 512 }; 513 514 struct OpenMPOpt { 515 516 using OptimizationRemarkGetter = 517 function_ref<OptimizationRemarkEmitter &(Function *)>; 518 519 OpenMPOpt(SmallVectorImpl<Function *> &SCC, CallGraphUpdater &CGUpdater, 520 OptimizationRemarkGetter OREGetter, 521 OMPInformationCache &OMPInfoCache, Attributor &A) 522 : M(*(*SCC.begin())->getParent()), SCC(SCC), CGUpdater(CGUpdater), 523 OREGetter(OREGetter), OMPInfoCache(OMPInfoCache), A(A) {} 524 525 /// Check if any remarks are enabled for openmp-opt 526 bool remarksEnabled() { 527 auto &Ctx = M.getContext(); 528 return Ctx.getDiagHandlerPtr()->isAnyRemarkEnabled(DEBUG_TYPE); 529 } 530 531 /// Run all OpenMP optimizations on the underlying SCC/ModuleSlice. 532 bool run(bool IsModulePass) { 533 if (SCC.empty()) 534 return false; 535 536 bool Changed = false; 537 538 LLVM_DEBUG(dbgs() << TAG << "Run on SCC with " << SCC.size() 539 << " functions in a slice with " 540 << OMPInfoCache.ModuleSlice.size() << " functions\n"); 541 542 if (IsModulePass) { 543 Changed |= runAttributor(); 544 545 if (remarksEnabled()) 546 analysisGlobalization(); 547 } else { 548 if (PrintICVValues) 549 printICVs(); 550 if (PrintOpenMPKernels) 551 printKernels(); 552 553 Changed |= rewriteDeviceCodeStateMachine(); 554 555 Changed |= runAttributor(); 556 557 // Recollect uses, in case Attributor deleted any. 558 OMPInfoCache.recollectUses(); 559 560 Changed |= deleteParallelRegions(); 561 if (HideMemoryTransferLatency) 562 Changed |= hideMemTransfersLatency(); 563 Changed |= deduplicateRuntimeCalls(); 564 if (EnableParallelRegionMerging) { 565 if (mergeParallelRegions()) { 566 deduplicateRuntimeCalls(); 567 Changed = true; 568 } 569 } 570 } 571 572 return Changed; 573 } 574 575 /// Print initial ICV values for testing. 576 /// FIXME: This should be done from the Attributor once it is added. 577 void printICVs() const { 578 InternalControlVar ICVs[] = {ICV_nthreads, ICV_active_levels, ICV_cancel, 579 ICV_proc_bind}; 580 581 for (Function *F : OMPInfoCache.ModuleSlice) { 582 for (auto ICV : ICVs) { 583 auto ICVInfo = OMPInfoCache.ICVs[ICV]; 584 auto Remark = [&](OptimizationRemarkAnalysis ORA) { 585 return ORA << "OpenMP ICV " << ore::NV("OpenMPICV", ICVInfo.Name) 586 << " Value: " 587 << (ICVInfo.InitValue 588 ? toString(ICVInfo.InitValue->getValue(), 10, true) 589 : "IMPLEMENTATION_DEFINED"); 590 }; 591 592 emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPICVTracker", Remark); 593 } 594 } 595 } 596 597 /// Print OpenMP GPU kernels for testing. 598 void printKernels() const { 599 for (Function *F : SCC) { 600 if (!OMPInfoCache.Kernels.count(F)) 601 continue; 602 603 auto Remark = [&](OptimizationRemarkAnalysis ORA) { 604 return ORA << "OpenMP GPU kernel " 605 << ore::NV("OpenMPGPUKernel", F->getName()) << "\n"; 606 }; 607 608 emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPGPU", Remark); 609 } 610 } 611 612 /// Return the call if \p U is a callee use in a regular call. If \p RFI is 613 /// given it has to be the callee or a nullptr is returned. 614 static CallInst *getCallIfRegularCall( 615 Use &U, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) { 616 CallInst *CI = dyn_cast<CallInst>(U.getUser()); 617 if (CI && CI->isCallee(&U) && !CI->hasOperandBundles() && 618 (!RFI || CI->getCalledFunction() == RFI->Declaration)) 619 return CI; 620 return nullptr; 621 } 622 623 /// Return the call if \p V is a regular call. If \p RFI is given it has to be 624 /// the callee or a nullptr is returned. 625 static CallInst *getCallIfRegularCall( 626 Value &V, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) { 627 CallInst *CI = dyn_cast<CallInst>(&V); 628 if (CI && !CI->hasOperandBundles() && 629 (!RFI || CI->getCalledFunction() == RFI->Declaration)) 630 return CI; 631 return nullptr; 632 } 633 634 private: 635 /// Merge parallel regions when it is safe. 636 bool mergeParallelRegions() { 637 const unsigned CallbackCalleeOperand = 2; 638 const unsigned CallbackFirstArgOperand = 3; 639 using InsertPointTy = OpenMPIRBuilder::InsertPointTy; 640 641 // Check if there are any __kmpc_fork_call calls to merge. 642 OMPInformationCache::RuntimeFunctionInfo &RFI = 643 OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call]; 644 645 if (!RFI.Declaration) 646 return false; 647 648 // Unmergable calls that prevent merging a parallel region. 649 OMPInformationCache::RuntimeFunctionInfo UnmergableCallsInfo[] = { 650 OMPInfoCache.RFIs[OMPRTL___kmpc_push_proc_bind], 651 OMPInfoCache.RFIs[OMPRTL___kmpc_push_num_threads], 652 }; 653 654 bool Changed = false; 655 LoopInfo *LI = nullptr; 656 DominatorTree *DT = nullptr; 657 658 SmallDenseMap<BasicBlock *, SmallPtrSet<Instruction *, 4>> BB2PRMap; 659 660 BasicBlock *StartBB = nullptr, *EndBB = nullptr; 661 auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, 662 BasicBlock &ContinuationIP) { 663 BasicBlock *CGStartBB = CodeGenIP.getBlock(); 664 BasicBlock *CGEndBB = 665 SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI); 666 assert(StartBB != nullptr && "StartBB should not be null"); 667 CGStartBB->getTerminator()->setSuccessor(0, StartBB); 668 assert(EndBB != nullptr && "EndBB should not be null"); 669 EndBB->getTerminator()->setSuccessor(0, CGEndBB); 670 }; 671 672 auto PrivCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, Value &, 673 Value &Inner, Value *&ReplacementValue) -> InsertPointTy { 674 ReplacementValue = &Inner; 675 return CodeGenIP; 676 }; 677 678 auto FiniCB = [&](InsertPointTy CodeGenIP) {}; 679 680 /// Create a sequential execution region within a merged parallel region, 681 /// encapsulated in a master construct with a barrier for synchronization. 682 auto CreateSequentialRegion = [&](Function *OuterFn, 683 BasicBlock *OuterPredBB, 684 Instruction *SeqStartI, 685 Instruction *SeqEndI) { 686 // Isolate the instructions of the sequential region to a separate 687 // block. 688 BasicBlock *ParentBB = SeqStartI->getParent(); 689 BasicBlock *SeqEndBB = 690 SplitBlock(ParentBB, SeqEndI->getNextNode(), DT, LI); 691 BasicBlock *SeqAfterBB = 692 SplitBlock(SeqEndBB, &*SeqEndBB->getFirstInsertionPt(), DT, LI); 693 BasicBlock *SeqStartBB = 694 SplitBlock(ParentBB, SeqStartI, DT, LI, nullptr, "seq.par.merged"); 695 696 assert(ParentBB->getUniqueSuccessor() == SeqStartBB && 697 "Expected a different CFG"); 698 const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc(); 699 ParentBB->getTerminator()->eraseFromParent(); 700 701 auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, 702 BasicBlock &ContinuationIP) { 703 BasicBlock *CGStartBB = CodeGenIP.getBlock(); 704 BasicBlock *CGEndBB = 705 SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI); 706 assert(SeqStartBB != nullptr && "SeqStartBB should not be null"); 707 CGStartBB->getTerminator()->setSuccessor(0, SeqStartBB); 708 assert(SeqEndBB != nullptr && "SeqEndBB should not be null"); 709 SeqEndBB->getTerminator()->setSuccessor(0, CGEndBB); 710 }; 711 auto FiniCB = [&](InsertPointTy CodeGenIP) {}; 712 713 // Find outputs from the sequential region to outside users and 714 // broadcast their values to them. 715 for (Instruction &I : *SeqStartBB) { 716 SmallPtrSet<Instruction *, 4> OutsideUsers; 717 for (User *Usr : I.users()) { 718 Instruction &UsrI = *cast<Instruction>(Usr); 719 // Ignore outputs to LT intrinsics, code extraction for the merged 720 // parallel region will fix them. 721 if (UsrI.isLifetimeStartOrEnd()) 722 continue; 723 724 if (UsrI.getParent() != SeqStartBB) 725 OutsideUsers.insert(&UsrI); 726 } 727 728 if (OutsideUsers.empty()) 729 continue; 730 731 // Emit an alloca in the outer region to store the broadcasted 732 // value. 733 const DataLayout &DL = M.getDataLayout(); 734 AllocaInst *AllocaI = new AllocaInst( 735 I.getType(), DL.getAllocaAddrSpace(), nullptr, 736 I.getName() + ".seq.output.alloc", &OuterFn->front().front()); 737 738 // Emit a store instruction in the sequential BB to update the 739 // value. 740 new StoreInst(&I, AllocaI, SeqStartBB->getTerminator()); 741 742 // Emit a load instruction and replace the use of the output value 743 // with it. 744 for (Instruction *UsrI : OutsideUsers) { 745 LoadInst *LoadI = new LoadInst( 746 I.getType(), AllocaI, I.getName() + ".seq.output.load", UsrI); 747 UsrI->replaceUsesOfWith(&I, LoadI); 748 } 749 } 750 751 OpenMPIRBuilder::LocationDescription Loc( 752 InsertPointTy(ParentBB, ParentBB->end()), DL); 753 InsertPointTy SeqAfterIP = 754 OMPInfoCache.OMPBuilder.createMaster(Loc, BodyGenCB, FiniCB); 755 756 OMPInfoCache.OMPBuilder.createBarrier(SeqAfterIP, OMPD_parallel); 757 758 BranchInst::Create(SeqAfterBB, SeqAfterIP.getBlock()); 759 760 LLVM_DEBUG(dbgs() << TAG << "After sequential inlining " << *OuterFn 761 << "\n"); 762 }; 763 764 // Helper to merge the __kmpc_fork_call calls in MergableCIs. They are all 765 // contained in BB and only separated by instructions that can be 766 // redundantly executed in parallel. The block BB is split before the first 767 // call (in MergableCIs) and after the last so the entire region we merge 768 // into a single parallel region is contained in a single basic block 769 // without any other instructions. We use the OpenMPIRBuilder to outline 770 // that block and call the resulting function via __kmpc_fork_call. 771 auto Merge = [&](SmallVectorImpl<CallInst *> &MergableCIs, BasicBlock *BB) { 772 // TODO: Change the interface to allow single CIs expanded, e.g, to 773 // include an outer loop. 774 assert(MergableCIs.size() > 1 && "Assumed multiple mergable CIs"); 775 776 auto Remark = [&](OptimizationRemark OR) { 777 OR << "Parallel region at " 778 << ore::NV("OpenMPParallelMergeFront", 779 MergableCIs.front()->getDebugLoc()) 780 << " merged with parallel regions at "; 781 for (auto *CI : llvm::drop_begin(MergableCIs)) { 782 OR << ore::NV("OpenMPParallelMerge", CI->getDebugLoc()); 783 if (CI != MergableCIs.back()) 784 OR << ", "; 785 } 786 return OR; 787 }; 788 789 emitRemark<OptimizationRemark>(MergableCIs.front(), 790 "OpenMPParallelRegionMerging", Remark); 791 792 Function *OriginalFn = BB->getParent(); 793 LLVM_DEBUG(dbgs() << TAG << "Merge " << MergableCIs.size() 794 << " parallel regions in " << OriginalFn->getName() 795 << "\n"); 796 797 // Isolate the calls to merge in a separate block. 798 EndBB = SplitBlock(BB, MergableCIs.back()->getNextNode(), DT, LI); 799 BasicBlock *AfterBB = 800 SplitBlock(EndBB, &*EndBB->getFirstInsertionPt(), DT, LI); 801 StartBB = SplitBlock(BB, MergableCIs.front(), DT, LI, nullptr, 802 "omp.par.merged"); 803 804 assert(BB->getUniqueSuccessor() == StartBB && "Expected a different CFG"); 805 const DebugLoc DL = BB->getTerminator()->getDebugLoc(); 806 BB->getTerminator()->eraseFromParent(); 807 808 // Create sequential regions for sequential instructions that are 809 // in-between mergable parallel regions. 810 for (auto *It = MergableCIs.begin(), *End = MergableCIs.end() - 1; 811 It != End; ++It) { 812 Instruction *ForkCI = *It; 813 Instruction *NextForkCI = *(It + 1); 814 815 // Continue if there are not in-between instructions. 816 if (ForkCI->getNextNode() == NextForkCI) 817 continue; 818 819 CreateSequentialRegion(OriginalFn, BB, ForkCI->getNextNode(), 820 NextForkCI->getPrevNode()); 821 } 822 823 OpenMPIRBuilder::LocationDescription Loc(InsertPointTy(BB, BB->end()), 824 DL); 825 IRBuilder<>::InsertPoint AllocaIP( 826 &OriginalFn->getEntryBlock(), 827 OriginalFn->getEntryBlock().getFirstInsertionPt()); 828 // Create the merged parallel region with default proc binding, to 829 // avoid overriding binding settings, and without explicit cancellation. 830 InsertPointTy AfterIP = OMPInfoCache.OMPBuilder.createParallel( 831 Loc, AllocaIP, BodyGenCB, PrivCB, FiniCB, nullptr, nullptr, 832 OMP_PROC_BIND_default, /* IsCancellable */ false); 833 BranchInst::Create(AfterBB, AfterIP.getBlock()); 834 835 // Perform the actual outlining. 836 OMPInfoCache.OMPBuilder.finalize(OriginalFn, 837 /* AllowExtractorSinking */ true); 838 839 Function *OutlinedFn = MergableCIs.front()->getCaller(); 840 841 // Replace the __kmpc_fork_call calls with direct calls to the outlined 842 // callbacks. 843 SmallVector<Value *, 8> Args; 844 for (auto *CI : MergableCIs) { 845 Value *Callee = 846 CI->getArgOperand(CallbackCalleeOperand)->stripPointerCasts(); 847 FunctionType *FT = 848 cast<FunctionType>(Callee->getType()->getPointerElementType()); 849 Args.clear(); 850 Args.push_back(OutlinedFn->getArg(0)); 851 Args.push_back(OutlinedFn->getArg(1)); 852 for (unsigned U = CallbackFirstArgOperand, E = CI->getNumArgOperands(); 853 U < E; ++U) 854 Args.push_back(CI->getArgOperand(U)); 855 856 CallInst *NewCI = CallInst::Create(FT, Callee, Args, "", CI); 857 if (CI->getDebugLoc()) 858 NewCI->setDebugLoc(CI->getDebugLoc()); 859 860 // Forward parameter attributes from the callback to the callee. 861 for (unsigned U = CallbackFirstArgOperand, E = CI->getNumArgOperands(); 862 U < E; ++U) 863 for (const Attribute &A : CI->getAttributes().getParamAttributes(U)) 864 NewCI->addParamAttr( 865 U - (CallbackFirstArgOperand - CallbackCalleeOperand), A); 866 867 // Emit an explicit barrier to replace the implicit fork-join barrier. 868 if (CI != MergableCIs.back()) { 869 // TODO: Remove barrier if the merged parallel region includes the 870 // 'nowait' clause. 871 OMPInfoCache.OMPBuilder.createBarrier( 872 InsertPointTy(NewCI->getParent(), 873 NewCI->getNextNode()->getIterator()), 874 OMPD_parallel); 875 } 876 877 auto Remark = [&](OptimizationRemark OR) { 878 return OR << "Parallel region at " 879 << ore::NV("OpenMPParallelMerge", CI->getDebugLoc()) 880 << " merged with " 881 << ore::NV("OpenMPParallelMergeFront", 882 MergableCIs.front()->getDebugLoc()); 883 }; 884 if (CI != MergableCIs.front()) 885 emitRemark<OptimizationRemark>(CI, "OpenMPParallelRegionMerging", 886 Remark); 887 888 CI->eraseFromParent(); 889 } 890 891 assert(OutlinedFn != OriginalFn && "Outlining failed"); 892 CGUpdater.registerOutlinedFunction(*OriginalFn, *OutlinedFn); 893 CGUpdater.reanalyzeFunction(*OriginalFn); 894 895 NumOpenMPParallelRegionsMerged += MergableCIs.size(); 896 897 return true; 898 }; 899 900 // Helper function that identifes sequences of 901 // __kmpc_fork_call uses in a basic block. 902 auto DetectPRsCB = [&](Use &U, Function &F) { 903 CallInst *CI = getCallIfRegularCall(U, &RFI); 904 BB2PRMap[CI->getParent()].insert(CI); 905 906 return false; 907 }; 908 909 BB2PRMap.clear(); 910 RFI.foreachUse(SCC, DetectPRsCB); 911 SmallVector<SmallVector<CallInst *, 4>, 4> MergableCIsVector; 912 // Find mergable parallel regions within a basic block that are 913 // safe to merge, that is any in-between instructions can safely 914 // execute in parallel after merging. 915 // TODO: support merging across basic-blocks. 916 for (auto &It : BB2PRMap) { 917 auto &CIs = It.getSecond(); 918 if (CIs.size() < 2) 919 continue; 920 921 BasicBlock *BB = It.getFirst(); 922 SmallVector<CallInst *, 4> MergableCIs; 923 924 /// Returns true if the instruction is mergable, false otherwise. 925 /// A terminator instruction is unmergable by definition since merging 926 /// works within a BB. Instructions before the mergable region are 927 /// mergable if they are not calls to OpenMP runtime functions that may 928 /// set different execution parameters for subsequent parallel regions. 929 /// Instructions in-between parallel regions are mergable if they are not 930 /// calls to any non-intrinsic function since that may call a non-mergable 931 /// OpenMP runtime function. 932 auto IsMergable = [&](Instruction &I, bool IsBeforeMergableRegion) { 933 // We do not merge across BBs, hence return false (unmergable) if the 934 // instruction is a terminator. 935 if (I.isTerminator()) 936 return false; 937 938 if (!isa<CallInst>(&I)) 939 return true; 940 941 CallInst *CI = cast<CallInst>(&I); 942 if (IsBeforeMergableRegion) { 943 Function *CalledFunction = CI->getCalledFunction(); 944 if (!CalledFunction) 945 return false; 946 // Return false (unmergable) if the call before the parallel 947 // region calls an explicit affinity (proc_bind) or number of 948 // threads (num_threads) compiler-generated function. Those settings 949 // may be incompatible with following parallel regions. 950 // TODO: ICV tracking to detect compatibility. 951 for (const auto &RFI : UnmergableCallsInfo) { 952 if (CalledFunction == RFI.Declaration) 953 return false; 954 } 955 } else { 956 // Return false (unmergable) if there is a call instruction 957 // in-between parallel regions when it is not an intrinsic. It 958 // may call an unmergable OpenMP runtime function in its callpath. 959 // TODO: Keep track of possible OpenMP calls in the callpath. 960 if (!isa<IntrinsicInst>(CI)) 961 return false; 962 } 963 964 return true; 965 }; 966 // Find maximal number of parallel region CIs that are safe to merge. 967 for (auto It = BB->begin(), End = BB->end(); It != End;) { 968 Instruction &I = *It; 969 ++It; 970 971 if (CIs.count(&I)) { 972 MergableCIs.push_back(cast<CallInst>(&I)); 973 continue; 974 } 975 976 // Continue expanding if the instruction is mergable. 977 if (IsMergable(I, MergableCIs.empty())) 978 continue; 979 980 // Forward the instruction iterator to skip the next parallel region 981 // since there is an unmergable instruction which can affect it. 982 for (; It != End; ++It) { 983 Instruction &SkipI = *It; 984 if (CIs.count(&SkipI)) { 985 LLVM_DEBUG(dbgs() << TAG << "Skip parallel region " << SkipI 986 << " due to " << I << "\n"); 987 ++It; 988 break; 989 } 990 } 991 992 // Store mergable regions found. 993 if (MergableCIs.size() > 1) { 994 MergableCIsVector.push_back(MergableCIs); 995 LLVM_DEBUG(dbgs() << TAG << "Found " << MergableCIs.size() 996 << " parallel regions in block " << BB->getName() 997 << " of function " << BB->getParent()->getName() 998 << "\n";); 999 } 1000 1001 MergableCIs.clear(); 1002 } 1003 1004 if (!MergableCIsVector.empty()) { 1005 Changed = true; 1006 1007 for (auto &MergableCIs : MergableCIsVector) 1008 Merge(MergableCIs, BB); 1009 MergableCIsVector.clear(); 1010 } 1011 } 1012 1013 if (Changed) { 1014 /// Re-collect use for fork calls, emitted barrier calls, and 1015 /// any emitted master/end_master calls. 1016 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_fork_call); 1017 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_barrier); 1018 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_master); 1019 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_end_master); 1020 } 1021 1022 return Changed; 1023 } 1024 1025 /// Try to delete parallel regions if possible. 1026 bool deleteParallelRegions() { 1027 const unsigned CallbackCalleeOperand = 2; 1028 1029 OMPInformationCache::RuntimeFunctionInfo &RFI = 1030 OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call]; 1031 1032 if (!RFI.Declaration) 1033 return false; 1034 1035 bool Changed = false; 1036 auto DeleteCallCB = [&](Use &U, Function &) { 1037 CallInst *CI = getCallIfRegularCall(U); 1038 if (!CI) 1039 return false; 1040 auto *Fn = dyn_cast<Function>( 1041 CI->getArgOperand(CallbackCalleeOperand)->stripPointerCasts()); 1042 if (!Fn) 1043 return false; 1044 if (!Fn->onlyReadsMemory()) 1045 return false; 1046 if (!Fn->hasFnAttribute(Attribute::WillReturn)) 1047 return false; 1048 1049 LLVM_DEBUG(dbgs() << TAG << "Delete read-only parallel region in " 1050 << CI->getCaller()->getName() << "\n"); 1051 1052 auto Remark = [&](OptimizationRemark OR) { 1053 return OR << "Parallel region in " 1054 << ore::NV("OpenMPParallelDelete", CI->getCaller()->getName()) 1055 << " deleted"; 1056 }; 1057 emitRemark<OptimizationRemark>(CI, "OpenMPParallelRegionDeletion", 1058 Remark); 1059 1060 CGUpdater.removeCallSite(*CI); 1061 CI->eraseFromParent(); 1062 Changed = true; 1063 ++NumOpenMPParallelRegionsDeleted; 1064 return true; 1065 }; 1066 1067 RFI.foreachUse(SCC, DeleteCallCB); 1068 1069 return Changed; 1070 } 1071 1072 /// Try to eliminate runtime calls by reusing existing ones. 1073 bool deduplicateRuntimeCalls() { 1074 bool Changed = false; 1075 1076 RuntimeFunction DeduplicableRuntimeCallIDs[] = { 1077 OMPRTL_omp_get_num_threads, 1078 OMPRTL_omp_in_parallel, 1079 OMPRTL_omp_get_cancellation, 1080 OMPRTL_omp_get_thread_limit, 1081 OMPRTL_omp_get_supported_active_levels, 1082 OMPRTL_omp_get_level, 1083 OMPRTL_omp_get_ancestor_thread_num, 1084 OMPRTL_omp_get_team_size, 1085 OMPRTL_omp_get_active_level, 1086 OMPRTL_omp_in_final, 1087 OMPRTL_omp_get_proc_bind, 1088 OMPRTL_omp_get_num_places, 1089 OMPRTL_omp_get_num_procs, 1090 OMPRTL_omp_get_place_num, 1091 OMPRTL_omp_get_partition_num_places, 1092 OMPRTL_omp_get_partition_place_nums}; 1093 1094 // Global-tid is handled separately. 1095 SmallSetVector<Value *, 16> GTIdArgs; 1096 collectGlobalThreadIdArguments(GTIdArgs); 1097 LLVM_DEBUG(dbgs() << TAG << "Found " << GTIdArgs.size() 1098 << " global thread ID arguments\n"); 1099 1100 for (Function *F : SCC) { 1101 for (auto DeduplicableRuntimeCallID : DeduplicableRuntimeCallIDs) 1102 Changed |= deduplicateRuntimeCalls( 1103 *F, OMPInfoCache.RFIs[DeduplicableRuntimeCallID]); 1104 1105 // __kmpc_global_thread_num is special as we can replace it with an 1106 // argument in enough cases to make it worth trying. 1107 Value *GTIdArg = nullptr; 1108 for (Argument &Arg : F->args()) 1109 if (GTIdArgs.count(&Arg)) { 1110 GTIdArg = &Arg; 1111 break; 1112 } 1113 Changed |= deduplicateRuntimeCalls( 1114 *F, OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num], GTIdArg); 1115 } 1116 1117 return Changed; 1118 } 1119 1120 /// Tries to hide the latency of runtime calls that involve host to 1121 /// device memory transfers by splitting them into their "issue" and "wait" 1122 /// versions. The "issue" is moved upwards as much as possible. The "wait" is 1123 /// moved downards as much as possible. The "issue" issues the memory transfer 1124 /// asynchronously, returning a handle. The "wait" waits in the returned 1125 /// handle for the memory transfer to finish. 1126 bool hideMemTransfersLatency() { 1127 auto &RFI = OMPInfoCache.RFIs[OMPRTL___tgt_target_data_begin_mapper]; 1128 bool Changed = false; 1129 auto SplitMemTransfers = [&](Use &U, Function &Decl) { 1130 auto *RTCall = getCallIfRegularCall(U, &RFI); 1131 if (!RTCall) 1132 return false; 1133 1134 OffloadArray OffloadArrays[3]; 1135 if (!getValuesInOffloadArrays(*RTCall, OffloadArrays)) 1136 return false; 1137 1138 LLVM_DEBUG(dumpValuesInOffloadArrays(OffloadArrays)); 1139 1140 // TODO: Check if can be moved upwards. 1141 bool WasSplit = false; 1142 Instruction *WaitMovementPoint = canBeMovedDownwards(*RTCall); 1143 if (WaitMovementPoint) 1144 WasSplit = splitTargetDataBeginRTC(*RTCall, *WaitMovementPoint); 1145 1146 Changed |= WasSplit; 1147 return WasSplit; 1148 }; 1149 RFI.foreachUse(SCC, SplitMemTransfers); 1150 1151 return Changed; 1152 } 1153 1154 void analysisGlobalization() { 1155 RuntimeFunction GlobalizationRuntimeIDs[] = { 1156 OMPRTL___kmpc_data_sharing_coalesced_push_stack, 1157 OMPRTL___kmpc_data_sharing_push_stack}; 1158 1159 for (const auto GlobalizationCallID : GlobalizationRuntimeIDs) { 1160 auto &RFI = OMPInfoCache.RFIs[GlobalizationCallID]; 1161 1162 auto CheckGlobalization = [&](Use &U, Function &Decl) { 1163 if (CallInst *CI = getCallIfRegularCall(U, &RFI)) { 1164 auto Remark = [&](OptimizationRemarkAnalysis ORA) { 1165 return ORA 1166 << "Found thread data sharing on the GPU. " 1167 << "Expect degraded performance due to data globalization."; 1168 }; 1169 emitRemark<OptimizationRemarkAnalysis>(CI, "OpenMPGlobalization", 1170 Remark); 1171 } 1172 1173 return false; 1174 }; 1175 1176 RFI.foreachUse(SCC, CheckGlobalization); 1177 } 1178 } 1179 1180 /// Maps the values stored in the offload arrays passed as arguments to 1181 /// \p RuntimeCall into the offload arrays in \p OAs. 1182 bool getValuesInOffloadArrays(CallInst &RuntimeCall, 1183 MutableArrayRef<OffloadArray> OAs) { 1184 assert(OAs.size() == 3 && "Need space for three offload arrays!"); 1185 1186 // A runtime call that involves memory offloading looks something like: 1187 // call void @__tgt_target_data_begin_mapper(arg0, arg1, 1188 // i8** %offload_baseptrs, i8** %offload_ptrs, i64* %offload_sizes, 1189 // ...) 1190 // So, the idea is to access the allocas that allocate space for these 1191 // offload arrays, offload_baseptrs, offload_ptrs, offload_sizes. 1192 // Therefore: 1193 // i8** %offload_baseptrs. 1194 Value *BasePtrsArg = 1195 RuntimeCall.getArgOperand(OffloadArray::BasePtrsArgNum); 1196 // i8** %offload_ptrs. 1197 Value *PtrsArg = RuntimeCall.getArgOperand(OffloadArray::PtrsArgNum); 1198 // i8** %offload_sizes. 1199 Value *SizesArg = RuntimeCall.getArgOperand(OffloadArray::SizesArgNum); 1200 1201 // Get values stored in **offload_baseptrs. 1202 auto *V = getUnderlyingObject(BasePtrsArg); 1203 if (!isa<AllocaInst>(V)) 1204 return false; 1205 auto *BasePtrsArray = cast<AllocaInst>(V); 1206 if (!OAs[0].initialize(*BasePtrsArray, RuntimeCall)) 1207 return false; 1208 1209 // Get values stored in **offload_baseptrs. 1210 V = getUnderlyingObject(PtrsArg); 1211 if (!isa<AllocaInst>(V)) 1212 return false; 1213 auto *PtrsArray = cast<AllocaInst>(V); 1214 if (!OAs[1].initialize(*PtrsArray, RuntimeCall)) 1215 return false; 1216 1217 // Get values stored in **offload_sizes. 1218 V = getUnderlyingObject(SizesArg); 1219 // If it's a [constant] global array don't analyze it. 1220 if (isa<GlobalValue>(V)) 1221 return isa<Constant>(V); 1222 if (!isa<AllocaInst>(V)) 1223 return false; 1224 1225 auto *SizesArray = cast<AllocaInst>(V); 1226 if (!OAs[2].initialize(*SizesArray, RuntimeCall)) 1227 return false; 1228 1229 return true; 1230 } 1231 1232 /// Prints the values in the OffloadArrays \p OAs using LLVM_DEBUG. 1233 /// For now this is a way to test that the function getValuesInOffloadArrays 1234 /// is working properly. 1235 /// TODO: Move this to a unittest when unittests are available for OpenMPOpt. 1236 void dumpValuesInOffloadArrays(ArrayRef<OffloadArray> OAs) { 1237 assert(OAs.size() == 3 && "There are three offload arrays to debug!"); 1238 1239 LLVM_DEBUG(dbgs() << TAG << " Successfully got offload values:\n"); 1240 std::string ValuesStr; 1241 raw_string_ostream Printer(ValuesStr); 1242 std::string Separator = " --- "; 1243 1244 for (auto *BP : OAs[0].StoredValues) { 1245 BP->print(Printer); 1246 Printer << Separator; 1247 } 1248 LLVM_DEBUG(dbgs() << "\t\toffload_baseptrs: " << Printer.str() << "\n"); 1249 ValuesStr.clear(); 1250 1251 for (auto *P : OAs[1].StoredValues) { 1252 P->print(Printer); 1253 Printer << Separator; 1254 } 1255 LLVM_DEBUG(dbgs() << "\t\toffload_ptrs: " << Printer.str() << "\n"); 1256 ValuesStr.clear(); 1257 1258 for (auto *S : OAs[2].StoredValues) { 1259 S->print(Printer); 1260 Printer << Separator; 1261 } 1262 LLVM_DEBUG(dbgs() << "\t\toffload_sizes: " << Printer.str() << "\n"); 1263 } 1264 1265 /// Returns the instruction where the "wait" counterpart \p RuntimeCall can be 1266 /// moved. Returns nullptr if the movement is not possible, or not worth it. 1267 Instruction *canBeMovedDownwards(CallInst &RuntimeCall) { 1268 // FIXME: This traverses only the BasicBlock where RuntimeCall is. 1269 // Make it traverse the CFG. 1270 1271 Instruction *CurrentI = &RuntimeCall; 1272 bool IsWorthIt = false; 1273 while ((CurrentI = CurrentI->getNextNode())) { 1274 1275 // TODO: Once we detect the regions to be offloaded we should use the 1276 // alias analysis manager to check if CurrentI may modify one of 1277 // the offloaded regions. 1278 if (CurrentI->mayHaveSideEffects() || CurrentI->mayReadFromMemory()) { 1279 if (IsWorthIt) 1280 return CurrentI; 1281 1282 return nullptr; 1283 } 1284 1285 // FIXME: For now if we move it over anything without side effect 1286 // is worth it. 1287 IsWorthIt = true; 1288 } 1289 1290 // Return end of BasicBlock. 1291 return RuntimeCall.getParent()->getTerminator(); 1292 } 1293 1294 /// Splits \p RuntimeCall into its "issue" and "wait" counterparts. 1295 bool splitTargetDataBeginRTC(CallInst &RuntimeCall, 1296 Instruction &WaitMovementPoint) { 1297 // Create stack allocated handle (__tgt_async_info) at the beginning of the 1298 // function. Used for storing information of the async transfer, allowing to 1299 // wait on it later. 1300 auto &IRBuilder = OMPInfoCache.OMPBuilder; 1301 auto *F = RuntimeCall.getCaller(); 1302 Instruction *FirstInst = &(F->getEntryBlock().front()); 1303 AllocaInst *Handle = new AllocaInst( 1304 IRBuilder.AsyncInfo, F->getAddressSpace(), "handle", FirstInst); 1305 1306 // Add "issue" runtime call declaration: 1307 // declare %struct.tgt_async_info @__tgt_target_data_begin_issue(i64, i32, 1308 // i8**, i8**, i64*, i64*) 1309 FunctionCallee IssueDecl = IRBuilder.getOrCreateRuntimeFunction( 1310 M, OMPRTL___tgt_target_data_begin_mapper_issue); 1311 1312 // Change RuntimeCall call site for its asynchronous version. 1313 SmallVector<Value *, 16> Args; 1314 for (auto &Arg : RuntimeCall.args()) 1315 Args.push_back(Arg.get()); 1316 Args.push_back(Handle); 1317 1318 CallInst *IssueCallsite = 1319 CallInst::Create(IssueDecl, Args, /*NameStr=*/"", &RuntimeCall); 1320 RuntimeCall.eraseFromParent(); 1321 1322 // Add "wait" runtime call declaration: 1323 // declare void @__tgt_target_data_begin_wait(i64, %struct.__tgt_async_info) 1324 FunctionCallee WaitDecl = IRBuilder.getOrCreateRuntimeFunction( 1325 M, OMPRTL___tgt_target_data_begin_mapper_wait); 1326 1327 Value *WaitParams[2] = { 1328 IssueCallsite->getArgOperand( 1329 OffloadArray::DeviceIDArgNum), // device_id. 1330 Handle // handle to wait on. 1331 }; 1332 CallInst::Create(WaitDecl, WaitParams, /*NameStr=*/"", &WaitMovementPoint); 1333 1334 return true; 1335 } 1336 1337 static Value *combinedIdentStruct(Value *CurrentIdent, Value *NextIdent, 1338 bool GlobalOnly, bool &SingleChoice) { 1339 if (CurrentIdent == NextIdent) 1340 return CurrentIdent; 1341 1342 // TODO: Figure out how to actually combine multiple debug locations. For 1343 // now we just keep an existing one if there is a single choice. 1344 if (!GlobalOnly || isa<GlobalValue>(NextIdent)) { 1345 SingleChoice = !CurrentIdent; 1346 return NextIdent; 1347 } 1348 return nullptr; 1349 } 1350 1351 /// Return an `struct ident_t*` value that represents the ones used in the 1352 /// calls of \p RFI inside of \p F. If \p GlobalOnly is true, we will not 1353 /// return a local `struct ident_t*`. For now, if we cannot find a suitable 1354 /// return value we create one from scratch. We also do not yet combine 1355 /// information, e.g., the source locations, see combinedIdentStruct. 1356 Value * 1357 getCombinedIdentFromCallUsesIn(OMPInformationCache::RuntimeFunctionInfo &RFI, 1358 Function &F, bool GlobalOnly) { 1359 bool SingleChoice = true; 1360 Value *Ident = nullptr; 1361 auto CombineIdentStruct = [&](Use &U, Function &Caller) { 1362 CallInst *CI = getCallIfRegularCall(U, &RFI); 1363 if (!CI || &F != &Caller) 1364 return false; 1365 Ident = combinedIdentStruct(Ident, CI->getArgOperand(0), 1366 /* GlobalOnly */ true, SingleChoice); 1367 return false; 1368 }; 1369 RFI.foreachUse(SCC, CombineIdentStruct); 1370 1371 if (!Ident || !SingleChoice) { 1372 // The IRBuilder uses the insertion block to get to the module, this is 1373 // unfortunate but we work around it for now. 1374 if (!OMPInfoCache.OMPBuilder.getInsertionPoint().getBlock()) 1375 OMPInfoCache.OMPBuilder.updateToLocation(OpenMPIRBuilder::InsertPointTy( 1376 &F.getEntryBlock(), F.getEntryBlock().begin())); 1377 // Create a fallback location if non was found. 1378 // TODO: Use the debug locations of the calls instead. 1379 Constant *Loc = OMPInfoCache.OMPBuilder.getOrCreateDefaultSrcLocStr(); 1380 Ident = OMPInfoCache.OMPBuilder.getOrCreateIdent(Loc); 1381 } 1382 return Ident; 1383 } 1384 1385 /// Try to eliminate calls of \p RFI in \p F by reusing an existing one or 1386 /// \p ReplVal if given. 1387 bool deduplicateRuntimeCalls(Function &F, 1388 OMPInformationCache::RuntimeFunctionInfo &RFI, 1389 Value *ReplVal = nullptr) { 1390 auto *UV = RFI.getUseVector(F); 1391 if (!UV || UV->size() + (ReplVal != nullptr) < 2) 1392 return false; 1393 1394 LLVM_DEBUG( 1395 dbgs() << TAG << "Deduplicate " << UV->size() << " uses of " << RFI.Name 1396 << (ReplVal ? " with an existing value\n" : "\n") << "\n"); 1397 1398 assert((!ReplVal || (isa<Argument>(ReplVal) && 1399 cast<Argument>(ReplVal)->getParent() == &F)) && 1400 "Unexpected replacement value!"); 1401 1402 // TODO: Use dominance to find a good position instead. 1403 auto CanBeMoved = [this](CallBase &CB) { 1404 unsigned NumArgs = CB.getNumArgOperands(); 1405 if (NumArgs == 0) 1406 return true; 1407 if (CB.getArgOperand(0)->getType() != OMPInfoCache.OMPBuilder.IdentPtr) 1408 return false; 1409 for (unsigned u = 1; u < NumArgs; ++u) 1410 if (isa<Instruction>(CB.getArgOperand(u))) 1411 return false; 1412 return true; 1413 }; 1414 1415 if (!ReplVal) { 1416 for (Use *U : *UV) 1417 if (CallInst *CI = getCallIfRegularCall(*U, &RFI)) { 1418 if (!CanBeMoved(*CI)) 1419 continue; 1420 1421 auto Remark = [&](OptimizationRemark OR) { 1422 return OR << "OpenMP runtime call " 1423 << ore::NV("OpenMPOptRuntime", RFI.Name) 1424 << " moved to beginning of OpenMP region"; 1425 }; 1426 emitRemark<OptimizationRemark>(&F, "OpenMPRuntimeCodeMotion", Remark); 1427 1428 CI->moveBefore(&*F.getEntryBlock().getFirstInsertionPt()); 1429 ReplVal = CI; 1430 break; 1431 } 1432 if (!ReplVal) 1433 return false; 1434 } 1435 1436 // If we use a call as a replacement value we need to make sure the ident is 1437 // valid at the new location. For now we just pick a global one, either 1438 // existing and used by one of the calls, or created from scratch. 1439 if (CallBase *CI = dyn_cast<CallBase>(ReplVal)) { 1440 if (CI->getNumArgOperands() > 0 && 1441 CI->getArgOperand(0)->getType() == OMPInfoCache.OMPBuilder.IdentPtr) { 1442 Value *Ident = getCombinedIdentFromCallUsesIn(RFI, F, 1443 /* GlobalOnly */ true); 1444 CI->setArgOperand(0, Ident); 1445 } 1446 } 1447 1448 bool Changed = false; 1449 auto ReplaceAndDeleteCB = [&](Use &U, Function &Caller) { 1450 CallInst *CI = getCallIfRegularCall(U, &RFI); 1451 if (!CI || CI == ReplVal || &F != &Caller) 1452 return false; 1453 assert(CI->getCaller() == &F && "Unexpected call!"); 1454 1455 auto Remark = [&](OptimizationRemark OR) { 1456 return OR << "OpenMP runtime call " 1457 << ore::NV("OpenMPOptRuntime", RFI.Name) << " deduplicated"; 1458 }; 1459 emitRemark<OptimizationRemark>(&F, "OpenMPRuntimeDeduplicated", Remark); 1460 1461 CGUpdater.removeCallSite(*CI); 1462 CI->replaceAllUsesWith(ReplVal); 1463 CI->eraseFromParent(); 1464 ++NumOpenMPRuntimeCallsDeduplicated; 1465 Changed = true; 1466 return true; 1467 }; 1468 RFI.foreachUse(SCC, ReplaceAndDeleteCB); 1469 1470 return Changed; 1471 } 1472 1473 /// Collect arguments that represent the global thread id in \p GTIdArgs. 1474 void collectGlobalThreadIdArguments(SmallSetVector<Value *, 16> >IdArgs) { 1475 // TODO: Below we basically perform a fixpoint iteration with a pessimistic 1476 // initialization. We could define an AbstractAttribute instead and 1477 // run the Attributor here once it can be run as an SCC pass. 1478 1479 // Helper to check the argument \p ArgNo at all call sites of \p F for 1480 // a GTId. 1481 auto CallArgOpIsGTId = [&](Function &F, unsigned ArgNo, CallInst &RefCI) { 1482 if (!F.hasLocalLinkage()) 1483 return false; 1484 for (Use &U : F.uses()) { 1485 if (CallInst *CI = getCallIfRegularCall(U)) { 1486 Value *ArgOp = CI->getArgOperand(ArgNo); 1487 if (CI == &RefCI || GTIdArgs.count(ArgOp) || 1488 getCallIfRegularCall( 1489 *ArgOp, &OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num])) 1490 continue; 1491 } 1492 return false; 1493 } 1494 return true; 1495 }; 1496 1497 // Helper to identify uses of a GTId as GTId arguments. 1498 auto AddUserArgs = [&](Value >Id) { 1499 for (Use &U : GTId.uses()) 1500 if (CallInst *CI = dyn_cast<CallInst>(U.getUser())) 1501 if (CI->isArgOperand(&U)) 1502 if (Function *Callee = CI->getCalledFunction()) 1503 if (CallArgOpIsGTId(*Callee, U.getOperandNo(), *CI)) 1504 GTIdArgs.insert(Callee->getArg(U.getOperandNo())); 1505 }; 1506 1507 // The argument users of __kmpc_global_thread_num calls are GTIds. 1508 OMPInformationCache::RuntimeFunctionInfo &GlobThreadNumRFI = 1509 OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num]; 1510 1511 GlobThreadNumRFI.foreachUse(SCC, [&](Use &U, Function &F) { 1512 if (CallInst *CI = getCallIfRegularCall(U, &GlobThreadNumRFI)) 1513 AddUserArgs(*CI); 1514 return false; 1515 }); 1516 1517 // Transitively search for more arguments by looking at the users of the 1518 // ones we know already. During the search the GTIdArgs vector is extended 1519 // so we cannot cache the size nor can we use a range based for. 1520 for (unsigned u = 0; u < GTIdArgs.size(); ++u) 1521 AddUserArgs(*GTIdArgs[u]); 1522 } 1523 1524 /// Kernel (=GPU) optimizations and utility functions 1525 /// 1526 ///{{ 1527 1528 /// Check if \p F is a kernel, hence entry point for target offloading. 1529 bool isKernel(Function &F) { return OMPInfoCache.Kernels.count(&F); } 1530 1531 /// Cache to remember the unique kernel for a function. 1532 DenseMap<Function *, Optional<Kernel>> UniqueKernelMap; 1533 1534 /// Find the unique kernel that will execute \p F, if any. 1535 Kernel getUniqueKernelFor(Function &F); 1536 1537 /// Find the unique kernel that will execute \p I, if any. 1538 Kernel getUniqueKernelFor(Instruction &I) { 1539 return getUniqueKernelFor(*I.getFunction()); 1540 } 1541 1542 /// Rewrite the device (=GPU) code state machine create in non-SPMD mode in 1543 /// the cases we can avoid taking the address of a function. 1544 bool rewriteDeviceCodeStateMachine(); 1545 1546 /// 1547 ///}} 1548 1549 /// Emit a remark generically 1550 /// 1551 /// This template function can be used to generically emit a remark. The 1552 /// RemarkKind should be one of the following: 1553 /// - OptimizationRemark to indicate a successful optimization attempt 1554 /// - OptimizationRemarkMissed to report a failed optimization attempt 1555 /// - OptimizationRemarkAnalysis to provide additional information about an 1556 /// optimization attempt 1557 /// 1558 /// The remark is built using a callback function provided by the caller that 1559 /// takes a RemarkKind as input and returns a RemarkKind. 1560 template <typename RemarkKind, typename RemarkCallBack> 1561 void emitRemark(Instruction *I, StringRef RemarkName, 1562 RemarkCallBack &&RemarkCB) const { 1563 Function *F = I->getParent()->getParent(); 1564 auto &ORE = OREGetter(F); 1565 1566 ORE.emit([&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I)); }); 1567 } 1568 1569 /// Emit a remark on a function. 1570 template <typename RemarkKind, typename RemarkCallBack> 1571 void emitRemark(Function *F, StringRef RemarkName, 1572 RemarkCallBack &&RemarkCB) const { 1573 auto &ORE = OREGetter(F); 1574 1575 ORE.emit([&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F)); }); 1576 } 1577 1578 /// The underlying module. 1579 Module &M; 1580 1581 /// The SCC we are operating on. 1582 SmallVectorImpl<Function *> &SCC; 1583 1584 /// Callback to update the call graph, the first argument is a removed call, 1585 /// the second an optional replacement call. 1586 CallGraphUpdater &CGUpdater; 1587 1588 /// Callback to get an OptimizationRemarkEmitter from a Function * 1589 OptimizationRemarkGetter OREGetter; 1590 1591 /// OpenMP-specific information cache. Also Used for Attributor runs. 1592 OMPInformationCache &OMPInfoCache; 1593 1594 /// Attributor instance. 1595 Attributor &A; 1596 1597 /// Helper function to run Attributor on SCC. 1598 bool runAttributor() { 1599 if (SCC.empty()) 1600 return false; 1601 1602 registerAAs(); 1603 1604 ChangeStatus Changed = A.run(); 1605 1606 LLVM_DEBUG(dbgs() << "[Attributor] Done with " << SCC.size() 1607 << " functions, result: " << Changed << ".\n"); 1608 1609 return Changed == ChangeStatus::CHANGED; 1610 } 1611 1612 /// Populate the Attributor with abstract attribute opportunities in the 1613 /// function. 1614 void registerAAs() { 1615 if (SCC.empty()) 1616 return; 1617 1618 // Create CallSite AA for all Getters. 1619 for (int Idx = 0; Idx < OMPInfoCache.ICVs.size() - 1; ++Idx) { 1620 auto ICVInfo = OMPInfoCache.ICVs[static_cast<InternalControlVar>(Idx)]; 1621 1622 auto &GetterRFI = OMPInfoCache.RFIs[ICVInfo.Getter]; 1623 1624 auto CreateAA = [&](Use &U, Function &Caller) { 1625 CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &GetterRFI); 1626 if (!CI) 1627 return false; 1628 1629 auto &CB = cast<CallBase>(*CI); 1630 1631 IRPosition CBPos = IRPosition::callsite_function(CB); 1632 A.getOrCreateAAFor<AAICVTracker>(CBPos); 1633 return false; 1634 }; 1635 1636 GetterRFI.foreachUse(SCC, CreateAA); 1637 } 1638 1639 for (auto &F : M) { 1640 if (!F.isDeclaration()) 1641 A.getOrCreateAAFor<AAExecutionDomain>(IRPosition::function(F)); 1642 } 1643 } 1644 }; 1645 1646 Kernel OpenMPOpt::getUniqueKernelFor(Function &F) { 1647 if (!OMPInfoCache.ModuleSlice.count(&F)) 1648 return nullptr; 1649 1650 // Use a scope to keep the lifetime of the CachedKernel short. 1651 { 1652 Optional<Kernel> &CachedKernel = UniqueKernelMap[&F]; 1653 if (CachedKernel) 1654 return *CachedKernel; 1655 1656 // TODO: We should use an AA to create an (optimistic and callback 1657 // call-aware) call graph. For now we stick to simple patterns that 1658 // are less powerful, basically the worst fixpoint. 1659 if (isKernel(F)) { 1660 CachedKernel = Kernel(&F); 1661 return *CachedKernel; 1662 } 1663 1664 CachedKernel = nullptr; 1665 if (!F.hasLocalLinkage()) { 1666 1667 // See https://openmp.llvm.org/remarks/OptimizationRemarks.html 1668 auto Remark = [&](OptimizationRemarkAnalysis ORA) { 1669 return ORA 1670 << "[OMP100] Potentially unknown OpenMP target region caller"; 1671 }; 1672 emitRemark<OptimizationRemarkAnalysis>(&F, "OMP100", Remark); 1673 1674 return nullptr; 1675 } 1676 } 1677 1678 auto GetUniqueKernelForUse = [&](const Use &U) -> Kernel { 1679 if (auto *Cmp = dyn_cast<ICmpInst>(U.getUser())) { 1680 // Allow use in equality comparisons. 1681 if (Cmp->isEquality()) 1682 return getUniqueKernelFor(*Cmp); 1683 return nullptr; 1684 } 1685 if (auto *CB = dyn_cast<CallBase>(U.getUser())) { 1686 // Allow direct calls. 1687 if (CB->isCallee(&U)) 1688 return getUniqueKernelFor(*CB); 1689 1690 OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI = 1691 OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51]; 1692 // Allow the use in __kmpc_parallel_51 calls. 1693 if (OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI)) 1694 return getUniqueKernelFor(*CB); 1695 return nullptr; 1696 } 1697 // Disallow every other use. 1698 return nullptr; 1699 }; 1700 1701 // TODO: In the future we want to track more than just a unique kernel. 1702 SmallPtrSet<Kernel, 2> PotentialKernels; 1703 OMPInformationCache::foreachUse(F, [&](const Use &U) { 1704 PotentialKernels.insert(GetUniqueKernelForUse(U)); 1705 }); 1706 1707 Kernel K = nullptr; 1708 if (PotentialKernels.size() == 1) 1709 K = *PotentialKernels.begin(); 1710 1711 // Cache the result. 1712 UniqueKernelMap[&F] = K; 1713 1714 return K; 1715 } 1716 1717 bool OpenMPOpt::rewriteDeviceCodeStateMachine() { 1718 OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI = 1719 OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51]; 1720 1721 bool Changed = false; 1722 if (!KernelParallelRFI) 1723 return Changed; 1724 1725 for (Function *F : SCC) { 1726 1727 // Check if the function is a use in a __kmpc_parallel_51 call at 1728 // all. 1729 bool UnknownUse = false; 1730 bool KernelParallelUse = false; 1731 unsigned NumDirectCalls = 0; 1732 1733 SmallVector<Use *, 2> ToBeReplacedStateMachineUses; 1734 OMPInformationCache::foreachUse(*F, [&](Use &U) { 1735 if (auto *CB = dyn_cast<CallBase>(U.getUser())) 1736 if (CB->isCallee(&U)) { 1737 ++NumDirectCalls; 1738 return; 1739 } 1740 1741 if (isa<ICmpInst>(U.getUser())) { 1742 ToBeReplacedStateMachineUses.push_back(&U); 1743 return; 1744 } 1745 1746 // Find wrapper functions that represent parallel kernels. 1747 CallInst *CI = 1748 OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI); 1749 const unsigned int WrapperFunctionArgNo = 6; 1750 if (!KernelParallelUse && CI && 1751 CI->getArgOperandNo(&U) == WrapperFunctionArgNo) { 1752 KernelParallelUse = true; 1753 ToBeReplacedStateMachineUses.push_back(&U); 1754 return; 1755 } 1756 UnknownUse = true; 1757 }); 1758 1759 // Do not emit a remark if we haven't seen a __kmpc_parallel_51 1760 // use. 1761 if (!KernelParallelUse) 1762 continue; 1763 1764 { 1765 auto Remark = [&](OptimizationRemarkAnalysis ORA) { 1766 return ORA << "Found a parallel region that is called in a target " 1767 "region but not part of a combined target construct nor " 1768 "nested inside a target construct without intermediate " 1769 "code. This can lead to excessive register usage for " 1770 "unrelated target regions in the same translation unit " 1771 "due to spurious call edges assumed by ptxas."; 1772 }; 1773 emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPParallelRegionInNonSPMD", 1774 Remark); 1775 } 1776 1777 // If this ever hits, we should investigate. 1778 // TODO: Checking the number of uses is not a necessary restriction and 1779 // should be lifted. 1780 if (UnknownUse || NumDirectCalls != 1 || 1781 ToBeReplacedStateMachineUses.size() != 2) { 1782 { 1783 auto Remark = [&](OptimizationRemarkAnalysis ORA) { 1784 return ORA << "Parallel region is used in " 1785 << (UnknownUse ? "unknown" : "unexpected") 1786 << " ways; will not attempt to rewrite the state machine."; 1787 }; 1788 emitRemark<OptimizationRemarkAnalysis>( 1789 F, "OpenMPParallelRegionInNonSPMD", Remark); 1790 } 1791 continue; 1792 } 1793 1794 // Even if we have __kmpc_parallel_51 calls, we (for now) give 1795 // up if the function is not called from a unique kernel. 1796 Kernel K = getUniqueKernelFor(*F); 1797 if (!K) { 1798 { 1799 auto Remark = [&](OptimizationRemarkAnalysis ORA) { 1800 return ORA << "Parallel region is not known to be called from a " 1801 "unique single target region, maybe the surrounding " 1802 "function has external linkage?; will not attempt to " 1803 "rewrite the state machine use."; 1804 }; 1805 emitRemark<OptimizationRemarkAnalysis>( 1806 F, "OpenMPParallelRegionInMultipleKernesl", Remark); 1807 } 1808 continue; 1809 } 1810 1811 // We now know F is a parallel body function called only from the kernel K. 1812 // We also identified the state machine uses in which we replace the 1813 // function pointer by a new global symbol for identification purposes. This 1814 // ensures only direct calls to the function are left. 1815 1816 { 1817 auto RemarkParalleRegion = [&](OptimizationRemarkAnalysis ORA) { 1818 return ORA << "Specialize parallel region that is only reached from a " 1819 "single target region to avoid spurious call edges and " 1820 "excessive register usage in other target regions. " 1821 "(parallel region ID: " 1822 << ore::NV("OpenMPParallelRegion", F->getName()) 1823 << ", kernel ID: " 1824 << ore::NV("OpenMPTargetRegion", K->getName()) << ")"; 1825 }; 1826 emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPParallelRegionInNonSPMD", 1827 RemarkParalleRegion); 1828 auto RemarkKernel = [&](OptimizationRemarkAnalysis ORA) { 1829 return ORA << "Target region containing the parallel region that is " 1830 "specialized. (parallel region ID: " 1831 << ore::NV("OpenMPParallelRegion", F->getName()) 1832 << ", kernel ID: " 1833 << ore::NV("OpenMPTargetRegion", K->getName()) << ")"; 1834 }; 1835 emitRemark<OptimizationRemarkAnalysis>(K, "OpenMPParallelRegionInNonSPMD", 1836 RemarkKernel); 1837 } 1838 1839 Module &M = *F->getParent(); 1840 Type *Int8Ty = Type::getInt8Ty(M.getContext()); 1841 1842 auto *ID = new GlobalVariable( 1843 M, Int8Ty, /* isConstant */ true, GlobalValue::PrivateLinkage, 1844 UndefValue::get(Int8Ty), F->getName() + ".ID"); 1845 1846 for (Use *U : ToBeReplacedStateMachineUses) 1847 U->set(ConstantExpr::getBitCast(ID, U->get()->getType())); 1848 1849 ++NumOpenMPParallelRegionsReplacedInGPUStateMachine; 1850 1851 Changed = true; 1852 } 1853 1854 return Changed; 1855 } 1856 1857 /// Abstract Attribute for tracking ICV values. 1858 struct AAICVTracker : public StateWrapper<BooleanState, AbstractAttribute> { 1859 using Base = StateWrapper<BooleanState, AbstractAttribute>; 1860 AAICVTracker(const IRPosition &IRP, Attributor &A) : Base(IRP) {} 1861 1862 void initialize(Attributor &A) override { 1863 Function *F = getAnchorScope(); 1864 if (!F || !A.isFunctionIPOAmendable(*F)) 1865 indicatePessimisticFixpoint(); 1866 } 1867 1868 /// Returns true if value is assumed to be tracked. 1869 bool isAssumedTracked() const { return getAssumed(); } 1870 1871 /// Returns true if value is known to be tracked. 1872 bool isKnownTracked() const { return getAssumed(); } 1873 1874 /// Create an abstract attribute biew for the position \p IRP. 1875 static AAICVTracker &createForPosition(const IRPosition &IRP, Attributor &A); 1876 1877 /// Return the value with which \p I can be replaced for specific \p ICV. 1878 virtual Optional<Value *> getReplacementValue(InternalControlVar ICV, 1879 const Instruction *I, 1880 Attributor &A) const { 1881 return None; 1882 } 1883 1884 /// Return an assumed unique ICV value if a single candidate is found. If 1885 /// there cannot be one, return a nullptr. If it is not clear yet, return the 1886 /// Optional::NoneType. 1887 virtual Optional<Value *> 1888 getUniqueReplacementValue(InternalControlVar ICV) const = 0; 1889 1890 // Currently only nthreads is being tracked. 1891 // this array will only grow with time. 1892 InternalControlVar TrackableICVs[1] = {ICV_nthreads}; 1893 1894 /// See AbstractAttribute::getName() 1895 const std::string getName() const override { return "AAICVTracker"; } 1896 1897 /// See AbstractAttribute::getIdAddr() 1898 const char *getIdAddr() const override { return &ID; } 1899 1900 /// This function should return true if the type of the \p AA is AAICVTracker 1901 static bool classof(const AbstractAttribute *AA) { 1902 return (AA->getIdAddr() == &ID); 1903 } 1904 1905 static const char ID; 1906 }; 1907 1908 struct AAICVTrackerFunction : public AAICVTracker { 1909 AAICVTrackerFunction(const IRPosition &IRP, Attributor &A) 1910 : AAICVTracker(IRP, A) {} 1911 1912 // FIXME: come up with better string. 1913 const std::string getAsStr() const override { return "ICVTrackerFunction"; } 1914 1915 // FIXME: come up with some stats. 1916 void trackStatistics() const override {} 1917 1918 /// We don't manifest anything for this AA. 1919 ChangeStatus manifest(Attributor &A) override { 1920 return ChangeStatus::UNCHANGED; 1921 } 1922 1923 // Map of ICV to their values at specific program point. 1924 EnumeratedArray<DenseMap<Instruction *, Value *>, InternalControlVar, 1925 InternalControlVar::ICV___last> 1926 ICVReplacementValuesMap; 1927 1928 ChangeStatus updateImpl(Attributor &A) override { 1929 ChangeStatus HasChanged = ChangeStatus::UNCHANGED; 1930 1931 Function *F = getAnchorScope(); 1932 1933 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); 1934 1935 for (InternalControlVar ICV : TrackableICVs) { 1936 auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter]; 1937 1938 auto &ValuesMap = ICVReplacementValuesMap[ICV]; 1939 auto TrackValues = [&](Use &U, Function &) { 1940 CallInst *CI = OpenMPOpt::getCallIfRegularCall(U); 1941 if (!CI) 1942 return false; 1943 1944 // FIXME: handle setters with more that 1 arguments. 1945 /// Track new value. 1946 if (ValuesMap.insert(std::make_pair(CI, CI->getArgOperand(0))).second) 1947 HasChanged = ChangeStatus::CHANGED; 1948 1949 return false; 1950 }; 1951 1952 auto CallCheck = [&](Instruction &I) { 1953 Optional<Value *> ReplVal = getValueForCall(A, &I, ICV); 1954 if (ReplVal.hasValue() && 1955 ValuesMap.insert(std::make_pair(&I, *ReplVal)).second) 1956 HasChanged = ChangeStatus::CHANGED; 1957 1958 return true; 1959 }; 1960 1961 // Track all changes of an ICV. 1962 SetterRFI.foreachUse(TrackValues, F); 1963 1964 A.checkForAllInstructions(CallCheck, *this, {Instruction::Call}, 1965 /* CheckBBLivenessOnly */ true); 1966 1967 /// TODO: Figure out a way to avoid adding entry in 1968 /// ICVReplacementValuesMap 1969 Instruction *Entry = &F->getEntryBlock().front(); 1970 if (HasChanged == ChangeStatus::CHANGED && !ValuesMap.count(Entry)) 1971 ValuesMap.insert(std::make_pair(Entry, nullptr)); 1972 } 1973 1974 return HasChanged; 1975 } 1976 1977 /// Hepler to check if \p I is a call and get the value for it if it is 1978 /// unique. 1979 Optional<Value *> getValueForCall(Attributor &A, const Instruction *I, 1980 InternalControlVar &ICV) const { 1981 1982 const auto *CB = dyn_cast<CallBase>(I); 1983 if (!CB || CB->hasFnAttr("no_openmp") || 1984 CB->hasFnAttr("no_openmp_routines")) 1985 return None; 1986 1987 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); 1988 auto &GetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Getter]; 1989 auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter]; 1990 Function *CalledFunction = CB->getCalledFunction(); 1991 1992 // Indirect call, assume ICV changes. 1993 if (CalledFunction == nullptr) 1994 return nullptr; 1995 if (CalledFunction == GetterRFI.Declaration) 1996 return None; 1997 if (CalledFunction == SetterRFI.Declaration) { 1998 if (ICVReplacementValuesMap[ICV].count(I)) 1999 return ICVReplacementValuesMap[ICV].lookup(I); 2000 2001 return nullptr; 2002 } 2003 2004 // Since we don't know, assume it changes the ICV. 2005 if (CalledFunction->isDeclaration()) 2006 return nullptr; 2007 2008 const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>( 2009 *this, IRPosition::callsite_returned(*CB), DepClassTy::REQUIRED); 2010 2011 if (ICVTrackingAA.isAssumedTracked()) 2012 return ICVTrackingAA.getUniqueReplacementValue(ICV); 2013 2014 // If we don't know, assume it changes. 2015 return nullptr; 2016 } 2017 2018 // We don't check unique value for a function, so return None. 2019 Optional<Value *> 2020 getUniqueReplacementValue(InternalControlVar ICV) const override { 2021 return None; 2022 } 2023 2024 /// Return the value with which \p I can be replaced for specific \p ICV. 2025 Optional<Value *> getReplacementValue(InternalControlVar ICV, 2026 const Instruction *I, 2027 Attributor &A) const override { 2028 const auto &ValuesMap = ICVReplacementValuesMap[ICV]; 2029 if (ValuesMap.count(I)) 2030 return ValuesMap.lookup(I); 2031 2032 SmallVector<const Instruction *, 16> Worklist; 2033 SmallPtrSet<const Instruction *, 16> Visited; 2034 Worklist.push_back(I); 2035 2036 Optional<Value *> ReplVal; 2037 2038 while (!Worklist.empty()) { 2039 const Instruction *CurrInst = Worklist.pop_back_val(); 2040 if (!Visited.insert(CurrInst).second) 2041 continue; 2042 2043 const BasicBlock *CurrBB = CurrInst->getParent(); 2044 2045 // Go up and look for all potential setters/calls that might change the 2046 // ICV. 2047 while ((CurrInst = CurrInst->getPrevNode())) { 2048 if (ValuesMap.count(CurrInst)) { 2049 Optional<Value *> NewReplVal = ValuesMap.lookup(CurrInst); 2050 // Unknown value, track new. 2051 if (!ReplVal.hasValue()) { 2052 ReplVal = NewReplVal; 2053 break; 2054 } 2055 2056 // If we found a new value, we can't know the icv value anymore. 2057 if (NewReplVal.hasValue()) 2058 if (ReplVal != NewReplVal) 2059 return nullptr; 2060 2061 break; 2062 } 2063 2064 Optional<Value *> NewReplVal = getValueForCall(A, CurrInst, ICV); 2065 if (!NewReplVal.hasValue()) 2066 continue; 2067 2068 // Unknown value, track new. 2069 if (!ReplVal.hasValue()) { 2070 ReplVal = NewReplVal; 2071 break; 2072 } 2073 2074 // if (NewReplVal.hasValue()) 2075 // We found a new value, we can't know the icv value anymore. 2076 if (ReplVal != NewReplVal) 2077 return nullptr; 2078 } 2079 2080 // If we are in the same BB and we have a value, we are done. 2081 if (CurrBB == I->getParent() && ReplVal.hasValue()) 2082 return ReplVal; 2083 2084 // Go through all predecessors and add terminators for analysis. 2085 for (const BasicBlock *Pred : predecessors(CurrBB)) 2086 if (const Instruction *Terminator = Pred->getTerminator()) 2087 Worklist.push_back(Terminator); 2088 } 2089 2090 return ReplVal; 2091 } 2092 }; 2093 2094 struct AAICVTrackerFunctionReturned : AAICVTracker { 2095 AAICVTrackerFunctionReturned(const IRPosition &IRP, Attributor &A) 2096 : AAICVTracker(IRP, A) {} 2097 2098 // FIXME: come up with better string. 2099 const std::string getAsStr() const override { 2100 return "ICVTrackerFunctionReturned"; 2101 } 2102 2103 // FIXME: come up with some stats. 2104 void trackStatistics() const override {} 2105 2106 /// We don't manifest anything for this AA. 2107 ChangeStatus manifest(Attributor &A) override { 2108 return ChangeStatus::UNCHANGED; 2109 } 2110 2111 // Map of ICV to their values at specific program point. 2112 EnumeratedArray<Optional<Value *>, InternalControlVar, 2113 InternalControlVar::ICV___last> 2114 ICVReplacementValuesMap; 2115 2116 /// Return the value with which \p I can be replaced for specific \p ICV. 2117 Optional<Value *> 2118 getUniqueReplacementValue(InternalControlVar ICV) const override { 2119 return ICVReplacementValuesMap[ICV]; 2120 } 2121 2122 ChangeStatus updateImpl(Attributor &A) override { 2123 ChangeStatus Changed = ChangeStatus::UNCHANGED; 2124 const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>( 2125 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED); 2126 2127 if (!ICVTrackingAA.isAssumedTracked()) 2128 return indicatePessimisticFixpoint(); 2129 2130 for (InternalControlVar ICV : TrackableICVs) { 2131 Optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV]; 2132 Optional<Value *> UniqueICVValue; 2133 2134 auto CheckReturnInst = [&](Instruction &I) { 2135 Optional<Value *> NewReplVal = 2136 ICVTrackingAA.getReplacementValue(ICV, &I, A); 2137 2138 // If we found a second ICV value there is no unique returned value. 2139 if (UniqueICVValue.hasValue() && UniqueICVValue != NewReplVal) 2140 return false; 2141 2142 UniqueICVValue = NewReplVal; 2143 2144 return true; 2145 }; 2146 2147 if (!A.checkForAllInstructions(CheckReturnInst, *this, {Instruction::Ret}, 2148 /* CheckBBLivenessOnly */ true)) 2149 UniqueICVValue = nullptr; 2150 2151 if (UniqueICVValue == ReplVal) 2152 continue; 2153 2154 ReplVal = UniqueICVValue; 2155 Changed = ChangeStatus::CHANGED; 2156 } 2157 2158 return Changed; 2159 } 2160 }; 2161 2162 struct AAICVTrackerCallSite : AAICVTracker { 2163 AAICVTrackerCallSite(const IRPosition &IRP, Attributor &A) 2164 : AAICVTracker(IRP, A) {} 2165 2166 void initialize(Attributor &A) override { 2167 Function *F = getAnchorScope(); 2168 if (!F || !A.isFunctionIPOAmendable(*F)) 2169 indicatePessimisticFixpoint(); 2170 2171 // We only initialize this AA for getters, so we need to know which ICV it 2172 // gets. 2173 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); 2174 for (InternalControlVar ICV : TrackableICVs) { 2175 auto ICVInfo = OMPInfoCache.ICVs[ICV]; 2176 auto &Getter = OMPInfoCache.RFIs[ICVInfo.Getter]; 2177 if (Getter.Declaration == getAssociatedFunction()) { 2178 AssociatedICV = ICVInfo.Kind; 2179 return; 2180 } 2181 } 2182 2183 /// Unknown ICV. 2184 indicatePessimisticFixpoint(); 2185 } 2186 2187 ChangeStatus manifest(Attributor &A) override { 2188 if (!ReplVal.hasValue() || !ReplVal.getValue()) 2189 return ChangeStatus::UNCHANGED; 2190 2191 A.changeValueAfterManifest(*getCtxI(), **ReplVal); 2192 A.deleteAfterManifest(*getCtxI()); 2193 2194 return ChangeStatus::CHANGED; 2195 } 2196 2197 // FIXME: come up with better string. 2198 const std::string getAsStr() const override { return "ICVTrackerCallSite"; } 2199 2200 // FIXME: come up with some stats. 2201 void trackStatistics() const override {} 2202 2203 InternalControlVar AssociatedICV; 2204 Optional<Value *> ReplVal; 2205 2206 ChangeStatus updateImpl(Attributor &A) override { 2207 const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>( 2208 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED); 2209 2210 // We don't have any information, so we assume it changes the ICV. 2211 if (!ICVTrackingAA.isAssumedTracked()) 2212 return indicatePessimisticFixpoint(); 2213 2214 Optional<Value *> NewReplVal = 2215 ICVTrackingAA.getReplacementValue(AssociatedICV, getCtxI(), A); 2216 2217 if (ReplVal == NewReplVal) 2218 return ChangeStatus::UNCHANGED; 2219 2220 ReplVal = NewReplVal; 2221 return ChangeStatus::CHANGED; 2222 } 2223 2224 // Return the value with which associated value can be replaced for specific 2225 // \p ICV. 2226 Optional<Value *> 2227 getUniqueReplacementValue(InternalControlVar ICV) const override { 2228 return ReplVal; 2229 } 2230 }; 2231 2232 struct AAICVTrackerCallSiteReturned : AAICVTracker { 2233 AAICVTrackerCallSiteReturned(const IRPosition &IRP, Attributor &A) 2234 : AAICVTracker(IRP, A) {} 2235 2236 // FIXME: come up with better string. 2237 const std::string getAsStr() const override { 2238 return "ICVTrackerCallSiteReturned"; 2239 } 2240 2241 // FIXME: come up with some stats. 2242 void trackStatistics() const override {} 2243 2244 /// We don't manifest anything for this AA. 2245 ChangeStatus manifest(Attributor &A) override { 2246 return ChangeStatus::UNCHANGED; 2247 } 2248 2249 // Map of ICV to their values at specific program point. 2250 EnumeratedArray<Optional<Value *>, InternalControlVar, 2251 InternalControlVar::ICV___last> 2252 ICVReplacementValuesMap; 2253 2254 /// Return the value with which associated value can be replaced for specific 2255 /// \p ICV. 2256 Optional<Value *> 2257 getUniqueReplacementValue(InternalControlVar ICV) const override { 2258 return ICVReplacementValuesMap[ICV]; 2259 } 2260 2261 ChangeStatus updateImpl(Attributor &A) override { 2262 ChangeStatus Changed = ChangeStatus::UNCHANGED; 2263 const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>( 2264 *this, IRPosition::returned(*getAssociatedFunction()), 2265 DepClassTy::REQUIRED); 2266 2267 // We don't have any information, so we assume it changes the ICV. 2268 if (!ICVTrackingAA.isAssumedTracked()) 2269 return indicatePessimisticFixpoint(); 2270 2271 for (InternalControlVar ICV : TrackableICVs) { 2272 Optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV]; 2273 Optional<Value *> NewReplVal = 2274 ICVTrackingAA.getUniqueReplacementValue(ICV); 2275 2276 if (ReplVal == NewReplVal) 2277 continue; 2278 2279 ReplVal = NewReplVal; 2280 Changed = ChangeStatus::CHANGED; 2281 } 2282 return Changed; 2283 } 2284 }; 2285 2286 struct AAExecutionDomainFunction : public AAExecutionDomain { 2287 AAExecutionDomainFunction(const IRPosition &IRP, Attributor &A) 2288 : AAExecutionDomain(IRP, A) {} 2289 2290 const std::string getAsStr() const override { 2291 return "[AAExecutionDomain] " + std::to_string(SingleThreadedBBs.size()) + 2292 "/" + std::to_string(NumBBs) + " BBs thread 0 only."; 2293 } 2294 2295 /// See AbstractAttribute::trackStatistics(). 2296 void trackStatistics() const override {} 2297 2298 void initialize(Attributor &A) override { 2299 Function *F = getAnchorScope(); 2300 for (const auto &BB : *F) 2301 SingleThreadedBBs.insert(&BB); 2302 NumBBs = SingleThreadedBBs.size(); 2303 } 2304 2305 ChangeStatus manifest(Attributor &A) override { 2306 LLVM_DEBUG({ 2307 for (const BasicBlock *BB : SingleThreadedBBs) 2308 dbgs() << TAG << " Basic block @" << getAnchorScope()->getName() << " " 2309 << BB->getName() << " is executed by a single thread.\n"; 2310 }); 2311 return ChangeStatus::UNCHANGED; 2312 } 2313 2314 ChangeStatus updateImpl(Attributor &A) override; 2315 2316 /// Check if an instruction is executed by a single thread. 2317 bool isSingleThreadExecution(const Instruction &I) const override { 2318 return isSingleThreadExecution(*I.getParent()); 2319 } 2320 2321 bool isSingleThreadExecution(const BasicBlock &BB) const override { 2322 return SingleThreadedBBs.contains(&BB); 2323 } 2324 2325 /// Set of basic blocks that are executed by a single thread. 2326 DenseSet<const BasicBlock *> SingleThreadedBBs; 2327 2328 /// Total number of basic blocks in this function. 2329 long unsigned NumBBs; 2330 }; 2331 2332 ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) { 2333 Function *F = getAnchorScope(); 2334 ReversePostOrderTraversal<Function *> RPOT(F); 2335 auto NumSingleThreadedBBs = SingleThreadedBBs.size(); 2336 2337 bool AllCallSitesKnown; 2338 auto PredForCallSite = [&](AbstractCallSite ACS) { 2339 const auto &ExecutionDomainAA = A.getAAFor<AAExecutionDomain>( 2340 *this, IRPosition::function(*ACS.getInstruction()->getFunction()), 2341 DepClassTy::REQUIRED); 2342 return ExecutionDomainAA.isSingleThreadExecution(*ACS.getInstruction()); 2343 }; 2344 2345 if (!A.checkForAllCallSites(PredForCallSite, *this, 2346 /* RequiresAllCallSites */ true, 2347 AllCallSitesKnown)) 2348 SingleThreadedBBs.erase(&F->getEntryBlock()); 2349 2350 // Check if the edge into the successor block compares a thread-id function to 2351 // a constant zero. 2352 // TODO: Use AAValueSimplify to simplify and propogate constants. 2353 // TODO: Check more than a single use for thread ID's. 2354 auto IsSingleThreadOnly = [&](BranchInst *Edge, BasicBlock *SuccessorBB) { 2355 if (!Edge || !Edge->isConditional()) 2356 return false; 2357 if (Edge->getSuccessor(0) != SuccessorBB) 2358 return false; 2359 2360 auto *Cmp = dyn_cast<CmpInst>(Edge->getCondition()); 2361 if (!Cmp || !Cmp->isTrueWhenEqual() || !Cmp->isEquality()) 2362 return false; 2363 2364 ConstantInt *C = dyn_cast<ConstantInt>(Cmp->getOperand(1)); 2365 if (!C || !C->isZero()) 2366 return false; 2367 2368 if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0))) 2369 if (II->getIntrinsicID() == Intrinsic::nvvm_read_ptx_sreg_tid_x) 2370 return true; 2371 if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0))) 2372 if (II->getIntrinsicID() == Intrinsic::amdgcn_workitem_id_x) 2373 return true; 2374 2375 return false; 2376 }; 2377 2378 // Merge all the predecessor states into the current basic block. A basic 2379 // block is executed by a single thread if all of its predecessors are. 2380 auto MergePredecessorStates = [&](BasicBlock *BB) { 2381 if (pred_begin(BB) == pred_end(BB)) 2382 return SingleThreadedBBs.contains(BB); 2383 2384 bool IsSingleThreaded = true; 2385 for (auto PredBB = pred_begin(BB), PredEndBB = pred_end(BB); 2386 PredBB != PredEndBB; ++PredBB) { 2387 if (!IsSingleThreadOnly(dyn_cast<BranchInst>((*PredBB)->getTerminator()), 2388 BB)) 2389 IsSingleThreaded &= SingleThreadedBBs.contains(*PredBB); 2390 } 2391 2392 return IsSingleThreaded; 2393 }; 2394 2395 for (auto *BB : RPOT) { 2396 if (!MergePredecessorStates(BB)) 2397 SingleThreadedBBs.erase(BB); 2398 } 2399 2400 return (NumSingleThreadedBBs == SingleThreadedBBs.size()) 2401 ? ChangeStatus::UNCHANGED 2402 : ChangeStatus::CHANGED; 2403 } 2404 2405 } // namespace 2406 2407 const char AAICVTracker::ID = 0; 2408 const char AAExecutionDomain::ID = 0; 2409 2410 AAICVTracker &AAICVTracker::createForPosition(const IRPosition &IRP, 2411 Attributor &A) { 2412 AAICVTracker *AA = nullptr; 2413 switch (IRP.getPositionKind()) { 2414 case IRPosition::IRP_INVALID: 2415 case IRPosition::IRP_FLOAT: 2416 case IRPosition::IRP_ARGUMENT: 2417 case IRPosition::IRP_CALL_SITE_ARGUMENT: 2418 llvm_unreachable("ICVTracker can only be created for function position!"); 2419 case IRPosition::IRP_RETURNED: 2420 AA = new (A.Allocator) AAICVTrackerFunctionReturned(IRP, A); 2421 break; 2422 case IRPosition::IRP_CALL_SITE_RETURNED: 2423 AA = new (A.Allocator) AAICVTrackerCallSiteReturned(IRP, A); 2424 break; 2425 case IRPosition::IRP_CALL_SITE: 2426 AA = new (A.Allocator) AAICVTrackerCallSite(IRP, A); 2427 break; 2428 case IRPosition::IRP_FUNCTION: 2429 AA = new (A.Allocator) AAICVTrackerFunction(IRP, A); 2430 break; 2431 } 2432 2433 return *AA; 2434 } 2435 2436 AAExecutionDomain &AAExecutionDomain::createForPosition(const IRPosition &IRP, 2437 Attributor &A) { 2438 AAExecutionDomainFunction *AA = nullptr; 2439 switch (IRP.getPositionKind()) { 2440 case IRPosition::IRP_INVALID: 2441 case IRPosition::IRP_FLOAT: 2442 case IRPosition::IRP_ARGUMENT: 2443 case IRPosition::IRP_CALL_SITE_ARGUMENT: 2444 case IRPosition::IRP_RETURNED: 2445 case IRPosition::IRP_CALL_SITE_RETURNED: 2446 case IRPosition::IRP_CALL_SITE: 2447 llvm_unreachable( 2448 "AAExecutionDomain can only be created for function position!"); 2449 case IRPosition::IRP_FUNCTION: 2450 AA = new (A.Allocator) AAExecutionDomainFunction(IRP, A); 2451 break; 2452 } 2453 2454 return *AA; 2455 } 2456 2457 PreservedAnalyses OpenMPOptPass::run(Module &M, ModuleAnalysisManager &AM) { 2458 if (!containsOpenMP(M, OMPInModule)) 2459 return PreservedAnalyses::all(); 2460 2461 if (DisableOpenMPOptimizations) 2462 return PreservedAnalyses::all(); 2463 2464 // Look at every function definition in the Module. 2465 SmallVector<Function *, 16> SCC; 2466 for (Function &Fn : M) 2467 if (!Fn.isDeclaration()) 2468 SCC.push_back(&Fn); 2469 2470 if (SCC.empty()) 2471 return PreservedAnalyses::all(); 2472 2473 FunctionAnalysisManager &FAM = 2474 AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); 2475 2476 AnalysisGetter AG(FAM); 2477 2478 auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & { 2479 return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F); 2480 }; 2481 2482 BumpPtrAllocator Allocator; 2483 CallGraphUpdater CGUpdater; 2484 2485 SetVector<Function *> Functions(SCC.begin(), SCC.end()); 2486 OMPInformationCache InfoCache(M, AG, Allocator, /*CGSCC*/ Functions, 2487 OMPInModule.getKernels()); 2488 2489 Attributor A(Functions, InfoCache, CGUpdater); 2490 2491 OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A); 2492 bool Changed = OMPOpt.run(true); 2493 if (Changed) 2494 return PreservedAnalyses::none(); 2495 2496 return PreservedAnalyses::all(); 2497 } 2498 2499 PreservedAnalyses OpenMPOptCGSCCPass::run(LazyCallGraph::SCC &C, 2500 CGSCCAnalysisManager &AM, 2501 LazyCallGraph &CG, 2502 CGSCCUpdateResult &UR) { 2503 if (!containsOpenMP(*C.begin()->getFunction().getParent(), OMPInModule)) 2504 return PreservedAnalyses::all(); 2505 2506 if (DisableOpenMPOptimizations) 2507 return PreservedAnalyses::all(); 2508 2509 SmallVector<Function *, 16> SCC; 2510 // If there are kernels in the module, we have to run on all SCC's. 2511 bool SCCIsInteresting = !OMPInModule.getKernels().empty(); 2512 for (LazyCallGraph::Node &N : C) { 2513 Function *Fn = &N.getFunction(); 2514 SCC.push_back(Fn); 2515 2516 // Do we already know that the SCC contains kernels, 2517 // or that OpenMP functions are called from this SCC? 2518 if (SCCIsInteresting) 2519 continue; 2520 // If not, let's check that. 2521 SCCIsInteresting |= OMPInModule.containsOMPRuntimeCalls(Fn); 2522 } 2523 2524 if (!SCCIsInteresting || SCC.empty()) 2525 return PreservedAnalyses::all(); 2526 2527 FunctionAnalysisManager &FAM = 2528 AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager(); 2529 2530 AnalysisGetter AG(FAM); 2531 2532 auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & { 2533 return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F); 2534 }; 2535 2536 BumpPtrAllocator Allocator; 2537 CallGraphUpdater CGUpdater; 2538 CGUpdater.initialize(CG, C, AM, UR); 2539 2540 SetVector<Function *> Functions(SCC.begin(), SCC.end()); 2541 OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG, Allocator, 2542 /*CGSCC*/ Functions, OMPInModule.getKernels()); 2543 2544 Attributor A(Functions, InfoCache, CGUpdater, nullptr, false); 2545 2546 OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A); 2547 bool Changed = OMPOpt.run(false); 2548 if (Changed) 2549 return PreservedAnalyses::none(); 2550 2551 return PreservedAnalyses::all(); 2552 } 2553 2554 namespace { 2555 2556 struct OpenMPOptCGSCCLegacyPass : public CallGraphSCCPass { 2557 CallGraphUpdater CGUpdater; 2558 OpenMPInModule OMPInModule; 2559 static char ID; 2560 2561 OpenMPOptCGSCCLegacyPass() : CallGraphSCCPass(ID) { 2562 initializeOpenMPOptCGSCCLegacyPassPass(*PassRegistry::getPassRegistry()); 2563 } 2564 2565 void getAnalysisUsage(AnalysisUsage &AU) const override { 2566 CallGraphSCCPass::getAnalysisUsage(AU); 2567 } 2568 2569 bool doInitialization(CallGraph &CG) override { 2570 // Disable the pass if there is no OpenMP (runtime call) in the module. 2571 containsOpenMP(CG.getModule(), OMPInModule); 2572 return false; 2573 } 2574 2575 bool runOnSCC(CallGraphSCC &CGSCC) override { 2576 if (!containsOpenMP(CGSCC.getCallGraph().getModule(), OMPInModule)) 2577 return false; 2578 if (DisableOpenMPOptimizations || skipSCC(CGSCC)) 2579 return false; 2580 2581 SmallVector<Function *, 16> SCC; 2582 // If there are kernels in the module, we have to run on all SCC's. 2583 bool SCCIsInteresting = !OMPInModule.getKernels().empty(); 2584 for (CallGraphNode *CGN : CGSCC) { 2585 Function *Fn = CGN->getFunction(); 2586 if (!Fn || Fn->isDeclaration()) 2587 continue; 2588 SCC.push_back(Fn); 2589 2590 // Do we already know that the SCC contains kernels, 2591 // or that OpenMP functions are called from this SCC? 2592 if (SCCIsInteresting) 2593 continue; 2594 // If not, let's check that. 2595 SCCIsInteresting |= OMPInModule.containsOMPRuntimeCalls(Fn); 2596 } 2597 2598 if (!SCCIsInteresting || SCC.empty()) 2599 return false; 2600 2601 CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph(); 2602 CGUpdater.initialize(CG, CGSCC); 2603 2604 // Maintain a map of functions to avoid rebuilding the ORE 2605 DenseMap<Function *, std::unique_ptr<OptimizationRemarkEmitter>> OREMap; 2606 auto OREGetter = [&OREMap](Function *F) -> OptimizationRemarkEmitter & { 2607 std::unique_ptr<OptimizationRemarkEmitter> &ORE = OREMap[F]; 2608 if (!ORE) 2609 ORE = std::make_unique<OptimizationRemarkEmitter>(F); 2610 return *ORE; 2611 }; 2612 2613 AnalysisGetter AG; 2614 SetVector<Function *> Functions(SCC.begin(), SCC.end()); 2615 BumpPtrAllocator Allocator; 2616 OMPInformationCache InfoCache( 2617 *(Functions.back()->getParent()), AG, Allocator, 2618 /*CGSCC*/ Functions, OMPInModule.getKernels()); 2619 2620 Attributor A(Functions, InfoCache, CGUpdater, nullptr, false); 2621 2622 OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A); 2623 return OMPOpt.run(false); 2624 } 2625 2626 bool doFinalization(CallGraph &CG) override { return CGUpdater.finalize(); } 2627 }; 2628 2629 } // end anonymous namespace 2630 2631 void OpenMPInModule::identifyKernels(Module &M) { 2632 2633 NamedMDNode *MD = M.getOrInsertNamedMetadata("nvvm.annotations"); 2634 if (!MD) 2635 return; 2636 2637 for (auto *Op : MD->operands()) { 2638 if (Op->getNumOperands() < 2) 2639 continue; 2640 MDString *KindID = dyn_cast<MDString>(Op->getOperand(1)); 2641 if (!KindID || KindID->getString() != "kernel") 2642 continue; 2643 2644 Function *KernelFn = 2645 mdconst::dyn_extract_or_null<Function>(Op->getOperand(0)); 2646 if (!KernelFn) 2647 continue; 2648 2649 ++NumOpenMPTargetRegionKernels; 2650 2651 Kernels.insert(KernelFn); 2652 } 2653 } 2654 2655 bool llvm::omp::containsOpenMP(Module &M, OpenMPInModule &OMPInModule) { 2656 if (OMPInModule.isKnown()) 2657 return OMPInModule; 2658 2659 auto RecordFunctionsContainingUsesOf = [&](Function *F) { 2660 for (User *U : F->users()) 2661 if (auto *I = dyn_cast<Instruction>(U)) 2662 OMPInModule.FuncsWithOMPRuntimeCalls.insert(I->getFunction()); 2663 }; 2664 2665 // MSVC doesn't like long if-else chains for some reason and instead just 2666 // issues an error. Work around it.. 2667 do { 2668 #define OMP_RTL(_Enum, _Name, ...) \ 2669 if (Function *F = M.getFunction(_Name)) { \ 2670 RecordFunctionsContainingUsesOf(F); \ 2671 OMPInModule = true; \ 2672 } 2673 #include "llvm/Frontend/OpenMP/OMPKinds.def" 2674 } while (false); 2675 2676 // Identify kernels once. TODO: We should split the OMPInformationCache into a 2677 // module and an SCC part. The kernel information, among other things, could 2678 // go into the module part. 2679 if (OMPInModule.isKnown() && OMPInModule) { 2680 OMPInModule.identifyKernels(M); 2681 return true; 2682 } 2683 2684 return OMPInModule = false; 2685 } 2686 2687 char OpenMPOptCGSCCLegacyPass::ID = 0; 2688 2689 INITIALIZE_PASS_BEGIN(OpenMPOptCGSCCLegacyPass, "openmp-opt-cgscc", 2690 "OpenMP specific optimizations", false, false) 2691 INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass) 2692 INITIALIZE_PASS_END(OpenMPOptCGSCCLegacyPass, "openmp-opt-cgscc", 2693 "OpenMP specific optimizations", false, false) 2694 2695 Pass *llvm::createOpenMPOptCGSCCLegacyPass() { 2696 return new OpenMPOptCGSCCLegacyPass(); 2697 } 2698