1 //===---- DivergenceAnalysis.cpp --- Divergence Analysis Implementation ----==// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a general divergence analysis for loop vectorization 10 // and GPU programs. It determines which branches and values in a loop or GPU 11 // program are divergent. It can help branch optimizations such as jump 12 // threading and loop unswitching to make better decisions. 13 // 14 // GPU programs typically use the SIMD execution model, where multiple threads 15 // in the same execution group have to execute in lock-step. Therefore, if the 16 // code contains divergent branches (i.e., threads in a group do not agree on 17 // which path of the branch to take), the group of threads has to execute all 18 // the paths from that branch with different subsets of threads enabled until 19 // they re-converge. 20 // 21 // Due to this execution model, some optimizations such as jump 22 // threading and loop unswitching can interfere with thread re-convergence. 23 // Therefore, an analysis that computes which branches in a GPU program are 24 // divergent can help the compiler to selectively run these optimizations. 25 // 26 // This implementation is derived from the Vectorization Analysis of the 27 // Region Vectorizer (RV). That implementation in turn is based on the approach 28 // described in 29 // 30 // Improving Performance of OpenCL on CPUs 31 // Ralf Karrenberg and Sebastian Hack 32 // CC '12 33 // 34 // This DivergenceAnalysis implementation is generic in the sense that it does 35 // not itself identify original sources of divergence. 36 // Instead specialized adapter classes, (LoopDivergenceAnalysis) for loops and 37 // (GPUDivergenceAnalysis) for GPU programs, identify the sources of divergence 38 // (e.g., special variables that hold the thread ID or the iteration variable). 39 // 40 // The generic implementation propagates divergence to variables that are data 41 // or sync dependent on a source of divergence. 42 // 43 // While data dependency is a well-known concept, the notion of sync dependency 44 // is worth more explanation. Sync dependence characterizes the control flow 45 // aspect of the propagation of branch divergence. For example, 46 // 47 // %cond = icmp slt i32 %tid, 10 48 // br i1 %cond, label %then, label %else 49 // then: 50 // br label %merge 51 // else: 52 // br label %merge 53 // merge: 54 // %a = phi i32 [ 0, %then ], [ 1, %else ] 55 // 56 // Suppose %tid holds the thread ID. Although %a is not data dependent on %tid 57 // because %tid is not on its use-def chains, %a is sync dependent on %tid 58 // because the branch "br i1 %cond" depends on %tid and affects which value %a 59 // is assigned to. 60 // 61 // The sync dependence detection (which branch induces divergence in which join 62 // points) is implemented in the SyncDependenceAnalysis. 63 // 64 // The current DivergenceAnalysis implementation has the following limitations: 65 // 1. intra-procedural. It conservatively considers the arguments of a 66 // non-kernel-entry function and the return value of a function call as 67 // divergent. 68 // 2. memory as black box. It conservatively considers values loaded from 69 // generic or local address as divergent. This can be improved by leveraging 70 // pointer analysis and/or by modelling non-escaping memory objects in SSA 71 // as done in RV. 72 // 73 //===----------------------------------------------------------------------===// 74 75 #include "llvm/Analysis/DivergenceAnalysis.h" 76 #include "llvm/Analysis/LoopInfo.h" 77 #include "llvm/Analysis/Passes.h" 78 #include "llvm/Analysis/PostDominators.h" 79 #include "llvm/Analysis/TargetTransformInfo.h" 80 #include "llvm/IR/Dominators.h" 81 #include "llvm/IR/InstIterator.h" 82 #include "llvm/IR/Instructions.h" 83 #include "llvm/IR/IntrinsicInst.h" 84 #include "llvm/IR/Value.h" 85 #include "llvm/Support/Debug.h" 86 #include "llvm/Support/raw_ostream.h" 87 88 using namespace llvm; 89 90 #define DEBUG_TYPE "divergence-analysis" 91 92 // class DivergenceAnalysis 93 DivergenceAnalysis::DivergenceAnalysis( 94 const Function &F, const Loop *RegionLoop, const DominatorTree &DT, 95 const LoopInfo &LI, SyncDependenceAnalysis &SDA, bool IsLCSSAForm) 96 : F(F), RegionLoop(RegionLoop), DT(DT), LI(LI), SDA(SDA), 97 IsLCSSAForm(IsLCSSAForm) {} 98 99 bool DivergenceAnalysis::markDivergent(const Value &DivVal) { 100 if (isAlwaysUniform(DivVal)) 101 return false; 102 assert(isa<Instruction>(DivVal) || isa<Argument>(DivVal)); 103 assert(!isAlwaysUniform(DivVal) && "cannot be a divergent"); 104 return DivergentValues.insert(&DivVal).second; 105 } 106 107 void DivergenceAnalysis::addUniformOverride(const Value &UniVal) { 108 UniformOverrides.insert(&UniVal); 109 } 110 111 bool DivergenceAnalysis::isTemporalDivergent(const BasicBlock &ObservingBlock, 112 const Value &Val) const { 113 const auto *Inst = dyn_cast<const Instruction>(&Val); 114 if (!Inst) 115 return false; 116 // check whether any divergent loop carrying Val terminates before control 117 // proceeds to ObservingBlock 118 for (const auto *Loop = LI.getLoopFor(Inst->getParent()); 119 Loop != RegionLoop && !Loop->contains(&ObservingBlock); 120 Loop = Loop->getParentLoop()) { 121 if (DivergentLoops.find(Loop) != DivergentLoops.end()) 122 return true; 123 } 124 125 return false; 126 } 127 128 bool DivergenceAnalysis::inRegion(const Instruction &I) const { 129 return I.getParent() && inRegion(*I.getParent()); 130 } 131 132 bool DivergenceAnalysis::inRegion(const BasicBlock &BB) const { 133 return (!RegionLoop && BB.getParent() == &F) || RegionLoop->contains(&BB); 134 } 135 136 void DivergenceAnalysis::pushUsers(const Value &V) { 137 const auto *I = dyn_cast<const Instruction>(&V); 138 139 if (I && I->isTerminator()) { 140 analyzeControlDivergence(*I); 141 return; 142 } 143 144 for (const auto *User : V.users()) { 145 const auto *UserInst = dyn_cast<const Instruction>(User); 146 if (!UserInst) 147 continue; 148 149 // only compute divergent inside loop 150 if (!inRegion(*UserInst)) 151 continue; 152 153 // All users of divergent values are immediate divergent 154 if (markDivergent(*UserInst)) 155 Worklist.push_back(UserInst); 156 } 157 } 158 159 static const Instruction *getIfCarriedInstruction(const Use &U, 160 const Loop &DivLoop) { 161 const auto *I = dyn_cast<const Instruction>(&U); 162 if (!I) 163 return nullptr; 164 if (!DivLoop.contains(I)) 165 return nullptr; 166 return I; 167 } 168 169 void DivergenceAnalysis::analyzeTemporalDivergence(const Instruction &I, 170 const Loop &OuterDivLoop) { 171 if (isAlwaysUniform(I)) 172 return; 173 if (isDivergent(I)) 174 return; 175 176 LLVM_DEBUG(dbgs() << "Analyze temporal divergence: " << I.getName() << "\n"); 177 assert((isa<PHINode>(I) || !IsLCSSAForm) && 178 "In LCSSA form all users of loop-exiting defs are Phi nodes."); 179 for (const Use &Op : I.operands()) { 180 const auto *OpInst = getIfCarriedInstruction(Op, OuterDivLoop); 181 if (!OpInst) 182 continue; 183 if (markDivergent(I)) 184 pushUsers(I); 185 return; 186 } 187 } 188 189 // marks all users of loop-carried values of the loop headed by LoopHeader as 190 // divergent 191 void DivergenceAnalysis::analyzeLoopExitDivergence(const BasicBlock &DivExit, 192 const Loop &OuterDivLoop) { 193 // All users are in immediate exit blocks 194 if (IsLCSSAForm) { 195 for (const auto &Phi : DivExit.phis()) { 196 analyzeTemporalDivergence(Phi, OuterDivLoop); 197 } 198 return; 199 } 200 201 // For non-LCSSA we have to follow all live out edges wherever they may lead. 202 const BasicBlock &LoopHeader = *OuterDivLoop.getHeader(); 203 SmallVector<const BasicBlock *, 8> TaintStack; 204 TaintStack.push_back(&DivExit); 205 206 // Otherwise potential users of loop-carried values could be anywhere in the 207 // dominance region of DivLoop (including its fringes for phi nodes) 208 DenseSet<const BasicBlock *> Visited; 209 Visited.insert(&DivExit); 210 211 do { 212 auto *UserBlock = TaintStack.back(); 213 TaintStack.pop_back(); 214 215 // don't spread divergence beyond the region 216 if (!inRegion(*UserBlock)) 217 continue; 218 219 assert(!OuterDivLoop.contains(UserBlock) && 220 "irreducible control flow detected"); 221 222 // phi nodes at the fringes of the dominance region 223 if (!DT.dominates(&LoopHeader, UserBlock)) { 224 // all PHI nodes of UserBlock become divergent 225 for (auto &Phi : UserBlock->phis()) { 226 analyzeTemporalDivergence(Phi, OuterDivLoop); 227 } 228 continue; 229 } 230 231 // Taint outside users of values carried by OuterDivLoop. 232 for (auto &I : *UserBlock) { 233 analyzeTemporalDivergence(I, OuterDivLoop); 234 } 235 236 // visit all blocks in the dominance region 237 for (auto *SuccBlock : successors(UserBlock)) { 238 if (!Visited.insert(SuccBlock).second) { 239 continue; 240 } 241 TaintStack.push_back(SuccBlock); 242 } 243 } while (!TaintStack.empty()); 244 } 245 246 void DivergenceAnalysis::propagateLoopExitDivergence(const BasicBlock &DivExit, 247 const Loop &InnerDivLoop) { 248 LLVM_DEBUG(dbgs() << "\tpropLoopExitDiv " << DivExit.getName() << "\n"); 249 250 // Find outer-most loop that does not contain \p DivExit 251 const Loop *DivLoop = &InnerDivLoop; 252 const Loop *OuterDivLoop = DivLoop; 253 const Loop *ExitLevelLoop = LI.getLoopFor(&DivExit); 254 const unsigned LoopExitDepth = 255 ExitLevelLoop ? ExitLevelLoop->getLoopDepth() : 0; 256 while (DivLoop && DivLoop->getLoopDepth() > LoopExitDepth) { 257 DivergentLoops.insert(DivLoop); // all crossed loops are divergent 258 OuterDivLoop = DivLoop; 259 DivLoop = DivLoop->getParentLoop(); 260 } 261 LLVM_DEBUG(dbgs() << "\tOuter-most left loop: " << OuterDivLoop->getName() 262 << "\n"); 263 264 analyzeLoopExitDivergence(DivExit, *OuterDivLoop); 265 } 266 267 // this is a divergent join point - mark all phi nodes as divergent and push 268 // them onto the stack. 269 void DivergenceAnalysis::taintAndPushPhiNodes(const BasicBlock &JoinBlock) { 270 LLVM_DEBUG(dbgs() << "taintAndPushPhiNodes in " << JoinBlock.getName() 271 << "\n"); 272 273 // ignore divergence outside the region 274 if (!inRegion(JoinBlock)) { 275 return; 276 } 277 278 // push non-divergent phi nodes in JoinBlock to the worklist 279 for (const auto &Phi : JoinBlock.phis()) { 280 if (isDivergent(Phi)) 281 continue; 282 // FIXME Theoretically ,the 'undef' value could be replaced by any other 283 // value causing spurious divergence. 284 if (Phi.hasConstantOrUndefValue()) 285 continue; 286 if (markDivergent(Phi)) 287 Worklist.push_back(&Phi); 288 } 289 } 290 291 void DivergenceAnalysis::analyzeControlDivergence(const Instruction &Term) { 292 LLVM_DEBUG(dbgs() << "analyzeControlDiv " << Term.getParent()->getName() 293 << "\n"); 294 295 // Don't propagate divergence from unreachable blocks. 296 if (!DT.isReachableFromEntry(Term.getParent())) 297 return; 298 299 const auto *BranchLoop = LI.getLoopFor(Term.getParent()); 300 301 const auto &DivDesc = SDA.getJoinBlocks(Term); 302 303 // Iterate over all blocks now reachable by a disjoint path join 304 for (const auto *JoinBlock : DivDesc.JoinDivBlocks) { 305 taintAndPushPhiNodes(*JoinBlock); 306 } 307 308 assert(DivDesc.LoopDivBlocks.empty() || BranchLoop); 309 for (const auto *DivExitBlock : DivDesc.LoopDivBlocks) { 310 propagateLoopExitDivergence(*DivExitBlock, *BranchLoop); 311 } 312 } 313 314 void DivergenceAnalysis::compute() { 315 // Initialize worklist. 316 auto DivValuesCopy = DivergentValues; 317 for (const auto *DivVal : DivValuesCopy) { 318 assert(isDivergent(*DivVal) && "Worklist invariant violated!"); 319 pushUsers(*DivVal); 320 } 321 322 // All values on the Worklist are divergent. 323 // Their users may not have been updated yed. 324 while (!Worklist.empty()) { 325 const Instruction &I = *Worklist.back(); 326 Worklist.pop_back(); 327 328 // propagate value divergence to users 329 assert(isDivergent(I) && "Worklist invariant violated!"); 330 pushUsers(I); 331 } 332 } 333 334 bool DivergenceAnalysis::isAlwaysUniform(const Value &V) const { 335 return UniformOverrides.find(&V) != UniformOverrides.end(); 336 } 337 338 bool DivergenceAnalysis::isDivergent(const Value &V) const { 339 return DivergentValues.find(&V) != DivergentValues.end(); 340 } 341 342 bool DivergenceAnalysis::isDivergentUse(const Use &U) const { 343 Value &V = *U.get(); 344 Instruction &I = *cast<Instruction>(U.getUser()); 345 return isDivergent(V) || isTemporalDivergent(*I.getParent(), V); 346 } 347 348 void DivergenceAnalysis::print(raw_ostream &OS, const Module *) const { 349 if (DivergentValues.empty()) 350 return; 351 // iterate instructions using instructions() to ensure a deterministic order. 352 for (auto &I : instructions(F)) { 353 if (isDivergent(I)) 354 OS << "DIVERGENT:" << I << '\n'; 355 } 356 } 357 358 // class GPUDivergenceAnalysis 359 GPUDivergenceAnalysis::GPUDivergenceAnalysis(Function &F, 360 const DominatorTree &DT, 361 const PostDominatorTree &PDT, 362 const LoopInfo &LI, 363 const TargetTransformInfo &TTI) 364 : SDA(DT, PDT, LI), DA(F, nullptr, DT, LI, SDA, /* LCSSA */ false) { 365 for (auto &I : instructions(F)) { 366 if (TTI.isSourceOfDivergence(&I)) { 367 DA.markDivergent(I); 368 } else if (TTI.isAlwaysUniform(&I)) { 369 DA.addUniformOverride(I); 370 } 371 } 372 for (auto &Arg : F.args()) { 373 if (TTI.isSourceOfDivergence(&Arg)) { 374 DA.markDivergent(Arg); 375 } 376 } 377 378 DA.compute(); 379 } 380 381 bool GPUDivergenceAnalysis::isDivergent(const Value &val) const { 382 return DA.isDivergent(val); 383 } 384 385 bool GPUDivergenceAnalysis::isDivergentUse(const Use &use) const { 386 return DA.isDivergentUse(use); 387 } 388 389 void GPUDivergenceAnalysis::print(raw_ostream &OS, const Module *mod) const { 390 OS << "Divergence of kernel " << DA.getFunction().getName() << " {\n"; 391 DA.print(OS, mod); 392 OS << "}\n"; 393 } 394