1 //==- SyncDependenceAnalysis.cpp - Divergent Branch Dependence Calculation -==// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements an algorithm that returns for a divergent branch 10 // the set of basic blocks whose phi nodes become divergent due to divergent 11 // control. These are the blocks that are reachable by two disjoint paths from 12 // the branch or loop exits that have a reaching path that is disjoint from a 13 // path to the loop latch. 14 // 15 // The SyncDependenceAnalysis is used in the DivergenceAnalysis to model 16 // control-induced divergence in phi nodes. 17 // 18 // -- Summary -- 19 // The SyncDependenceAnalysis lazily computes sync dependences [3]. 20 // The analysis evaluates the disjoint path criterion [2] by a reduction 21 // to SSA construction. The SSA construction algorithm is implemented as 22 // a simple data-flow analysis [1]. 23 // 24 // [1] "A Simple, Fast Dominance Algorithm", SPI '01, Cooper, Harvey and Kennedy 25 // [2] "Efficiently Computing Static Single Assignment Form 26 // and the Control Dependence Graph", TOPLAS '91, 27 // Cytron, Ferrante, Rosen, Wegman and Zadeck 28 // [3] "Improving Performance of OpenCL on CPUs", CC '12, Karrenberg and Hack 29 // [4] "Divergence Analysis", TOPLAS '13, Sampaio, Souza, Collange and Pereira 30 // 31 // -- Sync dependence -- 32 // Sync dependence [4] characterizes the control flow aspect of the 33 // propagation of branch divergence. For example, 34 // 35 // %cond = icmp slt i32 %tid, 10 36 // br i1 %cond, label %then, label %else 37 // then: 38 // br label %merge 39 // else: 40 // br label %merge 41 // merge: 42 // %a = phi i32 [ 0, %then ], [ 1, %else ] 43 // 44 // Suppose %tid holds the thread ID. Although %a is not data dependent on %tid 45 // because %tid is not on its use-def chains, %a is sync dependent on %tid 46 // because the branch "br i1 %cond" depends on %tid and affects which value %a 47 // is assigned to. 48 // 49 // -- Reduction to SSA construction -- 50 // There are two disjoint paths from A to X, if a certain variant of SSA 51 // construction places a phi node in X under the following set-up scheme [2]. 52 // 53 // This variant of SSA construction ignores incoming undef values. 54 // That is paths from the entry without a definition do not result in 55 // phi nodes. 56 // 57 // entry 58 // / \ 59 // A \ 60 // / \ Y 61 // B C / 62 // \ / \ / 63 // D E 64 // \ / 65 // F 66 // Assume that A contains a divergent branch. We are interested 67 // in the set of all blocks where each block is reachable from A 68 // via two disjoint paths. This would be the set {D, F} in this 69 // case. 70 // To generally reduce this query to SSA construction we introduce 71 // a virtual variable x and assign to x different values in each 72 // successor block of A. 73 // entry 74 // / \ 75 // A \ 76 // / \ Y 77 // x = 0 x = 1 / 78 // \ / \ / 79 // D E 80 // \ / 81 // F 82 // Our flavor of SSA construction for x will construct the following 83 // entry 84 // / \ 85 // A \ 86 // / \ Y 87 // x0 = 0 x1 = 1 / 88 // \ / \ / 89 // x2=phi E 90 // \ / 91 // x3=phi 92 // The blocks D and F contain phi nodes and are thus each reachable 93 // by two disjoins paths from A. 94 // 95 // -- Remarks -- 96 // In case of loop exits we need to check the disjoint path criterion for loops 97 // [2]. To this end, we check whether the definition of x differs between the 98 // loop exit and the loop header (_after_ SSA construction). 99 // 100 //===----------------------------------------------------------------------===// 101 #include "llvm/Analysis/SyncDependenceAnalysis.h" 102 #include "llvm/ADT/PostOrderIterator.h" 103 #include "llvm/ADT/SmallPtrSet.h" 104 #include "llvm/Analysis/PostDominators.h" 105 #include "llvm/IR/BasicBlock.h" 106 #include "llvm/IR/CFG.h" 107 #include "llvm/IR/Dominators.h" 108 #include "llvm/IR/Function.h" 109 110 #include <stack> 111 #include <unordered_set> 112 113 #define DEBUG_TYPE "sync-dependence" 114 115 namespace llvm { 116 117 ConstBlockSet SyncDependenceAnalysis::EmptyBlockSet; 118 119 SyncDependenceAnalysis::SyncDependenceAnalysis(const DominatorTree &DT, 120 const PostDominatorTree &PDT, 121 const LoopInfo &LI) 122 : FuncRPOT(DT.getRoot()->getParent()), DT(DT), PDT(PDT), LI(LI) {} 123 124 SyncDependenceAnalysis::~SyncDependenceAnalysis() {} 125 126 using FunctionRPOT = ReversePostOrderTraversal<const Function *>; 127 128 // divergence propagator for reducible CFGs 129 struct DivergencePropagator { 130 const FunctionRPOT &FuncRPOT; 131 const DominatorTree &DT; 132 const PostDominatorTree &PDT; 133 const LoopInfo &LI; 134 135 // identified join points 136 std::unique_ptr<ConstBlockSet> JoinBlocks; 137 138 // reached loop exits (by a path disjoint to a path to the loop header) 139 SmallPtrSet<const BasicBlock *, 4> ReachedLoopExits; 140 141 // if DefMap[B] == C then C is the dominating definition at block B 142 // if DefMap[B] ~ undef then we haven't seen B yet 143 // if DefMap[B] == B then B is a join point of disjoint paths from X or B is 144 // an immediate successor of X (initial value). 145 using DefiningBlockMap = std::map<const BasicBlock *, const BasicBlock *>; 146 DefiningBlockMap DefMap; 147 148 // all blocks with pending visits 149 std::unordered_set<const BasicBlock *> PendingUpdates; 150 151 DivergencePropagator(const FunctionRPOT &FuncRPOT, const DominatorTree &DT, 152 const PostDominatorTree &PDT, const LoopInfo &LI) 153 : FuncRPOT(FuncRPOT), DT(DT), PDT(PDT), LI(LI), 154 JoinBlocks(new ConstBlockSet) {} 155 156 // set the definition at @block and mark @block as pending for a visit 157 void addPending(const BasicBlock &Block, const BasicBlock &DefBlock) { 158 bool WasAdded = DefMap.emplace(&Block, &DefBlock).second; 159 if (WasAdded) 160 PendingUpdates.insert(&Block); 161 } 162 163 void printDefs(raw_ostream &Out) { 164 Out << "Propagator::DefMap {\n"; 165 for (const auto *Block : FuncRPOT) { 166 auto It = DefMap.find(Block); 167 Out << Block->getName() << " : "; 168 if (It == DefMap.end()) { 169 Out << "\n"; 170 } else { 171 const auto *DefBlock = It->second; 172 Out << (DefBlock ? DefBlock->getName() : "<null>") << "\n"; 173 } 174 } 175 Out << "}\n"; 176 } 177 178 // process @succBlock with reaching definition @defBlock 179 // the original divergent branch was in @parentLoop (if any) 180 void visitSuccessor(const BasicBlock &SuccBlock, const Loop *ParentLoop, 181 const BasicBlock &DefBlock) { 182 183 // @succBlock is a loop exit 184 if (ParentLoop && !ParentLoop->contains(&SuccBlock)) { 185 DefMap.emplace(&SuccBlock, &DefBlock); 186 ReachedLoopExits.insert(&SuccBlock); 187 return; 188 } 189 190 // first reaching def? 191 auto ItLastDef = DefMap.find(&SuccBlock); 192 if (ItLastDef == DefMap.end()) { 193 addPending(SuccBlock, DefBlock); 194 return; 195 } 196 197 // a join of at least two definitions 198 if (ItLastDef->second != &DefBlock) { 199 // do we know this join already? 200 if (!JoinBlocks->insert(&SuccBlock).second) 201 return; 202 203 // update the definition 204 addPending(SuccBlock, SuccBlock); 205 } 206 } 207 208 // find all blocks reachable by two disjoint paths from @rootTerm. 209 // This method works for both divergent terminators and loops with 210 // divergent exits. 211 // @rootBlock is either the block containing the branch or the header of the 212 // divergent loop. 213 // @nodeSuccessors is the set of successors of the node (Loop or Terminator) 214 // headed by @rootBlock. 215 // @parentLoop is the parent loop of the Loop or the loop that contains the 216 // Terminator. 217 template <typename SuccessorIterable> 218 std::unique_ptr<ConstBlockSet> 219 computeJoinPoints(const BasicBlock &RootBlock, 220 SuccessorIterable NodeSuccessors, const Loop *ParentLoop) { 221 assert(JoinBlocks); 222 223 LLVM_DEBUG(dbgs() << "SDA:computeJoinPoints. Parent loop: " 224 << (ParentLoop ? ParentLoop->getName() : "<null>") 225 << "\n"); 226 227 // bootstrap with branch targets 228 for (const auto *SuccBlock : NodeSuccessors) { 229 DefMap.emplace(SuccBlock, SuccBlock); 230 231 if (ParentLoop && !ParentLoop->contains(SuccBlock)) { 232 // immediate loop exit from node. 233 ReachedLoopExits.insert(SuccBlock); 234 } else { 235 // regular successor 236 PendingUpdates.insert(SuccBlock); 237 } 238 } 239 240 LLVM_DEBUG(dbgs() << "SDA: rpo order:\n"; for (const auto *RpoBlock 241 : FuncRPOT) { 242 dbgs() << "- " << RpoBlock->getName() << "\n"; 243 }); 244 245 auto ItBeginRPO = FuncRPOT.begin(); 246 auto ItEndRPO = FuncRPOT.end(); 247 248 // skip until term (TODO RPOT won't let us start at @term directly) 249 for (; *ItBeginRPO != &RootBlock; ++ItBeginRPO) { 250 assert(ItBeginRPO != ItEndRPO && "Unable to find RootBlock"); 251 } 252 253 // propagate definitions at the immediate successors of the node in RPO 254 auto ItBlockRPO = ItBeginRPO; 255 while ((++ItBlockRPO != ItEndRPO) && !PendingUpdates.empty()) { 256 const auto *Block = *ItBlockRPO; 257 LLVM_DEBUG(dbgs() << "SDA::joins. visiting " << Block->getName() << "\n"); 258 259 // skip Block if not pending update 260 auto ItPending = PendingUpdates.find(Block); 261 if (ItPending == PendingUpdates.end()) 262 continue; 263 PendingUpdates.erase(ItPending); 264 265 // propagate definition at Block to its successors 266 auto ItDef = DefMap.find(Block); 267 const auto *DefBlock = ItDef->second; 268 assert(DefBlock); 269 270 auto *BlockLoop = LI.getLoopFor(Block); 271 if (ParentLoop && 272 (ParentLoop != BlockLoop && ParentLoop->contains(BlockLoop))) { 273 // if the successor is the header of a nested loop pretend its a 274 // single node with the loop's exits as successors 275 SmallVector<BasicBlock *, 4> BlockLoopExits; 276 BlockLoop->getExitBlocks(BlockLoopExits); 277 for (const auto *BlockLoopExit : BlockLoopExits) { 278 visitSuccessor(*BlockLoopExit, ParentLoop, *DefBlock); 279 } 280 281 } else { 282 // the successors are either on the same loop level or loop exits 283 for (const auto *SuccBlock : successors(Block)) { 284 visitSuccessor(*SuccBlock, ParentLoop, *DefBlock); 285 } 286 } 287 } 288 289 LLVM_DEBUG(dbgs() << "SDA::joins. After propagation:\n"; printDefs(dbgs())); 290 291 // We need to know the definition at the parent loop header to decide 292 // whether the definition at the header is different from the definition at 293 // the loop exits, which would indicate a divergent loop exits. 294 // 295 // A // loop header 296 // | 297 // B // nested loop header 298 // | 299 // C -> X (exit from B loop) -..-> (A latch) 300 // | 301 // D -> back to B (B latch) 302 // | 303 // proper exit from both loops 304 // 305 // analyze reached loop exits 306 if (!ReachedLoopExits.empty()) { 307 const BasicBlock *ParentLoopHeader = 308 ParentLoop ? ParentLoop->getHeader() : nullptr; 309 310 assert(ParentLoop); 311 auto ItHeaderDef = DefMap.find(ParentLoopHeader); 312 const auto *HeaderDefBlock = 313 (ItHeaderDef == DefMap.end()) ? nullptr : ItHeaderDef->second; 314 315 LLVM_DEBUG(printDefs(dbgs())); 316 assert(HeaderDefBlock && "no definition at header of carrying loop"); 317 318 for (const auto *ExitBlock : ReachedLoopExits) { 319 auto ItExitDef = DefMap.find(ExitBlock); 320 assert((ItExitDef != DefMap.end()) && 321 "no reaching def at reachable loop exit"); 322 if (ItExitDef->second != HeaderDefBlock) { 323 JoinBlocks->insert(ExitBlock); 324 } 325 } 326 } 327 328 return std::move(JoinBlocks); 329 } 330 }; 331 332 const ConstBlockSet &SyncDependenceAnalysis::join_blocks(const Loop &Loop) { 333 using LoopExitVec = SmallVector<BasicBlock *, 4>; 334 LoopExitVec LoopExits; 335 Loop.getExitBlocks(LoopExits); 336 if (LoopExits.size() < 1) { 337 return EmptyBlockSet; 338 } 339 340 // already available in cache? 341 auto ItCached = CachedLoopExitJoins.find(&Loop); 342 if (ItCached != CachedLoopExitJoins.end()) { 343 return *ItCached->second; 344 } 345 346 // compute all join points 347 DivergencePropagator Propagator{FuncRPOT, DT, PDT, LI}; 348 auto JoinBlocks = Propagator.computeJoinPoints<const LoopExitVec &>( 349 *Loop.getHeader(), LoopExits, Loop.getParentLoop()); 350 351 auto ItInserted = CachedLoopExitJoins.emplace(&Loop, std::move(JoinBlocks)); 352 assert(ItInserted.second); 353 return *ItInserted.first->second; 354 } 355 356 const ConstBlockSet & 357 SyncDependenceAnalysis::join_blocks(const Instruction &Term) { 358 // trivial case 359 if (Term.getNumSuccessors() < 1) { 360 return EmptyBlockSet; 361 } 362 363 // already available in cache? 364 auto ItCached = CachedBranchJoins.find(&Term); 365 if (ItCached != CachedBranchJoins.end()) 366 return *ItCached->second; 367 368 // compute all join points 369 DivergencePropagator Propagator{FuncRPOT, DT, PDT, LI}; 370 const auto &TermBlock = *Term.getParent(); 371 auto JoinBlocks = Propagator.computeJoinPoints<const_succ_range>( 372 TermBlock, successors(Term.getParent()), LI.getLoopFor(&TermBlock)); 373 374 auto ItInserted = CachedBranchJoins.emplace(&Term, std::move(JoinBlocks)); 375 assert(ItInserted.second); 376 return *ItInserted.first->second; 377 } 378 379 } // namespace llvm 380