1 //===---- DivergenceAnalysis.cpp --- Divergence Analysis Implementation ----==// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a general divergence analysis for loop vectorization 10 // and GPU programs. It determines which branches and values in a loop or GPU 11 // program are divergent. It can help branch optimizations such as jump 12 // threading and loop unswitching to make better decisions. 13 // 14 // GPU programs typically use the SIMD execution model, where multiple threads 15 // in the same execution group have to execute in lock-step. Therefore, if the 16 // code contains divergent branches (i.e., threads in a group do not agree on 17 // which path of the branch to take), the group of threads has to execute all 18 // the paths from that branch with different subsets of threads enabled until 19 // they re-converge. 20 // 21 // Due to this execution model, some optimizations such as jump 22 // threading and loop unswitching can interfere with thread re-convergence. 23 // Therefore, an analysis that computes which branches in a GPU program are 24 // divergent can help the compiler to selectively run these optimizations. 25 // 26 // This implementation is derived from the Vectorization Analysis of the 27 // Region Vectorizer (RV). The analysis is based on the approach described in 28 // 29 // An abstract interpretation for SPMD divergence 30 // on reducible control flow graphs. 31 // Julian Rosemann, Simon Moll and Sebastian Hack 32 // POPL '21 33 // 34 // This implementation is generic in the sense that it does 35 // not itself identify original sources of divergence. 36 // Instead specialized adapter classes, (LoopDivergenceAnalysis) for loops and 37 // (DivergenceAnalysis) for functions, identify the sources of divergence 38 // (e.g., special variables that hold the thread ID or the iteration variable). 39 // 40 // The generic implementation propagates divergence to variables that are data 41 // or sync dependent on a source of divergence. 42 // 43 // While data dependency is a well-known concept, the notion of sync dependency 44 // is worth more explanation. Sync dependence characterizes the control flow 45 // aspect of the propagation of branch divergence. For example, 46 // 47 // %cond = icmp slt i32 %tid, 10 48 // br i1 %cond, label %then, label %else 49 // then: 50 // br label %merge 51 // else: 52 // br label %merge 53 // merge: 54 // %a = phi i32 [ 0, %then ], [ 1, %else ] 55 // 56 // Suppose %tid holds the thread ID. Although %a is not data dependent on %tid 57 // because %tid is not on its use-def chains, %a is sync dependent on %tid 58 // because the branch "br i1 %cond" depends on %tid and affects which value %a 59 // is assigned to. 60 // 61 // The sync dependence detection (which branch induces divergence in which join 62 // points) is implemented in the SyncDependenceAnalysis. 63 // 64 // The current implementation has the following limitations: 65 // 1. intra-procedural. It conservatively considers the arguments of a 66 // non-kernel-entry function and the return value of a function call as 67 // divergent. 68 // 2. memory as black box. It conservatively considers values loaded from 69 // generic or local address as divergent. This can be improved by leveraging 70 // pointer analysis and/or by modelling non-escaping memory objects in SSA 71 // as done in RV. 72 // 73 //===----------------------------------------------------------------------===// 74 75 #include "llvm/Analysis/DivergenceAnalysis.h" 76 #include "llvm/ADT/PostOrderIterator.h" 77 #include "llvm/Analysis/CFG.h" 78 #include "llvm/Analysis/LoopInfo.h" 79 #include "llvm/Analysis/PostDominators.h" 80 #include "llvm/Analysis/TargetTransformInfo.h" 81 #include "llvm/IR/Dominators.h" 82 #include "llvm/IR/InstIterator.h" 83 #include "llvm/IR/Instructions.h" 84 #include "llvm/IR/Value.h" 85 #include "llvm/Support/Debug.h" 86 #include "llvm/Support/raw_ostream.h" 87 88 using namespace llvm; 89 90 #define DEBUG_TYPE "divergence" 91 92 DivergenceAnalysisImpl::DivergenceAnalysisImpl( 93 const Function &F, const Loop *RegionLoop, const DominatorTree &DT, 94 const LoopInfo &LI, SyncDependenceAnalysis &SDA, bool IsLCSSAForm) 95 : F(F), RegionLoop(RegionLoop), DT(DT), LI(LI), SDA(SDA), 96 IsLCSSAForm(IsLCSSAForm) {} 97 98 bool DivergenceAnalysisImpl::markDivergent(const Value &DivVal) { 99 if (isAlwaysUniform(DivVal)) 100 return false; 101 assert(isa<Instruction>(DivVal) || isa<Argument>(DivVal)); 102 assert(!isAlwaysUniform(DivVal) && "cannot be a divergent"); 103 return DivergentValues.insert(&DivVal).second; 104 } 105 106 void DivergenceAnalysisImpl::addUniformOverride(const Value &UniVal) { 107 UniformOverrides.insert(&UniVal); 108 } 109 110 bool DivergenceAnalysisImpl::isTemporalDivergent( 111 const BasicBlock &ObservingBlock, const Value &Val) const { 112 const auto *Inst = dyn_cast<const Instruction>(&Val); 113 if (!Inst) 114 return false; 115 // check whether any divergent loop carrying Val terminates before control 116 // proceeds to ObservingBlock 117 for (const auto *Loop = LI.getLoopFor(Inst->getParent()); 118 Loop != RegionLoop && !Loop->contains(&ObservingBlock); 119 Loop = Loop->getParentLoop()) { 120 if (DivergentLoops.contains(Loop)) 121 return true; 122 } 123 124 return false; 125 } 126 127 bool DivergenceAnalysisImpl::inRegion(const Instruction &I) const { 128 return I.getParent() && inRegion(*I.getParent()); 129 } 130 131 bool DivergenceAnalysisImpl::inRegion(const BasicBlock &BB) const { 132 return RegionLoop ? RegionLoop->contains(&BB) : (BB.getParent() == &F); 133 } 134 135 void DivergenceAnalysisImpl::pushUsers(const Value &V) { 136 const auto *I = dyn_cast<const Instruction>(&V); 137 138 if (I && I->isTerminator()) { 139 analyzeControlDivergence(*I); 140 return; 141 } 142 143 for (const auto *User : V.users()) { 144 const auto *UserInst = dyn_cast<const Instruction>(User); 145 if (!UserInst) 146 continue; 147 148 // only compute divergent inside loop 149 if (!inRegion(*UserInst)) 150 continue; 151 152 // All users of divergent values are immediate divergent 153 if (markDivergent(*UserInst)) 154 Worklist.push_back(UserInst); 155 } 156 } 157 158 static const Instruction *getIfCarriedInstruction(const Use &U, 159 const Loop &DivLoop) { 160 const auto *I = dyn_cast<const Instruction>(&U); 161 if (!I) 162 return nullptr; 163 if (!DivLoop.contains(I)) 164 return nullptr; 165 return I; 166 } 167 168 void DivergenceAnalysisImpl::analyzeTemporalDivergence( 169 const Instruction &I, const Loop &OuterDivLoop) { 170 if (isAlwaysUniform(I)) 171 return; 172 if (isDivergent(I)) 173 return; 174 175 LLVM_DEBUG(dbgs() << "Analyze temporal divergence: " << I.getName() << "\n"); 176 assert((isa<PHINode>(I) || !IsLCSSAForm) && 177 "In LCSSA form all users of loop-exiting defs are Phi nodes."); 178 for (const Use &Op : I.operands()) { 179 const auto *OpInst = getIfCarriedInstruction(Op, OuterDivLoop); 180 if (!OpInst) 181 continue; 182 if (markDivergent(I)) 183 pushUsers(I); 184 return; 185 } 186 } 187 188 // marks all users of loop-carried values of the loop headed by LoopHeader as 189 // divergent 190 void DivergenceAnalysisImpl::analyzeLoopExitDivergence( 191 const BasicBlock &DivExit, const Loop &OuterDivLoop) { 192 // All users are in immediate exit blocks 193 if (IsLCSSAForm) { 194 for (const auto &Phi : DivExit.phis()) { 195 analyzeTemporalDivergence(Phi, OuterDivLoop); 196 } 197 return; 198 } 199 200 // For non-LCSSA we have to follow all live out edges wherever they may lead. 201 const BasicBlock &LoopHeader = *OuterDivLoop.getHeader(); 202 SmallVector<const BasicBlock *, 8> TaintStack; 203 TaintStack.push_back(&DivExit); 204 205 // Otherwise potential users of loop-carried values could be anywhere in the 206 // dominance region of DivLoop (including its fringes for phi nodes) 207 DenseSet<const BasicBlock *> Visited; 208 Visited.insert(&DivExit); 209 210 do { 211 auto *UserBlock = TaintStack.pop_back_val(); 212 213 // don't spread divergence beyond the region 214 if (!inRegion(*UserBlock)) 215 continue; 216 217 assert(!OuterDivLoop.contains(UserBlock) && 218 "irreducible control flow detected"); 219 220 // phi nodes at the fringes of the dominance region 221 if (!DT.dominates(&LoopHeader, UserBlock)) { 222 // all PHI nodes of UserBlock become divergent 223 for (auto &Phi : UserBlock->phis()) { 224 analyzeTemporalDivergence(Phi, OuterDivLoop); 225 } 226 continue; 227 } 228 229 // Taint outside users of values carried by OuterDivLoop. 230 for (auto &I : *UserBlock) { 231 analyzeTemporalDivergence(I, OuterDivLoop); 232 } 233 234 // visit all blocks in the dominance region 235 for (auto *SuccBlock : successors(UserBlock)) { 236 if (!Visited.insert(SuccBlock).second) { 237 continue; 238 } 239 TaintStack.push_back(SuccBlock); 240 } 241 } while (!TaintStack.empty()); 242 } 243 244 void DivergenceAnalysisImpl::propagateLoopExitDivergence( 245 const BasicBlock &DivExit, const Loop &InnerDivLoop) { 246 LLVM_DEBUG(dbgs() << "\tpropLoopExitDiv " << DivExit.getName() << "\n"); 247 248 // Find outer-most loop that does not contain \p DivExit 249 const Loop *DivLoop = &InnerDivLoop; 250 const Loop *OuterDivLoop = DivLoop; 251 const Loop *ExitLevelLoop = LI.getLoopFor(&DivExit); 252 const unsigned LoopExitDepth = 253 ExitLevelLoop ? ExitLevelLoop->getLoopDepth() : 0; 254 while (DivLoop && DivLoop->getLoopDepth() > LoopExitDepth) { 255 DivergentLoops.insert(DivLoop); // all crossed loops are divergent 256 OuterDivLoop = DivLoop; 257 DivLoop = DivLoop->getParentLoop(); 258 } 259 LLVM_DEBUG(dbgs() << "\tOuter-most left loop: " << OuterDivLoop->getName() 260 << "\n"); 261 262 analyzeLoopExitDivergence(DivExit, *OuterDivLoop); 263 } 264 265 // this is a divergent join point - mark all phi nodes as divergent and push 266 // them onto the stack. 267 void DivergenceAnalysisImpl::taintAndPushPhiNodes(const BasicBlock &JoinBlock) { 268 LLVM_DEBUG(dbgs() << "taintAndPushPhiNodes in " << JoinBlock.getName() 269 << "\n"); 270 271 // ignore divergence outside the region 272 if (!inRegion(JoinBlock)) { 273 return; 274 } 275 276 // push non-divergent phi nodes in JoinBlock to the worklist 277 for (const auto &Phi : JoinBlock.phis()) { 278 if (isDivergent(Phi)) 279 continue; 280 // FIXME Theoretically ,the 'undef' value could be replaced by any other 281 // value causing spurious divergence. 282 if (Phi.hasConstantOrUndefValue()) 283 continue; 284 if (markDivergent(Phi)) 285 Worklist.push_back(&Phi); 286 } 287 } 288 289 void DivergenceAnalysisImpl::analyzeControlDivergence(const Instruction &Term) { 290 LLVM_DEBUG(dbgs() << "analyzeControlDiv " << Term.getParent()->getName() 291 << "\n"); 292 293 // Don't propagate divergence from unreachable blocks. 294 if (!DT.isReachableFromEntry(Term.getParent())) 295 return; 296 297 const auto *BranchLoop = LI.getLoopFor(Term.getParent()); 298 299 const auto &DivDesc = SDA.getJoinBlocks(Term); 300 301 // Iterate over all blocks now reachable by a disjoint path join 302 for (const auto *JoinBlock : DivDesc.JoinDivBlocks) { 303 taintAndPushPhiNodes(*JoinBlock); 304 } 305 306 assert(DivDesc.LoopDivBlocks.empty() || BranchLoop); 307 for (const auto *DivExitBlock : DivDesc.LoopDivBlocks) { 308 propagateLoopExitDivergence(*DivExitBlock, *BranchLoop); 309 } 310 } 311 312 void DivergenceAnalysisImpl::compute() { 313 // Initialize worklist. 314 auto DivValuesCopy = DivergentValues; 315 for (const auto *DivVal : DivValuesCopy) { 316 assert(isDivergent(*DivVal) && "Worklist invariant violated!"); 317 pushUsers(*DivVal); 318 } 319 320 // All values on the Worklist are divergent. 321 // Their users may not have been updated yed. 322 while (!Worklist.empty()) { 323 const Instruction &I = *Worklist.back(); 324 Worklist.pop_back(); 325 326 // propagate value divergence to users 327 assert(isDivergent(I) && "Worklist invariant violated!"); 328 pushUsers(I); 329 } 330 } 331 332 bool DivergenceAnalysisImpl::isAlwaysUniform(const Value &V) const { 333 return UniformOverrides.contains(&V); 334 } 335 336 bool DivergenceAnalysisImpl::isDivergent(const Value &V) const { 337 return DivergentValues.contains(&V); 338 } 339 340 bool DivergenceAnalysisImpl::isDivergentUse(const Use &U) const { 341 Value &V = *U.get(); 342 Instruction &I = *cast<Instruction>(U.getUser()); 343 return isDivergent(V) || isTemporalDivergent(*I.getParent(), V); 344 } 345 346 DivergenceInfo::DivergenceInfo(Function &F, const DominatorTree &DT, 347 const PostDominatorTree &PDT, const LoopInfo &LI, 348 const TargetTransformInfo &TTI, 349 bool KnownReducible) 350 : F(F) { 351 if (!KnownReducible) { 352 using RPOTraversal = ReversePostOrderTraversal<const Function *>; 353 RPOTraversal FuncRPOT(&F); 354 if (containsIrreducibleCFG<const BasicBlock *, const RPOTraversal, 355 const LoopInfo>(FuncRPOT, LI)) { 356 ContainsIrreducible = true; 357 return; 358 } 359 } 360 SDA = std::make_unique<SyncDependenceAnalysis>(DT, PDT, LI); 361 DA = std::make_unique<DivergenceAnalysisImpl>(F, nullptr, DT, LI, *SDA, 362 /* LCSSA */ false); 363 for (auto &I : instructions(F)) { 364 if (TTI.isSourceOfDivergence(&I)) { 365 DA->markDivergent(I); 366 } else if (TTI.isAlwaysUniform(&I)) { 367 DA->addUniformOverride(I); 368 } 369 } 370 for (auto &Arg : F.args()) { 371 if (TTI.isSourceOfDivergence(&Arg)) { 372 DA->markDivergent(Arg); 373 } 374 } 375 376 DA->compute(); 377 } 378 379 AnalysisKey DivergenceAnalysis::Key; 380 381 DivergenceAnalysis::Result 382 DivergenceAnalysis::run(Function &F, FunctionAnalysisManager &AM) { 383 auto &DT = AM.getResult<DominatorTreeAnalysis>(F); 384 auto &PDT = AM.getResult<PostDominatorTreeAnalysis>(F); 385 auto &LI = AM.getResult<LoopAnalysis>(F); 386 auto &TTI = AM.getResult<TargetIRAnalysis>(F); 387 388 return DivergenceInfo(F, DT, PDT, LI, TTI, /* KnownReducible = */ false); 389 } 390 391 PreservedAnalyses 392 DivergenceAnalysisPrinterPass::run(Function &F, FunctionAnalysisManager &FAM) { 393 auto &DI = FAM.getResult<DivergenceAnalysis>(F); 394 OS << "'Divergence Analysis' for function '" << F.getName() << "':\n"; 395 if (DI.hasDivergence()) { 396 for (auto &Arg : F.args()) { 397 OS << (DI.isDivergent(Arg) ? "DIVERGENT: " : " "); 398 OS << Arg << "\n"; 399 } 400 for (const BasicBlock &BB : F) { 401 OS << "\n " << BB.getName() << ":\n"; 402 for (auto &I : BB.instructionsWithoutDebug()) { 403 OS << (DI.isDivergent(I) ? "DIVERGENT: " : " "); 404 OS << I << "\n"; 405 } 406 } 407 } 408 return PreservedAnalyses::all(); 409 } 410