1 //===--- SyncDependenceAnalysis.cpp - Compute Control Divergence Effects --===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements an algorithm that returns for a divergent branch 10 // the set of basic blocks whose phi nodes become divergent due to divergent 11 // control. These are the blocks that are reachable by two disjoint paths from 12 // the branch or loop exits that have a reaching path that is disjoint from a 13 // path to the loop latch. 14 // 15 // The SyncDependenceAnalysis is used in the DivergenceAnalysis to model 16 // control-induced divergence in phi nodes. 17 // 18 // 19 // -- Reference -- 20 // The algorithm is presented in Section 5 of 21 // 22 // An abstract interpretation for SPMD divergence 23 // on reducible control flow graphs. 24 // Julian Rosemann, Simon Moll and Sebastian Hack 25 // POPL '21 26 // 27 // 28 // -- Sync dependence -- 29 // Sync dependence characterizes the control flow aspect of the 30 // propagation of branch divergence. For example, 31 // 32 // %cond = icmp slt i32 %tid, 10 33 // br i1 %cond, label %then, label %else 34 // then: 35 // br label %merge 36 // else: 37 // br label %merge 38 // merge: 39 // %a = phi i32 [ 0, %then ], [ 1, %else ] 40 // 41 // Suppose %tid holds the thread ID. Although %a is not data dependent on %tid 42 // because %tid is not on its use-def chains, %a is sync dependent on %tid 43 // because the branch "br i1 %cond" depends on %tid and affects which value %a 44 // is assigned to. 45 // 46 // 47 // -- Reduction to SSA construction -- 48 // There are two disjoint paths from A to X, if a certain variant of SSA 49 // construction places a phi node in X under the following set-up scheme. 50 // 51 // This variant of SSA construction ignores incoming undef values. 52 // That is paths from the entry without a definition do not result in 53 // phi nodes. 54 // 55 // entry 56 // / \ 57 // A \ 58 // / \ Y 59 // B C / 60 // \ / \ / 61 // D E 62 // \ / 63 // F 64 // 65 // Assume that A contains a divergent branch. We are interested 66 // in the set of all blocks where each block is reachable from A 67 // via two disjoint paths. This would be the set {D, F} in this 68 // case. 69 // To generally reduce this query to SSA construction we introduce 70 // a virtual variable x and assign to x different values in each 71 // successor block of A. 72 // 73 // entry 74 // / \ 75 // A \ 76 // / \ Y 77 // x = 0 x = 1 / 78 // \ / \ / 79 // D E 80 // \ / 81 // F 82 // 83 // Our flavor of SSA construction for x will construct the following 84 // 85 // entry 86 // / \ 87 // A \ 88 // / \ Y 89 // x0 = 0 x1 = 1 / 90 // \ / \ / 91 // x2 = phi E 92 // \ / 93 // x3 = phi 94 // 95 // The blocks D and F contain phi nodes and are thus each reachable 96 // by two disjoins paths from A. 97 // 98 // -- Remarks -- 99 // * In case of loop exits we need to check the disjoint path criterion for loops. 100 // To this end, we check whether the definition of x differs between the 101 // loop exit and the loop header (_after_ SSA construction). 102 // 103 // -- Known Limitations & Future Work -- 104 // * The algorithm requires reducible loops because the implementation 105 // implicitly performs a single iteration of the underlying data flow analysis. 106 // This was done for pragmatism, simplicity and speed. 107 // 108 // Relevant related work for extending the algorithm to irreducible control: 109 // A simple algorithm for global data flow analysis problems. 110 // Matthew S. Hecht and Jeffrey D. Ullman. 111 // SIAM Journal on Computing, 4(4):519–532, December 1975. 112 // 113 // * Another reason for requiring reducible loops is that points of 114 // synchronization in irreducible loops aren't 'obvious' - there is no unique 115 // header where threads 'should' synchronize when entering or coming back 116 // around from the latch. 117 // 118 //===----------------------------------------------------------------------===// 119 #include "llvm/Analysis/SyncDependenceAnalysis.h" 120 #include "llvm/ADT/PostOrderIterator.h" 121 #include "llvm/ADT/SmallPtrSet.h" 122 #include "llvm/Analysis/PostDominators.h" 123 #include "llvm/IR/BasicBlock.h" 124 #include "llvm/IR/CFG.h" 125 #include "llvm/IR/Dominators.h" 126 #include "llvm/IR/Function.h" 127 128 #include <functional> 129 #include <stack> 130 #include <unordered_set> 131 132 #define DEBUG_TYPE "sync-dependence" 133 134 // The SDA algorithm operates on a modified CFG - we modify the edges leaving 135 // loop headers as follows: 136 // 137 // * We remove all edges leaving all loop headers. 138 // * We add additional edges from the loop headers to their exit blocks. 139 // 140 // The modification is virtual, that is whenever we visit a loop header we 141 // pretend it had different successors. 142 namespace { 143 using namespace llvm; 144 145 // Custom Post-Order Traveral 146 // 147 // We cannot use the vanilla (R)PO computation of LLVM because: 148 // * We (virtually) modify the CFG. 149 // * We want a loop-compact block enumeration, that is the numbers assigned to 150 // blocks of a loop form an interval 151 // 152 using POCB = std::function<void(const BasicBlock &)>; 153 using VisitedSet = std::set<const BasicBlock *>; 154 using BlockStack = std::vector<const BasicBlock *>; 155 156 // forward 157 static void computeLoopPO(const LoopInfo &LI, Loop &Loop, POCB CallBack, 158 VisitedSet &Finalized); 159 160 // for a nested region (top-level loop or nested loop) 161 static void computeStackPO(BlockStack &Stack, const LoopInfo &LI, Loop *Loop, 162 POCB CallBack, VisitedSet &Finalized) { 163 const auto *LoopHeader = Loop ? Loop->getHeader() : nullptr; 164 while (!Stack.empty()) { 165 const auto *NextBB = Stack.back(); 166 167 auto *NestedLoop = LI.getLoopFor(NextBB); 168 bool IsNestedLoop = NestedLoop != Loop; 169 170 // Treat the loop as a node 171 if (IsNestedLoop) { 172 SmallVector<BasicBlock *, 3> NestedExits; 173 NestedLoop->getUniqueExitBlocks(NestedExits); 174 bool PushedNodes = false; 175 for (const auto *NestedExitBB : NestedExits) { 176 if (NestedExitBB == LoopHeader) 177 continue; 178 if (Loop && !Loop->contains(NestedExitBB)) 179 continue; 180 if (Finalized.count(NestedExitBB)) 181 continue; 182 PushedNodes = true; 183 Stack.push_back(NestedExitBB); 184 } 185 if (!PushedNodes) { 186 // All loop exits finalized -> finish this node 187 Stack.pop_back(); 188 computeLoopPO(LI, *NestedLoop, CallBack, Finalized); 189 } 190 continue; 191 } 192 193 // DAG-style 194 bool PushedNodes = false; 195 for (const auto *SuccBB : successors(NextBB)) { 196 if (SuccBB == LoopHeader) 197 continue; 198 if (Loop && !Loop->contains(SuccBB)) 199 continue; 200 if (Finalized.count(SuccBB)) 201 continue; 202 PushedNodes = true; 203 Stack.push_back(SuccBB); 204 } 205 if (!PushedNodes) { 206 // Never push nodes twice 207 Stack.pop_back(); 208 if (!Finalized.insert(NextBB).second) 209 continue; 210 CallBack(*NextBB); 211 } 212 } 213 } 214 215 static void computeTopLevelPO(Function &F, const LoopInfo &LI, POCB CallBack) { 216 VisitedSet Finalized; 217 BlockStack Stack; 218 Stack.reserve(24); // FIXME made-up number 219 Stack.push_back(&F.getEntryBlock()); 220 computeStackPO(Stack, LI, nullptr, CallBack, Finalized); 221 } 222 223 static void computeLoopPO(const LoopInfo &LI, Loop &Loop, POCB CallBack, 224 VisitedSet &Finalized) { 225 /// Call CallBack on all loop blocks. 226 std::vector<const BasicBlock *> Stack; 227 const auto *LoopHeader = Loop.getHeader(); 228 229 // Visit the header last 230 Finalized.insert(LoopHeader); 231 CallBack(*LoopHeader); 232 233 // Initialize with immediate successors 234 for (const auto *BB : successors(LoopHeader)) { 235 if (!Loop.contains(BB)) 236 continue; 237 if (BB == LoopHeader) 238 continue; 239 Stack.push_back(BB); 240 } 241 242 // Compute PO inside region 243 computeStackPO(Stack, LI, &Loop, CallBack, Finalized); 244 } 245 246 } // namespace 247 248 namespace llvm { 249 250 ControlDivergenceDesc SyncDependenceAnalysis::EmptyDivergenceDesc; 251 252 SyncDependenceAnalysis::SyncDependenceAnalysis(const DominatorTree &DT, 253 const PostDominatorTree &PDT, 254 const LoopInfo &LI) 255 : DT(DT), PDT(PDT), LI(LI) { 256 computeTopLevelPO(*DT.getRoot()->getParent(), LI, 257 [&](const BasicBlock &BB) { LoopPO.appendBlock(BB); }); 258 } 259 260 SyncDependenceAnalysis::~SyncDependenceAnalysis() {} 261 262 // divergence propagator for reducible CFGs 263 struct DivergencePropagator { 264 const ModifiedPO &LoopPOT; 265 const DominatorTree &DT; 266 const PostDominatorTree &PDT; 267 const LoopInfo &LI; 268 const BasicBlock &DivTermBlock; 269 270 // * if BlockLabels[IndexOf(B)] == C then C is the dominating definition at 271 // block B 272 // * if BlockLabels[IndexOf(B)] ~ undef then we haven't seen B yet 273 // * if BlockLabels[IndexOf(B)] == B then B is a join point of disjoint paths 274 // from X or B is an immediate successor of X (initial value). 275 using BlockLabelVec = std::vector<const BasicBlock *>; 276 BlockLabelVec BlockLabels; 277 // divergent join and loop exit descriptor. 278 std::unique_ptr<ControlDivergenceDesc> DivDesc; 279 280 DivergencePropagator(const ModifiedPO &LoopPOT, const DominatorTree &DT, 281 const PostDominatorTree &PDT, const LoopInfo &LI, 282 const BasicBlock &DivTermBlock) 283 : LoopPOT(LoopPOT), DT(DT), PDT(PDT), LI(LI), DivTermBlock(DivTermBlock), 284 BlockLabels(LoopPOT.size(), nullptr), 285 DivDesc(new ControlDivergenceDesc) {} 286 287 void printDefs(raw_ostream &Out) { 288 Out << "Propagator::BlockLabels {\n"; 289 for (int BlockIdx = (int)BlockLabels.size() - 1; BlockIdx > 0; --BlockIdx) { 290 const auto *Label = BlockLabels[BlockIdx]; 291 Out << LoopPOT.getBlockAt(BlockIdx)->getName().str() << "(" << BlockIdx 292 << ") : "; 293 if (!Label) { 294 Out << "<null>\n"; 295 } else { 296 Out << Label->getName() << "\n"; 297 } 298 } 299 Out << "}\n"; 300 } 301 302 // Push a definition (\p PushedLabel) to \p SuccBlock and return whether this 303 // causes a divergent join. 304 bool computeJoin(const BasicBlock &SuccBlock, const BasicBlock &PushedLabel) { 305 auto SuccIdx = LoopPOT.getIndexOf(SuccBlock); 306 307 // unset or same reaching label 308 const auto *OldLabel = BlockLabels[SuccIdx]; 309 if (!OldLabel || (OldLabel == &PushedLabel)) { 310 BlockLabels[SuccIdx] = &PushedLabel; 311 return false; 312 } 313 314 // Update the definition 315 BlockLabels[SuccIdx] = &SuccBlock; 316 return true; 317 } 318 319 // visiting a virtual loop exit edge from the loop header --> temporal 320 // divergence on join 321 bool visitLoopExitEdge(const BasicBlock &ExitBlock, 322 const BasicBlock &DefBlock, bool FromParentLoop) { 323 // Pushing from a non-parent loop cannot cause temporal divergence. 324 if (!FromParentLoop) 325 return visitEdge(ExitBlock, DefBlock); 326 327 if (!computeJoin(ExitBlock, DefBlock)) 328 return false; 329 330 // Identified a divergent loop exit 331 DivDesc->LoopDivBlocks.insert(&ExitBlock); 332 LLVM_DEBUG(dbgs() << "\tDivergent loop exit: " << ExitBlock.getName() 333 << "\n"); 334 return true; 335 } 336 337 // process \p SuccBlock with reaching definition \p DefBlock 338 bool visitEdge(const BasicBlock &SuccBlock, const BasicBlock &DefBlock) { 339 if (!computeJoin(SuccBlock, DefBlock)) 340 return false; 341 342 // Divergent, disjoint paths join. 343 DivDesc->JoinDivBlocks.insert(&SuccBlock); 344 LLVM_DEBUG(dbgs() << "\tDivergent join: " << SuccBlock.getName()); 345 return true; 346 } 347 348 std::unique_ptr<ControlDivergenceDesc> computeJoinPoints() { 349 assert(DivDesc); 350 351 LLVM_DEBUG(dbgs() << "SDA:computeJoinPoints: " << DivTermBlock.getName() 352 << "\n"); 353 354 const auto *DivBlockLoop = LI.getLoopFor(&DivTermBlock); 355 356 // Early stopping criterion 357 int FloorIdx = LoopPOT.size() - 1; 358 const BasicBlock *FloorLabel = nullptr; 359 360 // bootstrap with branch targets 361 int BlockIdx = 0; 362 363 for (const auto *SuccBlock : successors(&DivTermBlock)) { 364 auto SuccIdx = LoopPOT.getIndexOf(*SuccBlock); 365 BlockLabels[SuccIdx] = SuccBlock; 366 367 // Find the successor with the highest index to start with 368 BlockIdx = std::max<int>(BlockIdx, SuccIdx); 369 FloorIdx = std::min<int>(FloorIdx, SuccIdx); 370 371 // Identify immediate divergent loop exits 372 if (!DivBlockLoop) 373 continue; 374 375 const auto *BlockLoop = LI.getLoopFor(SuccBlock); 376 if (BlockLoop && DivBlockLoop->contains(BlockLoop)) 377 continue; 378 DivDesc->LoopDivBlocks.insert(SuccBlock); 379 LLVM_DEBUG(dbgs() << "\tImmediate divergent loop exit: " 380 << SuccBlock->getName() << "\n"); 381 } 382 383 // propagate definitions at the immediate successors of the node in RPO 384 for (; BlockIdx >= FloorIdx; --BlockIdx) { 385 LLVM_DEBUG(dbgs() << "Before next visit:\n"; printDefs(dbgs())); 386 387 // Any label available here 388 const auto *Label = BlockLabels[BlockIdx]; 389 if (!Label) 390 continue; 391 392 // Ok. Get the block 393 const auto *Block = LoopPOT.getBlockAt(BlockIdx); 394 LLVM_DEBUG(dbgs() << "SDA::joins. visiting " << Block->getName() << "\n"); 395 396 auto *BlockLoop = LI.getLoopFor(Block); 397 bool IsLoopHeader = BlockLoop && BlockLoop->getHeader() == Block; 398 bool CausedJoin = false; 399 int LoweredFloorIdx = FloorIdx; 400 if (IsLoopHeader) { 401 // Disconnect from immediate successors and propagate directly to loop 402 // exits. 403 SmallVector<BasicBlock *, 4> BlockLoopExits; 404 BlockLoop->getExitBlocks(BlockLoopExits); 405 406 bool IsParentLoop = BlockLoop->contains(&DivTermBlock); 407 for (const auto *BlockLoopExit : BlockLoopExits) { 408 CausedJoin |= visitLoopExitEdge(*BlockLoopExit, *Label, IsParentLoop); 409 LoweredFloorIdx = std::min<int>(LoweredFloorIdx, 410 LoopPOT.getIndexOf(*BlockLoopExit)); 411 } 412 } else { 413 // Acyclic successor case 414 for (const auto *SuccBlock : successors(Block)) { 415 CausedJoin |= visitEdge(*SuccBlock, *Label); 416 LoweredFloorIdx = 417 std::min<int>(LoweredFloorIdx, LoopPOT.getIndexOf(*SuccBlock)); 418 } 419 } 420 421 // Floor update 422 if (CausedJoin) { 423 // 1. Different labels pushed to successors 424 FloorIdx = LoweredFloorIdx; 425 } else if (FloorLabel != Label) { 426 // 2. No join caused BUT we pushed a label that is different than the 427 // last pushed label 428 FloorIdx = LoweredFloorIdx; 429 FloorLabel = Label; 430 } 431 } 432 433 LLVM_DEBUG(dbgs() << "SDA::joins. After propagation:\n"; printDefs(dbgs())); 434 435 return std::move(DivDesc); 436 } 437 }; 438 439 #ifndef NDEBUG 440 static void printBlockSet(ConstBlockSet &Blocks, raw_ostream &Out) { 441 Out << "["; 442 ListSeparator LS; 443 for (const auto *BB : Blocks) 444 Out << LS << BB->getName(); 445 Out << "]"; 446 } 447 #endif 448 449 const ControlDivergenceDesc & 450 SyncDependenceAnalysis::getJoinBlocks(const Instruction &Term) { 451 // trivial case 452 if (Term.getNumSuccessors() <= 1) { 453 return EmptyDivergenceDesc; 454 } 455 456 // already available in cache? 457 auto ItCached = CachedControlDivDescs.find(&Term); 458 if (ItCached != CachedControlDivDescs.end()) 459 return *ItCached->second; 460 461 // compute all join points 462 // Special handling of divergent loop exits is not needed for LCSSA 463 const auto &TermBlock = *Term.getParent(); 464 DivergencePropagator Propagator(LoopPO, DT, PDT, LI, TermBlock); 465 auto DivDesc = Propagator.computeJoinPoints(); 466 467 LLVM_DEBUG(dbgs() << "Result (" << Term.getParent()->getName() << "):\n"; 468 dbgs() << "JoinDivBlocks: "; 469 printBlockSet(DivDesc->JoinDivBlocks, dbgs()); 470 dbgs() << "\nLoopDivBlocks: "; 471 printBlockSet(DivDesc->LoopDivBlocks, dbgs()); dbgs() << "\n";); 472 473 auto ItInserted = CachedControlDivDescs.emplace(&Term, std::move(DivDesc)); 474 assert(ItInserted.second); 475 return *ItInserted.first->second; 476 } 477 478 } // namespace llvm 479