1 //===- SyncDependenceAnalysis.cpp - Divergent Branch Dependence Calculation 2 //--===// 3 // 4 // The LLVM Compiler Infrastructure 5 // 6 // This file is distributed under the University of Illinois Open Source 7 // License. See LICENSE.TXT for details. 8 // 9 //===----------------------------------------------------------------------===// 10 // 11 // This file implements an algorithm that returns for a divergent branch 12 // the set of basic blocks whose phi nodes become divergent due to divergent 13 // control. These are the blocks that are reachable by two disjoint paths from 14 // the branch or loop exits that have a reaching path that is disjoint from a 15 // path to the loop latch. 16 // 17 // The SyncDependenceAnalysis is used in the DivergenceAnalysis to model 18 // control-induced divergence in phi nodes. 19 // 20 // -- Summary -- 21 // The SyncDependenceAnalysis lazily computes sync dependences [3]. 22 // The analysis evaluates the disjoint path criterion [2] by a reduction 23 // to SSA construction. The SSA construction algorithm is implemented as 24 // a simple data-flow analysis [1]. 25 // 26 // [1] "A Simple, Fast Dominance Algorithm", SPI '01, Cooper, Harvey and Kennedy 27 // [2] "Efficiently Computing Static Single Assignment Form 28 // and the Control Dependence Graph", TOPLAS '91, 29 // Cytron, Ferrante, Rosen, Wegman and Zadeck 30 // [3] "Improving Performance of OpenCL on CPUs", CC '12, Karrenberg and Hack 31 // [4] "Divergence Analysis", TOPLAS '13, Sampaio, Souza, Collange and Pereira 32 // 33 // -- Sync dependence -- 34 // Sync dependence [4] characterizes the control flow aspect of the 35 // propagation of branch divergence. For example, 36 // 37 // %cond = icmp slt i32 %tid, 10 38 // br i1 %cond, label %then, label %else 39 // then: 40 // br label %merge 41 // else: 42 // br label %merge 43 // merge: 44 // %a = phi i32 [ 0, %then ], [ 1, %else ] 45 // 46 // Suppose %tid holds the thread ID. Although %a is not data dependent on %tid 47 // because %tid is not on its use-def chains, %a is sync dependent on %tid 48 // because the branch "br i1 %cond" depends on %tid and affects which value %a 49 // is assigned to. 50 // 51 // -- Reduction to SSA construction -- 52 // There are two disjoint paths from A to X, if a certain variant of SSA 53 // construction places a phi node in X under the following set-up scheme [2]. 54 // 55 // This variant of SSA construction ignores incoming undef values. 56 // That is paths from the entry without a definition do not result in 57 // phi nodes. 58 // 59 // entry 60 // / \ 61 // A \ 62 // / \ Y 63 // B C / 64 // \ / \ / 65 // D E 66 // \ / 67 // F 68 // Assume that A contains a divergent branch. We are interested 69 // in the set of all blocks where each block is reachable from A 70 // via two disjoint paths. This would be the set {D, F} in this 71 // case. 72 // To generally reduce this query to SSA construction we introduce 73 // a virtual variable x and assign to x different values in each 74 // successor block of A. 75 // entry 76 // / \ 77 // A \ 78 // / \ Y 79 // x = 0 x = 1 / 80 // \ / \ / 81 // D E 82 // \ / 83 // F 84 // Our flavor of SSA construction for x will construct the following 85 // entry 86 // / \ 87 // A \ 88 // / \ Y 89 // x0 = 0 x1 = 1 / 90 // \ / \ / 91 // x2=phi E 92 // \ / 93 // x3=phi 94 // The blocks D and F contain phi nodes and are thus each reachable 95 // by two disjoins paths from A. 96 // 97 // -- Remarks -- 98 // In case of loop exits we need to check the disjoint path criterion for loops 99 // [2]. To this end, we check whether the definition of x differs between the 100 // loop exit and the loop header (_after_ SSA construction). 101 // 102 //===----------------------------------------------------------------------===// 103 #include "llvm/ADT/PostOrderIterator.h" 104 #include "llvm/ADT/SmallPtrSet.h" 105 #include "llvm/Analysis/PostDominators.h" 106 #include "llvm/Analysis/SyncDependenceAnalysis.h" 107 #include "llvm/IR/BasicBlock.h" 108 #include "llvm/IR/CFG.h" 109 #include "llvm/IR/Dominators.h" 110 #include "llvm/IR/Function.h" 111 112 #include <stack> 113 #include <unordered_set> 114 115 #define DEBUG_TYPE "sync-dependence" 116 117 namespace llvm { 118 119 ConstBlockSet SyncDependenceAnalysis::EmptyBlockSet; 120 121 SyncDependenceAnalysis::SyncDependenceAnalysis(const DominatorTree &DT, 122 const PostDominatorTree &PDT, 123 const LoopInfo &LI) 124 : FuncRPOT(DT.getRoot()->getParent()), DT(DT), PDT(PDT), LI(LI) {} 125 126 SyncDependenceAnalysis::~SyncDependenceAnalysis() {} 127 128 using FunctionRPOT = ReversePostOrderTraversal<const Function *>; 129 130 // divergence propagator for reducible CFGs 131 struct DivergencePropagator { 132 const FunctionRPOT &FuncRPOT; 133 const DominatorTree &DT; 134 const PostDominatorTree &PDT; 135 const LoopInfo &LI; 136 137 // identified join points 138 std::unique_ptr<ConstBlockSet> JoinBlocks; 139 140 // reached loop exits (by a path disjoint to a path to the loop header) 141 SmallPtrSet<const BasicBlock *, 4> ReachedLoopExits; 142 143 // if DefMap[B] == C then C is the dominating definition at block B 144 // if DefMap[B] ~ undef then we haven't seen B yet 145 // if DefMap[B] == B then B is a join point of disjoint paths from X or B is 146 // an immediate successor of X (initial value). 147 using DefiningBlockMap = std::map<const BasicBlock *, const BasicBlock *>; 148 DefiningBlockMap DefMap; 149 150 // all blocks with pending visits 151 std::unordered_set<const BasicBlock *> PendingUpdates; 152 153 DivergencePropagator(const FunctionRPOT &FuncRPOT, const DominatorTree &DT, 154 const PostDominatorTree &PDT, const LoopInfo &LI) 155 : FuncRPOT(FuncRPOT), DT(DT), PDT(PDT), LI(LI), 156 JoinBlocks(new ConstBlockSet) {} 157 158 // set the definition at @block and mark @block as pending for a visit 159 void addPending(const BasicBlock &Block, const BasicBlock &DefBlock) { 160 bool WasAdded = DefMap.emplace(&Block, &DefBlock).second; 161 if (WasAdded) 162 PendingUpdates.insert(&Block); 163 } 164 165 void printDefs(raw_ostream &Out) { 166 Out << "Propagator::DefMap {\n"; 167 for (const auto *Block : FuncRPOT) { 168 auto It = DefMap.find(Block); 169 Out << Block->getName() << " : "; 170 if (It == DefMap.end()) { 171 Out << "\n"; 172 } else { 173 const auto *DefBlock = It->second; 174 Out << (DefBlock ? DefBlock->getName() : "<null>") << "\n"; 175 } 176 } 177 Out << "}\n"; 178 } 179 180 // process @succBlock with reaching definition @defBlock 181 // the original divergent branch was in @parentLoop (if any) 182 void visitSuccessor(const BasicBlock &SuccBlock, const Loop *ParentLoop, 183 const BasicBlock &DefBlock) { 184 185 // @succBlock is a loop exit 186 if (ParentLoop && !ParentLoop->contains(&SuccBlock)) { 187 DefMap.emplace(&SuccBlock, &DefBlock); 188 ReachedLoopExits.insert(&SuccBlock); 189 return; 190 } 191 192 // first reaching def? 193 auto ItLastDef = DefMap.find(&SuccBlock); 194 if (ItLastDef == DefMap.end()) { 195 addPending(SuccBlock, DefBlock); 196 return; 197 } 198 199 // a join of at least two definitions 200 if (ItLastDef->second != &DefBlock) { 201 // do we know this join already? 202 if (!JoinBlocks->insert(&SuccBlock).second) 203 return; 204 205 // update the definition 206 addPending(SuccBlock, SuccBlock); 207 } 208 } 209 210 // find all blocks reachable by two disjoint paths from @rootTerm. 211 // This method works for both divergent terminators and loops with 212 // divergent exits. 213 // @rootBlock is either the block containing the branch or the header of the 214 // divergent loop. 215 // @nodeSuccessors is the set of successors of the node (Loop or Terminator) 216 // headed by @rootBlock. 217 // @parentLoop is the parent loop of the Loop or the loop that contains the 218 // Terminator. 219 template <typename SuccessorIterable> 220 std::unique_ptr<ConstBlockSet> 221 computeJoinPoints(const BasicBlock &RootBlock, 222 SuccessorIterable NodeSuccessors, const Loop *ParentLoop) { 223 assert(JoinBlocks); 224 225 // immediate post dominator (no join block beyond that block) 226 const auto *PdNode = PDT.getNode(const_cast<BasicBlock *>(&RootBlock)); 227 const auto *IpdNode = PdNode->getIDom(); 228 const auto *PdBoundBlock = IpdNode ? IpdNode->getBlock() : nullptr; 229 230 // bootstrap with branch targets 231 for (const auto *SuccBlock : NodeSuccessors) { 232 DefMap.emplace(SuccBlock, SuccBlock); 233 234 if (ParentLoop && !ParentLoop->contains(SuccBlock)) { 235 // immediate loop exit from node. 236 ReachedLoopExits.insert(SuccBlock); 237 continue; 238 } else { 239 // regular successor 240 PendingUpdates.insert(SuccBlock); 241 } 242 } 243 244 auto ItBeginRPO = FuncRPOT.begin(); 245 246 // skip until term (TODO RPOT won't let us start at @term directly) 247 for (; *ItBeginRPO != &RootBlock; ++ItBeginRPO) {} 248 249 auto ItEndRPO = FuncRPOT.end(); 250 assert(ItBeginRPO != ItEndRPO); 251 252 // propagate definitions at the immediate successors of the node in RPO 253 auto ItBlockRPO = ItBeginRPO; 254 while (++ItBlockRPO != ItEndRPO && *ItBlockRPO != PdBoundBlock) { 255 const auto *Block = *ItBlockRPO; 256 257 // skip @block if not pending update 258 auto ItPending = PendingUpdates.find(Block); 259 if (ItPending == PendingUpdates.end()) 260 continue; 261 PendingUpdates.erase(ItPending); 262 263 // propagate definition at @block to its successors 264 auto ItDef = DefMap.find(Block); 265 const auto *DefBlock = ItDef->second; 266 assert(DefBlock); 267 268 auto *BlockLoop = LI.getLoopFor(Block); 269 if (ParentLoop && 270 (ParentLoop != BlockLoop && ParentLoop->contains(BlockLoop))) { 271 // if the successor is the header of a nested loop pretend its a 272 // single node with the loop's exits as successors 273 SmallVector<BasicBlock *, 4> BlockLoopExits; 274 BlockLoop->getExitBlocks(BlockLoopExits); 275 for (const auto *BlockLoopExit : BlockLoopExits) { 276 visitSuccessor(*BlockLoopExit, ParentLoop, *DefBlock); 277 } 278 279 } else { 280 // the successors are either on the same loop level or loop exits 281 for (const auto *SuccBlock : successors(Block)) { 282 visitSuccessor(*SuccBlock, ParentLoop, *DefBlock); 283 } 284 } 285 } 286 287 // We need to know the definition at the parent loop header to decide 288 // whether the definition at the header is different from the definition at 289 // the loop exits, which would indicate a divergent loop exits. 290 // 291 // A // loop header 292 // | 293 // B // nested loop header 294 // | 295 // C -> X (exit from B loop) -..-> (A latch) 296 // | 297 // D -> back to B (B latch) 298 // | 299 // proper exit from both loops 300 // 301 // D post-dominates B as it is the only proper exit from the "A loop". 302 // If C has a divergent branch, propagation will therefore stop at D. 303 // That implies that B will never receive a definition. 304 // But that definition can only be the same as at D (D itself in thise case) 305 // because all paths to anywhere have to pass through D. 306 // 307 const BasicBlock *ParentLoopHeader = 308 ParentLoop ? ParentLoop->getHeader() : nullptr; 309 if (ParentLoop && ParentLoop->contains(PdBoundBlock)) { 310 DefMap[ParentLoopHeader] = DefMap[PdBoundBlock]; 311 } 312 313 // analyze reached loop exits 314 if (!ReachedLoopExits.empty()) { 315 assert(ParentLoop); 316 const auto *HeaderDefBlock = DefMap[ParentLoopHeader]; 317 LLVM_DEBUG(printDefs(dbgs())); 318 assert(HeaderDefBlock && "no definition in header of carrying loop"); 319 320 for (const auto *ExitBlock : ReachedLoopExits) { 321 auto ItExitDef = DefMap.find(ExitBlock); 322 assert((ItExitDef != DefMap.end()) && 323 "no reaching def at reachable loop exit"); 324 if (ItExitDef->second != HeaderDefBlock) { 325 JoinBlocks->insert(ExitBlock); 326 } 327 } 328 } 329 330 return std::move(JoinBlocks); 331 } 332 }; 333 334 const ConstBlockSet &SyncDependenceAnalysis::join_blocks(const Loop &Loop) { 335 using LoopExitVec = SmallVector<BasicBlock *, 4>; 336 LoopExitVec LoopExits; 337 Loop.getExitBlocks(LoopExits); 338 if (LoopExits.size() < 1) { 339 return EmptyBlockSet; 340 } 341 342 // already available in cache? 343 auto ItCached = CachedLoopExitJoins.find(&Loop); 344 if (ItCached != CachedLoopExitJoins.end()) 345 return *ItCached->second; 346 347 // compute all join points 348 DivergencePropagator Propagator{FuncRPOT, DT, PDT, LI}; 349 auto JoinBlocks = Propagator.computeJoinPoints<const LoopExitVec &>( 350 *Loop.getHeader(), LoopExits, Loop.getParentLoop()); 351 352 auto ItInserted = CachedLoopExitJoins.emplace(&Loop, std::move(JoinBlocks)); 353 assert(ItInserted.second); 354 return *ItInserted.first->second; 355 } 356 357 const ConstBlockSet & 358 SyncDependenceAnalysis::join_blocks(const Instruction &Term) { 359 // trivial case 360 if (Term.getNumSuccessors() < 1) { 361 return EmptyBlockSet; 362 } 363 364 // already available in cache? 365 auto ItCached = CachedBranchJoins.find(&Term); 366 if (ItCached != CachedBranchJoins.end()) 367 return *ItCached->second; 368 369 // compute all join points 370 DivergencePropagator Propagator{FuncRPOT, DT, PDT, LI}; 371 const auto &TermBlock = *Term.getParent(); 372 auto JoinBlocks = Propagator.computeJoinPoints<succ_const_range>( 373 TermBlock, successors(Term.getParent()), LI.getLoopFor(&TermBlock)); 374 375 auto ItInserted = CachedBranchJoins.emplace(&Term, std::move(JoinBlocks)); 376 assert(ItInserted.second); 377 return *ItInserted.first->second; 378 } 379 380 } // namespace llvm 381