1 //===---- DivergenceAnalysis.cpp --- Divergence Analysis Implementation ----==// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a general divergence analysis for loop vectorization 10 // and GPU programs. It determines which branches and values in a loop or GPU 11 // program are divergent. It can help branch optimizations such as jump 12 // threading and loop unswitching to make better decisions. 13 // 14 // GPU programs typically use the SIMD execution model, where multiple threads 15 // in the same execution group have to execute in lock-step. Therefore, if the 16 // code contains divergent branches (i.e., threads in a group do not agree on 17 // which path of the branch to take), the group of threads has to execute all 18 // the paths from that branch with different subsets of threads enabled until 19 // they re-converge. 20 // 21 // Due to this execution model, some optimizations such as jump 22 // threading and loop unswitching can interfere with thread re-convergence. 23 // Therefore, an analysis that computes which branches in a GPU program are 24 // divergent can help the compiler to selectively run these optimizations. 25 // 26 // This implementation is derived from the Vectorization Analysis of the 27 // Region Vectorizer (RV). That implementation in turn is based on the approach 28 // described in 29 // 30 // Improving Performance of OpenCL on CPUs 31 // Ralf Karrenberg and Sebastian Hack 32 // CC '12 33 // 34 // This implementation is generic in the sense that it does 35 // not itself identify original sources of divergence. 36 // Instead specialized adapter classes, (LoopDivergenceAnalysis) for loops and 37 // (DivergenceAnalysis) for functions, identify the sources of divergence 38 // (e.g., special variables that hold the thread ID or the iteration variable). 39 // 40 // The generic implementation propagates divergence to variables that are data 41 // or sync dependent on a source of divergence. 42 // 43 // While data dependency is a well-known concept, the notion of sync dependency 44 // is worth more explanation. Sync dependence characterizes the control flow 45 // aspect of the propagation of branch divergence. For example, 46 // 47 // %cond = icmp slt i32 %tid, 10 48 // br i1 %cond, label %then, label %else 49 // then: 50 // br label %merge 51 // else: 52 // br label %merge 53 // merge: 54 // %a = phi i32 [ 0, %then ], [ 1, %else ] 55 // 56 // Suppose %tid holds the thread ID. Although %a is not data dependent on %tid 57 // because %tid is not on its use-def chains, %a is sync dependent on %tid 58 // because the branch "br i1 %cond" depends on %tid and affects which value %a 59 // is assigned to. 60 // 61 // The sync dependence detection (which branch induces divergence in which join 62 // points) is implemented in the SyncDependenceAnalysis. 63 // 64 // The current implementation has the following limitations: 65 // 1. intra-procedural. It conservatively considers the arguments of a 66 // non-kernel-entry function and the return value of a function call as 67 // divergent. 68 // 2. memory as black box. It conservatively considers values loaded from 69 // generic or local address as divergent. This can be improved by leveraging 70 // pointer analysis and/or by modelling non-escaping memory objects in SSA 71 // as done in RV. 72 // 73 //===----------------------------------------------------------------------===// 74 75 #include "llvm/Analysis/DivergenceAnalysis.h" 76 #include "llvm/Analysis/CFG.h" 77 #include "llvm/Analysis/LoopInfo.h" 78 #include "llvm/Analysis/Passes.h" 79 #include "llvm/Analysis/PostDominators.h" 80 #include "llvm/Analysis/TargetTransformInfo.h" 81 #include "llvm/IR/Dominators.h" 82 #include "llvm/IR/InstIterator.h" 83 #include "llvm/IR/Instructions.h" 84 #include "llvm/IR/IntrinsicInst.h" 85 #include "llvm/IR/Value.h" 86 #include "llvm/Support/Debug.h" 87 #include "llvm/Support/raw_ostream.h" 88 89 using namespace llvm; 90 91 #define DEBUG_TYPE "divergence" 92 93 DivergenceAnalysisImpl::DivergenceAnalysisImpl( 94 const Function &F, const Loop *RegionLoop, const DominatorTree &DT, 95 const LoopInfo &LI, SyncDependenceAnalysis &SDA, bool IsLCSSAForm) 96 : F(F), RegionLoop(RegionLoop), DT(DT), LI(LI), SDA(SDA), 97 IsLCSSAForm(IsLCSSAForm) {} 98 99 bool DivergenceAnalysisImpl::markDivergent(const Value &DivVal) { 100 if (isAlwaysUniform(DivVal)) 101 return false; 102 assert(isa<Instruction>(DivVal) || isa<Argument>(DivVal)); 103 assert(!isAlwaysUniform(DivVal) && "cannot be a divergent"); 104 return DivergentValues.insert(&DivVal).second; 105 } 106 107 void DivergenceAnalysisImpl::addUniformOverride(const Value &UniVal) { 108 UniformOverrides.insert(&UniVal); 109 } 110 111 bool DivergenceAnalysisImpl::isTemporalDivergent( 112 const BasicBlock &ObservingBlock, const Value &Val) const { 113 const auto *Inst = dyn_cast<const Instruction>(&Val); 114 if (!Inst) 115 return false; 116 // check whether any divergent loop carrying Val terminates before control 117 // proceeds to ObservingBlock 118 for (const auto *Loop = LI.getLoopFor(Inst->getParent()); 119 Loop != RegionLoop && !Loop->contains(&ObservingBlock); 120 Loop = Loop->getParentLoop()) { 121 if (DivergentLoops.contains(Loop)) 122 return true; 123 } 124 125 return false; 126 } 127 128 bool DivergenceAnalysisImpl::inRegion(const Instruction &I) const { 129 return I.getParent() && inRegion(*I.getParent()); 130 } 131 132 bool DivergenceAnalysisImpl::inRegion(const BasicBlock &BB) const { 133 return (!RegionLoop && BB.getParent() == &F) || RegionLoop->contains(&BB); 134 } 135 136 void DivergenceAnalysisImpl::pushUsers(const Value &V) { 137 const auto *I = dyn_cast<const Instruction>(&V); 138 139 if (I && I->isTerminator()) { 140 analyzeControlDivergence(*I); 141 return; 142 } 143 144 for (const auto *User : V.users()) { 145 const auto *UserInst = dyn_cast<const Instruction>(User); 146 if (!UserInst) 147 continue; 148 149 // only compute divergent inside loop 150 if (!inRegion(*UserInst)) 151 continue; 152 153 // All users of divergent values are immediate divergent 154 if (markDivergent(*UserInst)) 155 Worklist.push_back(UserInst); 156 } 157 } 158 159 static const Instruction *getIfCarriedInstruction(const Use &U, 160 const Loop &DivLoop) { 161 const auto *I = dyn_cast<const Instruction>(&U); 162 if (!I) 163 return nullptr; 164 if (!DivLoop.contains(I)) 165 return nullptr; 166 return I; 167 } 168 169 void DivergenceAnalysisImpl::analyzeTemporalDivergence( 170 const Instruction &I, const Loop &OuterDivLoop) { 171 if (isAlwaysUniform(I)) 172 return; 173 if (isDivergent(I)) 174 return; 175 176 LLVM_DEBUG(dbgs() << "Analyze temporal divergence: " << I.getName() << "\n"); 177 assert((isa<PHINode>(I) || !IsLCSSAForm) && 178 "In LCSSA form all users of loop-exiting defs are Phi nodes."); 179 for (const Use &Op : I.operands()) { 180 const auto *OpInst = getIfCarriedInstruction(Op, OuterDivLoop); 181 if (!OpInst) 182 continue; 183 if (markDivergent(I)) 184 pushUsers(I); 185 return; 186 } 187 } 188 189 // marks all users of loop-carried values of the loop headed by LoopHeader as 190 // divergent 191 void DivergenceAnalysisImpl::analyzeLoopExitDivergence( 192 const BasicBlock &DivExit, const Loop &OuterDivLoop) { 193 // All users are in immediate exit blocks 194 if (IsLCSSAForm) { 195 for (const auto &Phi : DivExit.phis()) { 196 analyzeTemporalDivergence(Phi, OuterDivLoop); 197 } 198 return; 199 } 200 201 // For non-LCSSA we have to follow all live out edges wherever they may lead. 202 const BasicBlock &LoopHeader = *OuterDivLoop.getHeader(); 203 SmallVector<const BasicBlock *, 8> TaintStack; 204 TaintStack.push_back(&DivExit); 205 206 // Otherwise potential users of loop-carried values could be anywhere in the 207 // dominance region of DivLoop (including its fringes for phi nodes) 208 DenseSet<const BasicBlock *> Visited; 209 Visited.insert(&DivExit); 210 211 do { 212 auto *UserBlock = TaintStack.pop_back_val(); 213 214 // don't spread divergence beyond the region 215 if (!inRegion(*UserBlock)) 216 continue; 217 218 assert(!OuterDivLoop.contains(UserBlock) && 219 "irreducible control flow detected"); 220 221 // phi nodes at the fringes of the dominance region 222 if (!DT.dominates(&LoopHeader, UserBlock)) { 223 // all PHI nodes of UserBlock become divergent 224 for (auto &Phi : UserBlock->phis()) { 225 analyzeTemporalDivergence(Phi, OuterDivLoop); 226 } 227 continue; 228 } 229 230 // Taint outside users of values carried by OuterDivLoop. 231 for (auto &I : *UserBlock) { 232 analyzeTemporalDivergence(I, OuterDivLoop); 233 } 234 235 // visit all blocks in the dominance region 236 for (auto *SuccBlock : successors(UserBlock)) { 237 if (!Visited.insert(SuccBlock).second) { 238 continue; 239 } 240 TaintStack.push_back(SuccBlock); 241 } 242 } while (!TaintStack.empty()); 243 } 244 245 void DivergenceAnalysisImpl::propagateLoopExitDivergence( 246 const BasicBlock &DivExit, const Loop &InnerDivLoop) { 247 LLVM_DEBUG(dbgs() << "\tpropLoopExitDiv " << DivExit.getName() << "\n"); 248 249 // Find outer-most loop that does not contain \p DivExit 250 const Loop *DivLoop = &InnerDivLoop; 251 const Loop *OuterDivLoop = DivLoop; 252 const Loop *ExitLevelLoop = LI.getLoopFor(&DivExit); 253 const unsigned LoopExitDepth = 254 ExitLevelLoop ? ExitLevelLoop->getLoopDepth() : 0; 255 while (DivLoop && DivLoop->getLoopDepth() > LoopExitDepth) { 256 DivergentLoops.insert(DivLoop); // all crossed loops are divergent 257 OuterDivLoop = DivLoop; 258 DivLoop = DivLoop->getParentLoop(); 259 } 260 LLVM_DEBUG(dbgs() << "\tOuter-most left loop: " << OuterDivLoop->getName() 261 << "\n"); 262 263 analyzeLoopExitDivergence(DivExit, *OuterDivLoop); 264 } 265 266 // this is a divergent join point - mark all phi nodes as divergent and push 267 // them onto the stack. 268 void DivergenceAnalysisImpl::taintAndPushPhiNodes(const BasicBlock &JoinBlock) { 269 LLVM_DEBUG(dbgs() << "taintAndPushPhiNodes in " << JoinBlock.getName() 270 << "\n"); 271 272 // ignore divergence outside the region 273 if (!inRegion(JoinBlock)) { 274 return; 275 } 276 277 // push non-divergent phi nodes in JoinBlock to the worklist 278 for (const auto &Phi : JoinBlock.phis()) { 279 if (isDivergent(Phi)) 280 continue; 281 // FIXME Theoretically ,the 'undef' value could be replaced by any other 282 // value causing spurious divergence. 283 if (Phi.hasConstantOrUndefValue()) 284 continue; 285 if (markDivergent(Phi)) 286 Worklist.push_back(&Phi); 287 } 288 } 289 290 void DivergenceAnalysisImpl::analyzeControlDivergence(const Instruction &Term) { 291 LLVM_DEBUG(dbgs() << "analyzeControlDiv " << Term.getParent()->getName() 292 << "\n"); 293 294 // Don't propagate divergence from unreachable blocks. 295 if (!DT.isReachableFromEntry(Term.getParent())) 296 return; 297 298 const auto *BranchLoop = LI.getLoopFor(Term.getParent()); 299 300 const auto &DivDesc = SDA.getJoinBlocks(Term); 301 302 // Iterate over all blocks now reachable by a disjoint path join 303 for (const auto *JoinBlock : DivDesc.JoinDivBlocks) { 304 taintAndPushPhiNodes(*JoinBlock); 305 } 306 307 assert(DivDesc.LoopDivBlocks.empty() || BranchLoop); 308 for (const auto *DivExitBlock : DivDesc.LoopDivBlocks) { 309 propagateLoopExitDivergence(*DivExitBlock, *BranchLoop); 310 } 311 } 312 313 void DivergenceAnalysisImpl::compute() { 314 // Initialize worklist. 315 auto DivValuesCopy = DivergentValues; 316 for (const auto *DivVal : DivValuesCopy) { 317 assert(isDivergent(*DivVal) && "Worklist invariant violated!"); 318 pushUsers(*DivVal); 319 } 320 321 // All values on the Worklist are divergent. 322 // Their users may not have been updated yed. 323 while (!Worklist.empty()) { 324 const Instruction &I = *Worklist.back(); 325 Worklist.pop_back(); 326 327 // propagate value divergence to users 328 assert(isDivergent(I) && "Worklist invariant violated!"); 329 pushUsers(I); 330 } 331 } 332 333 bool DivergenceAnalysisImpl::isAlwaysUniform(const Value &V) const { 334 return UniformOverrides.contains(&V); 335 } 336 337 bool DivergenceAnalysisImpl::isDivergent(const Value &V) const { 338 return DivergentValues.contains(&V); 339 } 340 341 bool DivergenceAnalysisImpl::isDivergentUse(const Use &U) const { 342 Value &V = *U.get(); 343 Instruction &I = *cast<Instruction>(U.getUser()); 344 return isDivergent(V) || isTemporalDivergent(*I.getParent(), V); 345 } 346 347 DivergenceInfo::DivergenceInfo(Function &F, const DominatorTree &DT, 348 const PostDominatorTree &PDT, const LoopInfo &LI, 349 const TargetTransformInfo &TTI, 350 bool KnownReducible) 351 : F(F), ContainsIrreducible(false) { 352 if (!KnownReducible) { 353 using RPOTraversal = ReversePostOrderTraversal<const Function *>; 354 RPOTraversal FuncRPOT(&F); 355 if (containsIrreducibleCFG<const BasicBlock *, const RPOTraversal, 356 const LoopInfo>(FuncRPOT, LI)) { 357 ContainsIrreducible = true; 358 return; 359 } 360 } 361 SDA = std::make_unique<SyncDependenceAnalysis>(DT, PDT, LI); 362 DA = std::make_unique<DivergenceAnalysisImpl>(F, nullptr, DT, LI, *SDA, 363 /* LCSSA */ false); 364 for (auto &I : instructions(F)) { 365 if (TTI.isSourceOfDivergence(&I)) { 366 DA->markDivergent(I); 367 } else if (TTI.isAlwaysUniform(&I)) { 368 DA->addUniformOverride(I); 369 } 370 } 371 for (auto &Arg : F.args()) { 372 if (TTI.isSourceOfDivergence(&Arg)) { 373 DA->markDivergent(Arg); 374 } 375 } 376 377 DA->compute(); 378 } 379 380 AnalysisKey DivergenceAnalysis::Key; 381 382 DivergenceAnalysis::Result 383 DivergenceAnalysis::run(Function &F, FunctionAnalysisManager &AM) { 384 auto &DT = AM.getResult<DominatorTreeAnalysis>(F); 385 auto &PDT = AM.getResult<PostDominatorTreeAnalysis>(F); 386 auto &LI = AM.getResult<LoopAnalysis>(F); 387 auto &TTI = AM.getResult<TargetIRAnalysis>(F); 388 389 return DivergenceInfo(F, DT, PDT, LI, TTI, /* KnownReducible = */ false); 390 } 391 392 PreservedAnalyses 393 DivergenceAnalysisPrinterPass::run(Function &F, FunctionAnalysisManager &FAM) { 394 auto &DI = FAM.getResult<DivergenceAnalysis>(F); 395 OS << "'Divergence Analysis' for function '" << F.getName() << "':\n"; 396 if (DI.hasDivergence()) { 397 for (auto &Arg : F.args()) { 398 OS << (DI.isDivergent(Arg) ? "DIVERGENT: " : " "); 399 OS << Arg << "\n"; 400 } 401 for (const BasicBlock &BB : F) { 402 OS << "\n " << BB.getName() << ":\n"; 403 for (auto &I : BB.instructionsWithoutDebug()) { 404 OS << (DI.isDivergent(I) ? "DIVERGENT: " : " "); 405 OS << I << "\n"; 406 } 407 } 408 } 409 return PreservedAnalyses::all(); 410 } 411