1 //===-- IPO/OpenMPOpt.cpp - Collection of OpenMP specific optimizations ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // OpenMP specific optimizations: 10 // 11 // - Deduplication of runtime calls, e.g., omp_get_thread_num. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "llvm/Transforms/IPO/OpenMPOpt.h" 16 17 #include "llvm/ADT/EnumeratedArray.h" 18 #include "llvm/ADT/PostOrderIterator.h" 19 #include "llvm/ADT/Statistic.h" 20 #include "llvm/Analysis/CallGraph.h" 21 #include "llvm/Analysis/CallGraphSCCPass.h" 22 #include "llvm/Analysis/OptimizationRemarkEmitter.h" 23 #include "llvm/Analysis/ValueTracking.h" 24 #include "llvm/Frontend/OpenMP/OMPConstants.h" 25 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h" 26 #include "llvm/InitializePasses.h" 27 #include "llvm/Support/CommandLine.h" 28 #include "llvm/Transforms/IPO.h" 29 #include "llvm/Transforms/IPO/Attributor.h" 30 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 31 #include "llvm/Transforms/Utils/CallGraphUpdater.h" 32 #include "llvm/Transforms/Utils/CodeExtractor.h" 33 34 using namespace llvm; 35 using namespace omp; 36 37 #define DEBUG_TYPE "openmp-opt" 38 39 static cl::opt<bool> DisableOpenMPOptimizations( 40 "openmp-opt-disable", cl::ZeroOrMore, 41 cl::desc("Disable OpenMP specific optimizations."), cl::Hidden, 42 cl::init(false)); 43 44 static cl::opt<bool> EnableParallelRegionMerging( 45 "openmp-opt-enable-merging", cl::ZeroOrMore, 46 cl::desc("Enable the OpenMP region merging optimization."), cl::Hidden, 47 cl::init(false)); 48 49 static cl::opt<bool> PrintICVValues("openmp-print-icv-values", cl::init(false), 50 cl::Hidden); 51 static cl::opt<bool> PrintOpenMPKernels("openmp-print-gpu-kernels", 52 cl::init(false), cl::Hidden); 53 54 static cl::opt<bool> HideMemoryTransferLatency( 55 "openmp-hide-memory-transfer-latency", 56 cl::desc("[WIP] Tries to hide the latency of host to device memory" 57 " transfers"), 58 cl::Hidden, cl::init(false)); 59 60 STATISTIC(NumOpenMPRuntimeCallsDeduplicated, 61 "Number of OpenMP runtime calls deduplicated"); 62 STATISTIC(NumOpenMPParallelRegionsDeleted, 63 "Number of OpenMP parallel regions deleted"); 64 STATISTIC(NumOpenMPRuntimeFunctionsIdentified, 65 "Number of OpenMP runtime functions identified"); 66 STATISTIC(NumOpenMPRuntimeFunctionUsesIdentified, 67 "Number of OpenMP runtime function uses identified"); 68 STATISTIC(NumOpenMPTargetRegionKernels, 69 "Number of OpenMP target region entry points (=kernels) identified"); 70 STATISTIC( 71 NumOpenMPParallelRegionsReplacedInGPUStateMachine, 72 "Number of OpenMP parallel regions replaced with ID in GPU state machines"); 73 STATISTIC(NumOpenMPParallelRegionsMerged, 74 "Number of OpenMP parallel regions merged"); 75 76 #if !defined(NDEBUG) 77 static constexpr auto TAG = "[" DEBUG_TYPE "]"; 78 #endif 79 80 namespace { 81 82 struct AAExecutionDomain 83 : public StateWrapper<BooleanState, AbstractAttribute> { 84 using Base = StateWrapper<BooleanState, AbstractAttribute>; 85 AAExecutionDomain(const IRPosition &IRP, Attributor &A) : Base(IRP) {} 86 87 /// Create an abstract attribute view for the position \p IRP. 88 static AAExecutionDomain &createForPosition(const IRPosition &IRP, 89 Attributor &A); 90 91 /// See AbstractAttribute::getName(). 92 const std::string getName() const override { return "AAExecutionDomain"; } 93 94 /// See AbstractAttribute::getIdAddr(). 95 const char *getIdAddr() const override { return &ID; } 96 97 /// Check if an instruction is executed by a single thread. 98 virtual bool isSingleThreadExecution(const Instruction &) const = 0; 99 100 virtual bool isSingleThreadExecution(const BasicBlock &) const = 0; 101 102 /// This function should return true if the type of the \p AA is 103 /// AAExecutionDomain. 104 static bool classof(const AbstractAttribute *AA) { 105 return (AA->getIdAddr() == &ID); 106 } 107 108 /// Unique ID (due to the unique address) 109 static const char ID; 110 }; 111 112 struct AAICVTracker; 113 114 /// OpenMP specific information. For now, stores RFIs and ICVs also needed for 115 /// Attributor runs. 116 struct OMPInformationCache : public InformationCache { 117 OMPInformationCache(Module &M, AnalysisGetter &AG, 118 BumpPtrAllocator &Allocator, SetVector<Function *> &CGSCC, 119 SmallPtrSetImpl<Kernel> &Kernels) 120 : InformationCache(M, AG, Allocator, &CGSCC), OMPBuilder(M), 121 Kernels(Kernels) { 122 123 OMPBuilder.initialize(); 124 initializeRuntimeFunctions(); 125 initializeInternalControlVars(); 126 } 127 128 /// Generic information that describes an internal control variable. 129 struct InternalControlVarInfo { 130 /// The kind, as described by InternalControlVar enum. 131 InternalControlVar Kind; 132 133 /// The name of the ICV. 134 StringRef Name; 135 136 /// Environment variable associated with this ICV. 137 StringRef EnvVarName; 138 139 /// Initial value kind. 140 ICVInitValue InitKind; 141 142 /// Initial value. 143 ConstantInt *InitValue; 144 145 /// Setter RTL function associated with this ICV. 146 RuntimeFunction Setter; 147 148 /// Getter RTL function associated with this ICV. 149 RuntimeFunction Getter; 150 151 /// RTL Function corresponding to the override clause of this ICV 152 RuntimeFunction Clause; 153 }; 154 155 /// Generic information that describes a runtime function 156 struct RuntimeFunctionInfo { 157 158 /// The kind, as described by the RuntimeFunction enum. 159 RuntimeFunction Kind; 160 161 /// The name of the function. 162 StringRef Name; 163 164 /// Flag to indicate a variadic function. 165 bool IsVarArg; 166 167 /// The return type of the function. 168 Type *ReturnType; 169 170 /// The argument types of the function. 171 SmallVector<Type *, 8> ArgumentTypes; 172 173 /// The declaration if available. 174 Function *Declaration = nullptr; 175 176 /// Uses of this runtime function per function containing the use. 177 using UseVector = SmallVector<Use *, 16>; 178 179 /// Clear UsesMap for runtime function. 180 void clearUsesMap() { UsesMap.clear(); } 181 182 /// Boolean conversion that is true if the runtime function was found. 183 operator bool() const { return Declaration; } 184 185 /// Return the vector of uses in function \p F. 186 UseVector &getOrCreateUseVector(Function *F) { 187 std::shared_ptr<UseVector> &UV = UsesMap[F]; 188 if (!UV) 189 UV = std::make_shared<UseVector>(); 190 return *UV; 191 } 192 193 /// Return the vector of uses in function \p F or `nullptr` if there are 194 /// none. 195 const UseVector *getUseVector(Function &F) const { 196 auto I = UsesMap.find(&F); 197 if (I != UsesMap.end()) 198 return I->second.get(); 199 return nullptr; 200 } 201 202 /// Return how many functions contain uses of this runtime function. 203 size_t getNumFunctionsWithUses() const { return UsesMap.size(); } 204 205 /// Return the number of arguments (or the minimal number for variadic 206 /// functions). 207 size_t getNumArgs() const { return ArgumentTypes.size(); } 208 209 /// Run the callback \p CB on each use and forget the use if the result is 210 /// true. The callback will be fed the function in which the use was 211 /// encountered as second argument. 212 void foreachUse(SmallVectorImpl<Function *> &SCC, 213 function_ref<bool(Use &, Function &)> CB) { 214 for (Function *F : SCC) 215 foreachUse(CB, F); 216 } 217 218 /// Run the callback \p CB on each use within the function \p F and forget 219 /// the use if the result is true. 220 void foreachUse(function_ref<bool(Use &, Function &)> CB, Function *F) { 221 SmallVector<unsigned, 8> ToBeDeleted; 222 ToBeDeleted.clear(); 223 224 unsigned Idx = 0; 225 UseVector &UV = getOrCreateUseVector(F); 226 227 for (Use *U : UV) { 228 if (CB(*U, *F)) 229 ToBeDeleted.push_back(Idx); 230 ++Idx; 231 } 232 233 // Remove the to-be-deleted indices in reverse order as prior 234 // modifications will not modify the smaller indices. 235 while (!ToBeDeleted.empty()) { 236 unsigned Idx = ToBeDeleted.pop_back_val(); 237 UV[Idx] = UV.back(); 238 UV.pop_back(); 239 } 240 } 241 242 private: 243 /// Map from functions to all uses of this runtime function contained in 244 /// them. 245 DenseMap<Function *, std::shared_ptr<UseVector>> UsesMap; 246 }; 247 248 /// An OpenMP-IR-Builder instance 249 OpenMPIRBuilder OMPBuilder; 250 251 /// Map from runtime function kind to the runtime function description. 252 EnumeratedArray<RuntimeFunctionInfo, RuntimeFunction, 253 RuntimeFunction::OMPRTL___last> 254 RFIs; 255 256 /// Map from ICV kind to the ICV description. 257 EnumeratedArray<InternalControlVarInfo, InternalControlVar, 258 InternalControlVar::ICV___last> 259 ICVs; 260 261 /// Helper to initialize all internal control variable information for those 262 /// defined in OMPKinds.def. 263 void initializeInternalControlVars() { 264 #define ICV_RT_SET(_Name, RTL) \ 265 { \ 266 auto &ICV = ICVs[_Name]; \ 267 ICV.Setter = RTL; \ 268 } 269 #define ICV_RT_GET(Name, RTL) \ 270 { \ 271 auto &ICV = ICVs[Name]; \ 272 ICV.Getter = RTL; \ 273 } 274 #define ICV_DATA_ENV(Enum, _Name, _EnvVarName, Init) \ 275 { \ 276 auto &ICV = ICVs[Enum]; \ 277 ICV.Name = _Name; \ 278 ICV.Kind = Enum; \ 279 ICV.InitKind = Init; \ 280 ICV.EnvVarName = _EnvVarName; \ 281 switch (ICV.InitKind) { \ 282 case ICV_IMPLEMENTATION_DEFINED: \ 283 ICV.InitValue = nullptr; \ 284 break; \ 285 case ICV_ZERO: \ 286 ICV.InitValue = ConstantInt::get( \ 287 Type::getInt32Ty(OMPBuilder.Int32->getContext()), 0); \ 288 break; \ 289 case ICV_FALSE: \ 290 ICV.InitValue = ConstantInt::getFalse(OMPBuilder.Int1->getContext()); \ 291 break; \ 292 case ICV_LAST: \ 293 break; \ 294 } \ 295 } 296 #include "llvm/Frontend/OpenMP/OMPKinds.def" 297 } 298 299 /// Returns true if the function declaration \p F matches the runtime 300 /// function types, that is, return type \p RTFRetType, and argument types 301 /// \p RTFArgTypes. 302 static bool declMatchesRTFTypes(Function *F, Type *RTFRetType, 303 SmallVector<Type *, 8> &RTFArgTypes) { 304 // TODO: We should output information to the user (under debug output 305 // and via remarks). 306 307 if (!F) 308 return false; 309 if (F->getReturnType() != RTFRetType) 310 return false; 311 if (F->arg_size() != RTFArgTypes.size()) 312 return false; 313 314 auto RTFTyIt = RTFArgTypes.begin(); 315 for (Argument &Arg : F->args()) { 316 if (Arg.getType() != *RTFTyIt) 317 return false; 318 319 ++RTFTyIt; 320 } 321 322 return true; 323 } 324 325 // Helper to collect all uses of the declaration in the UsesMap. 326 unsigned collectUses(RuntimeFunctionInfo &RFI, bool CollectStats = true) { 327 unsigned NumUses = 0; 328 if (!RFI.Declaration) 329 return NumUses; 330 OMPBuilder.addAttributes(RFI.Kind, *RFI.Declaration); 331 332 if (CollectStats) { 333 NumOpenMPRuntimeFunctionsIdentified += 1; 334 NumOpenMPRuntimeFunctionUsesIdentified += RFI.Declaration->getNumUses(); 335 } 336 337 // TODO: We directly convert uses into proper calls and unknown uses. 338 for (Use &U : RFI.Declaration->uses()) { 339 if (Instruction *UserI = dyn_cast<Instruction>(U.getUser())) { 340 if (ModuleSlice.count(UserI->getFunction())) { 341 RFI.getOrCreateUseVector(UserI->getFunction()).push_back(&U); 342 ++NumUses; 343 } 344 } else { 345 RFI.getOrCreateUseVector(nullptr).push_back(&U); 346 ++NumUses; 347 } 348 } 349 return NumUses; 350 } 351 352 // Helper function to recollect uses of a runtime function. 353 void recollectUsesForFunction(RuntimeFunction RTF) { 354 auto &RFI = RFIs[RTF]; 355 RFI.clearUsesMap(); 356 collectUses(RFI, /*CollectStats*/ false); 357 } 358 359 // Helper function to recollect uses of all runtime functions. 360 void recollectUses() { 361 for (int Idx = 0; Idx < RFIs.size(); ++Idx) 362 recollectUsesForFunction(static_cast<RuntimeFunction>(Idx)); 363 } 364 365 /// Helper to initialize all runtime function information for those defined 366 /// in OpenMPKinds.def. 367 void initializeRuntimeFunctions() { 368 Module &M = *((*ModuleSlice.begin())->getParent()); 369 370 // Helper macros for handling __VA_ARGS__ in OMP_RTL 371 #define OMP_TYPE(VarName, ...) \ 372 Type *VarName = OMPBuilder.VarName; \ 373 (void)VarName; 374 375 #define OMP_ARRAY_TYPE(VarName, ...) \ 376 ArrayType *VarName##Ty = OMPBuilder.VarName##Ty; \ 377 (void)VarName##Ty; \ 378 PointerType *VarName##PtrTy = OMPBuilder.VarName##PtrTy; \ 379 (void)VarName##PtrTy; 380 381 #define OMP_FUNCTION_TYPE(VarName, ...) \ 382 FunctionType *VarName = OMPBuilder.VarName; \ 383 (void)VarName; \ 384 PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr; \ 385 (void)VarName##Ptr; 386 387 #define OMP_STRUCT_TYPE(VarName, ...) \ 388 StructType *VarName = OMPBuilder.VarName; \ 389 (void)VarName; \ 390 PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr; \ 391 (void)VarName##Ptr; 392 393 #define OMP_RTL(_Enum, _Name, _IsVarArg, _ReturnType, ...) \ 394 { \ 395 SmallVector<Type *, 8> ArgsTypes({__VA_ARGS__}); \ 396 Function *F = M.getFunction(_Name); \ 397 if (declMatchesRTFTypes(F, OMPBuilder._ReturnType, ArgsTypes)) { \ 398 auto &RFI = RFIs[_Enum]; \ 399 RFI.Kind = _Enum; \ 400 RFI.Name = _Name; \ 401 RFI.IsVarArg = _IsVarArg; \ 402 RFI.ReturnType = OMPBuilder._ReturnType; \ 403 RFI.ArgumentTypes = std::move(ArgsTypes); \ 404 RFI.Declaration = F; \ 405 unsigned NumUses = collectUses(RFI); \ 406 (void)NumUses; \ 407 LLVM_DEBUG({ \ 408 dbgs() << TAG << RFI.Name << (RFI.Declaration ? "" : " not") \ 409 << " found\n"; \ 410 if (RFI.Declaration) \ 411 dbgs() << TAG << "-> got " << NumUses << " uses in " \ 412 << RFI.getNumFunctionsWithUses() \ 413 << " different functions.\n"; \ 414 }); \ 415 } \ 416 } 417 #include "llvm/Frontend/OpenMP/OMPKinds.def" 418 419 // TODO: We should attach the attributes defined in OMPKinds.def. 420 } 421 422 /// Collection of known kernels (\see Kernel) in the module. 423 SmallPtrSetImpl<Kernel> &Kernels; 424 }; 425 426 /// Used to map the values physically (in the IR) stored in an offload 427 /// array, to a vector in memory. 428 struct OffloadArray { 429 /// Physical array (in the IR). 430 AllocaInst *Array = nullptr; 431 /// Mapped values. 432 SmallVector<Value *, 8> StoredValues; 433 /// Last stores made in the offload array. 434 SmallVector<StoreInst *, 8> LastAccesses; 435 436 OffloadArray() = default; 437 438 /// Initializes the OffloadArray with the values stored in \p Array before 439 /// instruction \p Before is reached. Returns false if the initialization 440 /// fails. 441 /// This MUST be used immediately after the construction of the object. 442 bool initialize(AllocaInst &Array, Instruction &Before) { 443 if (!Array.getAllocatedType()->isArrayTy()) 444 return false; 445 446 if (!getValues(Array, Before)) 447 return false; 448 449 this->Array = &Array; 450 return true; 451 } 452 453 static const unsigned DeviceIDArgNum = 1; 454 static const unsigned BasePtrsArgNum = 3; 455 static const unsigned PtrsArgNum = 4; 456 static const unsigned SizesArgNum = 5; 457 458 private: 459 /// Traverses the BasicBlock where \p Array is, collecting the stores made to 460 /// \p Array, leaving StoredValues with the values stored before the 461 /// instruction \p Before is reached. 462 bool getValues(AllocaInst &Array, Instruction &Before) { 463 // Initialize container. 464 const uint64_t NumValues = Array.getAllocatedType()->getArrayNumElements(); 465 StoredValues.assign(NumValues, nullptr); 466 LastAccesses.assign(NumValues, nullptr); 467 468 // TODO: This assumes the instruction \p Before is in the same 469 // BasicBlock as Array. Make it general, for any control flow graph. 470 BasicBlock *BB = Array.getParent(); 471 if (BB != Before.getParent()) 472 return false; 473 474 const DataLayout &DL = Array.getModule()->getDataLayout(); 475 const unsigned int PointerSize = DL.getPointerSize(); 476 477 for (Instruction &I : *BB) { 478 if (&I == &Before) 479 break; 480 481 if (!isa<StoreInst>(&I)) 482 continue; 483 484 auto *S = cast<StoreInst>(&I); 485 int64_t Offset = -1; 486 auto *Dst = 487 GetPointerBaseWithConstantOffset(S->getPointerOperand(), Offset, DL); 488 if (Dst == &Array) { 489 int64_t Idx = Offset / PointerSize; 490 StoredValues[Idx] = getUnderlyingObject(S->getValueOperand()); 491 LastAccesses[Idx] = S; 492 } 493 } 494 495 return isFilled(); 496 } 497 498 /// Returns true if all values in StoredValues and 499 /// LastAccesses are not nullptrs. 500 bool isFilled() { 501 const unsigned NumValues = StoredValues.size(); 502 for (unsigned I = 0; I < NumValues; ++I) { 503 if (!StoredValues[I] || !LastAccesses[I]) 504 return false; 505 } 506 507 return true; 508 } 509 }; 510 511 struct OpenMPOpt { 512 513 using OptimizationRemarkGetter = 514 function_ref<OptimizationRemarkEmitter &(Function *)>; 515 516 OpenMPOpt(SmallVectorImpl<Function *> &SCC, CallGraphUpdater &CGUpdater, 517 OptimizationRemarkGetter OREGetter, 518 OMPInformationCache &OMPInfoCache, Attributor &A) 519 : M(*(*SCC.begin())->getParent()), SCC(SCC), CGUpdater(CGUpdater), 520 OREGetter(OREGetter), OMPInfoCache(OMPInfoCache), A(A) {} 521 522 /// Check if any remarks are enabled for openmp-opt 523 bool remarksEnabled() { 524 auto &Ctx = M.getContext(); 525 return Ctx.getDiagHandlerPtr()->isAnyRemarkEnabled(DEBUG_TYPE); 526 } 527 528 /// Run all OpenMP optimizations on the underlying SCC/ModuleSlice. 529 bool run(bool IsModulePass) { 530 if (SCC.empty()) 531 return false; 532 533 bool Changed = false; 534 535 LLVM_DEBUG(dbgs() << TAG << "Run on SCC with " << SCC.size() 536 << " functions in a slice with " 537 << OMPInfoCache.ModuleSlice.size() << " functions\n"); 538 539 if (IsModulePass) { 540 Changed |= runAttributor(); 541 542 if (remarksEnabled()) 543 analysisGlobalization(); 544 } else { 545 if (PrintICVValues) 546 printICVs(); 547 if (PrintOpenMPKernels) 548 printKernels(); 549 550 Changed |= rewriteDeviceCodeStateMachine(); 551 552 Changed |= runAttributor(); 553 554 // Recollect uses, in case Attributor deleted any. 555 OMPInfoCache.recollectUses(); 556 557 Changed |= deleteParallelRegions(); 558 if (HideMemoryTransferLatency) 559 Changed |= hideMemTransfersLatency(); 560 Changed |= deduplicateRuntimeCalls(); 561 if (EnableParallelRegionMerging) { 562 if (mergeParallelRegions()) { 563 deduplicateRuntimeCalls(); 564 Changed = true; 565 } 566 } 567 } 568 569 return Changed; 570 } 571 572 /// Print initial ICV values for testing. 573 /// FIXME: This should be done from the Attributor once it is added. 574 void printICVs() const { 575 InternalControlVar ICVs[] = {ICV_nthreads, ICV_active_levels, ICV_cancel, 576 ICV_proc_bind}; 577 578 for (Function *F : OMPInfoCache.ModuleSlice) { 579 for (auto ICV : ICVs) { 580 auto ICVInfo = OMPInfoCache.ICVs[ICV]; 581 auto Remark = [&](OptimizationRemark OR) { 582 return OR << "OpenMP ICV " << ore::NV("OpenMPICV", ICVInfo.Name) 583 << " Value: " 584 << (ICVInfo.InitValue 585 ? ICVInfo.InitValue->getValue().toString(10, true) 586 : "IMPLEMENTATION_DEFINED"); 587 }; 588 589 emitRemarkOnFunction(F, "OpenMPICVTracker", Remark); 590 } 591 } 592 } 593 594 /// Print OpenMP GPU kernels for testing. 595 void printKernels() const { 596 for (Function *F : SCC) { 597 if (!OMPInfoCache.Kernels.count(F)) 598 continue; 599 600 auto Remark = [&](OptimizationRemark OR) { 601 return OR << "OpenMP GPU kernel " 602 << ore::NV("OpenMPGPUKernel", F->getName()) << "\n"; 603 }; 604 605 emitRemarkOnFunction(F, "OpenMPGPU", Remark); 606 } 607 } 608 609 /// Return the call if \p U is a callee use in a regular call. If \p RFI is 610 /// given it has to be the callee or a nullptr is returned. 611 static CallInst *getCallIfRegularCall( 612 Use &U, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) { 613 CallInst *CI = dyn_cast<CallInst>(U.getUser()); 614 if (CI && CI->isCallee(&U) && !CI->hasOperandBundles() && 615 (!RFI || CI->getCalledFunction() == RFI->Declaration)) 616 return CI; 617 return nullptr; 618 } 619 620 /// Return the call if \p V is a regular call. If \p RFI is given it has to be 621 /// the callee or a nullptr is returned. 622 static CallInst *getCallIfRegularCall( 623 Value &V, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) { 624 CallInst *CI = dyn_cast<CallInst>(&V); 625 if (CI && !CI->hasOperandBundles() && 626 (!RFI || CI->getCalledFunction() == RFI->Declaration)) 627 return CI; 628 return nullptr; 629 } 630 631 private: 632 /// Merge parallel regions when it is safe. 633 bool mergeParallelRegions() { 634 const unsigned CallbackCalleeOperand = 2; 635 const unsigned CallbackFirstArgOperand = 3; 636 using InsertPointTy = OpenMPIRBuilder::InsertPointTy; 637 638 // Check if there are any __kmpc_fork_call calls to merge. 639 OMPInformationCache::RuntimeFunctionInfo &RFI = 640 OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call]; 641 642 if (!RFI.Declaration) 643 return false; 644 645 // Unmergable calls that prevent merging a parallel region. 646 OMPInformationCache::RuntimeFunctionInfo UnmergableCallsInfo[] = { 647 OMPInfoCache.RFIs[OMPRTL___kmpc_push_proc_bind], 648 OMPInfoCache.RFIs[OMPRTL___kmpc_push_num_threads], 649 }; 650 651 bool Changed = false; 652 LoopInfo *LI = nullptr; 653 DominatorTree *DT = nullptr; 654 655 SmallDenseMap<BasicBlock *, SmallPtrSet<Instruction *, 4>> BB2PRMap; 656 657 BasicBlock *StartBB = nullptr, *EndBB = nullptr; 658 auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, 659 BasicBlock &ContinuationIP) { 660 BasicBlock *CGStartBB = CodeGenIP.getBlock(); 661 BasicBlock *CGEndBB = 662 SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI); 663 assert(StartBB != nullptr && "StartBB should not be null"); 664 CGStartBB->getTerminator()->setSuccessor(0, StartBB); 665 assert(EndBB != nullptr && "EndBB should not be null"); 666 EndBB->getTerminator()->setSuccessor(0, CGEndBB); 667 }; 668 669 auto PrivCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, Value &, 670 Value &Inner, Value *&ReplacementValue) -> InsertPointTy { 671 ReplacementValue = &Inner; 672 return CodeGenIP; 673 }; 674 675 auto FiniCB = [&](InsertPointTy CodeGenIP) {}; 676 677 /// Create a sequential execution region within a merged parallel region, 678 /// encapsulated in a master construct with a barrier for synchronization. 679 auto CreateSequentialRegion = [&](Function *OuterFn, 680 BasicBlock *OuterPredBB, 681 Instruction *SeqStartI, 682 Instruction *SeqEndI) { 683 // Isolate the instructions of the sequential region to a separate 684 // block. 685 BasicBlock *ParentBB = SeqStartI->getParent(); 686 BasicBlock *SeqEndBB = 687 SplitBlock(ParentBB, SeqEndI->getNextNode(), DT, LI); 688 BasicBlock *SeqAfterBB = 689 SplitBlock(SeqEndBB, &*SeqEndBB->getFirstInsertionPt(), DT, LI); 690 BasicBlock *SeqStartBB = 691 SplitBlock(ParentBB, SeqStartI, DT, LI, nullptr, "seq.par.merged"); 692 693 assert(ParentBB->getUniqueSuccessor() == SeqStartBB && 694 "Expected a different CFG"); 695 const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc(); 696 ParentBB->getTerminator()->eraseFromParent(); 697 698 auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, 699 BasicBlock &ContinuationIP) { 700 BasicBlock *CGStartBB = CodeGenIP.getBlock(); 701 BasicBlock *CGEndBB = 702 SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI); 703 assert(SeqStartBB != nullptr && "SeqStartBB should not be null"); 704 CGStartBB->getTerminator()->setSuccessor(0, SeqStartBB); 705 assert(SeqEndBB != nullptr && "SeqEndBB should not be null"); 706 SeqEndBB->getTerminator()->setSuccessor(0, CGEndBB); 707 }; 708 auto FiniCB = [&](InsertPointTy CodeGenIP) {}; 709 710 // Find outputs from the sequential region to outside users and 711 // broadcast their values to them. 712 for (Instruction &I : *SeqStartBB) { 713 SmallPtrSet<Instruction *, 4> OutsideUsers; 714 for (User *Usr : I.users()) { 715 Instruction &UsrI = *cast<Instruction>(Usr); 716 // Ignore outputs to LT intrinsics, code extraction for the merged 717 // parallel region will fix them. 718 if (UsrI.isLifetimeStartOrEnd()) 719 continue; 720 721 if (UsrI.getParent() != SeqStartBB) 722 OutsideUsers.insert(&UsrI); 723 } 724 725 if (OutsideUsers.empty()) 726 continue; 727 728 // Emit an alloca in the outer region to store the broadcasted 729 // value. 730 const DataLayout &DL = M.getDataLayout(); 731 AllocaInst *AllocaI = new AllocaInst( 732 I.getType(), DL.getAllocaAddrSpace(), nullptr, 733 I.getName() + ".seq.output.alloc", &OuterFn->front().front()); 734 735 // Emit a store instruction in the sequential BB to update the 736 // value. 737 new StoreInst(&I, AllocaI, SeqStartBB->getTerminator()); 738 739 // Emit a load instruction and replace the use of the output value 740 // with it. 741 for (Instruction *UsrI : OutsideUsers) { 742 LoadInst *LoadI = new LoadInst( 743 I.getType(), AllocaI, I.getName() + ".seq.output.load", UsrI); 744 UsrI->replaceUsesOfWith(&I, LoadI); 745 } 746 } 747 748 OpenMPIRBuilder::LocationDescription Loc( 749 InsertPointTy(ParentBB, ParentBB->end()), DL); 750 InsertPointTy SeqAfterIP = 751 OMPInfoCache.OMPBuilder.createMaster(Loc, BodyGenCB, FiniCB); 752 753 OMPInfoCache.OMPBuilder.createBarrier(SeqAfterIP, OMPD_parallel); 754 755 BranchInst::Create(SeqAfterBB, SeqAfterIP.getBlock()); 756 757 LLVM_DEBUG(dbgs() << TAG << "After sequential inlining " << *OuterFn 758 << "\n"); 759 }; 760 761 // Helper to merge the __kmpc_fork_call calls in MergableCIs. They are all 762 // contained in BB and only separated by instructions that can be 763 // redundantly executed in parallel. The block BB is split before the first 764 // call (in MergableCIs) and after the last so the entire region we merge 765 // into a single parallel region is contained in a single basic block 766 // without any other instructions. We use the OpenMPIRBuilder to outline 767 // that block and call the resulting function via __kmpc_fork_call. 768 auto Merge = [&](SmallVectorImpl<CallInst *> &MergableCIs, BasicBlock *BB) { 769 // TODO: Change the interface to allow single CIs expanded, e.g, to 770 // include an outer loop. 771 assert(MergableCIs.size() > 1 && "Assumed multiple mergable CIs"); 772 773 auto Remark = [&](OptimizationRemark OR) { 774 OR << "Parallel region at " 775 << ore::NV("OpenMPParallelMergeFront", 776 MergableCIs.front()->getDebugLoc()) 777 << " merged with parallel regions at "; 778 for (auto *CI : llvm::drop_begin(MergableCIs)) { 779 OR << ore::NV("OpenMPParallelMerge", CI->getDebugLoc()); 780 if (CI != MergableCIs.back()) 781 OR << ", "; 782 } 783 return OR; 784 }; 785 786 emitRemark<OptimizationRemark>(MergableCIs.front(), 787 "OpenMPParallelRegionMerging", Remark); 788 789 Function *OriginalFn = BB->getParent(); 790 LLVM_DEBUG(dbgs() << TAG << "Merge " << MergableCIs.size() 791 << " parallel regions in " << OriginalFn->getName() 792 << "\n"); 793 794 // Isolate the calls to merge in a separate block. 795 EndBB = SplitBlock(BB, MergableCIs.back()->getNextNode(), DT, LI); 796 BasicBlock *AfterBB = 797 SplitBlock(EndBB, &*EndBB->getFirstInsertionPt(), DT, LI); 798 StartBB = SplitBlock(BB, MergableCIs.front(), DT, LI, nullptr, 799 "omp.par.merged"); 800 801 assert(BB->getUniqueSuccessor() == StartBB && "Expected a different CFG"); 802 const DebugLoc DL = BB->getTerminator()->getDebugLoc(); 803 BB->getTerminator()->eraseFromParent(); 804 805 // Create sequential regions for sequential instructions that are 806 // in-between mergable parallel regions. 807 for (auto *It = MergableCIs.begin(), *End = MergableCIs.end() - 1; 808 It != End; ++It) { 809 Instruction *ForkCI = *It; 810 Instruction *NextForkCI = *(It + 1); 811 812 // Continue if there are not in-between instructions. 813 if (ForkCI->getNextNode() == NextForkCI) 814 continue; 815 816 CreateSequentialRegion(OriginalFn, BB, ForkCI->getNextNode(), 817 NextForkCI->getPrevNode()); 818 } 819 820 OpenMPIRBuilder::LocationDescription Loc(InsertPointTy(BB, BB->end()), 821 DL); 822 IRBuilder<>::InsertPoint AllocaIP( 823 &OriginalFn->getEntryBlock(), 824 OriginalFn->getEntryBlock().getFirstInsertionPt()); 825 // Create the merged parallel region with default proc binding, to 826 // avoid overriding binding settings, and without explicit cancellation. 827 InsertPointTy AfterIP = OMPInfoCache.OMPBuilder.createParallel( 828 Loc, AllocaIP, BodyGenCB, PrivCB, FiniCB, nullptr, nullptr, 829 OMP_PROC_BIND_default, /* IsCancellable */ false); 830 BranchInst::Create(AfterBB, AfterIP.getBlock()); 831 832 // Perform the actual outlining. 833 OMPInfoCache.OMPBuilder.finalize(OriginalFn, 834 /* AllowExtractorSinking */ true); 835 836 Function *OutlinedFn = MergableCIs.front()->getCaller(); 837 838 // Replace the __kmpc_fork_call calls with direct calls to the outlined 839 // callbacks. 840 SmallVector<Value *, 8> Args; 841 for (auto *CI : MergableCIs) { 842 Value *Callee = 843 CI->getArgOperand(CallbackCalleeOperand)->stripPointerCasts(); 844 FunctionType *FT = 845 cast<FunctionType>(Callee->getType()->getPointerElementType()); 846 Args.clear(); 847 Args.push_back(OutlinedFn->getArg(0)); 848 Args.push_back(OutlinedFn->getArg(1)); 849 for (unsigned U = CallbackFirstArgOperand, E = CI->getNumArgOperands(); 850 U < E; ++U) 851 Args.push_back(CI->getArgOperand(U)); 852 853 CallInst *NewCI = CallInst::Create(FT, Callee, Args, "", CI); 854 if (CI->getDebugLoc()) 855 NewCI->setDebugLoc(CI->getDebugLoc()); 856 857 // Forward parameter attributes from the callback to the callee. 858 for (unsigned U = CallbackFirstArgOperand, E = CI->getNumArgOperands(); 859 U < E; ++U) 860 for (const Attribute &A : CI->getAttributes().getParamAttributes(U)) 861 NewCI->addParamAttr( 862 U - (CallbackFirstArgOperand - CallbackCalleeOperand), A); 863 864 // Emit an explicit barrier to replace the implicit fork-join barrier. 865 if (CI != MergableCIs.back()) { 866 // TODO: Remove barrier if the merged parallel region includes the 867 // 'nowait' clause. 868 OMPInfoCache.OMPBuilder.createBarrier( 869 InsertPointTy(NewCI->getParent(), 870 NewCI->getNextNode()->getIterator()), 871 OMPD_parallel); 872 } 873 874 auto Remark = [&](OptimizationRemark OR) { 875 return OR << "Parallel region at " 876 << ore::NV("OpenMPParallelMerge", CI->getDebugLoc()) 877 << " merged with " 878 << ore::NV("OpenMPParallelMergeFront", 879 MergableCIs.front()->getDebugLoc()); 880 }; 881 if (CI != MergableCIs.front()) 882 emitRemark<OptimizationRemark>(CI, "OpenMPParallelRegionMerging", 883 Remark); 884 885 CI->eraseFromParent(); 886 } 887 888 assert(OutlinedFn != OriginalFn && "Outlining failed"); 889 CGUpdater.registerOutlinedFunction(*OriginalFn, *OutlinedFn); 890 CGUpdater.reanalyzeFunction(*OriginalFn); 891 892 NumOpenMPParallelRegionsMerged += MergableCIs.size(); 893 894 return true; 895 }; 896 897 // Helper function that identifes sequences of 898 // __kmpc_fork_call uses in a basic block. 899 auto DetectPRsCB = [&](Use &U, Function &F) { 900 CallInst *CI = getCallIfRegularCall(U, &RFI); 901 BB2PRMap[CI->getParent()].insert(CI); 902 903 return false; 904 }; 905 906 BB2PRMap.clear(); 907 RFI.foreachUse(SCC, DetectPRsCB); 908 SmallVector<SmallVector<CallInst *, 4>, 4> MergableCIsVector; 909 // Find mergable parallel regions within a basic block that are 910 // safe to merge, that is any in-between instructions can safely 911 // execute in parallel after merging. 912 // TODO: support merging across basic-blocks. 913 for (auto &It : BB2PRMap) { 914 auto &CIs = It.getSecond(); 915 if (CIs.size() < 2) 916 continue; 917 918 BasicBlock *BB = It.getFirst(); 919 SmallVector<CallInst *, 4> MergableCIs; 920 921 /// Returns true if the instruction is mergable, false otherwise. 922 /// A terminator instruction is unmergable by definition since merging 923 /// works within a BB. Instructions before the mergable region are 924 /// mergable if they are not calls to OpenMP runtime functions that may 925 /// set different execution parameters for subsequent parallel regions. 926 /// Instructions in-between parallel regions are mergable if they are not 927 /// calls to any non-intrinsic function since that may call a non-mergable 928 /// OpenMP runtime function. 929 auto IsMergable = [&](Instruction &I, bool IsBeforeMergableRegion) { 930 // We do not merge across BBs, hence return false (unmergable) if the 931 // instruction is a terminator. 932 if (I.isTerminator()) 933 return false; 934 935 if (!isa<CallInst>(&I)) 936 return true; 937 938 CallInst *CI = cast<CallInst>(&I); 939 if (IsBeforeMergableRegion) { 940 Function *CalledFunction = CI->getCalledFunction(); 941 if (!CalledFunction) 942 return false; 943 // Return false (unmergable) if the call before the parallel 944 // region calls an explicit affinity (proc_bind) or number of 945 // threads (num_threads) compiler-generated function. Those settings 946 // may be incompatible with following parallel regions. 947 // TODO: ICV tracking to detect compatibility. 948 for (const auto &RFI : UnmergableCallsInfo) { 949 if (CalledFunction == RFI.Declaration) 950 return false; 951 } 952 } else { 953 // Return false (unmergable) if there is a call instruction 954 // in-between parallel regions when it is not an intrinsic. It 955 // may call an unmergable OpenMP runtime function in its callpath. 956 // TODO: Keep track of possible OpenMP calls in the callpath. 957 if (!isa<IntrinsicInst>(CI)) 958 return false; 959 } 960 961 return true; 962 }; 963 // Find maximal number of parallel region CIs that are safe to merge. 964 for (auto It = BB->begin(), End = BB->end(); It != End;) { 965 Instruction &I = *It; 966 ++It; 967 968 if (CIs.count(&I)) { 969 MergableCIs.push_back(cast<CallInst>(&I)); 970 continue; 971 } 972 973 // Continue expanding if the instruction is mergable. 974 if (IsMergable(I, MergableCIs.empty())) 975 continue; 976 977 // Forward the instruction iterator to skip the next parallel region 978 // since there is an unmergable instruction which can affect it. 979 for (; It != End; ++It) { 980 Instruction &SkipI = *It; 981 if (CIs.count(&SkipI)) { 982 LLVM_DEBUG(dbgs() << TAG << "Skip parallel region " << SkipI 983 << " due to " << I << "\n"); 984 ++It; 985 break; 986 } 987 } 988 989 // Store mergable regions found. 990 if (MergableCIs.size() > 1) { 991 MergableCIsVector.push_back(MergableCIs); 992 LLVM_DEBUG(dbgs() << TAG << "Found " << MergableCIs.size() 993 << " parallel regions in block " << BB->getName() 994 << " of function " << BB->getParent()->getName() 995 << "\n";); 996 } 997 998 MergableCIs.clear(); 999 } 1000 1001 if (!MergableCIsVector.empty()) { 1002 Changed = true; 1003 1004 for (auto &MergableCIs : MergableCIsVector) 1005 Merge(MergableCIs, BB); 1006 MergableCIsVector.clear(); 1007 } 1008 } 1009 1010 if (Changed) { 1011 /// Re-collect use for fork calls, emitted barrier calls, and 1012 /// any emitted master/end_master calls. 1013 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_fork_call); 1014 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_barrier); 1015 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_master); 1016 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_end_master); 1017 } 1018 1019 return Changed; 1020 } 1021 1022 /// Try to delete parallel regions if possible. 1023 bool deleteParallelRegions() { 1024 const unsigned CallbackCalleeOperand = 2; 1025 1026 OMPInformationCache::RuntimeFunctionInfo &RFI = 1027 OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call]; 1028 1029 if (!RFI.Declaration) 1030 return false; 1031 1032 bool Changed = false; 1033 auto DeleteCallCB = [&](Use &U, Function &) { 1034 CallInst *CI = getCallIfRegularCall(U); 1035 if (!CI) 1036 return false; 1037 auto *Fn = dyn_cast<Function>( 1038 CI->getArgOperand(CallbackCalleeOperand)->stripPointerCasts()); 1039 if (!Fn) 1040 return false; 1041 if (!Fn->onlyReadsMemory()) 1042 return false; 1043 if (!Fn->hasFnAttribute(Attribute::WillReturn)) 1044 return false; 1045 1046 LLVM_DEBUG(dbgs() << TAG << "Delete read-only parallel region in " 1047 << CI->getCaller()->getName() << "\n"); 1048 1049 auto Remark = [&](OptimizationRemark OR) { 1050 return OR << "Parallel region in " 1051 << ore::NV("OpenMPParallelDelete", CI->getCaller()->getName()) 1052 << " deleted"; 1053 }; 1054 emitRemark<OptimizationRemark>(CI, "OpenMPParallelRegionDeletion", 1055 Remark); 1056 1057 CGUpdater.removeCallSite(*CI); 1058 CI->eraseFromParent(); 1059 Changed = true; 1060 ++NumOpenMPParallelRegionsDeleted; 1061 return true; 1062 }; 1063 1064 RFI.foreachUse(SCC, DeleteCallCB); 1065 1066 return Changed; 1067 } 1068 1069 /// Try to eliminate runtime calls by reusing existing ones. 1070 bool deduplicateRuntimeCalls() { 1071 bool Changed = false; 1072 1073 RuntimeFunction DeduplicableRuntimeCallIDs[] = { 1074 OMPRTL_omp_get_num_threads, 1075 OMPRTL_omp_in_parallel, 1076 OMPRTL_omp_get_cancellation, 1077 OMPRTL_omp_get_thread_limit, 1078 OMPRTL_omp_get_supported_active_levels, 1079 OMPRTL_omp_get_level, 1080 OMPRTL_omp_get_ancestor_thread_num, 1081 OMPRTL_omp_get_team_size, 1082 OMPRTL_omp_get_active_level, 1083 OMPRTL_omp_in_final, 1084 OMPRTL_omp_get_proc_bind, 1085 OMPRTL_omp_get_num_places, 1086 OMPRTL_omp_get_num_procs, 1087 OMPRTL_omp_get_place_num, 1088 OMPRTL_omp_get_partition_num_places, 1089 OMPRTL_omp_get_partition_place_nums}; 1090 1091 // Global-tid is handled separately. 1092 SmallSetVector<Value *, 16> GTIdArgs; 1093 collectGlobalThreadIdArguments(GTIdArgs); 1094 LLVM_DEBUG(dbgs() << TAG << "Found " << GTIdArgs.size() 1095 << " global thread ID arguments\n"); 1096 1097 for (Function *F : SCC) { 1098 for (auto DeduplicableRuntimeCallID : DeduplicableRuntimeCallIDs) 1099 Changed |= deduplicateRuntimeCalls( 1100 *F, OMPInfoCache.RFIs[DeduplicableRuntimeCallID]); 1101 1102 // __kmpc_global_thread_num is special as we can replace it with an 1103 // argument in enough cases to make it worth trying. 1104 Value *GTIdArg = nullptr; 1105 for (Argument &Arg : F->args()) 1106 if (GTIdArgs.count(&Arg)) { 1107 GTIdArg = &Arg; 1108 break; 1109 } 1110 Changed |= deduplicateRuntimeCalls( 1111 *F, OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num], GTIdArg); 1112 } 1113 1114 return Changed; 1115 } 1116 1117 /// Tries to hide the latency of runtime calls that involve host to 1118 /// device memory transfers by splitting them into their "issue" and "wait" 1119 /// versions. The "issue" is moved upwards as much as possible. The "wait" is 1120 /// moved downards as much as possible. The "issue" issues the memory transfer 1121 /// asynchronously, returning a handle. The "wait" waits in the returned 1122 /// handle for the memory transfer to finish. 1123 bool hideMemTransfersLatency() { 1124 auto &RFI = OMPInfoCache.RFIs[OMPRTL___tgt_target_data_begin_mapper]; 1125 bool Changed = false; 1126 auto SplitMemTransfers = [&](Use &U, Function &Decl) { 1127 auto *RTCall = getCallIfRegularCall(U, &RFI); 1128 if (!RTCall) 1129 return false; 1130 1131 OffloadArray OffloadArrays[3]; 1132 if (!getValuesInOffloadArrays(*RTCall, OffloadArrays)) 1133 return false; 1134 1135 LLVM_DEBUG(dumpValuesInOffloadArrays(OffloadArrays)); 1136 1137 // TODO: Check if can be moved upwards. 1138 bool WasSplit = false; 1139 Instruction *WaitMovementPoint = canBeMovedDownwards(*RTCall); 1140 if (WaitMovementPoint) 1141 WasSplit = splitTargetDataBeginRTC(*RTCall, *WaitMovementPoint); 1142 1143 Changed |= WasSplit; 1144 return WasSplit; 1145 }; 1146 RFI.foreachUse(SCC, SplitMemTransfers); 1147 1148 return Changed; 1149 } 1150 1151 void analysisGlobalization() { 1152 RuntimeFunction GlobalizationRuntimeIDs[] = { 1153 OMPRTL___kmpc_data_sharing_coalesced_push_stack, 1154 OMPRTL___kmpc_data_sharing_push_stack}; 1155 1156 for (const auto GlobalizationCallID : GlobalizationRuntimeIDs) { 1157 auto &RFI = OMPInfoCache.RFIs[GlobalizationCallID]; 1158 1159 auto CheckGlobalization = [&](Use &U, Function &Decl) { 1160 if (CallInst *CI = getCallIfRegularCall(U, &RFI)) { 1161 auto Remark = [&](OptimizationRemarkAnalysis ORA) { 1162 return ORA 1163 << "Found thread data sharing on the GPU. " 1164 << "Expect degraded performance due to data globalization."; 1165 }; 1166 emitRemark<OptimizationRemarkAnalysis>(CI, "OpenMPGlobalization", 1167 Remark); 1168 } 1169 1170 return false; 1171 }; 1172 1173 RFI.foreachUse(SCC, CheckGlobalization); 1174 } 1175 } 1176 1177 /// Maps the values stored in the offload arrays passed as arguments to 1178 /// \p RuntimeCall into the offload arrays in \p OAs. 1179 bool getValuesInOffloadArrays(CallInst &RuntimeCall, 1180 MutableArrayRef<OffloadArray> OAs) { 1181 assert(OAs.size() == 3 && "Need space for three offload arrays!"); 1182 1183 // A runtime call that involves memory offloading looks something like: 1184 // call void @__tgt_target_data_begin_mapper(arg0, arg1, 1185 // i8** %offload_baseptrs, i8** %offload_ptrs, i64* %offload_sizes, 1186 // ...) 1187 // So, the idea is to access the allocas that allocate space for these 1188 // offload arrays, offload_baseptrs, offload_ptrs, offload_sizes. 1189 // Therefore: 1190 // i8** %offload_baseptrs. 1191 Value *BasePtrsArg = 1192 RuntimeCall.getArgOperand(OffloadArray::BasePtrsArgNum); 1193 // i8** %offload_ptrs. 1194 Value *PtrsArg = RuntimeCall.getArgOperand(OffloadArray::PtrsArgNum); 1195 // i8** %offload_sizes. 1196 Value *SizesArg = RuntimeCall.getArgOperand(OffloadArray::SizesArgNum); 1197 1198 // Get values stored in **offload_baseptrs. 1199 auto *V = getUnderlyingObject(BasePtrsArg); 1200 if (!isa<AllocaInst>(V)) 1201 return false; 1202 auto *BasePtrsArray = cast<AllocaInst>(V); 1203 if (!OAs[0].initialize(*BasePtrsArray, RuntimeCall)) 1204 return false; 1205 1206 // Get values stored in **offload_baseptrs. 1207 V = getUnderlyingObject(PtrsArg); 1208 if (!isa<AllocaInst>(V)) 1209 return false; 1210 auto *PtrsArray = cast<AllocaInst>(V); 1211 if (!OAs[1].initialize(*PtrsArray, RuntimeCall)) 1212 return false; 1213 1214 // Get values stored in **offload_sizes. 1215 V = getUnderlyingObject(SizesArg); 1216 // If it's a [constant] global array don't analyze it. 1217 if (isa<GlobalValue>(V)) 1218 return isa<Constant>(V); 1219 if (!isa<AllocaInst>(V)) 1220 return false; 1221 1222 auto *SizesArray = cast<AllocaInst>(V); 1223 if (!OAs[2].initialize(*SizesArray, RuntimeCall)) 1224 return false; 1225 1226 return true; 1227 } 1228 1229 /// Prints the values in the OffloadArrays \p OAs using LLVM_DEBUG. 1230 /// For now this is a way to test that the function getValuesInOffloadArrays 1231 /// is working properly. 1232 /// TODO: Move this to a unittest when unittests are available for OpenMPOpt. 1233 void dumpValuesInOffloadArrays(ArrayRef<OffloadArray> OAs) { 1234 assert(OAs.size() == 3 && "There are three offload arrays to debug!"); 1235 1236 LLVM_DEBUG(dbgs() << TAG << " Successfully got offload values:\n"); 1237 std::string ValuesStr; 1238 raw_string_ostream Printer(ValuesStr); 1239 std::string Separator = " --- "; 1240 1241 for (auto *BP : OAs[0].StoredValues) { 1242 BP->print(Printer); 1243 Printer << Separator; 1244 } 1245 LLVM_DEBUG(dbgs() << "\t\toffload_baseptrs: " << Printer.str() << "\n"); 1246 ValuesStr.clear(); 1247 1248 for (auto *P : OAs[1].StoredValues) { 1249 P->print(Printer); 1250 Printer << Separator; 1251 } 1252 LLVM_DEBUG(dbgs() << "\t\toffload_ptrs: " << Printer.str() << "\n"); 1253 ValuesStr.clear(); 1254 1255 for (auto *S : OAs[2].StoredValues) { 1256 S->print(Printer); 1257 Printer << Separator; 1258 } 1259 LLVM_DEBUG(dbgs() << "\t\toffload_sizes: " << Printer.str() << "\n"); 1260 } 1261 1262 /// Returns the instruction where the "wait" counterpart \p RuntimeCall can be 1263 /// moved. Returns nullptr if the movement is not possible, or not worth it. 1264 Instruction *canBeMovedDownwards(CallInst &RuntimeCall) { 1265 // FIXME: This traverses only the BasicBlock where RuntimeCall is. 1266 // Make it traverse the CFG. 1267 1268 Instruction *CurrentI = &RuntimeCall; 1269 bool IsWorthIt = false; 1270 while ((CurrentI = CurrentI->getNextNode())) { 1271 1272 // TODO: Once we detect the regions to be offloaded we should use the 1273 // alias analysis manager to check if CurrentI may modify one of 1274 // the offloaded regions. 1275 if (CurrentI->mayHaveSideEffects() || CurrentI->mayReadFromMemory()) { 1276 if (IsWorthIt) 1277 return CurrentI; 1278 1279 return nullptr; 1280 } 1281 1282 // FIXME: For now if we move it over anything without side effect 1283 // is worth it. 1284 IsWorthIt = true; 1285 } 1286 1287 // Return end of BasicBlock. 1288 return RuntimeCall.getParent()->getTerminator(); 1289 } 1290 1291 /// Splits \p RuntimeCall into its "issue" and "wait" counterparts. 1292 bool splitTargetDataBeginRTC(CallInst &RuntimeCall, 1293 Instruction &WaitMovementPoint) { 1294 // Create stack allocated handle (__tgt_async_info) at the beginning of the 1295 // function. Used for storing information of the async transfer, allowing to 1296 // wait on it later. 1297 auto &IRBuilder = OMPInfoCache.OMPBuilder; 1298 auto *F = RuntimeCall.getCaller(); 1299 Instruction *FirstInst = &(F->getEntryBlock().front()); 1300 AllocaInst *Handle = new AllocaInst( 1301 IRBuilder.AsyncInfo, F->getAddressSpace(), "handle", FirstInst); 1302 1303 // Add "issue" runtime call declaration: 1304 // declare %struct.tgt_async_info @__tgt_target_data_begin_issue(i64, i32, 1305 // i8**, i8**, i64*, i64*) 1306 FunctionCallee IssueDecl = IRBuilder.getOrCreateRuntimeFunction( 1307 M, OMPRTL___tgt_target_data_begin_mapper_issue); 1308 1309 // Change RuntimeCall call site for its asynchronous version. 1310 SmallVector<Value *, 16> Args; 1311 for (auto &Arg : RuntimeCall.args()) 1312 Args.push_back(Arg.get()); 1313 Args.push_back(Handle); 1314 1315 CallInst *IssueCallsite = 1316 CallInst::Create(IssueDecl, Args, /*NameStr=*/"", &RuntimeCall); 1317 RuntimeCall.eraseFromParent(); 1318 1319 // Add "wait" runtime call declaration: 1320 // declare void @__tgt_target_data_begin_wait(i64, %struct.__tgt_async_info) 1321 FunctionCallee WaitDecl = IRBuilder.getOrCreateRuntimeFunction( 1322 M, OMPRTL___tgt_target_data_begin_mapper_wait); 1323 1324 Value *WaitParams[2] = { 1325 IssueCallsite->getArgOperand( 1326 OffloadArray::DeviceIDArgNum), // device_id. 1327 Handle // handle to wait on. 1328 }; 1329 CallInst::Create(WaitDecl, WaitParams, /*NameStr=*/"", &WaitMovementPoint); 1330 1331 return true; 1332 } 1333 1334 static Value *combinedIdentStruct(Value *CurrentIdent, Value *NextIdent, 1335 bool GlobalOnly, bool &SingleChoice) { 1336 if (CurrentIdent == NextIdent) 1337 return CurrentIdent; 1338 1339 // TODO: Figure out how to actually combine multiple debug locations. For 1340 // now we just keep an existing one if there is a single choice. 1341 if (!GlobalOnly || isa<GlobalValue>(NextIdent)) { 1342 SingleChoice = !CurrentIdent; 1343 return NextIdent; 1344 } 1345 return nullptr; 1346 } 1347 1348 /// Return an `struct ident_t*` value that represents the ones used in the 1349 /// calls of \p RFI inside of \p F. If \p GlobalOnly is true, we will not 1350 /// return a local `struct ident_t*`. For now, if we cannot find a suitable 1351 /// return value we create one from scratch. We also do not yet combine 1352 /// information, e.g., the source locations, see combinedIdentStruct. 1353 Value * 1354 getCombinedIdentFromCallUsesIn(OMPInformationCache::RuntimeFunctionInfo &RFI, 1355 Function &F, bool GlobalOnly) { 1356 bool SingleChoice = true; 1357 Value *Ident = nullptr; 1358 auto CombineIdentStruct = [&](Use &U, Function &Caller) { 1359 CallInst *CI = getCallIfRegularCall(U, &RFI); 1360 if (!CI || &F != &Caller) 1361 return false; 1362 Ident = combinedIdentStruct(Ident, CI->getArgOperand(0), 1363 /* GlobalOnly */ true, SingleChoice); 1364 return false; 1365 }; 1366 RFI.foreachUse(SCC, CombineIdentStruct); 1367 1368 if (!Ident || !SingleChoice) { 1369 // The IRBuilder uses the insertion block to get to the module, this is 1370 // unfortunate but we work around it for now. 1371 if (!OMPInfoCache.OMPBuilder.getInsertionPoint().getBlock()) 1372 OMPInfoCache.OMPBuilder.updateToLocation(OpenMPIRBuilder::InsertPointTy( 1373 &F.getEntryBlock(), F.getEntryBlock().begin())); 1374 // Create a fallback location if non was found. 1375 // TODO: Use the debug locations of the calls instead. 1376 Constant *Loc = OMPInfoCache.OMPBuilder.getOrCreateDefaultSrcLocStr(); 1377 Ident = OMPInfoCache.OMPBuilder.getOrCreateIdent(Loc); 1378 } 1379 return Ident; 1380 } 1381 1382 /// Try to eliminate calls of \p RFI in \p F by reusing an existing one or 1383 /// \p ReplVal if given. 1384 bool deduplicateRuntimeCalls(Function &F, 1385 OMPInformationCache::RuntimeFunctionInfo &RFI, 1386 Value *ReplVal = nullptr) { 1387 auto *UV = RFI.getUseVector(F); 1388 if (!UV || UV->size() + (ReplVal != nullptr) < 2) 1389 return false; 1390 1391 LLVM_DEBUG( 1392 dbgs() << TAG << "Deduplicate " << UV->size() << " uses of " << RFI.Name 1393 << (ReplVal ? " with an existing value\n" : "\n") << "\n"); 1394 1395 assert((!ReplVal || (isa<Argument>(ReplVal) && 1396 cast<Argument>(ReplVal)->getParent() == &F)) && 1397 "Unexpected replacement value!"); 1398 1399 // TODO: Use dominance to find a good position instead. 1400 auto CanBeMoved = [this](CallBase &CB) { 1401 unsigned NumArgs = CB.getNumArgOperands(); 1402 if (NumArgs == 0) 1403 return true; 1404 if (CB.getArgOperand(0)->getType() != OMPInfoCache.OMPBuilder.IdentPtr) 1405 return false; 1406 for (unsigned u = 1; u < NumArgs; ++u) 1407 if (isa<Instruction>(CB.getArgOperand(u))) 1408 return false; 1409 return true; 1410 }; 1411 1412 if (!ReplVal) { 1413 for (Use *U : *UV) 1414 if (CallInst *CI = getCallIfRegularCall(*U, &RFI)) { 1415 if (!CanBeMoved(*CI)) 1416 continue; 1417 1418 auto Remark = [&](OptimizationRemark OR) { 1419 auto newLoc = &*F.getEntryBlock().getFirstInsertionPt(); 1420 return OR << "OpenMP runtime call " 1421 << ore::NV("OpenMPOptRuntime", RFI.Name) << " moved to " 1422 << ore::NV("OpenMPRuntimeMoves", newLoc->getDebugLoc()); 1423 }; 1424 emitRemark<OptimizationRemark>(CI, "OpenMPRuntimeCodeMotion", Remark); 1425 1426 CI->moveBefore(&*F.getEntryBlock().getFirstInsertionPt()); 1427 ReplVal = CI; 1428 break; 1429 } 1430 if (!ReplVal) 1431 return false; 1432 } 1433 1434 // If we use a call as a replacement value we need to make sure the ident is 1435 // valid at the new location. For now we just pick a global one, either 1436 // existing and used by one of the calls, or created from scratch. 1437 if (CallBase *CI = dyn_cast<CallBase>(ReplVal)) { 1438 if (CI->getNumArgOperands() > 0 && 1439 CI->getArgOperand(0)->getType() == OMPInfoCache.OMPBuilder.IdentPtr) { 1440 Value *Ident = getCombinedIdentFromCallUsesIn(RFI, F, 1441 /* GlobalOnly */ true); 1442 CI->setArgOperand(0, Ident); 1443 } 1444 } 1445 1446 bool Changed = false; 1447 auto ReplaceAndDeleteCB = [&](Use &U, Function &Caller) { 1448 CallInst *CI = getCallIfRegularCall(U, &RFI); 1449 if (!CI || CI == ReplVal || &F != &Caller) 1450 return false; 1451 assert(CI->getCaller() == &F && "Unexpected call!"); 1452 1453 auto Remark = [&](OptimizationRemark OR) { 1454 return OR << "OpenMP runtime call " 1455 << ore::NV("OpenMPOptRuntime", RFI.Name) << " deduplicated"; 1456 }; 1457 emitRemark<OptimizationRemark>(CI, "OpenMPRuntimeDeduplicated", Remark); 1458 1459 CGUpdater.removeCallSite(*CI); 1460 CI->replaceAllUsesWith(ReplVal); 1461 CI->eraseFromParent(); 1462 ++NumOpenMPRuntimeCallsDeduplicated; 1463 Changed = true; 1464 return true; 1465 }; 1466 RFI.foreachUse(SCC, ReplaceAndDeleteCB); 1467 1468 return Changed; 1469 } 1470 1471 /// Collect arguments that represent the global thread id in \p GTIdArgs. 1472 void collectGlobalThreadIdArguments(SmallSetVector<Value *, 16> >IdArgs) { 1473 // TODO: Below we basically perform a fixpoint iteration with a pessimistic 1474 // initialization. We could define an AbstractAttribute instead and 1475 // run the Attributor here once it can be run as an SCC pass. 1476 1477 // Helper to check the argument \p ArgNo at all call sites of \p F for 1478 // a GTId. 1479 auto CallArgOpIsGTId = [&](Function &F, unsigned ArgNo, CallInst &RefCI) { 1480 if (!F.hasLocalLinkage()) 1481 return false; 1482 for (Use &U : F.uses()) { 1483 if (CallInst *CI = getCallIfRegularCall(U)) { 1484 Value *ArgOp = CI->getArgOperand(ArgNo); 1485 if (CI == &RefCI || GTIdArgs.count(ArgOp) || 1486 getCallIfRegularCall( 1487 *ArgOp, &OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num])) 1488 continue; 1489 } 1490 return false; 1491 } 1492 return true; 1493 }; 1494 1495 // Helper to identify uses of a GTId as GTId arguments. 1496 auto AddUserArgs = [&](Value >Id) { 1497 for (Use &U : GTId.uses()) 1498 if (CallInst *CI = dyn_cast<CallInst>(U.getUser())) 1499 if (CI->isArgOperand(&U)) 1500 if (Function *Callee = CI->getCalledFunction()) 1501 if (CallArgOpIsGTId(*Callee, U.getOperandNo(), *CI)) 1502 GTIdArgs.insert(Callee->getArg(U.getOperandNo())); 1503 }; 1504 1505 // The argument users of __kmpc_global_thread_num calls are GTIds. 1506 OMPInformationCache::RuntimeFunctionInfo &GlobThreadNumRFI = 1507 OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num]; 1508 1509 GlobThreadNumRFI.foreachUse(SCC, [&](Use &U, Function &F) { 1510 if (CallInst *CI = getCallIfRegularCall(U, &GlobThreadNumRFI)) 1511 AddUserArgs(*CI); 1512 return false; 1513 }); 1514 1515 // Transitively search for more arguments by looking at the users of the 1516 // ones we know already. During the search the GTIdArgs vector is extended 1517 // so we cannot cache the size nor can we use a range based for. 1518 for (unsigned u = 0; u < GTIdArgs.size(); ++u) 1519 AddUserArgs(*GTIdArgs[u]); 1520 } 1521 1522 /// Kernel (=GPU) optimizations and utility functions 1523 /// 1524 ///{{ 1525 1526 /// Check if \p F is a kernel, hence entry point for target offloading. 1527 bool isKernel(Function &F) { return OMPInfoCache.Kernels.count(&F); } 1528 1529 /// Cache to remember the unique kernel for a function. 1530 DenseMap<Function *, Optional<Kernel>> UniqueKernelMap; 1531 1532 /// Find the unique kernel that will execute \p F, if any. 1533 Kernel getUniqueKernelFor(Function &F); 1534 1535 /// Find the unique kernel that will execute \p I, if any. 1536 Kernel getUniqueKernelFor(Instruction &I) { 1537 return getUniqueKernelFor(*I.getFunction()); 1538 } 1539 1540 /// Rewrite the device (=GPU) code state machine create in non-SPMD mode in 1541 /// the cases we can avoid taking the address of a function. 1542 bool rewriteDeviceCodeStateMachine(); 1543 1544 /// 1545 ///}} 1546 1547 /// Emit a remark generically 1548 /// 1549 /// This template function can be used to generically emit a remark. The 1550 /// RemarkKind should be one of the following: 1551 /// - OptimizationRemark to indicate a successful optimization attempt 1552 /// - OptimizationRemarkMissed to report a failed optimization attempt 1553 /// - OptimizationRemarkAnalysis to provide additional information about an 1554 /// optimization attempt 1555 /// 1556 /// The remark is built using a callback function provided by the caller that 1557 /// takes a RemarkKind as input and returns a RemarkKind. 1558 template <typename RemarkKind, 1559 typename RemarkCallBack = function_ref<RemarkKind(RemarkKind &&)>> 1560 void emitRemark(Instruction *Inst, StringRef RemarkName, 1561 RemarkCallBack &&RemarkCB) const { 1562 Function *F = Inst->getParent()->getParent(); 1563 auto &ORE = OREGetter(F); 1564 1565 ORE.emit( 1566 [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, Inst)); }); 1567 } 1568 1569 /// Emit a remark on a function. Since only OptimizationRemark is supporting 1570 /// this, it can't be made generic. 1571 void 1572 emitRemarkOnFunction(Function *F, StringRef RemarkName, 1573 function_ref<OptimizationRemark(OptimizationRemark &&)> 1574 &&RemarkCB) const { 1575 auto &ORE = OREGetter(F); 1576 1577 ORE.emit([&]() { 1578 return RemarkCB(OptimizationRemark(DEBUG_TYPE, RemarkName, F)); 1579 }); 1580 } 1581 1582 /// The underlying module. 1583 Module &M; 1584 1585 /// The SCC we are operating on. 1586 SmallVectorImpl<Function *> &SCC; 1587 1588 /// Callback to update the call graph, the first argument is a removed call, 1589 /// the second an optional replacement call. 1590 CallGraphUpdater &CGUpdater; 1591 1592 /// Callback to get an OptimizationRemarkEmitter from a Function * 1593 OptimizationRemarkGetter OREGetter; 1594 1595 /// OpenMP-specific information cache. Also Used for Attributor runs. 1596 OMPInformationCache &OMPInfoCache; 1597 1598 /// Attributor instance. 1599 Attributor &A; 1600 1601 /// Helper function to run Attributor on SCC. 1602 bool runAttributor() { 1603 if (SCC.empty()) 1604 return false; 1605 1606 registerAAs(); 1607 1608 ChangeStatus Changed = A.run(); 1609 1610 LLVM_DEBUG(dbgs() << "[Attributor] Done with " << SCC.size() 1611 << " functions, result: " << Changed << ".\n"); 1612 1613 return Changed == ChangeStatus::CHANGED; 1614 } 1615 1616 /// Populate the Attributor with abstract attribute opportunities in the 1617 /// function. 1618 void registerAAs() { 1619 if (SCC.empty()) 1620 return; 1621 1622 // Create CallSite AA for all Getters. 1623 for (int Idx = 0; Idx < OMPInfoCache.ICVs.size() - 1; ++Idx) { 1624 auto ICVInfo = OMPInfoCache.ICVs[static_cast<InternalControlVar>(Idx)]; 1625 1626 auto &GetterRFI = OMPInfoCache.RFIs[ICVInfo.Getter]; 1627 1628 auto CreateAA = [&](Use &U, Function &Caller) { 1629 CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &GetterRFI); 1630 if (!CI) 1631 return false; 1632 1633 auto &CB = cast<CallBase>(*CI); 1634 1635 IRPosition CBPos = IRPosition::callsite_function(CB); 1636 A.getOrCreateAAFor<AAICVTracker>(CBPos); 1637 return false; 1638 }; 1639 1640 GetterRFI.foreachUse(SCC, CreateAA); 1641 } 1642 1643 for (auto &F : M) { 1644 if (!F.isDeclaration()) 1645 A.getOrCreateAAFor<AAExecutionDomain>(IRPosition::function(F)); 1646 } 1647 } 1648 }; 1649 1650 Kernel OpenMPOpt::getUniqueKernelFor(Function &F) { 1651 if (!OMPInfoCache.ModuleSlice.count(&F)) 1652 return nullptr; 1653 1654 // Use a scope to keep the lifetime of the CachedKernel short. 1655 { 1656 Optional<Kernel> &CachedKernel = UniqueKernelMap[&F]; 1657 if (CachedKernel) 1658 return *CachedKernel; 1659 1660 // TODO: We should use an AA to create an (optimistic and callback 1661 // call-aware) call graph. For now we stick to simple patterns that 1662 // are less powerful, basically the worst fixpoint. 1663 if (isKernel(F)) { 1664 CachedKernel = Kernel(&F); 1665 return *CachedKernel; 1666 } 1667 1668 CachedKernel = nullptr; 1669 if (!F.hasLocalLinkage()) { 1670 1671 // See https://openmp.llvm.org/remarks/OptimizationRemarks.html 1672 auto Remark = [&](OptimizationRemark OR) { 1673 return OR << "[OMP100] Potentially unknown OpenMP target region caller"; 1674 }; 1675 emitRemarkOnFunction(&F, "OMP100", Remark); 1676 1677 return nullptr; 1678 } 1679 } 1680 1681 auto GetUniqueKernelForUse = [&](const Use &U) -> Kernel { 1682 if (auto *Cmp = dyn_cast<ICmpInst>(U.getUser())) { 1683 // Allow use in equality comparisons. 1684 if (Cmp->isEquality()) 1685 return getUniqueKernelFor(*Cmp); 1686 return nullptr; 1687 } 1688 if (auto *CB = dyn_cast<CallBase>(U.getUser())) { 1689 // Allow direct calls. 1690 if (CB->isCallee(&U)) 1691 return getUniqueKernelFor(*CB); 1692 1693 OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI = 1694 OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51]; 1695 // Allow the use in __kmpc_parallel_51 calls. 1696 if (OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI)) 1697 return getUniqueKernelFor(*CB); 1698 return nullptr; 1699 } 1700 // Disallow every other use. 1701 return nullptr; 1702 }; 1703 1704 // TODO: In the future we want to track more than just a unique kernel. 1705 SmallPtrSet<Kernel, 2> PotentialKernels; 1706 OMPInformationCache::foreachUse(F, [&](const Use &U) { 1707 PotentialKernels.insert(GetUniqueKernelForUse(U)); 1708 }); 1709 1710 Kernel K = nullptr; 1711 if (PotentialKernels.size() == 1) 1712 K = *PotentialKernels.begin(); 1713 1714 // Cache the result. 1715 UniqueKernelMap[&F] = K; 1716 1717 return K; 1718 } 1719 1720 bool OpenMPOpt::rewriteDeviceCodeStateMachine() { 1721 OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI = 1722 OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51]; 1723 1724 bool Changed = false; 1725 if (!KernelParallelRFI) 1726 return Changed; 1727 1728 for (Function *F : SCC) { 1729 1730 // Check if the function is a use in a __kmpc_parallel_51 call at 1731 // all. 1732 bool UnknownUse = false; 1733 bool KernelParallelUse = false; 1734 unsigned NumDirectCalls = 0; 1735 1736 SmallVector<Use *, 2> ToBeReplacedStateMachineUses; 1737 OMPInformationCache::foreachUse(*F, [&](Use &U) { 1738 if (auto *CB = dyn_cast<CallBase>(U.getUser())) 1739 if (CB->isCallee(&U)) { 1740 ++NumDirectCalls; 1741 return; 1742 } 1743 1744 if (isa<ICmpInst>(U.getUser())) { 1745 ToBeReplacedStateMachineUses.push_back(&U); 1746 return; 1747 } 1748 1749 // Find wrapper functions that represent parallel kernels. 1750 CallInst *CI = 1751 OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI); 1752 const unsigned int WrapperFunctionArgNo = 6; 1753 if (!KernelParallelUse && CI && 1754 CI->getArgOperandNo(&U) == WrapperFunctionArgNo) { 1755 KernelParallelUse = true; 1756 ToBeReplacedStateMachineUses.push_back(&U); 1757 return; 1758 } 1759 UnknownUse = true; 1760 }); 1761 1762 // Do not emit a remark if we haven't seen a __kmpc_parallel_51 1763 // use. 1764 if (!KernelParallelUse) 1765 continue; 1766 1767 { 1768 auto Remark = [&](OptimizationRemark OR) { 1769 return OR << "Found a parallel region that is called in a target " 1770 "region but not part of a combined target construct nor " 1771 "nested inside a target construct without intermediate " 1772 "code. This can lead to excessive register usage for " 1773 "unrelated target regions in the same translation unit " 1774 "due to spurious call edges assumed by ptxas."; 1775 }; 1776 emitRemarkOnFunction(F, "OpenMPParallelRegionInNonSPMD", Remark); 1777 } 1778 1779 // If this ever hits, we should investigate. 1780 // TODO: Checking the number of uses is not a necessary restriction and 1781 // should be lifted. 1782 if (UnknownUse || NumDirectCalls != 1 || 1783 ToBeReplacedStateMachineUses.size() != 2) { 1784 { 1785 auto Remark = [&](OptimizationRemark OR) { 1786 return OR << "Parallel region is used in " 1787 << (UnknownUse ? "unknown" : "unexpected") 1788 << " ways; will not attempt to rewrite the state machine."; 1789 }; 1790 emitRemarkOnFunction(F, "OpenMPParallelRegionInNonSPMD", Remark); 1791 } 1792 continue; 1793 } 1794 1795 // Even if we have __kmpc_parallel_51 calls, we (for now) give 1796 // up if the function is not called from a unique kernel. 1797 Kernel K = getUniqueKernelFor(*F); 1798 if (!K) { 1799 { 1800 auto Remark = [&](OptimizationRemark OR) { 1801 return OR << "Parallel region is not known to be called from a " 1802 "unique single target region, maybe the surrounding " 1803 "function has external linkage?; will not attempt to " 1804 "rewrite the state machine use."; 1805 }; 1806 emitRemarkOnFunction(F, "OpenMPParallelRegionInMultipleKernesl", 1807 Remark); 1808 } 1809 continue; 1810 } 1811 1812 // We now know F is a parallel body function called only from the kernel K. 1813 // We also identified the state machine uses in which we replace the 1814 // function pointer by a new global symbol for identification purposes. This 1815 // ensures only direct calls to the function are left. 1816 1817 { 1818 auto RemarkParalleRegion = [&](OptimizationRemark OR) { 1819 return OR << "Specialize parallel region that is only reached from a " 1820 "single target region to avoid spurious call edges and " 1821 "excessive register usage in other target regions. " 1822 "(parallel region ID: " 1823 << ore::NV("OpenMPParallelRegion", F->getName()) 1824 << ", kernel ID: " 1825 << ore::NV("OpenMPTargetRegion", K->getName()) << ")"; 1826 }; 1827 emitRemarkOnFunction(F, "OpenMPParallelRegionInNonSPMD", 1828 RemarkParalleRegion); 1829 auto RemarkKernel = [&](OptimizationRemark OR) { 1830 return OR << "Target region containing the parallel region that is " 1831 "specialized. (parallel region ID: " 1832 << ore::NV("OpenMPParallelRegion", F->getName()) 1833 << ", kernel ID: " 1834 << ore::NV("OpenMPTargetRegion", K->getName()) << ")"; 1835 }; 1836 emitRemarkOnFunction(K, "OpenMPParallelRegionInNonSPMD", RemarkKernel); 1837 } 1838 1839 Module &M = *F->getParent(); 1840 Type *Int8Ty = Type::getInt8Ty(M.getContext()); 1841 1842 auto *ID = new GlobalVariable( 1843 M, Int8Ty, /* isConstant */ true, GlobalValue::PrivateLinkage, 1844 UndefValue::get(Int8Ty), F->getName() + ".ID"); 1845 1846 for (Use *U : ToBeReplacedStateMachineUses) 1847 U->set(ConstantExpr::getBitCast(ID, U->get()->getType())); 1848 1849 ++NumOpenMPParallelRegionsReplacedInGPUStateMachine; 1850 1851 Changed = true; 1852 } 1853 1854 return Changed; 1855 } 1856 1857 /// Abstract Attribute for tracking ICV values. 1858 struct AAICVTracker : public StateWrapper<BooleanState, AbstractAttribute> { 1859 using Base = StateWrapper<BooleanState, AbstractAttribute>; 1860 AAICVTracker(const IRPosition &IRP, Attributor &A) : Base(IRP) {} 1861 1862 void initialize(Attributor &A) override { 1863 Function *F = getAnchorScope(); 1864 if (!F || !A.isFunctionIPOAmendable(*F)) 1865 indicatePessimisticFixpoint(); 1866 } 1867 1868 /// Returns true if value is assumed to be tracked. 1869 bool isAssumedTracked() const { return getAssumed(); } 1870 1871 /// Returns true if value is known to be tracked. 1872 bool isKnownTracked() const { return getAssumed(); } 1873 1874 /// Create an abstract attribute biew for the position \p IRP. 1875 static AAICVTracker &createForPosition(const IRPosition &IRP, Attributor &A); 1876 1877 /// Return the value with which \p I can be replaced for specific \p ICV. 1878 virtual Optional<Value *> getReplacementValue(InternalControlVar ICV, 1879 const Instruction *I, 1880 Attributor &A) const { 1881 return None; 1882 } 1883 1884 /// Return an assumed unique ICV value if a single candidate is found. If 1885 /// there cannot be one, return a nullptr. If it is not clear yet, return the 1886 /// Optional::NoneType. 1887 virtual Optional<Value *> 1888 getUniqueReplacementValue(InternalControlVar ICV) const = 0; 1889 1890 // Currently only nthreads is being tracked. 1891 // this array will only grow with time. 1892 InternalControlVar TrackableICVs[1] = {ICV_nthreads}; 1893 1894 /// See AbstractAttribute::getName() 1895 const std::string getName() const override { return "AAICVTracker"; } 1896 1897 /// See AbstractAttribute::getIdAddr() 1898 const char *getIdAddr() const override { return &ID; } 1899 1900 /// This function should return true if the type of the \p AA is AAICVTracker 1901 static bool classof(const AbstractAttribute *AA) { 1902 return (AA->getIdAddr() == &ID); 1903 } 1904 1905 static const char ID; 1906 }; 1907 1908 struct AAICVTrackerFunction : public AAICVTracker { 1909 AAICVTrackerFunction(const IRPosition &IRP, Attributor &A) 1910 : AAICVTracker(IRP, A) {} 1911 1912 // FIXME: come up with better string. 1913 const std::string getAsStr() const override { return "ICVTrackerFunction"; } 1914 1915 // FIXME: come up with some stats. 1916 void trackStatistics() const override {} 1917 1918 /// We don't manifest anything for this AA. 1919 ChangeStatus manifest(Attributor &A) override { 1920 return ChangeStatus::UNCHANGED; 1921 } 1922 1923 // Map of ICV to their values at specific program point. 1924 EnumeratedArray<DenseMap<Instruction *, Value *>, InternalControlVar, 1925 InternalControlVar::ICV___last> 1926 ICVReplacementValuesMap; 1927 1928 ChangeStatus updateImpl(Attributor &A) override { 1929 ChangeStatus HasChanged = ChangeStatus::UNCHANGED; 1930 1931 Function *F = getAnchorScope(); 1932 1933 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); 1934 1935 for (InternalControlVar ICV : TrackableICVs) { 1936 auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter]; 1937 1938 auto &ValuesMap = ICVReplacementValuesMap[ICV]; 1939 auto TrackValues = [&](Use &U, Function &) { 1940 CallInst *CI = OpenMPOpt::getCallIfRegularCall(U); 1941 if (!CI) 1942 return false; 1943 1944 // FIXME: handle setters with more that 1 arguments. 1945 /// Track new value. 1946 if (ValuesMap.insert(std::make_pair(CI, CI->getArgOperand(0))).second) 1947 HasChanged = ChangeStatus::CHANGED; 1948 1949 return false; 1950 }; 1951 1952 auto CallCheck = [&](Instruction &I) { 1953 Optional<Value *> ReplVal = getValueForCall(A, &I, ICV); 1954 if (ReplVal.hasValue() && 1955 ValuesMap.insert(std::make_pair(&I, *ReplVal)).second) 1956 HasChanged = ChangeStatus::CHANGED; 1957 1958 return true; 1959 }; 1960 1961 // Track all changes of an ICV. 1962 SetterRFI.foreachUse(TrackValues, F); 1963 1964 A.checkForAllInstructions(CallCheck, *this, {Instruction::Call}, 1965 /* CheckBBLivenessOnly */ true); 1966 1967 /// TODO: Figure out a way to avoid adding entry in 1968 /// ICVReplacementValuesMap 1969 Instruction *Entry = &F->getEntryBlock().front(); 1970 if (HasChanged == ChangeStatus::CHANGED && !ValuesMap.count(Entry)) 1971 ValuesMap.insert(std::make_pair(Entry, nullptr)); 1972 } 1973 1974 return HasChanged; 1975 } 1976 1977 /// Hepler to check if \p I is a call and get the value for it if it is 1978 /// unique. 1979 Optional<Value *> getValueForCall(Attributor &A, const Instruction *I, 1980 InternalControlVar &ICV) const { 1981 1982 const auto *CB = dyn_cast<CallBase>(I); 1983 if (!CB || CB->hasFnAttr("no_openmp") || 1984 CB->hasFnAttr("no_openmp_routines")) 1985 return None; 1986 1987 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); 1988 auto &GetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Getter]; 1989 auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter]; 1990 Function *CalledFunction = CB->getCalledFunction(); 1991 1992 // Indirect call, assume ICV changes. 1993 if (CalledFunction == nullptr) 1994 return nullptr; 1995 if (CalledFunction == GetterRFI.Declaration) 1996 return None; 1997 if (CalledFunction == SetterRFI.Declaration) { 1998 if (ICVReplacementValuesMap[ICV].count(I)) 1999 return ICVReplacementValuesMap[ICV].lookup(I); 2000 2001 return nullptr; 2002 } 2003 2004 // Since we don't know, assume it changes the ICV. 2005 if (CalledFunction->isDeclaration()) 2006 return nullptr; 2007 2008 const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>( 2009 *this, IRPosition::callsite_returned(*CB), DepClassTy::REQUIRED); 2010 2011 if (ICVTrackingAA.isAssumedTracked()) 2012 return ICVTrackingAA.getUniqueReplacementValue(ICV); 2013 2014 // If we don't know, assume it changes. 2015 return nullptr; 2016 } 2017 2018 // We don't check unique value for a function, so return None. 2019 Optional<Value *> 2020 getUniqueReplacementValue(InternalControlVar ICV) const override { 2021 return None; 2022 } 2023 2024 /// Return the value with which \p I can be replaced for specific \p ICV. 2025 Optional<Value *> getReplacementValue(InternalControlVar ICV, 2026 const Instruction *I, 2027 Attributor &A) const override { 2028 const auto &ValuesMap = ICVReplacementValuesMap[ICV]; 2029 if (ValuesMap.count(I)) 2030 return ValuesMap.lookup(I); 2031 2032 SmallVector<const Instruction *, 16> Worklist; 2033 SmallPtrSet<const Instruction *, 16> Visited; 2034 Worklist.push_back(I); 2035 2036 Optional<Value *> ReplVal; 2037 2038 while (!Worklist.empty()) { 2039 const Instruction *CurrInst = Worklist.pop_back_val(); 2040 if (!Visited.insert(CurrInst).second) 2041 continue; 2042 2043 const BasicBlock *CurrBB = CurrInst->getParent(); 2044 2045 // Go up and look for all potential setters/calls that might change the 2046 // ICV. 2047 while ((CurrInst = CurrInst->getPrevNode())) { 2048 if (ValuesMap.count(CurrInst)) { 2049 Optional<Value *> NewReplVal = ValuesMap.lookup(CurrInst); 2050 // Unknown value, track new. 2051 if (!ReplVal.hasValue()) { 2052 ReplVal = NewReplVal; 2053 break; 2054 } 2055 2056 // If we found a new value, we can't know the icv value anymore. 2057 if (NewReplVal.hasValue()) 2058 if (ReplVal != NewReplVal) 2059 return nullptr; 2060 2061 break; 2062 } 2063 2064 Optional<Value *> NewReplVal = getValueForCall(A, CurrInst, ICV); 2065 if (!NewReplVal.hasValue()) 2066 continue; 2067 2068 // Unknown value, track new. 2069 if (!ReplVal.hasValue()) { 2070 ReplVal = NewReplVal; 2071 break; 2072 } 2073 2074 // if (NewReplVal.hasValue()) 2075 // We found a new value, we can't know the icv value anymore. 2076 if (ReplVal != NewReplVal) 2077 return nullptr; 2078 } 2079 2080 // If we are in the same BB and we have a value, we are done. 2081 if (CurrBB == I->getParent() && ReplVal.hasValue()) 2082 return ReplVal; 2083 2084 // Go through all predecessors and add terminators for analysis. 2085 for (const BasicBlock *Pred : predecessors(CurrBB)) 2086 if (const Instruction *Terminator = Pred->getTerminator()) 2087 Worklist.push_back(Terminator); 2088 } 2089 2090 return ReplVal; 2091 } 2092 }; 2093 2094 struct AAICVTrackerFunctionReturned : AAICVTracker { 2095 AAICVTrackerFunctionReturned(const IRPosition &IRP, Attributor &A) 2096 : AAICVTracker(IRP, A) {} 2097 2098 // FIXME: come up with better string. 2099 const std::string getAsStr() const override { 2100 return "ICVTrackerFunctionReturned"; 2101 } 2102 2103 // FIXME: come up with some stats. 2104 void trackStatistics() const override {} 2105 2106 /// We don't manifest anything for this AA. 2107 ChangeStatus manifest(Attributor &A) override { 2108 return ChangeStatus::UNCHANGED; 2109 } 2110 2111 // Map of ICV to their values at specific program point. 2112 EnumeratedArray<Optional<Value *>, InternalControlVar, 2113 InternalControlVar::ICV___last> 2114 ICVReplacementValuesMap; 2115 2116 /// Return the value with which \p I can be replaced for specific \p ICV. 2117 Optional<Value *> 2118 getUniqueReplacementValue(InternalControlVar ICV) const override { 2119 return ICVReplacementValuesMap[ICV]; 2120 } 2121 2122 ChangeStatus updateImpl(Attributor &A) override { 2123 ChangeStatus Changed = ChangeStatus::UNCHANGED; 2124 const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>( 2125 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED); 2126 2127 if (!ICVTrackingAA.isAssumedTracked()) 2128 return indicatePessimisticFixpoint(); 2129 2130 for (InternalControlVar ICV : TrackableICVs) { 2131 Optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV]; 2132 Optional<Value *> UniqueICVValue; 2133 2134 auto CheckReturnInst = [&](Instruction &I) { 2135 Optional<Value *> NewReplVal = 2136 ICVTrackingAA.getReplacementValue(ICV, &I, A); 2137 2138 // If we found a second ICV value there is no unique returned value. 2139 if (UniqueICVValue.hasValue() && UniqueICVValue != NewReplVal) 2140 return false; 2141 2142 UniqueICVValue = NewReplVal; 2143 2144 return true; 2145 }; 2146 2147 if (!A.checkForAllInstructions(CheckReturnInst, *this, {Instruction::Ret}, 2148 /* CheckBBLivenessOnly */ true)) 2149 UniqueICVValue = nullptr; 2150 2151 if (UniqueICVValue == ReplVal) 2152 continue; 2153 2154 ReplVal = UniqueICVValue; 2155 Changed = ChangeStatus::CHANGED; 2156 } 2157 2158 return Changed; 2159 } 2160 }; 2161 2162 struct AAICVTrackerCallSite : AAICVTracker { 2163 AAICVTrackerCallSite(const IRPosition &IRP, Attributor &A) 2164 : AAICVTracker(IRP, A) {} 2165 2166 void initialize(Attributor &A) override { 2167 Function *F = getAnchorScope(); 2168 if (!F || !A.isFunctionIPOAmendable(*F)) 2169 indicatePessimisticFixpoint(); 2170 2171 // We only initialize this AA for getters, so we need to know which ICV it 2172 // gets. 2173 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); 2174 for (InternalControlVar ICV : TrackableICVs) { 2175 auto ICVInfo = OMPInfoCache.ICVs[ICV]; 2176 auto &Getter = OMPInfoCache.RFIs[ICVInfo.Getter]; 2177 if (Getter.Declaration == getAssociatedFunction()) { 2178 AssociatedICV = ICVInfo.Kind; 2179 return; 2180 } 2181 } 2182 2183 /// Unknown ICV. 2184 indicatePessimisticFixpoint(); 2185 } 2186 2187 ChangeStatus manifest(Attributor &A) override { 2188 if (!ReplVal.hasValue() || !ReplVal.getValue()) 2189 return ChangeStatus::UNCHANGED; 2190 2191 A.changeValueAfterManifest(*getCtxI(), **ReplVal); 2192 A.deleteAfterManifest(*getCtxI()); 2193 2194 return ChangeStatus::CHANGED; 2195 } 2196 2197 // FIXME: come up with better string. 2198 const std::string getAsStr() const override { return "ICVTrackerCallSite"; } 2199 2200 // FIXME: come up with some stats. 2201 void trackStatistics() const override {} 2202 2203 InternalControlVar AssociatedICV; 2204 Optional<Value *> ReplVal; 2205 2206 ChangeStatus updateImpl(Attributor &A) override { 2207 const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>( 2208 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED); 2209 2210 // We don't have any information, so we assume it changes the ICV. 2211 if (!ICVTrackingAA.isAssumedTracked()) 2212 return indicatePessimisticFixpoint(); 2213 2214 Optional<Value *> NewReplVal = 2215 ICVTrackingAA.getReplacementValue(AssociatedICV, getCtxI(), A); 2216 2217 if (ReplVal == NewReplVal) 2218 return ChangeStatus::UNCHANGED; 2219 2220 ReplVal = NewReplVal; 2221 return ChangeStatus::CHANGED; 2222 } 2223 2224 // Return the value with which associated value can be replaced for specific 2225 // \p ICV. 2226 Optional<Value *> 2227 getUniqueReplacementValue(InternalControlVar ICV) const override { 2228 return ReplVal; 2229 } 2230 }; 2231 2232 struct AAICVTrackerCallSiteReturned : AAICVTracker { 2233 AAICVTrackerCallSiteReturned(const IRPosition &IRP, Attributor &A) 2234 : AAICVTracker(IRP, A) {} 2235 2236 // FIXME: come up with better string. 2237 const std::string getAsStr() const override { 2238 return "ICVTrackerCallSiteReturned"; 2239 } 2240 2241 // FIXME: come up with some stats. 2242 void trackStatistics() const override {} 2243 2244 /// We don't manifest anything for this AA. 2245 ChangeStatus manifest(Attributor &A) override { 2246 return ChangeStatus::UNCHANGED; 2247 } 2248 2249 // Map of ICV to their values at specific program point. 2250 EnumeratedArray<Optional<Value *>, InternalControlVar, 2251 InternalControlVar::ICV___last> 2252 ICVReplacementValuesMap; 2253 2254 /// Return the value with which associated value can be replaced for specific 2255 /// \p ICV. 2256 Optional<Value *> 2257 getUniqueReplacementValue(InternalControlVar ICV) const override { 2258 return ICVReplacementValuesMap[ICV]; 2259 } 2260 2261 ChangeStatus updateImpl(Attributor &A) override { 2262 ChangeStatus Changed = ChangeStatus::UNCHANGED; 2263 const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>( 2264 *this, IRPosition::returned(*getAssociatedFunction()), 2265 DepClassTy::REQUIRED); 2266 2267 // We don't have any information, so we assume it changes the ICV. 2268 if (!ICVTrackingAA.isAssumedTracked()) 2269 return indicatePessimisticFixpoint(); 2270 2271 for (InternalControlVar ICV : TrackableICVs) { 2272 Optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV]; 2273 Optional<Value *> NewReplVal = 2274 ICVTrackingAA.getUniqueReplacementValue(ICV); 2275 2276 if (ReplVal == NewReplVal) 2277 continue; 2278 2279 ReplVal = NewReplVal; 2280 Changed = ChangeStatus::CHANGED; 2281 } 2282 return Changed; 2283 } 2284 }; 2285 2286 struct AAExecutionDomainFunction : public AAExecutionDomain { 2287 AAExecutionDomainFunction(const IRPosition &IRP, Attributor &A) 2288 : AAExecutionDomain(IRP, A) {} 2289 2290 const std::string getAsStr() const override { 2291 return "[AAExecutionDomain] " + std::to_string(SingleThreadedBBs.size()) + 2292 "/" + std::to_string(NumBBs) + " BBs thread 0 only."; 2293 } 2294 2295 /// See AbstractAttribute::trackStatistics(). 2296 void trackStatistics() const override {} 2297 2298 void initialize(Attributor &A) override { 2299 Function *F = getAnchorScope(); 2300 for (const auto &BB : *F) 2301 SingleThreadedBBs.insert(&BB); 2302 NumBBs = SingleThreadedBBs.size(); 2303 } 2304 2305 ChangeStatus manifest(Attributor &A) override { 2306 LLVM_DEBUG({ 2307 for (const BasicBlock *BB : SingleThreadedBBs) 2308 dbgs() << TAG << " Basic block @" << getAnchorScope()->getName() << " " 2309 << BB->getName() << " is executed by a single thread.\n"; 2310 }); 2311 return ChangeStatus::UNCHANGED; 2312 } 2313 2314 ChangeStatus updateImpl(Attributor &A) override; 2315 2316 /// Check if an instruction is executed by a single thread. 2317 bool isSingleThreadExecution(const Instruction &I) const override { 2318 return isSingleThreadExecution(*I.getParent()); 2319 } 2320 2321 bool isSingleThreadExecution(const BasicBlock &BB) const override { 2322 return SingleThreadedBBs.contains(&BB); 2323 } 2324 2325 /// Set of basic blocks that are executed by a single thread. 2326 DenseSet<const BasicBlock *> SingleThreadedBBs; 2327 2328 /// Total number of basic blocks in this function. 2329 long unsigned NumBBs; 2330 }; 2331 2332 ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) { 2333 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); 2334 Function *F = getAnchorScope(); 2335 ReversePostOrderTraversal<Function *> RPOT(F); 2336 auto NumSingleThreadedBBs = SingleThreadedBBs.size(); 2337 2338 bool AllCallSitesKnown; 2339 auto PredForCallSite = [&](AbstractCallSite ACS) { 2340 const auto &ExecutionDomainAA = A.getAAFor<AAExecutionDomain>( 2341 *this, IRPosition::function(*ACS.getInstruction()->getFunction()), 2342 DepClassTy::REQUIRED); 2343 return ExecutionDomainAA.isSingleThreadExecution(*ACS.getInstruction()); 2344 }; 2345 2346 if (!A.checkForAllCallSites(PredForCallSite, *this, 2347 /* RequiresAllCallSites */ true, 2348 AllCallSitesKnown)) 2349 SingleThreadedBBs.erase(&F->getEntryBlock()); 2350 2351 // Check if the edge into the successor block compares a thread-id function to 2352 // a constant zero. 2353 // TODO: Use AAValueSimplify to simplify and propogate constants. 2354 // TODO: Check more than a single use for thread ID's. 2355 auto IsSingleThreadOnly = [&](BranchInst *Edge, BasicBlock *SuccessorBB) { 2356 if (!Edge || !Edge->isConditional()) 2357 return false; 2358 if (Edge->getSuccessor(0) != SuccessorBB) 2359 return false; 2360 2361 auto *Cmp = dyn_cast<CmpInst>(Edge->getCondition()); 2362 if (!Cmp || !Cmp->isTrueWhenEqual() || !Cmp->isEquality()) 2363 return false; 2364 2365 ConstantInt *C = dyn_cast<ConstantInt>(Cmp->getOperand(1)); 2366 if (!C || !C->isZero()) 2367 return false; 2368 2369 if (auto *CB = dyn_cast<CallBase>(Cmp->getOperand(0))) { 2370 RuntimeFunction ThreadNumRuntimeIDs[] = {OMPRTL_omp_get_thread_num, 2371 OMPRTL___kmpc_master, 2372 OMPRTL___kmpc_global_thread_num}; 2373 2374 for (const auto ThreadNumRuntimeID : ThreadNumRuntimeIDs) { 2375 auto &RFI = OMPInfoCache.RFIs[ThreadNumRuntimeID]; 2376 if (CB->getCalledFunction() == RFI.Declaration) 2377 return true; 2378 } 2379 } 2380 2381 return false; 2382 }; 2383 2384 // Merge all the predecessor states into the current basic block. A basic 2385 // block is executed by a single thread if all of its predecessors are. 2386 auto MergePredecessorStates = [&](BasicBlock *BB) { 2387 if (pred_begin(BB) == pred_end(BB)) 2388 return SingleThreadedBBs.contains(BB); 2389 2390 bool IsSingleThreaded = true; 2391 for (auto PredBB = pred_begin(BB), PredEndBB = pred_end(BB); 2392 PredBB != PredEndBB; ++PredBB) { 2393 if (!IsSingleThreadOnly(dyn_cast<BranchInst>((*PredBB)->getTerminator()), 2394 BB)) 2395 IsSingleThreaded &= SingleThreadedBBs.contains(*PredBB); 2396 } 2397 2398 return IsSingleThreaded; 2399 }; 2400 2401 for (auto *BB : RPOT) { 2402 if (!MergePredecessorStates(BB)) 2403 SingleThreadedBBs.erase(BB); 2404 } 2405 2406 return (NumSingleThreadedBBs == SingleThreadedBBs.size()) 2407 ? ChangeStatus::UNCHANGED 2408 : ChangeStatus::CHANGED; 2409 } 2410 2411 } // namespace 2412 2413 const char AAICVTracker::ID = 0; 2414 const char AAExecutionDomain::ID = 0; 2415 2416 AAICVTracker &AAICVTracker::createForPosition(const IRPosition &IRP, 2417 Attributor &A) { 2418 AAICVTracker *AA = nullptr; 2419 switch (IRP.getPositionKind()) { 2420 case IRPosition::IRP_INVALID: 2421 case IRPosition::IRP_FLOAT: 2422 case IRPosition::IRP_ARGUMENT: 2423 case IRPosition::IRP_CALL_SITE_ARGUMENT: 2424 llvm_unreachable("ICVTracker can only be created for function position!"); 2425 case IRPosition::IRP_RETURNED: 2426 AA = new (A.Allocator) AAICVTrackerFunctionReturned(IRP, A); 2427 break; 2428 case IRPosition::IRP_CALL_SITE_RETURNED: 2429 AA = new (A.Allocator) AAICVTrackerCallSiteReturned(IRP, A); 2430 break; 2431 case IRPosition::IRP_CALL_SITE: 2432 AA = new (A.Allocator) AAICVTrackerCallSite(IRP, A); 2433 break; 2434 case IRPosition::IRP_FUNCTION: 2435 AA = new (A.Allocator) AAICVTrackerFunction(IRP, A); 2436 break; 2437 } 2438 2439 return *AA; 2440 } 2441 2442 AAExecutionDomain &AAExecutionDomain::createForPosition(const IRPosition &IRP, 2443 Attributor &A) { 2444 AAExecutionDomainFunction *AA = nullptr; 2445 switch (IRP.getPositionKind()) { 2446 case IRPosition::IRP_INVALID: 2447 case IRPosition::IRP_FLOAT: 2448 case IRPosition::IRP_ARGUMENT: 2449 case IRPosition::IRP_CALL_SITE_ARGUMENT: 2450 case IRPosition::IRP_RETURNED: 2451 case IRPosition::IRP_CALL_SITE_RETURNED: 2452 case IRPosition::IRP_CALL_SITE: 2453 llvm_unreachable( 2454 "AAExecutionDomain can only be created for function position!"); 2455 case IRPosition::IRP_FUNCTION: 2456 AA = new (A.Allocator) AAExecutionDomainFunction(IRP, A); 2457 break; 2458 } 2459 2460 return *AA; 2461 } 2462 2463 PreservedAnalyses OpenMPOptPass::run(Module &M, ModuleAnalysisManager &AM) { 2464 if (!containsOpenMP(M, OMPInModule)) 2465 return PreservedAnalyses::all(); 2466 2467 if (DisableOpenMPOptimizations) 2468 return PreservedAnalyses::all(); 2469 2470 // Look at every function definition in the Module. 2471 SmallVector<Function *, 16> SCC; 2472 for (Function &Fn : M) 2473 if (!Fn.isDeclaration()) 2474 SCC.push_back(&Fn); 2475 2476 if (SCC.empty()) 2477 return PreservedAnalyses::all(); 2478 2479 FunctionAnalysisManager &FAM = 2480 AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); 2481 2482 AnalysisGetter AG(FAM); 2483 2484 auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & { 2485 return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F); 2486 }; 2487 2488 BumpPtrAllocator Allocator; 2489 CallGraphUpdater CGUpdater; 2490 2491 SetVector<Function *> Functions(SCC.begin(), SCC.end()); 2492 OMPInformationCache InfoCache(M, AG, Allocator, /*CGSCC*/ Functions, 2493 OMPInModule.getKernels()); 2494 2495 Attributor A(Functions, InfoCache, CGUpdater); 2496 2497 OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A); 2498 bool Changed = OMPOpt.run(true); 2499 if (Changed) 2500 return PreservedAnalyses::none(); 2501 2502 return PreservedAnalyses::all(); 2503 } 2504 2505 PreservedAnalyses OpenMPOptCGSCCPass::run(LazyCallGraph::SCC &C, 2506 CGSCCAnalysisManager &AM, 2507 LazyCallGraph &CG, 2508 CGSCCUpdateResult &UR) { 2509 if (!containsOpenMP(*C.begin()->getFunction().getParent(), OMPInModule)) 2510 return PreservedAnalyses::all(); 2511 2512 if (DisableOpenMPOptimizations) 2513 return PreservedAnalyses::all(); 2514 2515 SmallVector<Function *, 16> SCC; 2516 // If there are kernels in the module, we have to run on all SCC's. 2517 bool SCCIsInteresting = !OMPInModule.getKernels().empty(); 2518 for (LazyCallGraph::Node &N : C) { 2519 Function *Fn = &N.getFunction(); 2520 SCC.push_back(Fn); 2521 2522 // Do we already know that the SCC contains kernels, 2523 // or that OpenMP functions are called from this SCC? 2524 if (SCCIsInteresting) 2525 continue; 2526 // If not, let's check that. 2527 SCCIsInteresting |= OMPInModule.containsOMPRuntimeCalls(Fn); 2528 } 2529 2530 if (!SCCIsInteresting || SCC.empty()) 2531 return PreservedAnalyses::all(); 2532 2533 FunctionAnalysisManager &FAM = 2534 AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager(); 2535 2536 AnalysisGetter AG(FAM); 2537 2538 auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & { 2539 return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F); 2540 }; 2541 2542 BumpPtrAllocator Allocator; 2543 CallGraphUpdater CGUpdater; 2544 CGUpdater.initialize(CG, C, AM, UR); 2545 2546 SetVector<Function *> Functions(SCC.begin(), SCC.end()); 2547 OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG, Allocator, 2548 /*CGSCC*/ Functions, OMPInModule.getKernels()); 2549 2550 Attributor A(Functions, InfoCache, CGUpdater, nullptr, false); 2551 2552 OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A); 2553 bool Changed = OMPOpt.run(false); 2554 if (Changed) 2555 return PreservedAnalyses::none(); 2556 2557 return PreservedAnalyses::all(); 2558 } 2559 2560 namespace { 2561 2562 struct OpenMPOptCGSCCLegacyPass : public CallGraphSCCPass { 2563 CallGraphUpdater CGUpdater; 2564 OpenMPInModule OMPInModule; 2565 static char ID; 2566 2567 OpenMPOptCGSCCLegacyPass() : CallGraphSCCPass(ID) { 2568 initializeOpenMPOptCGSCCLegacyPassPass(*PassRegistry::getPassRegistry()); 2569 } 2570 2571 void getAnalysisUsage(AnalysisUsage &AU) const override { 2572 CallGraphSCCPass::getAnalysisUsage(AU); 2573 } 2574 2575 bool doInitialization(CallGraph &CG) override { 2576 // Disable the pass if there is no OpenMP (runtime call) in the module. 2577 containsOpenMP(CG.getModule(), OMPInModule); 2578 return false; 2579 } 2580 2581 bool runOnSCC(CallGraphSCC &CGSCC) override { 2582 if (!containsOpenMP(CGSCC.getCallGraph().getModule(), OMPInModule)) 2583 return false; 2584 if (DisableOpenMPOptimizations || skipSCC(CGSCC)) 2585 return false; 2586 2587 SmallVector<Function *, 16> SCC; 2588 // If there are kernels in the module, we have to run on all SCC's. 2589 bool SCCIsInteresting = !OMPInModule.getKernels().empty(); 2590 for (CallGraphNode *CGN : CGSCC) { 2591 Function *Fn = CGN->getFunction(); 2592 if (!Fn || Fn->isDeclaration()) 2593 continue; 2594 SCC.push_back(Fn); 2595 2596 // Do we already know that the SCC contains kernels, 2597 // or that OpenMP functions are called from this SCC? 2598 if (SCCIsInteresting) 2599 continue; 2600 // If not, let's check that. 2601 SCCIsInteresting |= OMPInModule.containsOMPRuntimeCalls(Fn); 2602 } 2603 2604 if (!SCCIsInteresting || SCC.empty()) 2605 return false; 2606 2607 CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph(); 2608 CGUpdater.initialize(CG, CGSCC); 2609 2610 // Maintain a map of functions to avoid rebuilding the ORE 2611 DenseMap<Function *, std::unique_ptr<OptimizationRemarkEmitter>> OREMap; 2612 auto OREGetter = [&OREMap](Function *F) -> OptimizationRemarkEmitter & { 2613 std::unique_ptr<OptimizationRemarkEmitter> &ORE = OREMap[F]; 2614 if (!ORE) 2615 ORE = std::make_unique<OptimizationRemarkEmitter>(F); 2616 return *ORE; 2617 }; 2618 2619 AnalysisGetter AG; 2620 SetVector<Function *> Functions(SCC.begin(), SCC.end()); 2621 BumpPtrAllocator Allocator; 2622 OMPInformationCache InfoCache( 2623 *(Functions.back()->getParent()), AG, Allocator, 2624 /*CGSCC*/ Functions, OMPInModule.getKernels()); 2625 2626 Attributor A(Functions, InfoCache, CGUpdater, nullptr, false); 2627 2628 OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A); 2629 return OMPOpt.run(false); 2630 } 2631 2632 bool doFinalization(CallGraph &CG) override { return CGUpdater.finalize(); } 2633 }; 2634 2635 } // end anonymous namespace 2636 2637 void OpenMPInModule::identifyKernels(Module &M) { 2638 2639 NamedMDNode *MD = M.getOrInsertNamedMetadata("nvvm.annotations"); 2640 if (!MD) 2641 return; 2642 2643 for (auto *Op : MD->operands()) { 2644 if (Op->getNumOperands() < 2) 2645 continue; 2646 MDString *KindID = dyn_cast<MDString>(Op->getOperand(1)); 2647 if (!KindID || KindID->getString() != "kernel") 2648 continue; 2649 2650 Function *KernelFn = 2651 mdconst::dyn_extract_or_null<Function>(Op->getOperand(0)); 2652 if (!KernelFn) 2653 continue; 2654 2655 ++NumOpenMPTargetRegionKernels; 2656 2657 Kernels.insert(KernelFn); 2658 } 2659 } 2660 2661 bool llvm::omp::containsOpenMP(Module &M, OpenMPInModule &OMPInModule) { 2662 if (OMPInModule.isKnown()) 2663 return OMPInModule; 2664 2665 auto RecordFunctionsContainingUsesOf = [&](Function *F) { 2666 for (User *U : F->users()) 2667 if (auto *I = dyn_cast<Instruction>(U)) 2668 OMPInModule.FuncsWithOMPRuntimeCalls.insert(I->getFunction()); 2669 }; 2670 2671 // MSVC doesn't like long if-else chains for some reason and instead just 2672 // issues an error. Work around it.. 2673 do { 2674 #define OMP_RTL(_Enum, _Name, ...) \ 2675 if (Function *F = M.getFunction(_Name)) { \ 2676 RecordFunctionsContainingUsesOf(F); \ 2677 OMPInModule = true; \ 2678 } 2679 #include "llvm/Frontend/OpenMP/OMPKinds.def" 2680 } while (false); 2681 2682 // Identify kernels once. TODO: We should split the OMPInformationCache into a 2683 // module and an SCC part. The kernel information, among other things, could 2684 // go into the module part. 2685 if (OMPInModule.isKnown() && OMPInModule) { 2686 OMPInModule.identifyKernels(M); 2687 return true; 2688 } 2689 2690 return OMPInModule = false; 2691 } 2692 2693 char OpenMPOptCGSCCLegacyPass::ID = 0; 2694 2695 INITIALIZE_PASS_BEGIN(OpenMPOptCGSCCLegacyPass, "openmp-opt-cgscc", 2696 "OpenMP specific optimizations", false, false) 2697 INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass) 2698 INITIALIZE_PASS_END(OpenMPOptCGSCCLegacyPass, "openmp-opt-cgscc", 2699 "OpenMP specific optimizations", false, false) 2700 2701 Pass *llvm::createOpenMPOptCGSCCLegacyPass() { 2702 return new OpenMPOptCGSCCLegacyPass(); 2703 } 2704