1 //===- SyncDependenceAnalysis.cpp - Divergent Branch Dependence Calculation 2 //--===// 3 // 4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 5 // See https://llvm.org/LICENSE.txt for license information. 6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // This file implements an algorithm that returns for a divergent branch 11 // the set of basic blocks whose phi nodes become divergent due to divergent 12 // control. These are the blocks that are reachable by two disjoint paths from 13 // the branch or loop exits that have a reaching path that is disjoint from a 14 // path to the loop latch. 15 // 16 // The SyncDependenceAnalysis is used in the DivergenceAnalysis to model 17 // control-induced divergence in phi nodes. 18 // 19 // -- Summary -- 20 // The SyncDependenceAnalysis lazily computes sync dependences [3]. 21 // The analysis evaluates the disjoint path criterion [2] by a reduction 22 // to SSA construction. The SSA construction algorithm is implemented as 23 // a simple data-flow analysis [1]. 24 // 25 // [1] "A Simple, Fast Dominance Algorithm", SPI '01, Cooper, Harvey and Kennedy 26 // [2] "Efficiently Computing Static Single Assignment Form 27 // and the Control Dependence Graph", TOPLAS '91, 28 // Cytron, Ferrante, Rosen, Wegman and Zadeck 29 // [3] "Improving Performance of OpenCL on CPUs", CC '12, Karrenberg and Hack 30 // [4] "Divergence Analysis", TOPLAS '13, Sampaio, Souza, Collange and Pereira 31 // 32 // -- Sync dependence -- 33 // Sync dependence [4] characterizes the control flow aspect of the 34 // propagation of branch divergence. For example, 35 // 36 // %cond = icmp slt i32 %tid, 10 37 // br i1 %cond, label %then, label %else 38 // then: 39 // br label %merge 40 // else: 41 // br label %merge 42 // merge: 43 // %a = phi i32 [ 0, %then ], [ 1, %else ] 44 // 45 // Suppose %tid holds the thread ID. Although %a is not data dependent on %tid 46 // because %tid is not on its use-def chains, %a is sync dependent on %tid 47 // because the branch "br i1 %cond" depends on %tid and affects which value %a 48 // is assigned to. 49 // 50 // -- Reduction to SSA construction -- 51 // There are two disjoint paths from A to X, if a certain variant of SSA 52 // construction places a phi node in X under the following set-up scheme [2]. 53 // 54 // This variant of SSA construction ignores incoming undef values. 55 // That is paths from the entry without a definition do not result in 56 // phi nodes. 57 // 58 // entry 59 // / \ 60 // A \ 61 // / \ Y 62 // B C / 63 // \ / \ / 64 // D E 65 // \ / 66 // F 67 // Assume that A contains a divergent branch. We are interested 68 // in the set of all blocks where each block is reachable from A 69 // via two disjoint paths. This would be the set {D, F} in this 70 // case. 71 // To generally reduce this query to SSA construction we introduce 72 // a virtual variable x and assign to x different values in each 73 // successor block of A. 74 // entry 75 // / \ 76 // A \ 77 // / \ Y 78 // x = 0 x = 1 / 79 // \ / \ / 80 // D E 81 // \ / 82 // F 83 // Our flavor of SSA construction for x will construct the following 84 // entry 85 // / \ 86 // A \ 87 // / \ Y 88 // x0 = 0 x1 = 1 / 89 // \ / \ / 90 // x2=phi E 91 // \ / 92 // x3=phi 93 // The blocks D and F contain phi nodes and are thus each reachable 94 // by two disjoins paths from A. 95 // 96 // -- Remarks -- 97 // In case of loop exits we need to check the disjoint path criterion for loops 98 // [2]. To this end, we check whether the definition of x differs between the 99 // loop exit and the loop header (_after_ SSA construction). 100 // 101 //===----------------------------------------------------------------------===// 102 #include "llvm/ADT/PostOrderIterator.h" 103 #include "llvm/ADT/SmallPtrSet.h" 104 #include "llvm/Analysis/PostDominators.h" 105 #include "llvm/Analysis/SyncDependenceAnalysis.h" 106 #include "llvm/IR/BasicBlock.h" 107 #include "llvm/IR/CFG.h" 108 #include "llvm/IR/Dominators.h" 109 #include "llvm/IR/Function.h" 110 111 #include <stack> 112 #include <unordered_set> 113 114 #define DEBUG_TYPE "sync-dependence" 115 116 namespace llvm { 117 118 ConstBlockSet SyncDependenceAnalysis::EmptyBlockSet; 119 120 SyncDependenceAnalysis::SyncDependenceAnalysis(const DominatorTree &DT, 121 const PostDominatorTree &PDT, 122 const LoopInfo &LI) 123 : FuncRPOT(DT.getRoot()->getParent()), DT(DT), PDT(PDT), LI(LI) {} 124 125 SyncDependenceAnalysis::~SyncDependenceAnalysis() {} 126 127 using FunctionRPOT = ReversePostOrderTraversal<const Function *>; 128 129 // divergence propagator for reducible CFGs 130 struct DivergencePropagator { 131 const FunctionRPOT &FuncRPOT; 132 const DominatorTree &DT; 133 const PostDominatorTree &PDT; 134 const LoopInfo &LI; 135 136 // identified join points 137 std::unique_ptr<ConstBlockSet> JoinBlocks; 138 139 // reached loop exits (by a path disjoint to a path to the loop header) 140 SmallPtrSet<const BasicBlock *, 4> ReachedLoopExits; 141 142 // if DefMap[B] == C then C is the dominating definition at block B 143 // if DefMap[B] ~ undef then we haven't seen B yet 144 // if DefMap[B] == B then B is a join point of disjoint paths from X or B is 145 // an immediate successor of X (initial value). 146 using DefiningBlockMap = std::map<const BasicBlock *, const BasicBlock *>; 147 DefiningBlockMap DefMap; 148 149 // all blocks with pending visits 150 std::unordered_set<const BasicBlock *> PendingUpdates; 151 152 DivergencePropagator(const FunctionRPOT &FuncRPOT, const DominatorTree &DT, 153 const PostDominatorTree &PDT, const LoopInfo &LI) 154 : FuncRPOT(FuncRPOT), DT(DT), PDT(PDT), LI(LI), 155 JoinBlocks(new ConstBlockSet) {} 156 157 // set the definition at @block and mark @block as pending for a visit 158 void addPending(const BasicBlock &Block, const BasicBlock &DefBlock) { 159 bool WasAdded = DefMap.emplace(&Block, &DefBlock).second; 160 if (WasAdded) 161 PendingUpdates.insert(&Block); 162 } 163 164 void printDefs(raw_ostream &Out) { 165 Out << "Propagator::DefMap {\n"; 166 for (const auto *Block : FuncRPOT) { 167 auto It = DefMap.find(Block); 168 Out << Block->getName() << " : "; 169 if (It == DefMap.end()) { 170 Out << "\n"; 171 } else { 172 const auto *DefBlock = It->second; 173 Out << (DefBlock ? DefBlock->getName() : "<null>") << "\n"; 174 } 175 } 176 Out << "}\n"; 177 } 178 179 // process @succBlock with reaching definition @defBlock 180 // the original divergent branch was in @parentLoop (if any) 181 void visitSuccessor(const BasicBlock &SuccBlock, const Loop *ParentLoop, 182 const BasicBlock &DefBlock) { 183 184 // @succBlock is a loop exit 185 if (ParentLoop && !ParentLoop->contains(&SuccBlock)) { 186 DefMap.emplace(&SuccBlock, &DefBlock); 187 ReachedLoopExits.insert(&SuccBlock); 188 return; 189 } 190 191 // first reaching def? 192 auto ItLastDef = DefMap.find(&SuccBlock); 193 if (ItLastDef == DefMap.end()) { 194 addPending(SuccBlock, DefBlock); 195 return; 196 } 197 198 // a join of at least two definitions 199 if (ItLastDef->second != &DefBlock) { 200 // do we know this join already? 201 if (!JoinBlocks->insert(&SuccBlock).second) 202 return; 203 204 // update the definition 205 addPending(SuccBlock, SuccBlock); 206 } 207 } 208 209 // find all blocks reachable by two disjoint paths from @rootTerm. 210 // This method works for both divergent terminators and loops with 211 // divergent exits. 212 // @rootBlock is either the block containing the branch or the header of the 213 // divergent loop. 214 // @nodeSuccessors is the set of successors of the node (Loop or Terminator) 215 // headed by @rootBlock. 216 // @parentLoop is the parent loop of the Loop or the loop that contains the 217 // Terminator. 218 template <typename SuccessorIterable> 219 std::unique_ptr<ConstBlockSet> 220 computeJoinPoints(const BasicBlock &RootBlock, 221 SuccessorIterable NodeSuccessors, const Loop *ParentLoop, const BasicBlock * PdBoundBlock) { 222 assert(JoinBlocks); 223 224 // bootstrap with branch targets 225 for (const auto *SuccBlock : NodeSuccessors) { 226 DefMap.emplace(SuccBlock, SuccBlock); 227 228 if (ParentLoop && !ParentLoop->contains(SuccBlock)) { 229 // immediate loop exit from node. 230 ReachedLoopExits.insert(SuccBlock); 231 continue; 232 } else { 233 // regular successor 234 PendingUpdates.insert(SuccBlock); 235 } 236 } 237 238 auto ItBeginRPO = FuncRPOT.begin(); 239 240 // skip until term (TODO RPOT won't let us start at @term directly) 241 for (; *ItBeginRPO != &RootBlock; ++ItBeginRPO) {} 242 243 auto ItEndRPO = FuncRPOT.end(); 244 assert(ItBeginRPO != ItEndRPO); 245 246 // propagate definitions at the immediate successors of the node in RPO 247 auto ItBlockRPO = ItBeginRPO; 248 while (++ItBlockRPO != ItEndRPO && *ItBlockRPO != PdBoundBlock) { 249 const auto *Block = *ItBlockRPO; 250 251 // skip @block if not pending update 252 auto ItPending = PendingUpdates.find(Block); 253 if (ItPending == PendingUpdates.end()) 254 continue; 255 PendingUpdates.erase(ItPending); 256 257 // propagate definition at @block to its successors 258 auto ItDef = DefMap.find(Block); 259 const auto *DefBlock = ItDef->second; 260 assert(DefBlock); 261 262 auto *BlockLoop = LI.getLoopFor(Block); 263 if (ParentLoop && 264 (ParentLoop != BlockLoop && ParentLoop->contains(BlockLoop))) { 265 // if the successor is the header of a nested loop pretend its a 266 // single node with the loop's exits as successors 267 SmallVector<BasicBlock *, 4> BlockLoopExits; 268 BlockLoop->getExitBlocks(BlockLoopExits); 269 for (const auto *BlockLoopExit : BlockLoopExits) { 270 visitSuccessor(*BlockLoopExit, ParentLoop, *DefBlock); 271 } 272 273 } else { 274 // the successors are either on the same loop level or loop exits 275 for (const auto *SuccBlock : successors(Block)) { 276 visitSuccessor(*SuccBlock, ParentLoop, *DefBlock); 277 } 278 } 279 } 280 281 // We need to know the definition at the parent loop header to decide 282 // whether the definition at the header is different from the definition at 283 // the loop exits, which would indicate a divergent loop exits. 284 // 285 // A // loop header 286 // | 287 // B // nested loop header 288 // | 289 // C -> X (exit from B loop) -..-> (A latch) 290 // | 291 // D -> back to B (B latch) 292 // | 293 // proper exit from both loops 294 // 295 // D post-dominates B as it is the only proper exit from the "A loop". 296 // If C has a divergent branch, propagation will therefore stop at D. 297 // That implies that B will never receive a definition. 298 // But that definition can only be the same as at D (D itself in thise case) 299 // because all paths to anywhere have to pass through D. 300 // 301 const BasicBlock *ParentLoopHeader = 302 ParentLoop ? ParentLoop->getHeader() : nullptr; 303 if (ParentLoop && ParentLoop->contains(PdBoundBlock)) { 304 DefMap[ParentLoopHeader] = DefMap[PdBoundBlock]; 305 } 306 307 // analyze reached loop exits 308 if (!ReachedLoopExits.empty()) { 309 assert(ParentLoop); 310 const auto *HeaderDefBlock = DefMap[ParentLoopHeader]; 311 LLVM_DEBUG(printDefs(dbgs())); 312 assert(HeaderDefBlock && "no definition in header of carrying loop"); 313 314 for (const auto *ExitBlock : ReachedLoopExits) { 315 auto ItExitDef = DefMap.find(ExitBlock); 316 assert((ItExitDef != DefMap.end()) && 317 "no reaching def at reachable loop exit"); 318 if (ItExitDef->second != HeaderDefBlock) { 319 JoinBlocks->insert(ExitBlock); 320 } 321 } 322 } 323 324 return std::move(JoinBlocks); 325 } 326 }; 327 328 const ConstBlockSet &SyncDependenceAnalysis::join_blocks(const Loop &Loop) { 329 using LoopExitVec = SmallVector<BasicBlock *, 4>; 330 LoopExitVec LoopExits; 331 Loop.getExitBlocks(LoopExits); 332 if (LoopExits.size() < 1) { 333 return EmptyBlockSet; 334 } 335 336 // already available in cache? 337 auto ItCached = CachedLoopExitJoins.find(&Loop); 338 if (ItCached != CachedLoopExitJoins.end()) { 339 return *ItCached->second; 340 } 341 342 // dont propagte beyond the immediate post dom of the loop 343 const auto *PdNode = PDT.getNode(const_cast<BasicBlock *>(Loop.getHeader())); 344 const auto *IpdNode = PdNode->getIDom(); 345 const auto *PdBoundBlock = IpdNode ? IpdNode->getBlock() : nullptr; 346 while (PdBoundBlock && Loop.contains(PdBoundBlock)) { 347 IpdNode = IpdNode->getIDom(); 348 PdBoundBlock = IpdNode ? IpdNode->getBlock() : nullptr; 349 } 350 351 // compute all join points 352 DivergencePropagator Propagator{FuncRPOT, DT, PDT, LI}; 353 auto JoinBlocks = Propagator.computeJoinPoints<const LoopExitVec &>( 354 *Loop.getHeader(), LoopExits, Loop.getParentLoop(), PdBoundBlock); 355 356 auto ItInserted = CachedLoopExitJoins.emplace(&Loop, std::move(JoinBlocks)); 357 assert(ItInserted.second); 358 return *ItInserted.first->second; 359 } 360 361 const ConstBlockSet & 362 SyncDependenceAnalysis::join_blocks(const Instruction &Term) { 363 // trivial case 364 if (Term.getNumSuccessors() < 1) { 365 return EmptyBlockSet; 366 } 367 368 // already available in cache? 369 auto ItCached = CachedBranchJoins.find(&Term); 370 if (ItCached != CachedBranchJoins.end()) 371 return *ItCached->second; 372 373 // dont propagate beyond the immediate post dominator of the branch 374 const auto *PdNode = PDT.getNode(const_cast<BasicBlock *>(Term.getParent())); 375 const auto *IpdNode = PdNode->getIDom(); 376 const auto *PdBoundBlock = IpdNode ? IpdNode->getBlock() : nullptr; 377 378 // compute all join points 379 DivergencePropagator Propagator{FuncRPOT, DT, PDT, LI}; 380 const auto &TermBlock = *Term.getParent(); 381 auto JoinBlocks = Propagator.computeJoinPoints<succ_const_range>( 382 TermBlock, successors(Term.getParent()), LI.getLoopFor(&TermBlock), PdBoundBlock); 383 384 auto ItInserted = CachedBranchJoins.emplace(&Term, std::move(JoinBlocks)); 385 assert(ItInserted.second); 386 return *ItInserted.first->second; 387 } 388 389 } // namespace llvm 390