1 //===-- IPO/OpenMPOpt.cpp - Collection of OpenMP specific optimizations ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // OpenMP specific optimizations: 10 // 11 // - Deduplication of runtime calls, e.g., omp_get_thread_num. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "llvm/Transforms/IPO/OpenMPOpt.h" 16 17 #include "llvm/ADT/EnumeratedArray.h" 18 #include "llvm/ADT/PostOrderIterator.h" 19 #include "llvm/ADT/Statistic.h" 20 #include "llvm/Analysis/CallGraph.h" 21 #include "llvm/Analysis/CallGraphSCCPass.h" 22 #include "llvm/Analysis/OptimizationRemarkEmitter.h" 23 #include "llvm/Analysis/ValueTracking.h" 24 #include "llvm/Frontend/OpenMP/OMPConstants.h" 25 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h" 26 #include "llvm/IR/IntrinsicInst.h" 27 #include "llvm/IR/IntrinsicsAMDGPU.h" 28 #include "llvm/IR/IntrinsicsNVPTX.h" 29 #include "llvm/InitializePasses.h" 30 #include "llvm/Support/CommandLine.h" 31 #include "llvm/Transforms/IPO.h" 32 #include "llvm/Transforms/IPO/Attributor.h" 33 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 34 #include "llvm/Transforms/Utils/CallGraphUpdater.h" 35 #include "llvm/Transforms/Utils/CodeExtractor.h" 36 37 using namespace llvm; 38 using namespace omp; 39 40 #define DEBUG_TYPE "openmp-opt" 41 42 static cl::opt<bool> DisableOpenMPOptimizations( 43 "openmp-opt-disable", cl::ZeroOrMore, 44 cl::desc("Disable OpenMP specific optimizations."), cl::Hidden, 45 cl::init(false)); 46 47 static cl::opt<bool> EnableParallelRegionMerging( 48 "openmp-opt-enable-merging", cl::ZeroOrMore, 49 cl::desc("Enable the OpenMP region merging optimization."), cl::Hidden, 50 cl::init(false)); 51 52 static cl::opt<bool> PrintICVValues("openmp-print-icv-values", cl::init(false), 53 cl::Hidden); 54 static cl::opt<bool> PrintOpenMPKernels("openmp-print-gpu-kernels", 55 cl::init(false), cl::Hidden); 56 57 static cl::opt<bool> HideMemoryTransferLatency( 58 "openmp-hide-memory-transfer-latency", 59 cl::desc("[WIP] Tries to hide the latency of host to device memory" 60 " transfers"), 61 cl::Hidden, cl::init(false)); 62 63 STATISTIC(NumOpenMPRuntimeCallsDeduplicated, 64 "Number of OpenMP runtime calls deduplicated"); 65 STATISTIC(NumOpenMPParallelRegionsDeleted, 66 "Number of OpenMP parallel regions deleted"); 67 STATISTIC(NumOpenMPRuntimeFunctionsIdentified, 68 "Number of OpenMP runtime functions identified"); 69 STATISTIC(NumOpenMPRuntimeFunctionUsesIdentified, 70 "Number of OpenMP runtime function uses identified"); 71 STATISTIC(NumOpenMPTargetRegionKernels, 72 "Number of OpenMP target region entry points (=kernels) identified"); 73 STATISTIC( 74 NumOpenMPParallelRegionsReplacedInGPUStateMachine, 75 "Number of OpenMP parallel regions replaced with ID in GPU state machines"); 76 STATISTIC(NumOpenMPParallelRegionsMerged, 77 "Number of OpenMP parallel regions merged"); 78 79 #if !defined(NDEBUG) 80 static constexpr auto TAG = "[" DEBUG_TYPE "]"; 81 #endif 82 83 namespace { 84 85 struct AAExecutionDomain 86 : public StateWrapper<BooleanState, AbstractAttribute> { 87 using Base = StateWrapper<BooleanState, AbstractAttribute>; 88 AAExecutionDomain(const IRPosition &IRP, Attributor &A) : Base(IRP) {} 89 90 /// Create an abstract attribute view for the position \p IRP. 91 static AAExecutionDomain &createForPosition(const IRPosition &IRP, 92 Attributor &A); 93 94 /// See AbstractAttribute::getName(). 95 const std::string getName() const override { return "AAExecutionDomain"; } 96 97 /// See AbstractAttribute::getIdAddr(). 98 const char *getIdAddr() const override { return &ID; } 99 100 /// Check if an instruction is executed by a single thread. 101 virtual bool isSingleThreadExecution(const Instruction &) const = 0; 102 103 virtual bool isSingleThreadExecution(const BasicBlock &) const = 0; 104 105 /// This function should return true if the type of the \p AA is 106 /// AAExecutionDomain. 107 static bool classof(const AbstractAttribute *AA) { 108 return (AA->getIdAddr() == &ID); 109 } 110 111 /// Unique ID (due to the unique address) 112 static const char ID; 113 }; 114 115 struct AAICVTracker; 116 117 /// OpenMP specific information. For now, stores RFIs and ICVs also needed for 118 /// Attributor runs. 119 struct OMPInformationCache : public InformationCache { 120 OMPInformationCache(Module &M, AnalysisGetter &AG, 121 BumpPtrAllocator &Allocator, SetVector<Function *> &CGSCC, 122 SmallPtrSetImpl<Kernel> &Kernels) 123 : InformationCache(M, AG, Allocator, &CGSCC), OMPBuilder(M), 124 Kernels(Kernels) { 125 126 OMPBuilder.initialize(); 127 initializeRuntimeFunctions(); 128 initializeInternalControlVars(); 129 } 130 131 /// Generic information that describes an internal control variable. 132 struct InternalControlVarInfo { 133 /// The kind, as described by InternalControlVar enum. 134 InternalControlVar Kind; 135 136 /// The name of the ICV. 137 StringRef Name; 138 139 /// Environment variable associated with this ICV. 140 StringRef EnvVarName; 141 142 /// Initial value kind. 143 ICVInitValue InitKind; 144 145 /// Initial value. 146 ConstantInt *InitValue; 147 148 /// Setter RTL function associated with this ICV. 149 RuntimeFunction Setter; 150 151 /// Getter RTL function associated with this ICV. 152 RuntimeFunction Getter; 153 154 /// RTL Function corresponding to the override clause of this ICV 155 RuntimeFunction Clause; 156 }; 157 158 /// Generic information that describes a runtime function 159 struct RuntimeFunctionInfo { 160 161 /// The kind, as described by the RuntimeFunction enum. 162 RuntimeFunction Kind; 163 164 /// The name of the function. 165 StringRef Name; 166 167 /// Flag to indicate a variadic function. 168 bool IsVarArg; 169 170 /// The return type of the function. 171 Type *ReturnType; 172 173 /// The argument types of the function. 174 SmallVector<Type *, 8> ArgumentTypes; 175 176 /// The declaration if available. 177 Function *Declaration = nullptr; 178 179 /// Uses of this runtime function per function containing the use. 180 using UseVector = SmallVector<Use *, 16>; 181 182 /// Clear UsesMap for runtime function. 183 void clearUsesMap() { UsesMap.clear(); } 184 185 /// Boolean conversion that is true if the runtime function was found. 186 operator bool() const { return Declaration; } 187 188 /// Return the vector of uses in function \p F. 189 UseVector &getOrCreateUseVector(Function *F) { 190 std::shared_ptr<UseVector> &UV = UsesMap[F]; 191 if (!UV) 192 UV = std::make_shared<UseVector>(); 193 return *UV; 194 } 195 196 /// Return the vector of uses in function \p F or `nullptr` if there are 197 /// none. 198 const UseVector *getUseVector(Function &F) const { 199 auto I = UsesMap.find(&F); 200 if (I != UsesMap.end()) 201 return I->second.get(); 202 return nullptr; 203 } 204 205 /// Return how many functions contain uses of this runtime function. 206 size_t getNumFunctionsWithUses() const { return UsesMap.size(); } 207 208 /// Return the number of arguments (or the minimal number for variadic 209 /// functions). 210 size_t getNumArgs() const { return ArgumentTypes.size(); } 211 212 /// Run the callback \p CB on each use and forget the use if the result is 213 /// true. The callback will be fed the function in which the use was 214 /// encountered as second argument. 215 void foreachUse(SmallVectorImpl<Function *> &SCC, 216 function_ref<bool(Use &, Function &)> CB) { 217 for (Function *F : SCC) 218 foreachUse(CB, F); 219 } 220 221 /// Run the callback \p CB on each use within the function \p F and forget 222 /// the use if the result is true. 223 void foreachUse(function_ref<bool(Use &, Function &)> CB, Function *F) { 224 SmallVector<unsigned, 8> ToBeDeleted; 225 ToBeDeleted.clear(); 226 227 unsigned Idx = 0; 228 UseVector &UV = getOrCreateUseVector(F); 229 230 for (Use *U : UV) { 231 if (CB(*U, *F)) 232 ToBeDeleted.push_back(Idx); 233 ++Idx; 234 } 235 236 // Remove the to-be-deleted indices in reverse order as prior 237 // modifications will not modify the smaller indices. 238 while (!ToBeDeleted.empty()) { 239 unsigned Idx = ToBeDeleted.pop_back_val(); 240 UV[Idx] = UV.back(); 241 UV.pop_back(); 242 } 243 } 244 245 private: 246 /// Map from functions to all uses of this runtime function contained in 247 /// them. 248 DenseMap<Function *, std::shared_ptr<UseVector>> UsesMap; 249 }; 250 251 /// An OpenMP-IR-Builder instance 252 OpenMPIRBuilder OMPBuilder; 253 254 /// Map from runtime function kind to the runtime function description. 255 EnumeratedArray<RuntimeFunctionInfo, RuntimeFunction, 256 RuntimeFunction::OMPRTL___last> 257 RFIs; 258 259 /// Map from ICV kind to the ICV description. 260 EnumeratedArray<InternalControlVarInfo, InternalControlVar, 261 InternalControlVar::ICV___last> 262 ICVs; 263 264 /// Helper to initialize all internal control variable information for those 265 /// defined in OMPKinds.def. 266 void initializeInternalControlVars() { 267 #define ICV_RT_SET(_Name, RTL) \ 268 { \ 269 auto &ICV = ICVs[_Name]; \ 270 ICV.Setter = RTL; \ 271 } 272 #define ICV_RT_GET(Name, RTL) \ 273 { \ 274 auto &ICV = ICVs[Name]; \ 275 ICV.Getter = RTL; \ 276 } 277 #define ICV_DATA_ENV(Enum, _Name, _EnvVarName, Init) \ 278 { \ 279 auto &ICV = ICVs[Enum]; \ 280 ICV.Name = _Name; \ 281 ICV.Kind = Enum; \ 282 ICV.InitKind = Init; \ 283 ICV.EnvVarName = _EnvVarName; \ 284 switch (ICV.InitKind) { \ 285 case ICV_IMPLEMENTATION_DEFINED: \ 286 ICV.InitValue = nullptr; \ 287 break; \ 288 case ICV_ZERO: \ 289 ICV.InitValue = ConstantInt::get( \ 290 Type::getInt32Ty(OMPBuilder.Int32->getContext()), 0); \ 291 break; \ 292 case ICV_FALSE: \ 293 ICV.InitValue = ConstantInt::getFalse(OMPBuilder.Int1->getContext()); \ 294 break; \ 295 case ICV_LAST: \ 296 break; \ 297 } \ 298 } 299 #include "llvm/Frontend/OpenMP/OMPKinds.def" 300 } 301 302 /// Returns true if the function declaration \p F matches the runtime 303 /// function types, that is, return type \p RTFRetType, and argument types 304 /// \p RTFArgTypes. 305 static bool declMatchesRTFTypes(Function *F, Type *RTFRetType, 306 SmallVector<Type *, 8> &RTFArgTypes) { 307 // TODO: We should output information to the user (under debug output 308 // and via remarks). 309 310 if (!F) 311 return false; 312 if (F->getReturnType() != RTFRetType) 313 return false; 314 if (F->arg_size() != RTFArgTypes.size()) 315 return false; 316 317 auto RTFTyIt = RTFArgTypes.begin(); 318 for (Argument &Arg : F->args()) { 319 if (Arg.getType() != *RTFTyIt) 320 return false; 321 322 ++RTFTyIt; 323 } 324 325 return true; 326 } 327 328 // Helper to collect all uses of the declaration in the UsesMap. 329 unsigned collectUses(RuntimeFunctionInfo &RFI, bool CollectStats = true) { 330 unsigned NumUses = 0; 331 if (!RFI.Declaration) 332 return NumUses; 333 OMPBuilder.addAttributes(RFI.Kind, *RFI.Declaration); 334 335 if (CollectStats) { 336 NumOpenMPRuntimeFunctionsIdentified += 1; 337 NumOpenMPRuntimeFunctionUsesIdentified += RFI.Declaration->getNumUses(); 338 } 339 340 // TODO: We directly convert uses into proper calls and unknown uses. 341 for (Use &U : RFI.Declaration->uses()) { 342 if (Instruction *UserI = dyn_cast<Instruction>(U.getUser())) { 343 if (ModuleSlice.count(UserI->getFunction())) { 344 RFI.getOrCreateUseVector(UserI->getFunction()).push_back(&U); 345 ++NumUses; 346 } 347 } else { 348 RFI.getOrCreateUseVector(nullptr).push_back(&U); 349 ++NumUses; 350 } 351 } 352 return NumUses; 353 } 354 355 // Helper function to recollect uses of a runtime function. 356 void recollectUsesForFunction(RuntimeFunction RTF) { 357 auto &RFI = RFIs[RTF]; 358 RFI.clearUsesMap(); 359 collectUses(RFI, /*CollectStats*/ false); 360 } 361 362 // Helper function to recollect uses of all runtime functions. 363 void recollectUses() { 364 for (int Idx = 0; Idx < RFIs.size(); ++Idx) 365 recollectUsesForFunction(static_cast<RuntimeFunction>(Idx)); 366 } 367 368 /// Helper to initialize all runtime function information for those defined 369 /// in OpenMPKinds.def. 370 void initializeRuntimeFunctions() { 371 Module &M = *((*ModuleSlice.begin())->getParent()); 372 373 // Helper macros for handling __VA_ARGS__ in OMP_RTL 374 #define OMP_TYPE(VarName, ...) \ 375 Type *VarName = OMPBuilder.VarName; \ 376 (void)VarName; 377 378 #define OMP_ARRAY_TYPE(VarName, ...) \ 379 ArrayType *VarName##Ty = OMPBuilder.VarName##Ty; \ 380 (void)VarName##Ty; \ 381 PointerType *VarName##PtrTy = OMPBuilder.VarName##PtrTy; \ 382 (void)VarName##PtrTy; 383 384 #define OMP_FUNCTION_TYPE(VarName, ...) \ 385 FunctionType *VarName = OMPBuilder.VarName; \ 386 (void)VarName; \ 387 PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr; \ 388 (void)VarName##Ptr; 389 390 #define OMP_STRUCT_TYPE(VarName, ...) \ 391 StructType *VarName = OMPBuilder.VarName; \ 392 (void)VarName; \ 393 PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr; \ 394 (void)VarName##Ptr; 395 396 #define OMP_RTL(_Enum, _Name, _IsVarArg, _ReturnType, ...) \ 397 { \ 398 SmallVector<Type *, 8> ArgsTypes({__VA_ARGS__}); \ 399 Function *F = M.getFunction(_Name); \ 400 if (declMatchesRTFTypes(F, OMPBuilder._ReturnType, ArgsTypes)) { \ 401 auto &RFI = RFIs[_Enum]; \ 402 RFI.Kind = _Enum; \ 403 RFI.Name = _Name; \ 404 RFI.IsVarArg = _IsVarArg; \ 405 RFI.ReturnType = OMPBuilder._ReturnType; \ 406 RFI.ArgumentTypes = std::move(ArgsTypes); \ 407 RFI.Declaration = F; \ 408 unsigned NumUses = collectUses(RFI); \ 409 (void)NumUses; \ 410 LLVM_DEBUG({ \ 411 dbgs() << TAG << RFI.Name << (RFI.Declaration ? "" : " not") \ 412 << " found\n"; \ 413 if (RFI.Declaration) \ 414 dbgs() << TAG << "-> got " << NumUses << " uses in " \ 415 << RFI.getNumFunctionsWithUses() \ 416 << " different functions.\n"; \ 417 }); \ 418 } \ 419 } 420 #include "llvm/Frontend/OpenMP/OMPKinds.def" 421 422 // TODO: We should attach the attributes defined in OMPKinds.def. 423 } 424 425 /// Collection of known kernels (\see Kernel) in the module. 426 SmallPtrSetImpl<Kernel> &Kernels; 427 }; 428 429 /// Used to map the values physically (in the IR) stored in an offload 430 /// array, to a vector in memory. 431 struct OffloadArray { 432 /// Physical array (in the IR). 433 AllocaInst *Array = nullptr; 434 /// Mapped values. 435 SmallVector<Value *, 8> StoredValues; 436 /// Last stores made in the offload array. 437 SmallVector<StoreInst *, 8> LastAccesses; 438 439 OffloadArray() = default; 440 441 /// Initializes the OffloadArray with the values stored in \p Array before 442 /// instruction \p Before is reached. Returns false if the initialization 443 /// fails. 444 /// This MUST be used immediately after the construction of the object. 445 bool initialize(AllocaInst &Array, Instruction &Before) { 446 if (!Array.getAllocatedType()->isArrayTy()) 447 return false; 448 449 if (!getValues(Array, Before)) 450 return false; 451 452 this->Array = &Array; 453 return true; 454 } 455 456 static const unsigned DeviceIDArgNum = 1; 457 static const unsigned BasePtrsArgNum = 3; 458 static const unsigned PtrsArgNum = 4; 459 static const unsigned SizesArgNum = 5; 460 461 private: 462 /// Traverses the BasicBlock where \p Array is, collecting the stores made to 463 /// \p Array, leaving StoredValues with the values stored before the 464 /// instruction \p Before is reached. 465 bool getValues(AllocaInst &Array, Instruction &Before) { 466 // Initialize container. 467 const uint64_t NumValues = Array.getAllocatedType()->getArrayNumElements(); 468 StoredValues.assign(NumValues, nullptr); 469 LastAccesses.assign(NumValues, nullptr); 470 471 // TODO: This assumes the instruction \p Before is in the same 472 // BasicBlock as Array. Make it general, for any control flow graph. 473 BasicBlock *BB = Array.getParent(); 474 if (BB != Before.getParent()) 475 return false; 476 477 const DataLayout &DL = Array.getModule()->getDataLayout(); 478 const unsigned int PointerSize = DL.getPointerSize(); 479 480 for (Instruction &I : *BB) { 481 if (&I == &Before) 482 break; 483 484 if (!isa<StoreInst>(&I)) 485 continue; 486 487 auto *S = cast<StoreInst>(&I); 488 int64_t Offset = -1; 489 auto *Dst = 490 GetPointerBaseWithConstantOffset(S->getPointerOperand(), Offset, DL); 491 if (Dst == &Array) { 492 int64_t Idx = Offset / PointerSize; 493 StoredValues[Idx] = getUnderlyingObject(S->getValueOperand()); 494 LastAccesses[Idx] = S; 495 } 496 } 497 498 return isFilled(); 499 } 500 501 /// Returns true if all values in StoredValues and 502 /// LastAccesses are not nullptrs. 503 bool isFilled() { 504 const unsigned NumValues = StoredValues.size(); 505 for (unsigned I = 0; I < NumValues; ++I) { 506 if (!StoredValues[I] || !LastAccesses[I]) 507 return false; 508 } 509 510 return true; 511 } 512 }; 513 514 struct OpenMPOpt { 515 516 using OptimizationRemarkGetter = 517 function_ref<OptimizationRemarkEmitter &(Function *)>; 518 519 OpenMPOpt(SmallVectorImpl<Function *> &SCC, CallGraphUpdater &CGUpdater, 520 OptimizationRemarkGetter OREGetter, 521 OMPInformationCache &OMPInfoCache, Attributor &A) 522 : M(*(*SCC.begin())->getParent()), SCC(SCC), CGUpdater(CGUpdater), 523 OREGetter(OREGetter), OMPInfoCache(OMPInfoCache), A(A) {} 524 525 /// Check if any remarks are enabled for openmp-opt 526 bool remarksEnabled() { 527 auto &Ctx = M.getContext(); 528 return Ctx.getDiagHandlerPtr()->isAnyRemarkEnabled(DEBUG_TYPE); 529 } 530 531 /// Run all OpenMP optimizations on the underlying SCC/ModuleSlice. 532 bool run(bool IsModulePass) { 533 if (SCC.empty()) 534 return false; 535 536 bool Changed = false; 537 538 LLVM_DEBUG(dbgs() << TAG << "Run on SCC with " << SCC.size() 539 << " functions in a slice with " 540 << OMPInfoCache.ModuleSlice.size() << " functions\n"); 541 542 if (IsModulePass) { 543 Changed |= runAttributor(); 544 545 if (remarksEnabled()) 546 analysisGlobalization(); 547 } else { 548 if (PrintICVValues) 549 printICVs(); 550 if (PrintOpenMPKernels) 551 printKernels(); 552 553 Changed |= rewriteDeviceCodeStateMachine(); 554 555 Changed |= runAttributor(); 556 557 // Recollect uses, in case Attributor deleted any. 558 OMPInfoCache.recollectUses(); 559 560 Changed |= deleteParallelRegions(); 561 if (HideMemoryTransferLatency) 562 Changed |= hideMemTransfersLatency(); 563 Changed |= deduplicateRuntimeCalls(); 564 if (EnableParallelRegionMerging) { 565 if (mergeParallelRegions()) { 566 deduplicateRuntimeCalls(); 567 Changed = true; 568 } 569 } 570 } 571 572 return Changed; 573 } 574 575 /// Print initial ICV values for testing. 576 /// FIXME: This should be done from the Attributor once it is added. 577 void printICVs() const { 578 InternalControlVar ICVs[] = {ICV_nthreads, ICV_active_levels, ICV_cancel, 579 ICV_proc_bind}; 580 581 for (Function *F : OMPInfoCache.ModuleSlice) { 582 for (auto ICV : ICVs) { 583 auto ICVInfo = OMPInfoCache.ICVs[ICV]; 584 auto Remark = [&](OptimizationRemark OR) { 585 return OR << "OpenMP ICV " << ore::NV("OpenMPICV", ICVInfo.Name) 586 << " Value: " 587 << (ICVInfo.InitValue 588 ? ICVInfo.InitValue->getValue().toString(10, true) 589 : "IMPLEMENTATION_DEFINED"); 590 }; 591 592 emitRemarkOnFunction(F, "OpenMPICVTracker", Remark); 593 } 594 } 595 } 596 597 /// Print OpenMP GPU kernels for testing. 598 void printKernels() const { 599 for (Function *F : SCC) { 600 if (!OMPInfoCache.Kernels.count(F)) 601 continue; 602 603 auto Remark = [&](OptimizationRemark OR) { 604 return OR << "OpenMP GPU kernel " 605 << ore::NV("OpenMPGPUKernel", F->getName()) << "\n"; 606 }; 607 608 emitRemarkOnFunction(F, "OpenMPGPU", Remark); 609 } 610 } 611 612 /// Return the call if \p U is a callee use in a regular call. If \p RFI is 613 /// given it has to be the callee or a nullptr is returned. 614 static CallInst *getCallIfRegularCall( 615 Use &U, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) { 616 CallInst *CI = dyn_cast<CallInst>(U.getUser()); 617 if (CI && CI->isCallee(&U) && !CI->hasOperandBundles() && 618 (!RFI || CI->getCalledFunction() == RFI->Declaration)) 619 return CI; 620 return nullptr; 621 } 622 623 /// Return the call if \p V is a regular call. If \p RFI is given it has to be 624 /// the callee or a nullptr is returned. 625 static CallInst *getCallIfRegularCall( 626 Value &V, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) { 627 CallInst *CI = dyn_cast<CallInst>(&V); 628 if (CI && !CI->hasOperandBundles() && 629 (!RFI || CI->getCalledFunction() == RFI->Declaration)) 630 return CI; 631 return nullptr; 632 } 633 634 private: 635 /// Merge parallel regions when it is safe. 636 bool mergeParallelRegions() { 637 const unsigned CallbackCalleeOperand = 2; 638 const unsigned CallbackFirstArgOperand = 3; 639 using InsertPointTy = OpenMPIRBuilder::InsertPointTy; 640 641 // Check if there are any __kmpc_fork_call calls to merge. 642 OMPInformationCache::RuntimeFunctionInfo &RFI = 643 OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call]; 644 645 if (!RFI.Declaration) 646 return false; 647 648 // Unmergable calls that prevent merging a parallel region. 649 OMPInformationCache::RuntimeFunctionInfo UnmergableCallsInfo[] = { 650 OMPInfoCache.RFIs[OMPRTL___kmpc_push_proc_bind], 651 OMPInfoCache.RFIs[OMPRTL___kmpc_push_num_threads], 652 }; 653 654 bool Changed = false; 655 LoopInfo *LI = nullptr; 656 DominatorTree *DT = nullptr; 657 658 SmallDenseMap<BasicBlock *, SmallPtrSet<Instruction *, 4>> BB2PRMap; 659 660 BasicBlock *StartBB = nullptr, *EndBB = nullptr; 661 auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, 662 BasicBlock &ContinuationIP) { 663 BasicBlock *CGStartBB = CodeGenIP.getBlock(); 664 BasicBlock *CGEndBB = 665 SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI); 666 assert(StartBB != nullptr && "StartBB should not be null"); 667 CGStartBB->getTerminator()->setSuccessor(0, StartBB); 668 assert(EndBB != nullptr && "EndBB should not be null"); 669 EndBB->getTerminator()->setSuccessor(0, CGEndBB); 670 }; 671 672 auto PrivCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, Value &, 673 Value &Inner, Value *&ReplacementValue) -> InsertPointTy { 674 ReplacementValue = &Inner; 675 return CodeGenIP; 676 }; 677 678 auto FiniCB = [&](InsertPointTy CodeGenIP) {}; 679 680 /// Create a sequential execution region within a merged parallel region, 681 /// encapsulated in a master construct with a barrier for synchronization. 682 auto CreateSequentialRegion = [&](Function *OuterFn, 683 BasicBlock *OuterPredBB, 684 Instruction *SeqStartI, 685 Instruction *SeqEndI) { 686 // Isolate the instructions of the sequential region to a separate 687 // block. 688 BasicBlock *ParentBB = SeqStartI->getParent(); 689 BasicBlock *SeqEndBB = 690 SplitBlock(ParentBB, SeqEndI->getNextNode(), DT, LI); 691 BasicBlock *SeqAfterBB = 692 SplitBlock(SeqEndBB, &*SeqEndBB->getFirstInsertionPt(), DT, LI); 693 BasicBlock *SeqStartBB = 694 SplitBlock(ParentBB, SeqStartI, DT, LI, nullptr, "seq.par.merged"); 695 696 assert(ParentBB->getUniqueSuccessor() == SeqStartBB && 697 "Expected a different CFG"); 698 const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc(); 699 ParentBB->getTerminator()->eraseFromParent(); 700 701 auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, 702 BasicBlock &ContinuationIP) { 703 BasicBlock *CGStartBB = CodeGenIP.getBlock(); 704 BasicBlock *CGEndBB = 705 SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI); 706 assert(SeqStartBB != nullptr && "SeqStartBB should not be null"); 707 CGStartBB->getTerminator()->setSuccessor(0, SeqStartBB); 708 assert(SeqEndBB != nullptr && "SeqEndBB should not be null"); 709 SeqEndBB->getTerminator()->setSuccessor(0, CGEndBB); 710 }; 711 auto FiniCB = [&](InsertPointTy CodeGenIP) {}; 712 713 // Find outputs from the sequential region to outside users and 714 // broadcast their values to them. 715 for (Instruction &I : *SeqStartBB) { 716 SmallPtrSet<Instruction *, 4> OutsideUsers; 717 for (User *Usr : I.users()) { 718 Instruction &UsrI = *cast<Instruction>(Usr); 719 // Ignore outputs to LT intrinsics, code extraction for the merged 720 // parallel region will fix them. 721 if (UsrI.isLifetimeStartOrEnd()) 722 continue; 723 724 if (UsrI.getParent() != SeqStartBB) 725 OutsideUsers.insert(&UsrI); 726 } 727 728 if (OutsideUsers.empty()) 729 continue; 730 731 // Emit an alloca in the outer region to store the broadcasted 732 // value. 733 const DataLayout &DL = M.getDataLayout(); 734 AllocaInst *AllocaI = new AllocaInst( 735 I.getType(), DL.getAllocaAddrSpace(), nullptr, 736 I.getName() + ".seq.output.alloc", &OuterFn->front().front()); 737 738 // Emit a store instruction in the sequential BB to update the 739 // value. 740 new StoreInst(&I, AllocaI, SeqStartBB->getTerminator()); 741 742 // Emit a load instruction and replace the use of the output value 743 // with it. 744 for (Instruction *UsrI : OutsideUsers) { 745 LoadInst *LoadI = new LoadInst( 746 I.getType(), AllocaI, I.getName() + ".seq.output.load", UsrI); 747 UsrI->replaceUsesOfWith(&I, LoadI); 748 } 749 } 750 751 OpenMPIRBuilder::LocationDescription Loc( 752 InsertPointTy(ParentBB, ParentBB->end()), DL); 753 InsertPointTy SeqAfterIP = 754 OMPInfoCache.OMPBuilder.createMaster(Loc, BodyGenCB, FiniCB); 755 756 OMPInfoCache.OMPBuilder.createBarrier(SeqAfterIP, OMPD_parallel); 757 758 BranchInst::Create(SeqAfterBB, SeqAfterIP.getBlock()); 759 760 LLVM_DEBUG(dbgs() << TAG << "After sequential inlining " << *OuterFn 761 << "\n"); 762 }; 763 764 // Helper to merge the __kmpc_fork_call calls in MergableCIs. They are all 765 // contained in BB and only separated by instructions that can be 766 // redundantly executed in parallel. The block BB is split before the first 767 // call (in MergableCIs) and after the last so the entire region we merge 768 // into a single parallel region is contained in a single basic block 769 // without any other instructions. We use the OpenMPIRBuilder to outline 770 // that block and call the resulting function via __kmpc_fork_call. 771 auto Merge = [&](SmallVectorImpl<CallInst *> &MergableCIs, BasicBlock *BB) { 772 // TODO: Change the interface to allow single CIs expanded, e.g, to 773 // include an outer loop. 774 assert(MergableCIs.size() > 1 && "Assumed multiple mergable CIs"); 775 776 auto Remark = [&](OptimizationRemark OR) { 777 OR << "Parallel region at " 778 << ore::NV("OpenMPParallelMergeFront", 779 MergableCIs.front()->getDebugLoc()) 780 << " merged with parallel regions at "; 781 for (auto *CI : llvm::drop_begin(MergableCIs)) { 782 OR << ore::NV("OpenMPParallelMerge", CI->getDebugLoc()); 783 if (CI != MergableCIs.back()) 784 OR << ", "; 785 } 786 return OR; 787 }; 788 789 emitRemark<OptimizationRemark>(MergableCIs.front(), 790 "OpenMPParallelRegionMerging", Remark); 791 792 Function *OriginalFn = BB->getParent(); 793 LLVM_DEBUG(dbgs() << TAG << "Merge " << MergableCIs.size() 794 << " parallel regions in " << OriginalFn->getName() 795 << "\n"); 796 797 // Isolate the calls to merge in a separate block. 798 EndBB = SplitBlock(BB, MergableCIs.back()->getNextNode(), DT, LI); 799 BasicBlock *AfterBB = 800 SplitBlock(EndBB, &*EndBB->getFirstInsertionPt(), DT, LI); 801 StartBB = SplitBlock(BB, MergableCIs.front(), DT, LI, nullptr, 802 "omp.par.merged"); 803 804 assert(BB->getUniqueSuccessor() == StartBB && "Expected a different CFG"); 805 const DebugLoc DL = BB->getTerminator()->getDebugLoc(); 806 BB->getTerminator()->eraseFromParent(); 807 808 // Create sequential regions for sequential instructions that are 809 // in-between mergable parallel regions. 810 for (auto *It = MergableCIs.begin(), *End = MergableCIs.end() - 1; 811 It != End; ++It) { 812 Instruction *ForkCI = *It; 813 Instruction *NextForkCI = *(It + 1); 814 815 // Continue if there are not in-between instructions. 816 if (ForkCI->getNextNode() == NextForkCI) 817 continue; 818 819 CreateSequentialRegion(OriginalFn, BB, ForkCI->getNextNode(), 820 NextForkCI->getPrevNode()); 821 } 822 823 OpenMPIRBuilder::LocationDescription Loc(InsertPointTy(BB, BB->end()), 824 DL); 825 IRBuilder<>::InsertPoint AllocaIP( 826 &OriginalFn->getEntryBlock(), 827 OriginalFn->getEntryBlock().getFirstInsertionPt()); 828 // Create the merged parallel region with default proc binding, to 829 // avoid overriding binding settings, and without explicit cancellation. 830 InsertPointTy AfterIP = OMPInfoCache.OMPBuilder.createParallel( 831 Loc, AllocaIP, BodyGenCB, PrivCB, FiniCB, nullptr, nullptr, 832 OMP_PROC_BIND_default, /* IsCancellable */ false); 833 BranchInst::Create(AfterBB, AfterIP.getBlock()); 834 835 // Perform the actual outlining. 836 OMPInfoCache.OMPBuilder.finalize(OriginalFn, 837 /* AllowExtractorSinking */ true); 838 839 Function *OutlinedFn = MergableCIs.front()->getCaller(); 840 841 // Replace the __kmpc_fork_call calls with direct calls to the outlined 842 // callbacks. 843 SmallVector<Value *, 8> Args; 844 for (auto *CI : MergableCIs) { 845 Value *Callee = 846 CI->getArgOperand(CallbackCalleeOperand)->stripPointerCasts(); 847 FunctionType *FT = 848 cast<FunctionType>(Callee->getType()->getPointerElementType()); 849 Args.clear(); 850 Args.push_back(OutlinedFn->getArg(0)); 851 Args.push_back(OutlinedFn->getArg(1)); 852 for (unsigned U = CallbackFirstArgOperand, E = CI->getNumArgOperands(); 853 U < E; ++U) 854 Args.push_back(CI->getArgOperand(U)); 855 856 CallInst *NewCI = CallInst::Create(FT, Callee, Args, "", CI); 857 if (CI->getDebugLoc()) 858 NewCI->setDebugLoc(CI->getDebugLoc()); 859 860 // Forward parameter attributes from the callback to the callee. 861 for (unsigned U = CallbackFirstArgOperand, E = CI->getNumArgOperands(); 862 U < E; ++U) 863 for (const Attribute &A : CI->getAttributes().getParamAttributes(U)) 864 NewCI->addParamAttr( 865 U - (CallbackFirstArgOperand - CallbackCalleeOperand), A); 866 867 // Emit an explicit barrier to replace the implicit fork-join barrier. 868 if (CI != MergableCIs.back()) { 869 // TODO: Remove barrier if the merged parallel region includes the 870 // 'nowait' clause. 871 OMPInfoCache.OMPBuilder.createBarrier( 872 InsertPointTy(NewCI->getParent(), 873 NewCI->getNextNode()->getIterator()), 874 OMPD_parallel); 875 } 876 877 auto Remark = [&](OptimizationRemark OR) { 878 return OR << "Parallel region at " 879 << ore::NV("OpenMPParallelMerge", CI->getDebugLoc()) 880 << " merged with " 881 << ore::NV("OpenMPParallelMergeFront", 882 MergableCIs.front()->getDebugLoc()); 883 }; 884 if (CI != MergableCIs.front()) 885 emitRemark<OptimizationRemark>(CI, "OpenMPParallelRegionMerging", 886 Remark); 887 888 CI->eraseFromParent(); 889 } 890 891 assert(OutlinedFn != OriginalFn && "Outlining failed"); 892 CGUpdater.registerOutlinedFunction(*OriginalFn, *OutlinedFn); 893 CGUpdater.reanalyzeFunction(*OriginalFn); 894 895 NumOpenMPParallelRegionsMerged += MergableCIs.size(); 896 897 return true; 898 }; 899 900 // Helper function that identifes sequences of 901 // __kmpc_fork_call uses in a basic block. 902 auto DetectPRsCB = [&](Use &U, Function &F) { 903 CallInst *CI = getCallIfRegularCall(U, &RFI); 904 BB2PRMap[CI->getParent()].insert(CI); 905 906 return false; 907 }; 908 909 BB2PRMap.clear(); 910 RFI.foreachUse(SCC, DetectPRsCB); 911 SmallVector<SmallVector<CallInst *, 4>, 4> MergableCIsVector; 912 // Find mergable parallel regions within a basic block that are 913 // safe to merge, that is any in-between instructions can safely 914 // execute in parallel after merging. 915 // TODO: support merging across basic-blocks. 916 for (auto &It : BB2PRMap) { 917 auto &CIs = It.getSecond(); 918 if (CIs.size() < 2) 919 continue; 920 921 BasicBlock *BB = It.getFirst(); 922 SmallVector<CallInst *, 4> MergableCIs; 923 924 /// Returns true if the instruction is mergable, false otherwise. 925 /// A terminator instruction is unmergable by definition since merging 926 /// works within a BB. Instructions before the mergable region are 927 /// mergable if they are not calls to OpenMP runtime functions that may 928 /// set different execution parameters for subsequent parallel regions. 929 /// Instructions in-between parallel regions are mergable if they are not 930 /// calls to any non-intrinsic function since that may call a non-mergable 931 /// OpenMP runtime function. 932 auto IsMergable = [&](Instruction &I, bool IsBeforeMergableRegion) { 933 // We do not merge across BBs, hence return false (unmergable) if the 934 // instruction is a terminator. 935 if (I.isTerminator()) 936 return false; 937 938 if (!isa<CallInst>(&I)) 939 return true; 940 941 CallInst *CI = cast<CallInst>(&I); 942 if (IsBeforeMergableRegion) { 943 Function *CalledFunction = CI->getCalledFunction(); 944 if (!CalledFunction) 945 return false; 946 // Return false (unmergable) if the call before the parallel 947 // region calls an explicit affinity (proc_bind) or number of 948 // threads (num_threads) compiler-generated function. Those settings 949 // may be incompatible with following parallel regions. 950 // TODO: ICV tracking to detect compatibility. 951 for (const auto &RFI : UnmergableCallsInfo) { 952 if (CalledFunction == RFI.Declaration) 953 return false; 954 } 955 } else { 956 // Return false (unmergable) if there is a call instruction 957 // in-between parallel regions when it is not an intrinsic. It 958 // may call an unmergable OpenMP runtime function in its callpath. 959 // TODO: Keep track of possible OpenMP calls in the callpath. 960 if (!isa<IntrinsicInst>(CI)) 961 return false; 962 } 963 964 return true; 965 }; 966 // Find maximal number of parallel region CIs that are safe to merge. 967 for (auto It = BB->begin(), End = BB->end(); It != End;) { 968 Instruction &I = *It; 969 ++It; 970 971 if (CIs.count(&I)) { 972 MergableCIs.push_back(cast<CallInst>(&I)); 973 continue; 974 } 975 976 // Continue expanding if the instruction is mergable. 977 if (IsMergable(I, MergableCIs.empty())) 978 continue; 979 980 // Forward the instruction iterator to skip the next parallel region 981 // since there is an unmergable instruction which can affect it. 982 for (; It != End; ++It) { 983 Instruction &SkipI = *It; 984 if (CIs.count(&SkipI)) { 985 LLVM_DEBUG(dbgs() << TAG << "Skip parallel region " << SkipI 986 << " due to " << I << "\n"); 987 ++It; 988 break; 989 } 990 } 991 992 // Store mergable regions found. 993 if (MergableCIs.size() > 1) { 994 MergableCIsVector.push_back(MergableCIs); 995 LLVM_DEBUG(dbgs() << TAG << "Found " << MergableCIs.size() 996 << " parallel regions in block " << BB->getName() 997 << " of function " << BB->getParent()->getName() 998 << "\n";); 999 } 1000 1001 MergableCIs.clear(); 1002 } 1003 1004 if (!MergableCIsVector.empty()) { 1005 Changed = true; 1006 1007 for (auto &MergableCIs : MergableCIsVector) 1008 Merge(MergableCIs, BB); 1009 MergableCIsVector.clear(); 1010 } 1011 } 1012 1013 if (Changed) { 1014 /// Re-collect use for fork calls, emitted barrier calls, and 1015 /// any emitted master/end_master calls. 1016 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_fork_call); 1017 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_barrier); 1018 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_master); 1019 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_end_master); 1020 } 1021 1022 return Changed; 1023 } 1024 1025 /// Try to delete parallel regions if possible. 1026 bool deleteParallelRegions() { 1027 const unsigned CallbackCalleeOperand = 2; 1028 1029 OMPInformationCache::RuntimeFunctionInfo &RFI = 1030 OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call]; 1031 1032 if (!RFI.Declaration) 1033 return false; 1034 1035 bool Changed = false; 1036 auto DeleteCallCB = [&](Use &U, Function &) { 1037 CallInst *CI = getCallIfRegularCall(U); 1038 if (!CI) 1039 return false; 1040 auto *Fn = dyn_cast<Function>( 1041 CI->getArgOperand(CallbackCalleeOperand)->stripPointerCasts()); 1042 if (!Fn) 1043 return false; 1044 if (!Fn->onlyReadsMemory()) 1045 return false; 1046 if (!Fn->hasFnAttribute(Attribute::WillReturn)) 1047 return false; 1048 1049 LLVM_DEBUG(dbgs() << TAG << "Delete read-only parallel region in " 1050 << CI->getCaller()->getName() << "\n"); 1051 1052 auto Remark = [&](OptimizationRemark OR) { 1053 return OR << "Parallel region in " 1054 << ore::NV("OpenMPParallelDelete", CI->getCaller()->getName()) 1055 << " deleted"; 1056 }; 1057 emitRemark<OptimizationRemark>(CI, "OpenMPParallelRegionDeletion", 1058 Remark); 1059 1060 CGUpdater.removeCallSite(*CI); 1061 CI->eraseFromParent(); 1062 Changed = true; 1063 ++NumOpenMPParallelRegionsDeleted; 1064 return true; 1065 }; 1066 1067 RFI.foreachUse(SCC, DeleteCallCB); 1068 1069 return Changed; 1070 } 1071 1072 /// Try to eliminate runtime calls by reusing existing ones. 1073 bool deduplicateRuntimeCalls() { 1074 bool Changed = false; 1075 1076 RuntimeFunction DeduplicableRuntimeCallIDs[] = { 1077 OMPRTL_omp_get_num_threads, 1078 OMPRTL_omp_in_parallel, 1079 OMPRTL_omp_get_cancellation, 1080 OMPRTL_omp_get_thread_limit, 1081 OMPRTL_omp_get_supported_active_levels, 1082 OMPRTL_omp_get_level, 1083 OMPRTL_omp_get_ancestor_thread_num, 1084 OMPRTL_omp_get_team_size, 1085 OMPRTL_omp_get_active_level, 1086 OMPRTL_omp_in_final, 1087 OMPRTL_omp_get_proc_bind, 1088 OMPRTL_omp_get_num_places, 1089 OMPRTL_omp_get_num_procs, 1090 OMPRTL_omp_get_place_num, 1091 OMPRTL_omp_get_partition_num_places, 1092 OMPRTL_omp_get_partition_place_nums}; 1093 1094 // Global-tid is handled separately. 1095 SmallSetVector<Value *, 16> GTIdArgs; 1096 collectGlobalThreadIdArguments(GTIdArgs); 1097 LLVM_DEBUG(dbgs() << TAG << "Found " << GTIdArgs.size() 1098 << " global thread ID arguments\n"); 1099 1100 for (Function *F : SCC) { 1101 for (auto DeduplicableRuntimeCallID : DeduplicableRuntimeCallIDs) 1102 Changed |= deduplicateRuntimeCalls( 1103 *F, OMPInfoCache.RFIs[DeduplicableRuntimeCallID]); 1104 1105 // __kmpc_global_thread_num is special as we can replace it with an 1106 // argument in enough cases to make it worth trying. 1107 Value *GTIdArg = nullptr; 1108 for (Argument &Arg : F->args()) 1109 if (GTIdArgs.count(&Arg)) { 1110 GTIdArg = &Arg; 1111 break; 1112 } 1113 Changed |= deduplicateRuntimeCalls( 1114 *F, OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num], GTIdArg); 1115 } 1116 1117 return Changed; 1118 } 1119 1120 /// Tries to hide the latency of runtime calls that involve host to 1121 /// device memory transfers by splitting them into their "issue" and "wait" 1122 /// versions. The "issue" is moved upwards as much as possible. The "wait" is 1123 /// moved downards as much as possible. The "issue" issues the memory transfer 1124 /// asynchronously, returning a handle. The "wait" waits in the returned 1125 /// handle for the memory transfer to finish. 1126 bool hideMemTransfersLatency() { 1127 auto &RFI = OMPInfoCache.RFIs[OMPRTL___tgt_target_data_begin_mapper]; 1128 bool Changed = false; 1129 auto SplitMemTransfers = [&](Use &U, Function &Decl) { 1130 auto *RTCall = getCallIfRegularCall(U, &RFI); 1131 if (!RTCall) 1132 return false; 1133 1134 OffloadArray OffloadArrays[3]; 1135 if (!getValuesInOffloadArrays(*RTCall, OffloadArrays)) 1136 return false; 1137 1138 LLVM_DEBUG(dumpValuesInOffloadArrays(OffloadArrays)); 1139 1140 // TODO: Check if can be moved upwards. 1141 bool WasSplit = false; 1142 Instruction *WaitMovementPoint = canBeMovedDownwards(*RTCall); 1143 if (WaitMovementPoint) 1144 WasSplit = splitTargetDataBeginRTC(*RTCall, *WaitMovementPoint); 1145 1146 Changed |= WasSplit; 1147 return WasSplit; 1148 }; 1149 RFI.foreachUse(SCC, SplitMemTransfers); 1150 1151 return Changed; 1152 } 1153 1154 void analysisGlobalization() { 1155 RuntimeFunction GlobalizationRuntimeIDs[] = { 1156 OMPRTL___kmpc_data_sharing_coalesced_push_stack, 1157 OMPRTL___kmpc_data_sharing_push_stack}; 1158 1159 for (const auto GlobalizationCallID : GlobalizationRuntimeIDs) { 1160 auto &RFI = OMPInfoCache.RFIs[GlobalizationCallID]; 1161 1162 auto CheckGlobalization = [&](Use &U, Function &Decl) { 1163 if (CallInst *CI = getCallIfRegularCall(U, &RFI)) { 1164 auto Remark = [&](OptimizationRemarkAnalysis ORA) { 1165 return ORA 1166 << "Found thread data sharing on the GPU. " 1167 << "Expect degraded performance due to data globalization."; 1168 }; 1169 emitRemark<OptimizationRemarkAnalysis>(CI, "OpenMPGlobalization", 1170 Remark); 1171 } 1172 1173 return false; 1174 }; 1175 1176 RFI.foreachUse(SCC, CheckGlobalization); 1177 } 1178 } 1179 1180 /// Maps the values stored in the offload arrays passed as arguments to 1181 /// \p RuntimeCall into the offload arrays in \p OAs. 1182 bool getValuesInOffloadArrays(CallInst &RuntimeCall, 1183 MutableArrayRef<OffloadArray> OAs) { 1184 assert(OAs.size() == 3 && "Need space for three offload arrays!"); 1185 1186 // A runtime call that involves memory offloading looks something like: 1187 // call void @__tgt_target_data_begin_mapper(arg0, arg1, 1188 // i8** %offload_baseptrs, i8** %offload_ptrs, i64* %offload_sizes, 1189 // ...) 1190 // So, the idea is to access the allocas that allocate space for these 1191 // offload arrays, offload_baseptrs, offload_ptrs, offload_sizes. 1192 // Therefore: 1193 // i8** %offload_baseptrs. 1194 Value *BasePtrsArg = 1195 RuntimeCall.getArgOperand(OffloadArray::BasePtrsArgNum); 1196 // i8** %offload_ptrs. 1197 Value *PtrsArg = RuntimeCall.getArgOperand(OffloadArray::PtrsArgNum); 1198 // i8** %offload_sizes. 1199 Value *SizesArg = RuntimeCall.getArgOperand(OffloadArray::SizesArgNum); 1200 1201 // Get values stored in **offload_baseptrs. 1202 auto *V = getUnderlyingObject(BasePtrsArg); 1203 if (!isa<AllocaInst>(V)) 1204 return false; 1205 auto *BasePtrsArray = cast<AllocaInst>(V); 1206 if (!OAs[0].initialize(*BasePtrsArray, RuntimeCall)) 1207 return false; 1208 1209 // Get values stored in **offload_baseptrs. 1210 V = getUnderlyingObject(PtrsArg); 1211 if (!isa<AllocaInst>(V)) 1212 return false; 1213 auto *PtrsArray = cast<AllocaInst>(V); 1214 if (!OAs[1].initialize(*PtrsArray, RuntimeCall)) 1215 return false; 1216 1217 // Get values stored in **offload_sizes. 1218 V = getUnderlyingObject(SizesArg); 1219 // If it's a [constant] global array don't analyze it. 1220 if (isa<GlobalValue>(V)) 1221 return isa<Constant>(V); 1222 if (!isa<AllocaInst>(V)) 1223 return false; 1224 1225 auto *SizesArray = cast<AllocaInst>(V); 1226 if (!OAs[2].initialize(*SizesArray, RuntimeCall)) 1227 return false; 1228 1229 return true; 1230 } 1231 1232 /// Prints the values in the OffloadArrays \p OAs using LLVM_DEBUG. 1233 /// For now this is a way to test that the function getValuesInOffloadArrays 1234 /// is working properly. 1235 /// TODO: Move this to a unittest when unittests are available for OpenMPOpt. 1236 void dumpValuesInOffloadArrays(ArrayRef<OffloadArray> OAs) { 1237 assert(OAs.size() == 3 && "There are three offload arrays to debug!"); 1238 1239 LLVM_DEBUG(dbgs() << TAG << " Successfully got offload values:\n"); 1240 std::string ValuesStr; 1241 raw_string_ostream Printer(ValuesStr); 1242 std::string Separator = " --- "; 1243 1244 for (auto *BP : OAs[0].StoredValues) { 1245 BP->print(Printer); 1246 Printer << Separator; 1247 } 1248 LLVM_DEBUG(dbgs() << "\t\toffload_baseptrs: " << Printer.str() << "\n"); 1249 ValuesStr.clear(); 1250 1251 for (auto *P : OAs[1].StoredValues) { 1252 P->print(Printer); 1253 Printer << Separator; 1254 } 1255 LLVM_DEBUG(dbgs() << "\t\toffload_ptrs: " << Printer.str() << "\n"); 1256 ValuesStr.clear(); 1257 1258 for (auto *S : OAs[2].StoredValues) { 1259 S->print(Printer); 1260 Printer << Separator; 1261 } 1262 LLVM_DEBUG(dbgs() << "\t\toffload_sizes: " << Printer.str() << "\n"); 1263 } 1264 1265 /// Returns the instruction where the "wait" counterpart \p RuntimeCall can be 1266 /// moved. Returns nullptr if the movement is not possible, or not worth it. 1267 Instruction *canBeMovedDownwards(CallInst &RuntimeCall) { 1268 // FIXME: This traverses only the BasicBlock where RuntimeCall is. 1269 // Make it traverse the CFG. 1270 1271 Instruction *CurrentI = &RuntimeCall; 1272 bool IsWorthIt = false; 1273 while ((CurrentI = CurrentI->getNextNode())) { 1274 1275 // TODO: Once we detect the regions to be offloaded we should use the 1276 // alias analysis manager to check if CurrentI may modify one of 1277 // the offloaded regions. 1278 if (CurrentI->mayHaveSideEffects() || CurrentI->mayReadFromMemory()) { 1279 if (IsWorthIt) 1280 return CurrentI; 1281 1282 return nullptr; 1283 } 1284 1285 // FIXME: For now if we move it over anything without side effect 1286 // is worth it. 1287 IsWorthIt = true; 1288 } 1289 1290 // Return end of BasicBlock. 1291 return RuntimeCall.getParent()->getTerminator(); 1292 } 1293 1294 /// Splits \p RuntimeCall into its "issue" and "wait" counterparts. 1295 bool splitTargetDataBeginRTC(CallInst &RuntimeCall, 1296 Instruction &WaitMovementPoint) { 1297 // Create stack allocated handle (__tgt_async_info) at the beginning of the 1298 // function. Used for storing information of the async transfer, allowing to 1299 // wait on it later. 1300 auto &IRBuilder = OMPInfoCache.OMPBuilder; 1301 auto *F = RuntimeCall.getCaller(); 1302 Instruction *FirstInst = &(F->getEntryBlock().front()); 1303 AllocaInst *Handle = new AllocaInst( 1304 IRBuilder.AsyncInfo, F->getAddressSpace(), "handle", FirstInst); 1305 1306 // Add "issue" runtime call declaration: 1307 // declare %struct.tgt_async_info @__tgt_target_data_begin_issue(i64, i32, 1308 // i8**, i8**, i64*, i64*) 1309 FunctionCallee IssueDecl = IRBuilder.getOrCreateRuntimeFunction( 1310 M, OMPRTL___tgt_target_data_begin_mapper_issue); 1311 1312 // Change RuntimeCall call site for its asynchronous version. 1313 SmallVector<Value *, 16> Args; 1314 for (auto &Arg : RuntimeCall.args()) 1315 Args.push_back(Arg.get()); 1316 Args.push_back(Handle); 1317 1318 CallInst *IssueCallsite = 1319 CallInst::Create(IssueDecl, Args, /*NameStr=*/"", &RuntimeCall); 1320 RuntimeCall.eraseFromParent(); 1321 1322 // Add "wait" runtime call declaration: 1323 // declare void @__tgt_target_data_begin_wait(i64, %struct.__tgt_async_info) 1324 FunctionCallee WaitDecl = IRBuilder.getOrCreateRuntimeFunction( 1325 M, OMPRTL___tgt_target_data_begin_mapper_wait); 1326 1327 Value *WaitParams[2] = { 1328 IssueCallsite->getArgOperand( 1329 OffloadArray::DeviceIDArgNum), // device_id. 1330 Handle // handle to wait on. 1331 }; 1332 CallInst::Create(WaitDecl, WaitParams, /*NameStr=*/"", &WaitMovementPoint); 1333 1334 return true; 1335 } 1336 1337 static Value *combinedIdentStruct(Value *CurrentIdent, Value *NextIdent, 1338 bool GlobalOnly, bool &SingleChoice) { 1339 if (CurrentIdent == NextIdent) 1340 return CurrentIdent; 1341 1342 // TODO: Figure out how to actually combine multiple debug locations. For 1343 // now we just keep an existing one if there is a single choice. 1344 if (!GlobalOnly || isa<GlobalValue>(NextIdent)) { 1345 SingleChoice = !CurrentIdent; 1346 return NextIdent; 1347 } 1348 return nullptr; 1349 } 1350 1351 /// Return an `struct ident_t*` value that represents the ones used in the 1352 /// calls of \p RFI inside of \p F. If \p GlobalOnly is true, we will not 1353 /// return a local `struct ident_t*`. For now, if we cannot find a suitable 1354 /// return value we create one from scratch. We also do not yet combine 1355 /// information, e.g., the source locations, see combinedIdentStruct. 1356 Value * 1357 getCombinedIdentFromCallUsesIn(OMPInformationCache::RuntimeFunctionInfo &RFI, 1358 Function &F, bool GlobalOnly) { 1359 bool SingleChoice = true; 1360 Value *Ident = nullptr; 1361 auto CombineIdentStruct = [&](Use &U, Function &Caller) { 1362 CallInst *CI = getCallIfRegularCall(U, &RFI); 1363 if (!CI || &F != &Caller) 1364 return false; 1365 Ident = combinedIdentStruct(Ident, CI->getArgOperand(0), 1366 /* GlobalOnly */ true, SingleChoice); 1367 return false; 1368 }; 1369 RFI.foreachUse(SCC, CombineIdentStruct); 1370 1371 if (!Ident || !SingleChoice) { 1372 // The IRBuilder uses the insertion block to get to the module, this is 1373 // unfortunate but we work around it for now. 1374 if (!OMPInfoCache.OMPBuilder.getInsertionPoint().getBlock()) 1375 OMPInfoCache.OMPBuilder.updateToLocation(OpenMPIRBuilder::InsertPointTy( 1376 &F.getEntryBlock(), F.getEntryBlock().begin())); 1377 // Create a fallback location if non was found. 1378 // TODO: Use the debug locations of the calls instead. 1379 Constant *Loc = OMPInfoCache.OMPBuilder.getOrCreateDefaultSrcLocStr(); 1380 Ident = OMPInfoCache.OMPBuilder.getOrCreateIdent(Loc); 1381 } 1382 return Ident; 1383 } 1384 1385 /// Try to eliminate calls of \p RFI in \p F by reusing an existing one or 1386 /// \p ReplVal if given. 1387 bool deduplicateRuntimeCalls(Function &F, 1388 OMPInformationCache::RuntimeFunctionInfo &RFI, 1389 Value *ReplVal = nullptr) { 1390 auto *UV = RFI.getUseVector(F); 1391 if (!UV || UV->size() + (ReplVal != nullptr) < 2) 1392 return false; 1393 1394 LLVM_DEBUG( 1395 dbgs() << TAG << "Deduplicate " << UV->size() << " uses of " << RFI.Name 1396 << (ReplVal ? " with an existing value\n" : "\n") << "\n"); 1397 1398 assert((!ReplVal || (isa<Argument>(ReplVal) && 1399 cast<Argument>(ReplVal)->getParent() == &F)) && 1400 "Unexpected replacement value!"); 1401 1402 // TODO: Use dominance to find a good position instead. 1403 auto CanBeMoved = [this](CallBase &CB) { 1404 unsigned NumArgs = CB.getNumArgOperands(); 1405 if (NumArgs == 0) 1406 return true; 1407 if (CB.getArgOperand(0)->getType() != OMPInfoCache.OMPBuilder.IdentPtr) 1408 return false; 1409 for (unsigned u = 1; u < NumArgs; ++u) 1410 if (isa<Instruction>(CB.getArgOperand(u))) 1411 return false; 1412 return true; 1413 }; 1414 1415 if (!ReplVal) { 1416 for (Use *U : *UV) 1417 if (CallInst *CI = getCallIfRegularCall(*U, &RFI)) { 1418 if (!CanBeMoved(*CI)) 1419 continue; 1420 1421 auto Remark = [&](OptimizationRemark OR) { 1422 auto newLoc = &*F.getEntryBlock().getFirstInsertionPt(); 1423 return OR << "OpenMP runtime call " 1424 << ore::NV("OpenMPOptRuntime", RFI.Name) << " moved to " 1425 << ore::NV("OpenMPRuntimeMoves", newLoc->getDebugLoc()); 1426 }; 1427 emitRemark<OptimizationRemark>(CI, "OpenMPRuntimeCodeMotion", Remark); 1428 1429 CI->moveBefore(&*F.getEntryBlock().getFirstInsertionPt()); 1430 ReplVal = CI; 1431 break; 1432 } 1433 if (!ReplVal) 1434 return false; 1435 } 1436 1437 // If we use a call as a replacement value we need to make sure the ident is 1438 // valid at the new location. For now we just pick a global one, either 1439 // existing and used by one of the calls, or created from scratch. 1440 if (CallBase *CI = dyn_cast<CallBase>(ReplVal)) { 1441 if (CI->getNumArgOperands() > 0 && 1442 CI->getArgOperand(0)->getType() == OMPInfoCache.OMPBuilder.IdentPtr) { 1443 Value *Ident = getCombinedIdentFromCallUsesIn(RFI, F, 1444 /* GlobalOnly */ true); 1445 CI->setArgOperand(0, Ident); 1446 } 1447 } 1448 1449 bool Changed = false; 1450 auto ReplaceAndDeleteCB = [&](Use &U, Function &Caller) { 1451 CallInst *CI = getCallIfRegularCall(U, &RFI); 1452 if (!CI || CI == ReplVal || &F != &Caller) 1453 return false; 1454 assert(CI->getCaller() == &F && "Unexpected call!"); 1455 1456 auto Remark = [&](OptimizationRemark OR) { 1457 return OR << "OpenMP runtime call " 1458 << ore::NV("OpenMPOptRuntime", RFI.Name) << " deduplicated"; 1459 }; 1460 emitRemark<OptimizationRemark>(CI, "OpenMPRuntimeDeduplicated", Remark); 1461 1462 CGUpdater.removeCallSite(*CI); 1463 CI->replaceAllUsesWith(ReplVal); 1464 CI->eraseFromParent(); 1465 ++NumOpenMPRuntimeCallsDeduplicated; 1466 Changed = true; 1467 return true; 1468 }; 1469 RFI.foreachUse(SCC, ReplaceAndDeleteCB); 1470 1471 return Changed; 1472 } 1473 1474 /// Collect arguments that represent the global thread id in \p GTIdArgs. 1475 void collectGlobalThreadIdArguments(SmallSetVector<Value *, 16> >IdArgs) { 1476 // TODO: Below we basically perform a fixpoint iteration with a pessimistic 1477 // initialization. We could define an AbstractAttribute instead and 1478 // run the Attributor here once it can be run as an SCC pass. 1479 1480 // Helper to check the argument \p ArgNo at all call sites of \p F for 1481 // a GTId. 1482 auto CallArgOpIsGTId = [&](Function &F, unsigned ArgNo, CallInst &RefCI) { 1483 if (!F.hasLocalLinkage()) 1484 return false; 1485 for (Use &U : F.uses()) { 1486 if (CallInst *CI = getCallIfRegularCall(U)) { 1487 Value *ArgOp = CI->getArgOperand(ArgNo); 1488 if (CI == &RefCI || GTIdArgs.count(ArgOp) || 1489 getCallIfRegularCall( 1490 *ArgOp, &OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num])) 1491 continue; 1492 } 1493 return false; 1494 } 1495 return true; 1496 }; 1497 1498 // Helper to identify uses of a GTId as GTId arguments. 1499 auto AddUserArgs = [&](Value >Id) { 1500 for (Use &U : GTId.uses()) 1501 if (CallInst *CI = dyn_cast<CallInst>(U.getUser())) 1502 if (CI->isArgOperand(&U)) 1503 if (Function *Callee = CI->getCalledFunction()) 1504 if (CallArgOpIsGTId(*Callee, U.getOperandNo(), *CI)) 1505 GTIdArgs.insert(Callee->getArg(U.getOperandNo())); 1506 }; 1507 1508 // The argument users of __kmpc_global_thread_num calls are GTIds. 1509 OMPInformationCache::RuntimeFunctionInfo &GlobThreadNumRFI = 1510 OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num]; 1511 1512 GlobThreadNumRFI.foreachUse(SCC, [&](Use &U, Function &F) { 1513 if (CallInst *CI = getCallIfRegularCall(U, &GlobThreadNumRFI)) 1514 AddUserArgs(*CI); 1515 return false; 1516 }); 1517 1518 // Transitively search for more arguments by looking at the users of the 1519 // ones we know already. During the search the GTIdArgs vector is extended 1520 // so we cannot cache the size nor can we use a range based for. 1521 for (unsigned u = 0; u < GTIdArgs.size(); ++u) 1522 AddUserArgs(*GTIdArgs[u]); 1523 } 1524 1525 /// Kernel (=GPU) optimizations and utility functions 1526 /// 1527 ///{{ 1528 1529 /// Check if \p F is a kernel, hence entry point for target offloading. 1530 bool isKernel(Function &F) { return OMPInfoCache.Kernels.count(&F); } 1531 1532 /// Cache to remember the unique kernel for a function. 1533 DenseMap<Function *, Optional<Kernel>> UniqueKernelMap; 1534 1535 /// Find the unique kernel that will execute \p F, if any. 1536 Kernel getUniqueKernelFor(Function &F); 1537 1538 /// Find the unique kernel that will execute \p I, if any. 1539 Kernel getUniqueKernelFor(Instruction &I) { 1540 return getUniqueKernelFor(*I.getFunction()); 1541 } 1542 1543 /// Rewrite the device (=GPU) code state machine create in non-SPMD mode in 1544 /// the cases we can avoid taking the address of a function. 1545 bool rewriteDeviceCodeStateMachine(); 1546 1547 /// 1548 ///}} 1549 1550 /// Emit a remark generically 1551 /// 1552 /// This template function can be used to generically emit a remark. The 1553 /// RemarkKind should be one of the following: 1554 /// - OptimizationRemark to indicate a successful optimization attempt 1555 /// - OptimizationRemarkMissed to report a failed optimization attempt 1556 /// - OptimizationRemarkAnalysis to provide additional information about an 1557 /// optimization attempt 1558 /// 1559 /// The remark is built using a callback function provided by the caller that 1560 /// takes a RemarkKind as input and returns a RemarkKind. 1561 template <typename RemarkKind, 1562 typename RemarkCallBack = function_ref<RemarkKind(RemarkKind &&)>> 1563 void emitRemark(Instruction *Inst, StringRef RemarkName, 1564 RemarkCallBack &&RemarkCB) const { 1565 Function *F = Inst->getParent()->getParent(); 1566 auto &ORE = OREGetter(F); 1567 1568 ORE.emit( 1569 [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, Inst)); }); 1570 } 1571 1572 /// Emit a remark on a function. Since only OptimizationRemark is supporting 1573 /// this, it can't be made generic. 1574 void 1575 emitRemarkOnFunction(Function *F, StringRef RemarkName, 1576 function_ref<OptimizationRemark(OptimizationRemark &&)> 1577 &&RemarkCB) const { 1578 auto &ORE = OREGetter(F); 1579 1580 ORE.emit([&]() { 1581 return RemarkCB(OptimizationRemark(DEBUG_TYPE, RemarkName, F)); 1582 }); 1583 } 1584 1585 /// The underlying module. 1586 Module &M; 1587 1588 /// The SCC we are operating on. 1589 SmallVectorImpl<Function *> &SCC; 1590 1591 /// Callback to update the call graph, the first argument is a removed call, 1592 /// the second an optional replacement call. 1593 CallGraphUpdater &CGUpdater; 1594 1595 /// Callback to get an OptimizationRemarkEmitter from a Function * 1596 OptimizationRemarkGetter OREGetter; 1597 1598 /// OpenMP-specific information cache. Also Used for Attributor runs. 1599 OMPInformationCache &OMPInfoCache; 1600 1601 /// Attributor instance. 1602 Attributor &A; 1603 1604 /// Helper function to run Attributor on SCC. 1605 bool runAttributor() { 1606 if (SCC.empty()) 1607 return false; 1608 1609 registerAAs(); 1610 1611 ChangeStatus Changed = A.run(); 1612 1613 LLVM_DEBUG(dbgs() << "[Attributor] Done with " << SCC.size() 1614 << " functions, result: " << Changed << ".\n"); 1615 1616 return Changed == ChangeStatus::CHANGED; 1617 } 1618 1619 /// Populate the Attributor with abstract attribute opportunities in the 1620 /// function. 1621 void registerAAs() { 1622 if (SCC.empty()) 1623 return; 1624 1625 // Create CallSite AA for all Getters. 1626 for (int Idx = 0; Idx < OMPInfoCache.ICVs.size() - 1; ++Idx) { 1627 auto ICVInfo = OMPInfoCache.ICVs[static_cast<InternalControlVar>(Idx)]; 1628 1629 auto &GetterRFI = OMPInfoCache.RFIs[ICVInfo.Getter]; 1630 1631 auto CreateAA = [&](Use &U, Function &Caller) { 1632 CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &GetterRFI); 1633 if (!CI) 1634 return false; 1635 1636 auto &CB = cast<CallBase>(*CI); 1637 1638 IRPosition CBPos = IRPosition::callsite_function(CB); 1639 A.getOrCreateAAFor<AAICVTracker>(CBPos); 1640 return false; 1641 }; 1642 1643 GetterRFI.foreachUse(SCC, CreateAA); 1644 } 1645 1646 for (auto &F : M) { 1647 if (!F.isDeclaration()) 1648 A.getOrCreateAAFor<AAExecutionDomain>(IRPosition::function(F)); 1649 } 1650 } 1651 }; 1652 1653 Kernel OpenMPOpt::getUniqueKernelFor(Function &F) { 1654 if (!OMPInfoCache.ModuleSlice.count(&F)) 1655 return nullptr; 1656 1657 // Use a scope to keep the lifetime of the CachedKernel short. 1658 { 1659 Optional<Kernel> &CachedKernel = UniqueKernelMap[&F]; 1660 if (CachedKernel) 1661 return *CachedKernel; 1662 1663 // TODO: We should use an AA to create an (optimistic and callback 1664 // call-aware) call graph. For now we stick to simple patterns that 1665 // are less powerful, basically the worst fixpoint. 1666 if (isKernel(F)) { 1667 CachedKernel = Kernel(&F); 1668 return *CachedKernel; 1669 } 1670 1671 CachedKernel = nullptr; 1672 if (!F.hasLocalLinkage()) { 1673 1674 // See https://openmp.llvm.org/remarks/OptimizationRemarks.html 1675 auto Remark = [&](OptimizationRemark OR) { 1676 return OR << "[OMP100] Potentially unknown OpenMP target region caller"; 1677 }; 1678 emitRemarkOnFunction(&F, "OMP100", Remark); 1679 1680 return nullptr; 1681 } 1682 } 1683 1684 auto GetUniqueKernelForUse = [&](const Use &U) -> Kernel { 1685 if (auto *Cmp = dyn_cast<ICmpInst>(U.getUser())) { 1686 // Allow use in equality comparisons. 1687 if (Cmp->isEquality()) 1688 return getUniqueKernelFor(*Cmp); 1689 return nullptr; 1690 } 1691 if (auto *CB = dyn_cast<CallBase>(U.getUser())) { 1692 // Allow direct calls. 1693 if (CB->isCallee(&U)) 1694 return getUniqueKernelFor(*CB); 1695 1696 OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI = 1697 OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51]; 1698 // Allow the use in __kmpc_parallel_51 calls. 1699 if (OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI)) 1700 return getUniqueKernelFor(*CB); 1701 return nullptr; 1702 } 1703 // Disallow every other use. 1704 return nullptr; 1705 }; 1706 1707 // TODO: In the future we want to track more than just a unique kernel. 1708 SmallPtrSet<Kernel, 2> PotentialKernels; 1709 OMPInformationCache::foreachUse(F, [&](const Use &U) { 1710 PotentialKernels.insert(GetUniqueKernelForUse(U)); 1711 }); 1712 1713 Kernel K = nullptr; 1714 if (PotentialKernels.size() == 1) 1715 K = *PotentialKernels.begin(); 1716 1717 // Cache the result. 1718 UniqueKernelMap[&F] = K; 1719 1720 return K; 1721 } 1722 1723 bool OpenMPOpt::rewriteDeviceCodeStateMachine() { 1724 OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI = 1725 OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51]; 1726 1727 bool Changed = false; 1728 if (!KernelParallelRFI) 1729 return Changed; 1730 1731 for (Function *F : SCC) { 1732 1733 // Check if the function is a use in a __kmpc_parallel_51 call at 1734 // all. 1735 bool UnknownUse = false; 1736 bool KernelParallelUse = false; 1737 unsigned NumDirectCalls = 0; 1738 1739 SmallVector<Use *, 2> ToBeReplacedStateMachineUses; 1740 OMPInformationCache::foreachUse(*F, [&](Use &U) { 1741 if (auto *CB = dyn_cast<CallBase>(U.getUser())) 1742 if (CB->isCallee(&U)) { 1743 ++NumDirectCalls; 1744 return; 1745 } 1746 1747 if (isa<ICmpInst>(U.getUser())) { 1748 ToBeReplacedStateMachineUses.push_back(&U); 1749 return; 1750 } 1751 1752 // Find wrapper functions that represent parallel kernels. 1753 CallInst *CI = 1754 OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI); 1755 const unsigned int WrapperFunctionArgNo = 6; 1756 if (!KernelParallelUse && CI && 1757 CI->getArgOperandNo(&U) == WrapperFunctionArgNo) { 1758 KernelParallelUse = true; 1759 ToBeReplacedStateMachineUses.push_back(&U); 1760 return; 1761 } 1762 UnknownUse = true; 1763 }); 1764 1765 // Do not emit a remark if we haven't seen a __kmpc_parallel_51 1766 // use. 1767 if (!KernelParallelUse) 1768 continue; 1769 1770 { 1771 auto Remark = [&](OptimizationRemark OR) { 1772 return OR << "Found a parallel region that is called in a target " 1773 "region but not part of a combined target construct nor " 1774 "nested inside a target construct without intermediate " 1775 "code. This can lead to excessive register usage for " 1776 "unrelated target regions in the same translation unit " 1777 "due to spurious call edges assumed by ptxas."; 1778 }; 1779 emitRemarkOnFunction(F, "OpenMPParallelRegionInNonSPMD", Remark); 1780 } 1781 1782 // If this ever hits, we should investigate. 1783 // TODO: Checking the number of uses is not a necessary restriction and 1784 // should be lifted. 1785 if (UnknownUse || NumDirectCalls != 1 || 1786 ToBeReplacedStateMachineUses.size() != 2) { 1787 { 1788 auto Remark = [&](OptimizationRemark OR) { 1789 return OR << "Parallel region is used in " 1790 << (UnknownUse ? "unknown" : "unexpected") 1791 << " ways; will not attempt to rewrite the state machine."; 1792 }; 1793 emitRemarkOnFunction(F, "OpenMPParallelRegionInNonSPMD", Remark); 1794 } 1795 continue; 1796 } 1797 1798 // Even if we have __kmpc_parallel_51 calls, we (for now) give 1799 // up if the function is not called from a unique kernel. 1800 Kernel K = getUniqueKernelFor(*F); 1801 if (!K) { 1802 { 1803 auto Remark = [&](OptimizationRemark OR) { 1804 return OR << "Parallel region is not known to be called from a " 1805 "unique single target region, maybe the surrounding " 1806 "function has external linkage?; will not attempt to " 1807 "rewrite the state machine use."; 1808 }; 1809 emitRemarkOnFunction(F, "OpenMPParallelRegionInMultipleKernesl", 1810 Remark); 1811 } 1812 continue; 1813 } 1814 1815 // We now know F is a parallel body function called only from the kernel K. 1816 // We also identified the state machine uses in which we replace the 1817 // function pointer by a new global symbol for identification purposes. This 1818 // ensures only direct calls to the function are left. 1819 1820 { 1821 auto RemarkParalleRegion = [&](OptimizationRemark OR) { 1822 return OR << "Specialize parallel region that is only reached from a " 1823 "single target region to avoid spurious call edges and " 1824 "excessive register usage in other target regions. " 1825 "(parallel region ID: " 1826 << ore::NV("OpenMPParallelRegion", F->getName()) 1827 << ", kernel ID: " 1828 << ore::NV("OpenMPTargetRegion", K->getName()) << ")"; 1829 }; 1830 emitRemarkOnFunction(F, "OpenMPParallelRegionInNonSPMD", 1831 RemarkParalleRegion); 1832 auto RemarkKernel = [&](OptimizationRemark OR) { 1833 return OR << "Target region containing the parallel region that is " 1834 "specialized. (parallel region ID: " 1835 << ore::NV("OpenMPParallelRegion", F->getName()) 1836 << ", kernel ID: " 1837 << ore::NV("OpenMPTargetRegion", K->getName()) << ")"; 1838 }; 1839 emitRemarkOnFunction(K, "OpenMPParallelRegionInNonSPMD", RemarkKernel); 1840 } 1841 1842 Module &M = *F->getParent(); 1843 Type *Int8Ty = Type::getInt8Ty(M.getContext()); 1844 1845 auto *ID = new GlobalVariable( 1846 M, Int8Ty, /* isConstant */ true, GlobalValue::PrivateLinkage, 1847 UndefValue::get(Int8Ty), F->getName() + ".ID"); 1848 1849 for (Use *U : ToBeReplacedStateMachineUses) 1850 U->set(ConstantExpr::getBitCast(ID, U->get()->getType())); 1851 1852 ++NumOpenMPParallelRegionsReplacedInGPUStateMachine; 1853 1854 Changed = true; 1855 } 1856 1857 return Changed; 1858 } 1859 1860 /// Abstract Attribute for tracking ICV values. 1861 struct AAICVTracker : public StateWrapper<BooleanState, AbstractAttribute> { 1862 using Base = StateWrapper<BooleanState, AbstractAttribute>; 1863 AAICVTracker(const IRPosition &IRP, Attributor &A) : Base(IRP) {} 1864 1865 void initialize(Attributor &A) override { 1866 Function *F = getAnchorScope(); 1867 if (!F || !A.isFunctionIPOAmendable(*F)) 1868 indicatePessimisticFixpoint(); 1869 } 1870 1871 /// Returns true if value is assumed to be tracked. 1872 bool isAssumedTracked() const { return getAssumed(); } 1873 1874 /// Returns true if value is known to be tracked. 1875 bool isKnownTracked() const { return getAssumed(); } 1876 1877 /// Create an abstract attribute biew for the position \p IRP. 1878 static AAICVTracker &createForPosition(const IRPosition &IRP, Attributor &A); 1879 1880 /// Return the value with which \p I can be replaced for specific \p ICV. 1881 virtual Optional<Value *> getReplacementValue(InternalControlVar ICV, 1882 const Instruction *I, 1883 Attributor &A) const { 1884 return None; 1885 } 1886 1887 /// Return an assumed unique ICV value if a single candidate is found. If 1888 /// there cannot be one, return a nullptr. If it is not clear yet, return the 1889 /// Optional::NoneType. 1890 virtual Optional<Value *> 1891 getUniqueReplacementValue(InternalControlVar ICV) const = 0; 1892 1893 // Currently only nthreads is being tracked. 1894 // this array will only grow with time. 1895 InternalControlVar TrackableICVs[1] = {ICV_nthreads}; 1896 1897 /// See AbstractAttribute::getName() 1898 const std::string getName() const override { return "AAICVTracker"; } 1899 1900 /// See AbstractAttribute::getIdAddr() 1901 const char *getIdAddr() const override { return &ID; } 1902 1903 /// This function should return true if the type of the \p AA is AAICVTracker 1904 static bool classof(const AbstractAttribute *AA) { 1905 return (AA->getIdAddr() == &ID); 1906 } 1907 1908 static const char ID; 1909 }; 1910 1911 struct AAICVTrackerFunction : public AAICVTracker { 1912 AAICVTrackerFunction(const IRPosition &IRP, Attributor &A) 1913 : AAICVTracker(IRP, A) {} 1914 1915 // FIXME: come up with better string. 1916 const std::string getAsStr() const override { return "ICVTrackerFunction"; } 1917 1918 // FIXME: come up with some stats. 1919 void trackStatistics() const override {} 1920 1921 /// We don't manifest anything for this AA. 1922 ChangeStatus manifest(Attributor &A) override { 1923 return ChangeStatus::UNCHANGED; 1924 } 1925 1926 // Map of ICV to their values at specific program point. 1927 EnumeratedArray<DenseMap<Instruction *, Value *>, InternalControlVar, 1928 InternalControlVar::ICV___last> 1929 ICVReplacementValuesMap; 1930 1931 ChangeStatus updateImpl(Attributor &A) override { 1932 ChangeStatus HasChanged = ChangeStatus::UNCHANGED; 1933 1934 Function *F = getAnchorScope(); 1935 1936 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); 1937 1938 for (InternalControlVar ICV : TrackableICVs) { 1939 auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter]; 1940 1941 auto &ValuesMap = ICVReplacementValuesMap[ICV]; 1942 auto TrackValues = [&](Use &U, Function &) { 1943 CallInst *CI = OpenMPOpt::getCallIfRegularCall(U); 1944 if (!CI) 1945 return false; 1946 1947 // FIXME: handle setters with more that 1 arguments. 1948 /// Track new value. 1949 if (ValuesMap.insert(std::make_pair(CI, CI->getArgOperand(0))).second) 1950 HasChanged = ChangeStatus::CHANGED; 1951 1952 return false; 1953 }; 1954 1955 auto CallCheck = [&](Instruction &I) { 1956 Optional<Value *> ReplVal = getValueForCall(A, &I, ICV); 1957 if (ReplVal.hasValue() && 1958 ValuesMap.insert(std::make_pair(&I, *ReplVal)).second) 1959 HasChanged = ChangeStatus::CHANGED; 1960 1961 return true; 1962 }; 1963 1964 // Track all changes of an ICV. 1965 SetterRFI.foreachUse(TrackValues, F); 1966 1967 A.checkForAllInstructions(CallCheck, *this, {Instruction::Call}, 1968 /* CheckBBLivenessOnly */ true); 1969 1970 /// TODO: Figure out a way to avoid adding entry in 1971 /// ICVReplacementValuesMap 1972 Instruction *Entry = &F->getEntryBlock().front(); 1973 if (HasChanged == ChangeStatus::CHANGED && !ValuesMap.count(Entry)) 1974 ValuesMap.insert(std::make_pair(Entry, nullptr)); 1975 } 1976 1977 return HasChanged; 1978 } 1979 1980 /// Hepler to check if \p I is a call and get the value for it if it is 1981 /// unique. 1982 Optional<Value *> getValueForCall(Attributor &A, const Instruction *I, 1983 InternalControlVar &ICV) const { 1984 1985 const auto *CB = dyn_cast<CallBase>(I); 1986 if (!CB || CB->hasFnAttr("no_openmp") || 1987 CB->hasFnAttr("no_openmp_routines")) 1988 return None; 1989 1990 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); 1991 auto &GetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Getter]; 1992 auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter]; 1993 Function *CalledFunction = CB->getCalledFunction(); 1994 1995 // Indirect call, assume ICV changes. 1996 if (CalledFunction == nullptr) 1997 return nullptr; 1998 if (CalledFunction == GetterRFI.Declaration) 1999 return None; 2000 if (CalledFunction == SetterRFI.Declaration) { 2001 if (ICVReplacementValuesMap[ICV].count(I)) 2002 return ICVReplacementValuesMap[ICV].lookup(I); 2003 2004 return nullptr; 2005 } 2006 2007 // Since we don't know, assume it changes the ICV. 2008 if (CalledFunction->isDeclaration()) 2009 return nullptr; 2010 2011 const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>( 2012 *this, IRPosition::callsite_returned(*CB), DepClassTy::REQUIRED); 2013 2014 if (ICVTrackingAA.isAssumedTracked()) 2015 return ICVTrackingAA.getUniqueReplacementValue(ICV); 2016 2017 // If we don't know, assume it changes. 2018 return nullptr; 2019 } 2020 2021 // We don't check unique value for a function, so return None. 2022 Optional<Value *> 2023 getUniqueReplacementValue(InternalControlVar ICV) const override { 2024 return None; 2025 } 2026 2027 /// Return the value with which \p I can be replaced for specific \p ICV. 2028 Optional<Value *> getReplacementValue(InternalControlVar ICV, 2029 const Instruction *I, 2030 Attributor &A) const override { 2031 const auto &ValuesMap = ICVReplacementValuesMap[ICV]; 2032 if (ValuesMap.count(I)) 2033 return ValuesMap.lookup(I); 2034 2035 SmallVector<const Instruction *, 16> Worklist; 2036 SmallPtrSet<const Instruction *, 16> Visited; 2037 Worklist.push_back(I); 2038 2039 Optional<Value *> ReplVal; 2040 2041 while (!Worklist.empty()) { 2042 const Instruction *CurrInst = Worklist.pop_back_val(); 2043 if (!Visited.insert(CurrInst).second) 2044 continue; 2045 2046 const BasicBlock *CurrBB = CurrInst->getParent(); 2047 2048 // Go up and look for all potential setters/calls that might change the 2049 // ICV. 2050 while ((CurrInst = CurrInst->getPrevNode())) { 2051 if (ValuesMap.count(CurrInst)) { 2052 Optional<Value *> NewReplVal = ValuesMap.lookup(CurrInst); 2053 // Unknown value, track new. 2054 if (!ReplVal.hasValue()) { 2055 ReplVal = NewReplVal; 2056 break; 2057 } 2058 2059 // If we found a new value, we can't know the icv value anymore. 2060 if (NewReplVal.hasValue()) 2061 if (ReplVal != NewReplVal) 2062 return nullptr; 2063 2064 break; 2065 } 2066 2067 Optional<Value *> NewReplVal = getValueForCall(A, CurrInst, ICV); 2068 if (!NewReplVal.hasValue()) 2069 continue; 2070 2071 // Unknown value, track new. 2072 if (!ReplVal.hasValue()) { 2073 ReplVal = NewReplVal; 2074 break; 2075 } 2076 2077 // if (NewReplVal.hasValue()) 2078 // We found a new value, we can't know the icv value anymore. 2079 if (ReplVal != NewReplVal) 2080 return nullptr; 2081 } 2082 2083 // If we are in the same BB and we have a value, we are done. 2084 if (CurrBB == I->getParent() && ReplVal.hasValue()) 2085 return ReplVal; 2086 2087 // Go through all predecessors and add terminators for analysis. 2088 for (const BasicBlock *Pred : predecessors(CurrBB)) 2089 if (const Instruction *Terminator = Pred->getTerminator()) 2090 Worklist.push_back(Terminator); 2091 } 2092 2093 return ReplVal; 2094 } 2095 }; 2096 2097 struct AAICVTrackerFunctionReturned : AAICVTracker { 2098 AAICVTrackerFunctionReturned(const IRPosition &IRP, Attributor &A) 2099 : AAICVTracker(IRP, A) {} 2100 2101 // FIXME: come up with better string. 2102 const std::string getAsStr() const override { 2103 return "ICVTrackerFunctionReturned"; 2104 } 2105 2106 // FIXME: come up with some stats. 2107 void trackStatistics() const override {} 2108 2109 /// We don't manifest anything for this AA. 2110 ChangeStatus manifest(Attributor &A) override { 2111 return ChangeStatus::UNCHANGED; 2112 } 2113 2114 // Map of ICV to their values at specific program point. 2115 EnumeratedArray<Optional<Value *>, InternalControlVar, 2116 InternalControlVar::ICV___last> 2117 ICVReplacementValuesMap; 2118 2119 /// Return the value with which \p I can be replaced for specific \p ICV. 2120 Optional<Value *> 2121 getUniqueReplacementValue(InternalControlVar ICV) const override { 2122 return ICVReplacementValuesMap[ICV]; 2123 } 2124 2125 ChangeStatus updateImpl(Attributor &A) override { 2126 ChangeStatus Changed = ChangeStatus::UNCHANGED; 2127 const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>( 2128 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED); 2129 2130 if (!ICVTrackingAA.isAssumedTracked()) 2131 return indicatePessimisticFixpoint(); 2132 2133 for (InternalControlVar ICV : TrackableICVs) { 2134 Optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV]; 2135 Optional<Value *> UniqueICVValue; 2136 2137 auto CheckReturnInst = [&](Instruction &I) { 2138 Optional<Value *> NewReplVal = 2139 ICVTrackingAA.getReplacementValue(ICV, &I, A); 2140 2141 // If we found a second ICV value there is no unique returned value. 2142 if (UniqueICVValue.hasValue() && UniqueICVValue != NewReplVal) 2143 return false; 2144 2145 UniqueICVValue = NewReplVal; 2146 2147 return true; 2148 }; 2149 2150 if (!A.checkForAllInstructions(CheckReturnInst, *this, {Instruction::Ret}, 2151 /* CheckBBLivenessOnly */ true)) 2152 UniqueICVValue = nullptr; 2153 2154 if (UniqueICVValue == ReplVal) 2155 continue; 2156 2157 ReplVal = UniqueICVValue; 2158 Changed = ChangeStatus::CHANGED; 2159 } 2160 2161 return Changed; 2162 } 2163 }; 2164 2165 struct AAICVTrackerCallSite : AAICVTracker { 2166 AAICVTrackerCallSite(const IRPosition &IRP, Attributor &A) 2167 : AAICVTracker(IRP, A) {} 2168 2169 void initialize(Attributor &A) override { 2170 Function *F = getAnchorScope(); 2171 if (!F || !A.isFunctionIPOAmendable(*F)) 2172 indicatePessimisticFixpoint(); 2173 2174 // We only initialize this AA for getters, so we need to know which ICV it 2175 // gets. 2176 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); 2177 for (InternalControlVar ICV : TrackableICVs) { 2178 auto ICVInfo = OMPInfoCache.ICVs[ICV]; 2179 auto &Getter = OMPInfoCache.RFIs[ICVInfo.Getter]; 2180 if (Getter.Declaration == getAssociatedFunction()) { 2181 AssociatedICV = ICVInfo.Kind; 2182 return; 2183 } 2184 } 2185 2186 /// Unknown ICV. 2187 indicatePessimisticFixpoint(); 2188 } 2189 2190 ChangeStatus manifest(Attributor &A) override { 2191 if (!ReplVal.hasValue() || !ReplVal.getValue()) 2192 return ChangeStatus::UNCHANGED; 2193 2194 A.changeValueAfterManifest(*getCtxI(), **ReplVal); 2195 A.deleteAfterManifest(*getCtxI()); 2196 2197 return ChangeStatus::CHANGED; 2198 } 2199 2200 // FIXME: come up with better string. 2201 const std::string getAsStr() const override { return "ICVTrackerCallSite"; } 2202 2203 // FIXME: come up with some stats. 2204 void trackStatistics() const override {} 2205 2206 InternalControlVar AssociatedICV; 2207 Optional<Value *> ReplVal; 2208 2209 ChangeStatus updateImpl(Attributor &A) override { 2210 const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>( 2211 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED); 2212 2213 // We don't have any information, so we assume it changes the ICV. 2214 if (!ICVTrackingAA.isAssumedTracked()) 2215 return indicatePessimisticFixpoint(); 2216 2217 Optional<Value *> NewReplVal = 2218 ICVTrackingAA.getReplacementValue(AssociatedICV, getCtxI(), A); 2219 2220 if (ReplVal == NewReplVal) 2221 return ChangeStatus::UNCHANGED; 2222 2223 ReplVal = NewReplVal; 2224 return ChangeStatus::CHANGED; 2225 } 2226 2227 // Return the value with which associated value can be replaced for specific 2228 // \p ICV. 2229 Optional<Value *> 2230 getUniqueReplacementValue(InternalControlVar ICV) const override { 2231 return ReplVal; 2232 } 2233 }; 2234 2235 struct AAICVTrackerCallSiteReturned : AAICVTracker { 2236 AAICVTrackerCallSiteReturned(const IRPosition &IRP, Attributor &A) 2237 : AAICVTracker(IRP, A) {} 2238 2239 // FIXME: come up with better string. 2240 const std::string getAsStr() const override { 2241 return "ICVTrackerCallSiteReturned"; 2242 } 2243 2244 // FIXME: come up with some stats. 2245 void trackStatistics() const override {} 2246 2247 /// We don't manifest anything for this AA. 2248 ChangeStatus manifest(Attributor &A) override { 2249 return ChangeStatus::UNCHANGED; 2250 } 2251 2252 // Map of ICV to their values at specific program point. 2253 EnumeratedArray<Optional<Value *>, InternalControlVar, 2254 InternalControlVar::ICV___last> 2255 ICVReplacementValuesMap; 2256 2257 /// Return the value with which associated value can be replaced for specific 2258 /// \p ICV. 2259 Optional<Value *> 2260 getUniqueReplacementValue(InternalControlVar ICV) const override { 2261 return ICVReplacementValuesMap[ICV]; 2262 } 2263 2264 ChangeStatus updateImpl(Attributor &A) override { 2265 ChangeStatus Changed = ChangeStatus::UNCHANGED; 2266 const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>( 2267 *this, IRPosition::returned(*getAssociatedFunction()), 2268 DepClassTy::REQUIRED); 2269 2270 // We don't have any information, so we assume it changes the ICV. 2271 if (!ICVTrackingAA.isAssumedTracked()) 2272 return indicatePessimisticFixpoint(); 2273 2274 for (InternalControlVar ICV : TrackableICVs) { 2275 Optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV]; 2276 Optional<Value *> NewReplVal = 2277 ICVTrackingAA.getUniqueReplacementValue(ICV); 2278 2279 if (ReplVal == NewReplVal) 2280 continue; 2281 2282 ReplVal = NewReplVal; 2283 Changed = ChangeStatus::CHANGED; 2284 } 2285 return Changed; 2286 } 2287 }; 2288 2289 struct AAExecutionDomainFunction : public AAExecutionDomain { 2290 AAExecutionDomainFunction(const IRPosition &IRP, Attributor &A) 2291 : AAExecutionDomain(IRP, A) {} 2292 2293 const std::string getAsStr() const override { 2294 return "[AAExecutionDomain] " + std::to_string(SingleThreadedBBs.size()) + 2295 "/" + std::to_string(NumBBs) + " BBs thread 0 only."; 2296 } 2297 2298 /// See AbstractAttribute::trackStatistics(). 2299 void trackStatistics() const override {} 2300 2301 void initialize(Attributor &A) override { 2302 Function *F = getAnchorScope(); 2303 for (const auto &BB : *F) 2304 SingleThreadedBBs.insert(&BB); 2305 NumBBs = SingleThreadedBBs.size(); 2306 } 2307 2308 ChangeStatus manifest(Attributor &A) override { 2309 LLVM_DEBUG({ 2310 for (const BasicBlock *BB : SingleThreadedBBs) 2311 dbgs() << TAG << " Basic block @" << getAnchorScope()->getName() << " " 2312 << BB->getName() << " is executed by a single thread.\n"; 2313 }); 2314 return ChangeStatus::UNCHANGED; 2315 } 2316 2317 ChangeStatus updateImpl(Attributor &A) override; 2318 2319 /// Check if an instruction is executed by a single thread. 2320 bool isSingleThreadExecution(const Instruction &I) const override { 2321 return isSingleThreadExecution(*I.getParent()); 2322 } 2323 2324 bool isSingleThreadExecution(const BasicBlock &BB) const override { 2325 return SingleThreadedBBs.contains(&BB); 2326 } 2327 2328 /// Set of basic blocks that are executed by a single thread. 2329 DenseSet<const BasicBlock *> SingleThreadedBBs; 2330 2331 /// Total number of basic blocks in this function. 2332 long unsigned NumBBs; 2333 }; 2334 2335 ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) { 2336 Function *F = getAnchorScope(); 2337 ReversePostOrderTraversal<Function *> RPOT(F); 2338 auto NumSingleThreadedBBs = SingleThreadedBBs.size(); 2339 2340 bool AllCallSitesKnown; 2341 auto PredForCallSite = [&](AbstractCallSite ACS) { 2342 const auto &ExecutionDomainAA = A.getAAFor<AAExecutionDomain>( 2343 *this, IRPosition::function(*ACS.getInstruction()->getFunction()), 2344 DepClassTy::REQUIRED); 2345 return ExecutionDomainAA.isSingleThreadExecution(*ACS.getInstruction()); 2346 }; 2347 2348 if (!A.checkForAllCallSites(PredForCallSite, *this, 2349 /* RequiresAllCallSites */ true, 2350 AllCallSitesKnown)) 2351 SingleThreadedBBs.erase(&F->getEntryBlock()); 2352 2353 // Check if the edge into the successor block compares a thread-id function to 2354 // a constant zero. 2355 // TODO: Use AAValueSimplify to simplify and propogate constants. 2356 // TODO: Check more than a single use for thread ID's. 2357 auto IsSingleThreadOnly = [&](BranchInst *Edge, BasicBlock *SuccessorBB) { 2358 if (!Edge || !Edge->isConditional()) 2359 return false; 2360 if (Edge->getSuccessor(0) != SuccessorBB) 2361 return false; 2362 2363 auto *Cmp = dyn_cast<CmpInst>(Edge->getCondition()); 2364 if (!Cmp || !Cmp->isTrueWhenEqual() || !Cmp->isEquality()) 2365 return false; 2366 2367 ConstantInt *C = dyn_cast<ConstantInt>(Cmp->getOperand(1)); 2368 if (!C || !C->isZero()) 2369 return false; 2370 2371 if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0))) 2372 if (II->getIntrinsicID() == Intrinsic::nvvm_read_ptx_sreg_tid_x) 2373 return true; 2374 if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0))) 2375 if (II->getIntrinsicID() == Intrinsic::amdgcn_workitem_id_x) 2376 return true; 2377 2378 return false; 2379 }; 2380 2381 // Merge all the predecessor states into the current basic block. A basic 2382 // block is executed by a single thread if all of its predecessors are. 2383 auto MergePredecessorStates = [&](BasicBlock *BB) { 2384 if (pred_begin(BB) == pred_end(BB)) 2385 return SingleThreadedBBs.contains(BB); 2386 2387 bool IsSingleThreaded = true; 2388 for (auto PredBB = pred_begin(BB), PredEndBB = pred_end(BB); 2389 PredBB != PredEndBB; ++PredBB) { 2390 if (!IsSingleThreadOnly(dyn_cast<BranchInst>((*PredBB)->getTerminator()), 2391 BB)) 2392 IsSingleThreaded &= SingleThreadedBBs.contains(*PredBB); 2393 } 2394 2395 return IsSingleThreaded; 2396 }; 2397 2398 for (auto *BB : RPOT) { 2399 if (!MergePredecessorStates(BB)) 2400 SingleThreadedBBs.erase(BB); 2401 } 2402 2403 return (NumSingleThreadedBBs == SingleThreadedBBs.size()) 2404 ? ChangeStatus::UNCHANGED 2405 : ChangeStatus::CHANGED; 2406 } 2407 2408 } // namespace 2409 2410 const char AAICVTracker::ID = 0; 2411 const char AAExecutionDomain::ID = 0; 2412 2413 AAICVTracker &AAICVTracker::createForPosition(const IRPosition &IRP, 2414 Attributor &A) { 2415 AAICVTracker *AA = nullptr; 2416 switch (IRP.getPositionKind()) { 2417 case IRPosition::IRP_INVALID: 2418 case IRPosition::IRP_FLOAT: 2419 case IRPosition::IRP_ARGUMENT: 2420 case IRPosition::IRP_CALL_SITE_ARGUMENT: 2421 llvm_unreachable("ICVTracker can only be created for function position!"); 2422 case IRPosition::IRP_RETURNED: 2423 AA = new (A.Allocator) AAICVTrackerFunctionReturned(IRP, A); 2424 break; 2425 case IRPosition::IRP_CALL_SITE_RETURNED: 2426 AA = new (A.Allocator) AAICVTrackerCallSiteReturned(IRP, A); 2427 break; 2428 case IRPosition::IRP_CALL_SITE: 2429 AA = new (A.Allocator) AAICVTrackerCallSite(IRP, A); 2430 break; 2431 case IRPosition::IRP_FUNCTION: 2432 AA = new (A.Allocator) AAICVTrackerFunction(IRP, A); 2433 break; 2434 } 2435 2436 return *AA; 2437 } 2438 2439 AAExecutionDomain &AAExecutionDomain::createForPosition(const IRPosition &IRP, 2440 Attributor &A) { 2441 AAExecutionDomainFunction *AA = nullptr; 2442 switch (IRP.getPositionKind()) { 2443 case IRPosition::IRP_INVALID: 2444 case IRPosition::IRP_FLOAT: 2445 case IRPosition::IRP_ARGUMENT: 2446 case IRPosition::IRP_CALL_SITE_ARGUMENT: 2447 case IRPosition::IRP_RETURNED: 2448 case IRPosition::IRP_CALL_SITE_RETURNED: 2449 case IRPosition::IRP_CALL_SITE: 2450 llvm_unreachable( 2451 "AAExecutionDomain can only be created for function position!"); 2452 case IRPosition::IRP_FUNCTION: 2453 AA = new (A.Allocator) AAExecutionDomainFunction(IRP, A); 2454 break; 2455 } 2456 2457 return *AA; 2458 } 2459 2460 PreservedAnalyses OpenMPOptPass::run(Module &M, ModuleAnalysisManager &AM) { 2461 if (!containsOpenMP(M, OMPInModule)) 2462 return PreservedAnalyses::all(); 2463 2464 if (DisableOpenMPOptimizations) 2465 return PreservedAnalyses::all(); 2466 2467 // Look at every function definition in the Module. 2468 SmallVector<Function *, 16> SCC; 2469 for (Function &Fn : M) 2470 if (!Fn.isDeclaration()) 2471 SCC.push_back(&Fn); 2472 2473 if (SCC.empty()) 2474 return PreservedAnalyses::all(); 2475 2476 FunctionAnalysisManager &FAM = 2477 AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); 2478 2479 AnalysisGetter AG(FAM); 2480 2481 auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & { 2482 return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F); 2483 }; 2484 2485 BumpPtrAllocator Allocator; 2486 CallGraphUpdater CGUpdater; 2487 2488 SetVector<Function *> Functions(SCC.begin(), SCC.end()); 2489 OMPInformationCache InfoCache(M, AG, Allocator, /*CGSCC*/ Functions, 2490 OMPInModule.getKernels()); 2491 2492 Attributor A(Functions, InfoCache, CGUpdater); 2493 2494 OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A); 2495 bool Changed = OMPOpt.run(true); 2496 if (Changed) 2497 return PreservedAnalyses::none(); 2498 2499 return PreservedAnalyses::all(); 2500 } 2501 2502 PreservedAnalyses OpenMPOptCGSCCPass::run(LazyCallGraph::SCC &C, 2503 CGSCCAnalysisManager &AM, 2504 LazyCallGraph &CG, 2505 CGSCCUpdateResult &UR) { 2506 if (!containsOpenMP(*C.begin()->getFunction().getParent(), OMPInModule)) 2507 return PreservedAnalyses::all(); 2508 2509 if (DisableOpenMPOptimizations) 2510 return PreservedAnalyses::all(); 2511 2512 SmallVector<Function *, 16> SCC; 2513 // If there are kernels in the module, we have to run on all SCC's. 2514 bool SCCIsInteresting = !OMPInModule.getKernels().empty(); 2515 for (LazyCallGraph::Node &N : C) { 2516 Function *Fn = &N.getFunction(); 2517 SCC.push_back(Fn); 2518 2519 // Do we already know that the SCC contains kernels, 2520 // or that OpenMP functions are called from this SCC? 2521 if (SCCIsInteresting) 2522 continue; 2523 // If not, let's check that. 2524 SCCIsInteresting |= OMPInModule.containsOMPRuntimeCalls(Fn); 2525 } 2526 2527 if (!SCCIsInteresting || SCC.empty()) 2528 return PreservedAnalyses::all(); 2529 2530 FunctionAnalysisManager &FAM = 2531 AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager(); 2532 2533 AnalysisGetter AG(FAM); 2534 2535 auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & { 2536 return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F); 2537 }; 2538 2539 BumpPtrAllocator Allocator; 2540 CallGraphUpdater CGUpdater; 2541 CGUpdater.initialize(CG, C, AM, UR); 2542 2543 SetVector<Function *> Functions(SCC.begin(), SCC.end()); 2544 OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG, Allocator, 2545 /*CGSCC*/ Functions, OMPInModule.getKernels()); 2546 2547 Attributor A(Functions, InfoCache, CGUpdater, nullptr, false); 2548 2549 OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A); 2550 bool Changed = OMPOpt.run(false); 2551 if (Changed) 2552 return PreservedAnalyses::none(); 2553 2554 return PreservedAnalyses::all(); 2555 } 2556 2557 namespace { 2558 2559 struct OpenMPOptCGSCCLegacyPass : public CallGraphSCCPass { 2560 CallGraphUpdater CGUpdater; 2561 OpenMPInModule OMPInModule; 2562 static char ID; 2563 2564 OpenMPOptCGSCCLegacyPass() : CallGraphSCCPass(ID) { 2565 initializeOpenMPOptCGSCCLegacyPassPass(*PassRegistry::getPassRegistry()); 2566 } 2567 2568 void getAnalysisUsage(AnalysisUsage &AU) const override { 2569 CallGraphSCCPass::getAnalysisUsage(AU); 2570 } 2571 2572 bool doInitialization(CallGraph &CG) override { 2573 // Disable the pass if there is no OpenMP (runtime call) in the module. 2574 containsOpenMP(CG.getModule(), OMPInModule); 2575 return false; 2576 } 2577 2578 bool runOnSCC(CallGraphSCC &CGSCC) override { 2579 if (!containsOpenMP(CGSCC.getCallGraph().getModule(), OMPInModule)) 2580 return false; 2581 if (DisableOpenMPOptimizations || skipSCC(CGSCC)) 2582 return false; 2583 2584 SmallVector<Function *, 16> SCC; 2585 // If there are kernels in the module, we have to run on all SCC's. 2586 bool SCCIsInteresting = !OMPInModule.getKernels().empty(); 2587 for (CallGraphNode *CGN : CGSCC) { 2588 Function *Fn = CGN->getFunction(); 2589 if (!Fn || Fn->isDeclaration()) 2590 continue; 2591 SCC.push_back(Fn); 2592 2593 // Do we already know that the SCC contains kernels, 2594 // or that OpenMP functions are called from this SCC? 2595 if (SCCIsInteresting) 2596 continue; 2597 // If not, let's check that. 2598 SCCIsInteresting |= OMPInModule.containsOMPRuntimeCalls(Fn); 2599 } 2600 2601 if (!SCCIsInteresting || SCC.empty()) 2602 return false; 2603 2604 CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph(); 2605 CGUpdater.initialize(CG, CGSCC); 2606 2607 // Maintain a map of functions to avoid rebuilding the ORE 2608 DenseMap<Function *, std::unique_ptr<OptimizationRemarkEmitter>> OREMap; 2609 auto OREGetter = [&OREMap](Function *F) -> OptimizationRemarkEmitter & { 2610 std::unique_ptr<OptimizationRemarkEmitter> &ORE = OREMap[F]; 2611 if (!ORE) 2612 ORE = std::make_unique<OptimizationRemarkEmitter>(F); 2613 return *ORE; 2614 }; 2615 2616 AnalysisGetter AG; 2617 SetVector<Function *> Functions(SCC.begin(), SCC.end()); 2618 BumpPtrAllocator Allocator; 2619 OMPInformationCache InfoCache( 2620 *(Functions.back()->getParent()), AG, Allocator, 2621 /*CGSCC*/ Functions, OMPInModule.getKernels()); 2622 2623 Attributor A(Functions, InfoCache, CGUpdater, nullptr, false); 2624 2625 OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A); 2626 return OMPOpt.run(false); 2627 } 2628 2629 bool doFinalization(CallGraph &CG) override { return CGUpdater.finalize(); } 2630 }; 2631 2632 } // end anonymous namespace 2633 2634 void OpenMPInModule::identifyKernels(Module &M) { 2635 2636 NamedMDNode *MD = M.getOrInsertNamedMetadata("nvvm.annotations"); 2637 if (!MD) 2638 return; 2639 2640 for (auto *Op : MD->operands()) { 2641 if (Op->getNumOperands() < 2) 2642 continue; 2643 MDString *KindID = dyn_cast<MDString>(Op->getOperand(1)); 2644 if (!KindID || KindID->getString() != "kernel") 2645 continue; 2646 2647 Function *KernelFn = 2648 mdconst::dyn_extract_or_null<Function>(Op->getOperand(0)); 2649 if (!KernelFn) 2650 continue; 2651 2652 ++NumOpenMPTargetRegionKernels; 2653 2654 Kernels.insert(KernelFn); 2655 } 2656 } 2657 2658 bool llvm::omp::containsOpenMP(Module &M, OpenMPInModule &OMPInModule) { 2659 if (OMPInModule.isKnown()) 2660 return OMPInModule; 2661 2662 auto RecordFunctionsContainingUsesOf = [&](Function *F) { 2663 for (User *U : F->users()) 2664 if (auto *I = dyn_cast<Instruction>(U)) 2665 OMPInModule.FuncsWithOMPRuntimeCalls.insert(I->getFunction()); 2666 }; 2667 2668 // MSVC doesn't like long if-else chains for some reason and instead just 2669 // issues an error. Work around it.. 2670 do { 2671 #define OMP_RTL(_Enum, _Name, ...) \ 2672 if (Function *F = M.getFunction(_Name)) { \ 2673 RecordFunctionsContainingUsesOf(F); \ 2674 OMPInModule = true; \ 2675 } 2676 #include "llvm/Frontend/OpenMP/OMPKinds.def" 2677 } while (false); 2678 2679 // Identify kernels once. TODO: We should split the OMPInformationCache into a 2680 // module and an SCC part. The kernel information, among other things, could 2681 // go into the module part. 2682 if (OMPInModule.isKnown() && OMPInModule) { 2683 OMPInModule.identifyKernels(M); 2684 return true; 2685 } 2686 2687 return OMPInModule = false; 2688 } 2689 2690 char OpenMPOptCGSCCLegacyPass::ID = 0; 2691 2692 INITIALIZE_PASS_BEGIN(OpenMPOptCGSCCLegacyPass, "openmp-opt-cgscc", 2693 "OpenMP specific optimizations", false, false) 2694 INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass) 2695 INITIALIZE_PASS_END(OpenMPOptCGSCCLegacyPass, "openmp-opt-cgscc", 2696 "OpenMP specific optimizations", false, false) 2697 2698 Pass *llvm::createOpenMPOptCGSCCLegacyPass() { 2699 return new OpenMPOptCGSCCLegacyPass(); 2700 } 2701