1 //===- DivergenceAnalysis.cpp --------- Divergence Analysis Implementation -==// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a general divergence analysis for loop vectorization 10 // and GPU programs. It determines which branches and values in a loop or GPU 11 // program are divergent. It can help branch optimizations such as jump 12 // threading and loop unswitching to make better decisions. 13 // 14 // GPU programs typically use the SIMD execution model, where multiple threads 15 // in the same execution group have to execute in lock-step. Therefore, if the 16 // code contains divergent branches (i.e., threads in a group do not agree on 17 // which path of the branch to take), the group of threads has to execute all 18 // the paths from that branch with different subsets of threads enabled until 19 // they re-converge. 20 // 21 // Due to this execution model, some optimizations such as jump 22 // threading and loop unswitching can interfere with thread re-convergence. 23 // Therefore, an analysis that computes which branches in a GPU program are 24 // divergent can help the compiler to selectively run these optimizations. 25 // 26 // This implementation is derived from the Vectorization Analysis of the 27 // Region Vectorizer (RV). That implementation in turn is based on the approach 28 // described in 29 // 30 // Improving Performance of OpenCL on CPUs 31 // Ralf Karrenberg and Sebastian Hack 32 // CC '12 33 // 34 // This DivergenceAnalysis implementation is generic in the sense that it does 35 // not itself identify original sources of divergence. 36 // Instead specialized adapter classes, (LoopDivergenceAnalysis) for loops and 37 // (GPUDivergenceAnalysis) for GPU programs, identify the sources of divergence 38 // (e.g., special variables that hold the thread ID or the iteration variable). 39 // 40 // The generic implementation propagates divergence to variables that are data 41 // or sync dependent on a source of divergence. 42 // 43 // While data dependency is a well-known concept, the notion of sync dependency 44 // is worth more explanation. Sync dependence characterizes the control flow 45 // aspect of the propagation of branch divergence. For example, 46 // 47 // %cond = icmp slt i32 %tid, 10 48 // br i1 %cond, label %then, label %else 49 // then: 50 // br label %merge 51 // else: 52 // br label %merge 53 // merge: 54 // %a = phi i32 [ 0, %then ], [ 1, %else ] 55 // 56 // Suppose %tid holds the thread ID. Although %a is not data dependent on %tid 57 // because %tid is not on its use-def chains, %a is sync dependent on %tid 58 // because the branch "br i1 %cond" depends on %tid and affects which value %a 59 // is assigned to. 60 // 61 // The sync dependence detection (which branch induces divergence in which join 62 // points) is implemented in the SyncDependenceAnalysis. 63 // 64 // The current DivergenceAnalysis implementation has the following limitations: 65 // 1. intra-procedural. It conservatively considers the arguments of a 66 // non-kernel-entry function and the return value of a function call as 67 // divergent. 68 // 2. memory as black box. It conservatively considers values loaded from 69 // generic or local address as divergent. This can be improved by leveraging 70 // pointer analysis and/or by modelling non-escaping memory objects in SSA 71 // as done in RV. 72 // 73 //===----------------------------------------------------------------------===// 74 75 #include "llvm/Analysis/DivergenceAnalysis.h" 76 #include "llvm/Analysis/LoopInfo.h" 77 #include "llvm/Analysis/Passes.h" 78 #include "llvm/Analysis/PostDominators.h" 79 #include "llvm/Analysis/TargetTransformInfo.h" 80 #include "llvm/IR/Dominators.h" 81 #include "llvm/IR/InstIterator.h" 82 #include "llvm/IR/Instructions.h" 83 #include "llvm/IR/IntrinsicInst.h" 84 #include "llvm/IR/Value.h" 85 #include "llvm/Support/Debug.h" 86 #include "llvm/Support/raw_ostream.h" 87 #include <vector> 88 89 using namespace llvm; 90 91 #define DEBUG_TYPE "divergence-analysis" 92 93 // class DivergenceAnalysis 94 DivergenceAnalysis::DivergenceAnalysis( 95 const Function &F, const Loop *RegionLoop, const DominatorTree &DT, 96 const LoopInfo &LI, SyncDependenceAnalysis &SDA, bool IsLCSSAForm) 97 : F(F), RegionLoop(RegionLoop), DT(DT), LI(LI), SDA(SDA), 98 IsLCSSAForm(IsLCSSAForm) {} 99 100 void DivergenceAnalysis::markDivergent(const Value &DivVal) { 101 assert(isa<Instruction>(DivVal) || isa<Argument>(DivVal)); 102 assert(!isAlwaysUniform(DivVal) && "cannot be a divergent"); 103 DivergentValues.insert(&DivVal); 104 } 105 106 void DivergenceAnalysis::addUniformOverride(const Value &UniVal) { 107 UniformOverrides.insert(&UniVal); 108 } 109 110 bool DivergenceAnalysis::updateTerminator(const Instruction &Term) const { 111 if (Term.getNumSuccessors() <= 1) 112 return false; 113 if (auto *BranchTerm = dyn_cast<BranchInst>(&Term)) { 114 assert(BranchTerm->isConditional()); 115 return isDivergent(*BranchTerm->getCondition()); 116 } 117 if (auto *SwitchTerm = dyn_cast<SwitchInst>(&Term)) { 118 return isDivergent(*SwitchTerm->getCondition()); 119 } 120 if (isa<InvokeInst>(Term)) { 121 return false; // ignore abnormal executions through landingpad 122 } 123 124 llvm_unreachable("unexpected terminator"); 125 } 126 127 bool DivergenceAnalysis::updateNormalInstruction(const Instruction &I) const { 128 // TODO function calls with side effects, etc 129 for (const auto &Op : I.operands()) { 130 if (isDivergent(*Op)) 131 return true; 132 } 133 return false; 134 } 135 136 bool DivergenceAnalysis::isTemporalDivergent(const BasicBlock &ObservingBlock, 137 const Value &Val) const { 138 const auto *Inst = dyn_cast<const Instruction>(&Val); 139 if (!Inst) 140 return false; 141 // check whether any divergent loop carrying Val terminates before control 142 // proceeds to ObservingBlock 143 for (const auto *Loop = LI.getLoopFor(Inst->getParent()); 144 Loop != RegionLoop && !Loop->contains(&ObservingBlock); 145 Loop = Loop->getParentLoop()) { 146 if (DivergentLoops.find(Loop) != DivergentLoops.end()) 147 return true; 148 } 149 150 return false; 151 } 152 153 bool DivergenceAnalysis::updatePHINode(const PHINode &Phi) const { 154 // joining divergent disjoint path in Phi parent block 155 if (!Phi.hasConstantOrUndefValue() && isJoinDivergent(*Phi.getParent())) { 156 return true; 157 } 158 159 // An incoming value could be divergent by itself. 160 // Otherwise, an incoming value could be uniform within the loop 161 // that carries its definition but it may appear divergent 162 // from outside the loop. This happens when divergent loop exits 163 // drop definitions of that uniform value in different iterations. 164 // 165 // for (int i = 0; i < n; ++i) { // 'i' is uniform inside the loop 166 // if (i % thread_id == 0) break; // divergent loop exit 167 // } 168 // int divI = i; // divI is divergent 169 for (size_t i = 0; i < Phi.getNumIncomingValues(); ++i) { 170 const auto *InVal = Phi.getIncomingValue(i); 171 if (isDivergent(*Phi.getIncomingValue(i)) || 172 isTemporalDivergent(*Phi.getParent(), *InVal)) { 173 return true; 174 } 175 } 176 return false; 177 } 178 179 bool DivergenceAnalysis::inRegion(const Instruction &I) const { 180 return I.getParent() && inRegion(*I.getParent()); 181 } 182 183 bool DivergenceAnalysis::inRegion(const BasicBlock &BB) const { 184 return (!RegionLoop && BB.getParent() == &F) || RegionLoop->contains(&BB); 185 } 186 187 static bool usesLiveOut(const Instruction &I, const Loop *DivLoop) { 188 for (auto &Op : I.operands()) { 189 auto *OpInst = dyn_cast<Instruction>(&Op); 190 if (!OpInst) 191 continue; 192 if (DivLoop->contains(OpInst->getParent())) 193 return true; 194 } 195 return false; 196 } 197 198 // marks all users of loop-carried values of the loop headed by LoopHeader as 199 // divergent 200 void DivergenceAnalysis::taintLoopLiveOuts(const BasicBlock &LoopHeader) { 201 auto *DivLoop = LI.getLoopFor(&LoopHeader); 202 assert(DivLoop && "loopHeader is not actually part of a loop"); 203 204 SmallVector<BasicBlock *, 8> TaintStack; 205 DivLoop->getExitBlocks(TaintStack); 206 207 // Otherwise potential users of loop-carried values could be anywhere in the 208 // dominance region of DivLoop (including its fringes for phi nodes) 209 DenseSet<const BasicBlock *> Visited; 210 for (auto *Block : TaintStack) { 211 Visited.insert(Block); 212 } 213 Visited.insert(&LoopHeader); 214 215 while (!TaintStack.empty()) { 216 auto *UserBlock = TaintStack.back(); 217 TaintStack.pop_back(); 218 219 // don't spread divergence beyond the region 220 if (!inRegion(*UserBlock)) 221 continue; 222 223 assert(!DivLoop->contains(UserBlock) && 224 "irreducible control flow detected"); 225 226 // phi nodes at the fringes of the dominance region 227 if (!DT.dominates(&LoopHeader, UserBlock)) { 228 // all PHI nodes of UserBlock become divergent 229 for (auto &Phi : UserBlock->phis()) { 230 Worklist.push_back(&Phi); 231 } 232 continue; 233 } 234 235 // taint outside users of values carried by DivLoop 236 for (auto &I : *UserBlock) { 237 if (isAlwaysUniform(I)) 238 continue; 239 if (isDivergent(I)) 240 continue; 241 if (!usesLiveOut(I, DivLoop)) 242 continue; 243 244 markDivergent(I); 245 if (I.isTerminator()) { 246 propagateBranchDivergence(I); 247 } else { 248 pushUsers(I); 249 } 250 } 251 252 // visit all blocks in the dominance region 253 for (auto *SuccBlock : successors(UserBlock)) { 254 if (!Visited.insert(SuccBlock).second) { 255 continue; 256 } 257 TaintStack.push_back(SuccBlock); 258 } 259 } 260 } 261 262 void DivergenceAnalysis::pushPHINodes(const BasicBlock &Block) { 263 for (const auto &Phi : Block.phis()) { 264 if (isDivergent(Phi)) 265 continue; 266 Worklist.push_back(&Phi); 267 } 268 } 269 270 void DivergenceAnalysis::pushUsers(const Value &V) { 271 for (const auto *User : V.users()) { 272 const auto *UserInst = dyn_cast<const Instruction>(User); 273 if (!UserInst) 274 continue; 275 276 if (isDivergent(*UserInst)) 277 continue; 278 279 // only compute divergent inside loop 280 if (!inRegion(*UserInst)) 281 continue; 282 Worklist.push_back(UserInst); 283 } 284 } 285 286 bool DivergenceAnalysis::propagateJoinDivergence(const BasicBlock &JoinBlock, 287 const Loop *BranchLoop) { 288 LLVM_DEBUG(dbgs() << "\tpropJoinDiv " << JoinBlock.getName() << "\n"); 289 290 // ignore divergence outside the region 291 if (!inRegion(JoinBlock)) { 292 return false; 293 } 294 295 // push non-divergent phi nodes in JoinBlock to the worklist 296 pushPHINodes(JoinBlock); 297 298 // disjoint-paths divergent at JoinBlock 299 markBlockJoinDivergent(JoinBlock); 300 301 // JoinBlock is a divergent loop exit 302 return BranchLoop && !BranchLoop->contains(&JoinBlock); 303 } 304 305 void DivergenceAnalysis::propagateBranchDivergence(const Instruction &Term) { 306 LLVM_DEBUG(dbgs() << "propBranchDiv " << Term.getParent()->getName() << "\n"); 307 308 markDivergent(Term); 309 310 // Don't propagate divergence from unreachable blocks. 311 if (!DT.isReachableFromEntry(Term.getParent())) 312 return; 313 314 const auto *BranchLoop = LI.getLoopFor(Term.getParent()); 315 316 // whether there is a divergent loop exit from BranchLoop (if any) 317 bool IsBranchLoopDivergent = false; 318 319 // iterate over all blocks reachable by disjoint from Term within the loop 320 // also iterates over loop exits that become divergent due to Term. 321 for (const auto *JoinBlock : SDA.join_blocks(Term)) { 322 IsBranchLoopDivergent |= propagateJoinDivergence(*JoinBlock, BranchLoop); 323 } 324 325 // Branch loop is a divergent loop due to the divergent branch in Term 326 if (IsBranchLoopDivergent) { 327 assert(BranchLoop); 328 if (!DivergentLoops.insert(BranchLoop).second) { 329 return; 330 } 331 propagateLoopDivergence(*BranchLoop); 332 } 333 } 334 335 void DivergenceAnalysis::propagateLoopDivergence(const Loop &ExitingLoop) { 336 LLVM_DEBUG(dbgs() << "propLoopDiv " << ExitingLoop.getName() << "\n"); 337 338 // don't propagate beyond region 339 if (!inRegion(*ExitingLoop.getHeader())) 340 return; 341 342 const auto *BranchLoop = ExitingLoop.getParentLoop(); 343 344 // Uses of loop-carried values could occur anywhere 345 // within the dominance region of the definition. All loop-carried 346 // definitions are dominated by the loop header (reducible control). 347 // Thus all users have to be in the dominance region of the loop header, 348 // except PHI nodes that can also live at the fringes of the dom region 349 // (incoming defining value). 350 if (!IsLCSSAForm) 351 taintLoopLiveOuts(*ExitingLoop.getHeader()); 352 353 // whether there is a divergent loop exit from BranchLoop (if any) 354 bool IsBranchLoopDivergent = false; 355 356 // iterate over all blocks reachable by disjoint paths from exits of 357 // ExitingLoop also iterates over loop exits (of BranchLoop) that in turn 358 // become divergent. 359 for (const auto *JoinBlock : SDA.join_blocks(ExitingLoop)) { 360 IsBranchLoopDivergent |= propagateJoinDivergence(*JoinBlock, BranchLoop); 361 } 362 363 // Branch loop is a divergent due to divergent loop exit in ExitingLoop 364 if (IsBranchLoopDivergent) { 365 assert(BranchLoop); 366 if (!DivergentLoops.insert(BranchLoop).second) { 367 return; 368 } 369 propagateLoopDivergence(*BranchLoop); 370 } 371 } 372 373 void DivergenceAnalysis::compute() { 374 for (auto *DivVal : DivergentValues) { 375 pushUsers(*DivVal); 376 } 377 378 // propagate divergence 379 while (!Worklist.empty()) { 380 const Instruction &I = *Worklist.back(); 381 Worklist.pop_back(); 382 383 // maintain uniformity of overrides 384 if (isAlwaysUniform(I)) 385 continue; 386 387 bool WasDivergent = isDivergent(I); 388 if (WasDivergent) 389 continue; 390 391 // propagate divergence caused by terminator 392 if (I.isTerminator()) { 393 if (updateTerminator(I)) { 394 // propagate control divergence to affected instructions 395 propagateBranchDivergence(I); 396 continue; 397 } 398 } 399 400 // update divergence of I due to divergent operands 401 bool DivergentUpd = false; 402 const auto *Phi = dyn_cast<const PHINode>(&I); 403 if (Phi) { 404 DivergentUpd = updatePHINode(*Phi); 405 } else { 406 DivergentUpd = updateNormalInstruction(I); 407 } 408 409 // propagate value divergence to users 410 if (DivergentUpd) { 411 markDivergent(I); 412 pushUsers(I); 413 } 414 } 415 } 416 417 bool DivergenceAnalysis::isAlwaysUniform(const Value &V) const { 418 return UniformOverrides.find(&V) != UniformOverrides.end(); 419 } 420 421 bool DivergenceAnalysis::isDivergent(const Value &V) const { 422 return DivergentValues.find(&V) != DivergentValues.end(); 423 } 424 425 bool DivergenceAnalysis::isDivergentUse(const Use &U) const { 426 Value &V = *U.get(); 427 Instruction &I = *cast<Instruction>(U.getUser()); 428 return isDivergent(V) || isTemporalDivergent(*I.getParent(), V); 429 } 430 431 void DivergenceAnalysis::print(raw_ostream &OS, const Module *) const { 432 if (DivergentValues.empty()) 433 return; 434 // iterate instructions using instructions() to ensure a deterministic order. 435 for (auto &I : instructions(F)) { 436 if (isDivergent(I)) 437 OS << "DIVERGENT:" << I << '\n'; 438 } 439 } 440 441 // class GPUDivergenceAnalysis 442 GPUDivergenceAnalysis::GPUDivergenceAnalysis(Function &F, 443 const DominatorTree &DT, 444 const PostDominatorTree &PDT, 445 const LoopInfo &LI, 446 const TargetTransformInfo &TTI) 447 : SDA(DT, PDT, LI), DA(F, nullptr, DT, LI, SDA, false) { 448 for (auto &I : instructions(F)) { 449 if (TTI.isSourceOfDivergence(&I)) { 450 DA.markDivergent(I); 451 } else if (TTI.isAlwaysUniform(&I)) { 452 DA.addUniformOverride(I); 453 } 454 } 455 for (auto &Arg : F.args()) { 456 if (TTI.isSourceOfDivergence(&Arg)) { 457 DA.markDivergent(Arg); 458 } 459 } 460 461 DA.compute(); 462 } 463 464 bool GPUDivergenceAnalysis::isDivergent(const Value &val) const { 465 return DA.isDivergent(val); 466 } 467 468 bool GPUDivergenceAnalysis::isDivergentUse(const Use &use) const { 469 return DA.isDivergentUse(use); 470 } 471 472 void GPUDivergenceAnalysis::print(raw_ostream &OS, const Module *mod) const { 473 OS << "Divergence of kernel " << DA.getFunction().getName() << " {\n"; 474 DA.print(OS, mod); 475 OS << "}\n"; 476 } 477