1 //===---- DivergenceAnalysis.cpp --- Divergence Analysis Implementation ----==// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a general divergence analysis for loop vectorization 10 // and GPU programs. It determines which branches and values in a loop or GPU 11 // program are divergent. It can help branch optimizations such as jump 12 // threading and loop unswitching to make better decisions. 13 // 14 // GPU programs typically use the SIMD execution model, where multiple threads 15 // in the same execution group have to execute in lock-step. Therefore, if the 16 // code contains divergent branches (i.e., threads in a group do not agree on 17 // which path of the branch to take), the group of threads has to execute all 18 // the paths from that branch with different subsets of threads enabled until 19 // they re-converge. 20 // 21 // Due to this execution model, some optimizations such as jump 22 // threading and loop unswitching can interfere with thread re-convergence. 23 // Therefore, an analysis that computes which branches in a GPU program are 24 // divergent can help the compiler to selectively run these optimizations. 25 // 26 // This implementation is derived from the Vectorization Analysis of the 27 // Region Vectorizer (RV). That implementation in turn is based on the approach 28 // described in 29 // 30 // Improving Performance of OpenCL on CPUs 31 // Ralf Karrenberg and Sebastian Hack 32 // CC '12 33 // 34 // This DivergenceAnalysis implementation is generic in the sense that it does 35 // not itself identify original sources of divergence. 36 // Instead specialized adapter classes, (LoopDivergenceAnalysis) for loops and 37 // (GPUDivergenceAnalysis) for GPU programs, identify the sources of divergence 38 // (e.g., special variables that hold the thread ID or the iteration variable). 39 // 40 // The generic implementation propagates divergence to variables that are data 41 // or sync dependent on a source of divergence. 42 // 43 // While data dependency is a well-known concept, the notion of sync dependency 44 // is worth more explanation. Sync dependence characterizes the control flow 45 // aspect of the propagation of branch divergence. For example, 46 // 47 // %cond = icmp slt i32 %tid, 10 48 // br i1 %cond, label %then, label %else 49 // then: 50 // br label %merge 51 // else: 52 // br label %merge 53 // merge: 54 // %a = phi i32 [ 0, %then ], [ 1, %else ] 55 // 56 // Suppose %tid holds the thread ID. Although %a is not data dependent on %tid 57 // because %tid is not on its use-def chains, %a is sync dependent on %tid 58 // because the branch "br i1 %cond" depends on %tid and affects which value %a 59 // is assigned to. 60 // 61 // The sync dependence detection (which branch induces divergence in which join 62 // points) is implemented in the SyncDependenceAnalysis. 63 // 64 // The current DivergenceAnalysis implementation has the following limitations: 65 // 1. intra-procedural. It conservatively considers the arguments of a 66 // non-kernel-entry function and the return value of a function call as 67 // divergent. 68 // 2. memory as black box. It conservatively considers values loaded from 69 // generic or local address as divergent. This can be improved by leveraging 70 // pointer analysis and/or by modelling non-escaping memory objects in SSA 71 // as done in RV. 72 // 73 //===----------------------------------------------------------------------===// 74 75 #include "llvm/Analysis/DivergenceAnalysis.h" 76 #include "llvm/Analysis/LoopInfo.h" 77 #include "llvm/Analysis/Passes.h" 78 #include "llvm/Analysis/PostDominators.h" 79 #include "llvm/Analysis/TargetTransformInfo.h" 80 #include "llvm/IR/Dominators.h" 81 #include "llvm/IR/InstIterator.h" 82 #include "llvm/IR/Instructions.h" 83 #include "llvm/IR/IntrinsicInst.h" 84 #include "llvm/IR/Value.h" 85 #include "llvm/Support/Debug.h" 86 #include "llvm/Support/raw_ostream.h" 87 #include <vector> 88 89 using namespace llvm; 90 91 #define DEBUG_TYPE "divergence-analysis" 92 93 // class DivergenceAnalysis 94 DivergenceAnalysis::DivergenceAnalysis( 95 const Function &F, const Loop *RegionLoop, const DominatorTree &DT, 96 const LoopInfo &LI, SyncDependenceAnalysis &SDA, bool IsLCSSAForm) 97 : F(F), RegionLoop(RegionLoop), DT(DT), LI(LI), SDA(SDA), 98 IsLCSSAForm(IsLCSSAForm) {} 99 100 bool DivergenceAnalysis::markDivergent(const Value &DivVal) { 101 if (isAlwaysUniform(DivVal)) 102 return false; 103 assert(isa<Instruction>(DivVal) || isa<Argument>(DivVal)); 104 assert(!isAlwaysUniform(DivVal) && "cannot be a divergent"); 105 return DivergentValues.insert(&DivVal).second; 106 } 107 108 void DivergenceAnalysis::addUniformOverride(const Value &UniVal) { 109 UniformOverrides.insert(&UniVal); 110 } 111 112 bool DivergenceAnalysis::isTemporalDivergent(const BasicBlock &ObservingBlock, 113 const Value &Val) const { 114 const auto *Inst = dyn_cast<const Instruction>(&Val); 115 if (!Inst) 116 return false; 117 // check whether any divergent loop carrying Val terminates before control 118 // proceeds to ObservingBlock 119 for (const auto *Loop = LI.getLoopFor(Inst->getParent()); 120 Loop != RegionLoop && !Loop->contains(&ObservingBlock); 121 Loop = Loop->getParentLoop()) { 122 if (DivergentLoops.find(Loop) != DivergentLoops.end()) 123 return true; 124 } 125 126 return false; 127 } 128 129 bool DivergenceAnalysis::inRegion(const Instruction &I) const { 130 return I.getParent() && inRegion(*I.getParent()); 131 } 132 133 bool DivergenceAnalysis::inRegion(const BasicBlock &BB) const { 134 return (!RegionLoop && BB.getParent() == &F) || RegionLoop->contains(&BB); 135 } 136 137 void DivergenceAnalysis::pushUsers(const Value &V) { 138 const auto *I = dyn_cast<const Instruction>(&V); 139 140 if (I && I->isTerminator()) { 141 analyzeControlDivergence(*I); 142 return; 143 } 144 145 for (const auto *User : V.users()) { 146 const auto *UserInst = dyn_cast<const Instruction>(User); 147 if (!UserInst) 148 continue; 149 150 // only compute divergent inside loop 151 if (!inRegion(*UserInst)) 152 continue; 153 154 // All users of divergent values are immediate divergent 155 if (markDivergent(*UserInst)) 156 Worklist.push_back(UserInst); 157 } 158 } 159 160 static const Instruction *getIfCarriedInstruction(const Use &U, 161 const Loop &DivLoop) { 162 const auto *I = dyn_cast<const Instruction>(&U); 163 if (!I) 164 return nullptr; 165 if (!DivLoop.contains(I)) 166 return nullptr; 167 return I; 168 } 169 170 void DivergenceAnalysis::analyzeTemporalDivergence(const Instruction &I, 171 const Loop &OuterDivLoop) { 172 if (isAlwaysUniform(I)) 173 return; 174 if (isDivergent(I)) 175 return; 176 177 LLVM_DEBUG(dbgs() << "Analyze temporal divergence: " << I.getName() << "\n"); 178 assert((isa<PHINode>(I) || !IsLCSSAForm) && 179 "In LCSSA form all users of loop-exiting defs are Phi nodes."); 180 for (const Use &Op : I.operands()) { 181 const auto *OpInst = getIfCarriedInstruction(Op, OuterDivLoop); 182 if (!OpInst) 183 continue; 184 if (markDivergent(I)) 185 pushUsers(I); 186 return; 187 } 188 } 189 190 // marks all users of loop-carried values of the loop headed by LoopHeader as 191 // divergent 192 void DivergenceAnalysis::analyzeLoopExitDivergence(const BasicBlock &DivExit, 193 const Loop &OuterDivLoop) { 194 // All users are in immediate exit blocks 195 if (IsLCSSAForm) { 196 for (const auto &Phi : DivExit.phis()) { 197 analyzeTemporalDivergence(Phi, OuterDivLoop); 198 } 199 return; 200 } 201 202 // For non-LCSSA we have to follow all live out edges wherever they may lead. 203 const BasicBlock &LoopHeader = *OuterDivLoop.getHeader(); 204 SmallVector<const BasicBlock *, 8> TaintStack; 205 TaintStack.push_back(&DivExit); 206 207 // Otherwise potential users of loop-carried values could be anywhere in the 208 // dominance region of DivLoop (including its fringes for phi nodes) 209 DenseSet<const BasicBlock *> Visited; 210 Visited.insert(&DivExit); 211 212 do { 213 auto *UserBlock = TaintStack.back(); 214 TaintStack.pop_back(); 215 216 // don't spread divergence beyond the region 217 if (!inRegion(*UserBlock)) 218 continue; 219 220 assert(!OuterDivLoop.contains(UserBlock) && 221 "irreducible control flow detected"); 222 223 // phi nodes at the fringes of the dominance region 224 if (!DT.dominates(&LoopHeader, UserBlock)) { 225 // all PHI nodes of UserBlock become divergent 226 for (auto &Phi : UserBlock->phis()) { 227 analyzeTemporalDivergence(Phi, OuterDivLoop); 228 } 229 continue; 230 } 231 232 // Taint outside users of values carried by OuterDivLoop. 233 for (auto &I : *UserBlock) { 234 analyzeTemporalDivergence(I, OuterDivLoop); 235 } 236 237 // visit all blocks in the dominance region 238 for (auto *SuccBlock : successors(UserBlock)) { 239 if (!Visited.insert(SuccBlock).second) { 240 continue; 241 } 242 TaintStack.push_back(SuccBlock); 243 } 244 } while (!TaintStack.empty()); 245 } 246 247 void DivergenceAnalysis::propagateLoopExitDivergence(const BasicBlock &DivExit, 248 const Loop &InnerDivLoop) { 249 LLVM_DEBUG(dbgs() << "\tpropLoopExitDiv " << DivExit.getName() << "\n"); 250 251 // Find outer-most loop that does not contain \p DivExit 252 const Loop *DivLoop = &InnerDivLoop; 253 const Loop *OuterDivLoop = DivLoop; 254 const Loop *ExitLevelLoop = LI.getLoopFor(&DivExit); 255 const unsigned LoopExitDepth = 256 ExitLevelLoop ? ExitLevelLoop->getLoopDepth() : 0; 257 while (DivLoop && DivLoop->getLoopDepth() > LoopExitDepth) { 258 DivergentLoops.insert(DivLoop); // all crossed loops are divergent 259 OuterDivLoop = DivLoop; 260 DivLoop = DivLoop->getParentLoop(); 261 } 262 LLVM_DEBUG(dbgs() << "\tOuter-most left loop: " << OuterDivLoop->getName() 263 << "\n"); 264 265 analyzeLoopExitDivergence(DivExit, *OuterDivLoop); 266 } 267 268 // this is a divergent join point - mark all phi nodes as divergent and push 269 // them onto the stack. 270 void DivergenceAnalysis::taintAndPushPhiNodes(const BasicBlock &JoinBlock) { 271 LLVM_DEBUG(dbgs() << "taintAndPushPhiNodes in " << JoinBlock.getName() 272 << "\n"); 273 274 // ignore divergence outside the region 275 if (!inRegion(JoinBlock)) { 276 return; 277 } 278 279 // push non-divergent phi nodes in JoinBlock to the worklist 280 for (const auto &Phi : JoinBlock.phis()) { 281 if (isDivergent(Phi)) 282 continue; 283 // FIXME Theoretically ,the 'undef' value could be replaced by any other 284 // value causing spurious divergence. 285 if (Phi.hasConstantOrUndefValue()) 286 continue; 287 if (markDivergent(Phi)) 288 Worklist.push_back(&Phi); 289 } 290 } 291 292 void DivergenceAnalysis::analyzeControlDivergence(const Instruction &Term) { 293 LLVM_DEBUG(dbgs() << "analyzeControlDiv " << Term.getParent()->getName() 294 << "\n"); 295 296 // Don't propagate divergence from unreachable blocks. 297 if (!DT.isReachableFromEntry(Term.getParent())) 298 return; 299 300 const auto *BranchLoop = LI.getLoopFor(Term.getParent()); 301 302 const auto &DivDesc = SDA.getJoinBlocks(Term); 303 304 // Iterate over all blocks now reachable by a disjoint path join 305 for (const auto *JoinBlock : DivDesc.JoinDivBlocks) { 306 taintAndPushPhiNodes(*JoinBlock); 307 } 308 309 assert(DivDesc.LoopDivBlocks.empty() || BranchLoop); 310 for (const auto *DivExitBlock : DivDesc.LoopDivBlocks) { 311 propagateLoopExitDivergence(*DivExitBlock, *BranchLoop); 312 } 313 } 314 315 void DivergenceAnalysis::compute() { 316 // Initialize worklist. 317 auto DivValuesCopy = DivergentValues; 318 for (const auto *DivVal : DivValuesCopy) { 319 assert(isDivergent(*DivVal) && "Worklist invariant violated!"); 320 pushUsers(*DivVal); 321 } 322 323 // All values on the Worklist are divergent. 324 // Their users may not have been updated yed. 325 while (!Worklist.empty()) { 326 const Instruction &I = *Worklist.back(); 327 Worklist.pop_back(); 328 329 // propagate value divergence to users 330 assert(isDivergent(I) && "Worklist invariant violated!"); 331 pushUsers(I); 332 } 333 } 334 335 bool DivergenceAnalysis::isAlwaysUniform(const Value &V) const { 336 return UniformOverrides.find(&V) != UniformOverrides.end(); 337 } 338 339 bool DivergenceAnalysis::isDivergent(const Value &V) const { 340 return DivergentValues.find(&V) != DivergentValues.end(); 341 } 342 343 bool DivergenceAnalysis::isDivergentUse(const Use &U) const { 344 Value &V = *U.get(); 345 Instruction &I = *cast<Instruction>(U.getUser()); 346 return isDivergent(V) || isTemporalDivergent(*I.getParent(), V); 347 } 348 349 void DivergenceAnalysis::print(raw_ostream &OS, const Module *) const { 350 if (DivergentValues.empty()) 351 return; 352 // iterate instructions using instructions() to ensure a deterministic order. 353 for (auto &I : instructions(F)) { 354 if (isDivergent(I)) 355 OS << "DIVERGENT:" << I << '\n'; 356 } 357 } 358 359 // class GPUDivergenceAnalysis 360 GPUDivergenceAnalysis::GPUDivergenceAnalysis(Function &F, 361 const DominatorTree &DT, 362 const PostDominatorTree &PDT, 363 const LoopInfo &LI, 364 const TargetTransformInfo &TTI) 365 : SDA(DT, PDT, LI), DA(F, nullptr, DT, LI, SDA, /* LCSSA */ false) { 366 for (auto &I : instructions(F)) { 367 if (TTI.isSourceOfDivergence(&I)) { 368 DA.markDivergent(I); 369 } else if (TTI.isAlwaysUniform(&I)) { 370 DA.addUniformOverride(I); 371 } 372 } 373 for (auto &Arg : F.args()) { 374 if (TTI.isSourceOfDivergence(&Arg)) { 375 DA.markDivergent(Arg); 376 } 377 } 378 379 DA.compute(); 380 } 381 382 bool GPUDivergenceAnalysis::isDivergent(const Value &val) const { 383 return DA.isDivergent(val); 384 } 385 386 bool GPUDivergenceAnalysis::isDivergentUse(const Use &use) const { 387 return DA.isDivergentUse(use); 388 } 389 390 void GPUDivergenceAnalysis::print(raw_ostream &OS, const Module *mod) const { 391 OS << "Divergence of kernel " << DA.getFunction().getName() << " {\n"; 392 DA.print(OS, mod); 393 OS << "}\n"; 394 } 395