1 //===--- SyncDependenceAnalysis.cpp - Compute Control Divergence Effects --===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements an algorithm that returns for a divergent branch 10 // the set of basic blocks whose phi nodes become divergent due to divergent 11 // control. These are the blocks that are reachable by two disjoint paths from 12 // the branch or loop exits that have a reaching path that is disjoint from a 13 // path to the loop latch. 14 // 15 // The SyncDependenceAnalysis is used in the DivergenceAnalysis to model 16 // control-induced divergence in phi nodes. 17 // 18 // 19 // -- Reference -- 20 // The algorithm is presented in Section 5 of 21 // 22 // An abstract interpretation for SPMD divergence 23 // on reducible control flow graphs. 24 // Julian Rosemann, Simon Moll and Sebastian Hack 25 // POPL '21 26 // 27 // 28 // -- Sync dependence -- 29 // Sync dependence characterizes the control flow aspect of the 30 // propagation of branch divergence. For example, 31 // 32 // %cond = icmp slt i32 %tid, 10 33 // br i1 %cond, label %then, label %else 34 // then: 35 // br label %merge 36 // else: 37 // br label %merge 38 // merge: 39 // %a = phi i32 [ 0, %then ], [ 1, %else ] 40 // 41 // Suppose %tid holds the thread ID. Although %a is not data dependent on %tid 42 // because %tid is not on its use-def chains, %a is sync dependent on %tid 43 // because the branch "br i1 %cond" depends on %tid and affects which value %a 44 // is assigned to. 45 // 46 // 47 // -- Reduction to SSA construction -- 48 // There are two disjoint paths from A to X, if a certain variant of SSA 49 // construction places a phi node in X under the following set-up scheme. 50 // 51 // This variant of SSA construction ignores incoming undef values. 52 // That is paths from the entry without a definition do not result in 53 // phi nodes. 54 // 55 // entry 56 // / \ 57 // A \ 58 // / \ Y 59 // B C / 60 // \ / \ / 61 // D E 62 // \ / 63 // F 64 // 65 // Assume that A contains a divergent branch. We are interested 66 // in the set of all blocks where each block is reachable from A 67 // via two disjoint paths. This would be the set {D, F} in this 68 // case. 69 // To generally reduce this query to SSA construction we introduce 70 // a virtual variable x and assign to x different values in each 71 // successor block of A. 72 // 73 // entry 74 // / \ 75 // A \ 76 // / \ Y 77 // x = 0 x = 1 / 78 // \ / \ / 79 // D E 80 // \ / 81 // F 82 // 83 // Our flavor of SSA construction for x will construct the following 84 // 85 // entry 86 // / \ 87 // A \ 88 // / \ Y 89 // x0 = 0 x1 = 1 / 90 // \ / \ / 91 // x2 = phi E 92 // \ / 93 // x3 = phi 94 // 95 // The blocks D and F contain phi nodes and are thus each reachable 96 // by two disjoins paths from A. 97 // 98 // -- Remarks -- 99 // * In case of loop exits we need to check the disjoint path criterion for loops. 100 // To this end, we check whether the definition of x differs between the 101 // loop exit and the loop header (_after_ SSA construction). 102 // 103 // -- Known Limitations & Future Work -- 104 // * The algorithm requires reducible loops because the implementation 105 // implicitly performs a single iteration of the underlying data flow analysis. 106 // This was done for pragmatism, simplicity and speed. 107 // 108 // Relevant related work for extending the algorithm to irreducible control: 109 // A simple algorithm for global data flow analysis problems. 110 // Matthew S. Hecht and Jeffrey D. Ullman. 111 // SIAM Journal on Computing, 4(4):519–532, December 1975. 112 // 113 // * Another reason for requiring reducible loops is that points of 114 // synchronization in irreducible loops aren't 'obvious' - there is no unique 115 // header where threads 'should' synchronize when entering or coming back 116 // around from the latch. 117 // 118 //===----------------------------------------------------------------------===// 119 120 #include "llvm/Analysis/SyncDependenceAnalysis.h" 121 #include "llvm/ADT/PostOrderIterator.h" 122 #include "llvm/ADT/SmallPtrSet.h" 123 #include "llvm/Analysis/LoopInfo.h" 124 #include "llvm/Analysis/PostDominators.h" 125 #include "llvm/IR/BasicBlock.h" 126 #include "llvm/IR/CFG.h" 127 #include "llvm/IR/Dominators.h" 128 #include "llvm/IR/Function.h" 129 130 #include <functional> 131 #include <stack> 132 #include <unordered_set> 133 134 #define DEBUG_TYPE "sync-dependence" 135 136 // The SDA algorithm operates on a modified CFG - we modify the edges leaving 137 // loop headers as follows: 138 // 139 // * We remove all edges leaving all loop headers. 140 // * We add additional edges from the loop headers to their exit blocks. 141 // 142 // The modification is virtual, that is whenever we visit a loop header we 143 // pretend it had different successors. 144 namespace { 145 using namespace llvm; 146 147 // Custom Post-Order Traveral 148 // 149 // We cannot use the vanilla (R)PO computation of LLVM because: 150 // * We (virtually) modify the CFG. 151 // * We want a loop-compact block enumeration, that is the numbers assigned to 152 // blocks of a loop form an interval 153 // 154 using POCB = std::function<void(const BasicBlock &)>; 155 using VisitedSet = std::set<const BasicBlock *>; 156 using BlockStack = std::vector<const BasicBlock *>; 157 158 // forward 159 static void computeLoopPO(const LoopInfo &LI, Loop &Loop, POCB CallBack, 160 VisitedSet &Finalized); 161 162 // for a nested region (top-level loop or nested loop) 163 static void computeStackPO(BlockStack &Stack, const LoopInfo &LI, Loop *Loop, 164 POCB CallBack, VisitedSet &Finalized) { 165 const auto *LoopHeader = Loop ? Loop->getHeader() : nullptr; 166 while (!Stack.empty()) { 167 const auto *NextBB = Stack.back(); 168 169 auto *NestedLoop = LI.getLoopFor(NextBB); 170 bool IsNestedLoop = NestedLoop != Loop; 171 172 // Treat the loop as a node 173 if (IsNestedLoop) { 174 SmallVector<BasicBlock *, 3> NestedExits; 175 NestedLoop->getUniqueExitBlocks(NestedExits); 176 bool PushedNodes = false; 177 for (const auto *NestedExitBB : NestedExits) { 178 if (NestedExitBB == LoopHeader) 179 continue; 180 if (Loop && !Loop->contains(NestedExitBB)) 181 continue; 182 if (Finalized.count(NestedExitBB)) 183 continue; 184 PushedNodes = true; 185 Stack.push_back(NestedExitBB); 186 } 187 if (!PushedNodes) { 188 // All loop exits finalized -> finish this node 189 Stack.pop_back(); 190 computeLoopPO(LI, *NestedLoop, CallBack, Finalized); 191 } 192 continue; 193 } 194 195 // DAG-style 196 bool PushedNodes = false; 197 for (const auto *SuccBB : successors(NextBB)) { 198 if (SuccBB == LoopHeader) 199 continue; 200 if (Loop && !Loop->contains(SuccBB)) 201 continue; 202 if (Finalized.count(SuccBB)) 203 continue; 204 PushedNodes = true; 205 Stack.push_back(SuccBB); 206 } 207 if (!PushedNodes) { 208 // Never push nodes twice 209 Stack.pop_back(); 210 if (!Finalized.insert(NextBB).second) 211 continue; 212 CallBack(*NextBB); 213 } 214 } 215 } 216 217 static void computeTopLevelPO(Function &F, const LoopInfo &LI, POCB CallBack) { 218 VisitedSet Finalized; 219 BlockStack Stack; 220 Stack.reserve(24); // FIXME made-up number 221 Stack.push_back(&F.getEntryBlock()); 222 computeStackPO(Stack, LI, nullptr, CallBack, Finalized); 223 } 224 225 static void computeLoopPO(const LoopInfo &LI, Loop &Loop, POCB CallBack, 226 VisitedSet &Finalized) { 227 /// Call CallBack on all loop blocks. 228 std::vector<const BasicBlock *> Stack; 229 const auto *LoopHeader = Loop.getHeader(); 230 231 // Visit the header last 232 Finalized.insert(LoopHeader); 233 CallBack(*LoopHeader); 234 235 // Initialize with immediate successors 236 for (const auto *BB : successors(LoopHeader)) { 237 if (!Loop.contains(BB)) 238 continue; 239 if (BB == LoopHeader) 240 continue; 241 Stack.push_back(BB); 242 } 243 244 // Compute PO inside region 245 computeStackPO(Stack, LI, &Loop, CallBack, Finalized); 246 } 247 248 } // namespace 249 250 namespace llvm { 251 252 ControlDivergenceDesc SyncDependenceAnalysis::EmptyDivergenceDesc; 253 254 SyncDependenceAnalysis::SyncDependenceAnalysis(const DominatorTree &DT, 255 const PostDominatorTree &PDT, 256 const LoopInfo &LI) 257 : DT(DT), PDT(PDT), LI(LI) { 258 computeTopLevelPO(*DT.getRoot()->getParent(), LI, 259 [&](const BasicBlock &BB) { LoopPO.appendBlock(BB); }); 260 } 261 262 SyncDependenceAnalysis::~SyncDependenceAnalysis() = default; 263 264 // divergence propagator for reducible CFGs 265 struct DivergencePropagator { 266 const ModifiedPO &LoopPOT; 267 const DominatorTree &DT; 268 const PostDominatorTree &PDT; 269 const LoopInfo &LI; 270 const BasicBlock &DivTermBlock; 271 272 // * if BlockLabels[IndexOf(B)] == C then C is the dominating definition at 273 // block B 274 // * if BlockLabels[IndexOf(B)] ~ undef then we haven't seen B yet 275 // * if BlockLabels[IndexOf(B)] == B then B is a join point of disjoint paths 276 // from X or B is an immediate successor of X (initial value). 277 using BlockLabelVec = std::vector<const BasicBlock *>; 278 BlockLabelVec BlockLabels; 279 // divergent join and loop exit descriptor. 280 std::unique_ptr<ControlDivergenceDesc> DivDesc; 281 282 DivergencePropagator(const ModifiedPO &LoopPOT, const DominatorTree &DT, 283 const PostDominatorTree &PDT, const LoopInfo &LI, 284 const BasicBlock &DivTermBlock) 285 : LoopPOT(LoopPOT), DT(DT), PDT(PDT), LI(LI), DivTermBlock(DivTermBlock), 286 BlockLabels(LoopPOT.size(), nullptr), 287 DivDesc(new ControlDivergenceDesc) {} 288 289 void printDefs(raw_ostream &Out) { 290 Out << "Propagator::BlockLabels {\n"; 291 for (int BlockIdx = (int)BlockLabels.size() - 1; BlockIdx > 0; --BlockIdx) { 292 const auto *Label = BlockLabels[BlockIdx]; 293 Out << LoopPOT.getBlockAt(BlockIdx)->getName().str() << "(" << BlockIdx 294 << ") : "; 295 if (!Label) { 296 Out << "<null>\n"; 297 } else { 298 Out << Label->getName() << "\n"; 299 } 300 } 301 Out << "}\n"; 302 } 303 304 // Push a definition (\p PushedLabel) to \p SuccBlock and return whether this 305 // causes a divergent join. 306 bool computeJoin(const BasicBlock &SuccBlock, const BasicBlock &PushedLabel) { 307 auto SuccIdx = LoopPOT.getIndexOf(SuccBlock); 308 309 // unset or same reaching label 310 const auto *OldLabel = BlockLabels[SuccIdx]; 311 if (!OldLabel || (OldLabel == &PushedLabel)) { 312 BlockLabels[SuccIdx] = &PushedLabel; 313 return false; 314 } 315 316 // Update the definition 317 BlockLabels[SuccIdx] = &SuccBlock; 318 return true; 319 } 320 321 // visiting a virtual loop exit edge from the loop header --> temporal 322 // divergence on join 323 bool visitLoopExitEdge(const BasicBlock &ExitBlock, 324 const BasicBlock &DefBlock, bool FromParentLoop) { 325 // Pushing from a non-parent loop cannot cause temporal divergence. 326 if (!FromParentLoop) 327 return visitEdge(ExitBlock, DefBlock); 328 329 if (!computeJoin(ExitBlock, DefBlock)) 330 return false; 331 332 // Identified a divergent loop exit 333 DivDesc->LoopDivBlocks.insert(&ExitBlock); 334 LLVM_DEBUG(dbgs() << "\tDivergent loop exit: " << ExitBlock.getName() 335 << "\n"); 336 return true; 337 } 338 339 // process \p SuccBlock with reaching definition \p DefBlock 340 bool visitEdge(const BasicBlock &SuccBlock, const BasicBlock &DefBlock) { 341 if (!computeJoin(SuccBlock, DefBlock)) 342 return false; 343 344 // Divergent, disjoint paths join. 345 DivDesc->JoinDivBlocks.insert(&SuccBlock); 346 LLVM_DEBUG(dbgs() << "\tDivergent join: " << SuccBlock.getName()); 347 return true; 348 } 349 350 std::unique_ptr<ControlDivergenceDesc> computeJoinPoints() { 351 assert(DivDesc); 352 353 LLVM_DEBUG(dbgs() << "SDA:computeJoinPoints: " << DivTermBlock.getName() 354 << "\n"); 355 356 const auto *DivBlockLoop = LI.getLoopFor(&DivTermBlock); 357 358 // Early stopping criterion 359 int FloorIdx = LoopPOT.size() - 1; 360 const BasicBlock *FloorLabel = nullptr; 361 362 // bootstrap with branch targets 363 int BlockIdx = 0; 364 365 for (const auto *SuccBlock : successors(&DivTermBlock)) { 366 auto SuccIdx = LoopPOT.getIndexOf(*SuccBlock); 367 BlockLabels[SuccIdx] = SuccBlock; 368 369 // Find the successor with the highest index to start with 370 BlockIdx = std::max<int>(BlockIdx, SuccIdx); 371 FloorIdx = std::min<int>(FloorIdx, SuccIdx); 372 373 // Identify immediate divergent loop exits 374 if (!DivBlockLoop) 375 continue; 376 377 const auto *BlockLoop = LI.getLoopFor(SuccBlock); 378 if (BlockLoop && DivBlockLoop->contains(BlockLoop)) 379 continue; 380 DivDesc->LoopDivBlocks.insert(SuccBlock); 381 LLVM_DEBUG(dbgs() << "\tImmediate divergent loop exit: " 382 << SuccBlock->getName() << "\n"); 383 } 384 385 // propagate definitions at the immediate successors of the node in RPO 386 for (; BlockIdx >= FloorIdx; --BlockIdx) { 387 LLVM_DEBUG(dbgs() << "Before next visit:\n"; printDefs(dbgs())); 388 389 // Any label available here 390 const auto *Label = BlockLabels[BlockIdx]; 391 if (!Label) 392 continue; 393 394 // Ok. Get the block 395 const auto *Block = LoopPOT.getBlockAt(BlockIdx); 396 LLVM_DEBUG(dbgs() << "SDA::joins. visiting " << Block->getName() << "\n"); 397 398 auto *BlockLoop = LI.getLoopFor(Block); 399 bool IsLoopHeader = BlockLoop && BlockLoop->getHeader() == Block; 400 bool CausedJoin = false; 401 int LoweredFloorIdx = FloorIdx; 402 if (IsLoopHeader) { 403 // Disconnect from immediate successors and propagate directly to loop 404 // exits. 405 SmallVector<BasicBlock *, 4> BlockLoopExits; 406 BlockLoop->getExitBlocks(BlockLoopExits); 407 408 bool IsParentLoop = BlockLoop->contains(&DivTermBlock); 409 for (const auto *BlockLoopExit : BlockLoopExits) { 410 CausedJoin |= visitLoopExitEdge(*BlockLoopExit, *Label, IsParentLoop); 411 LoweredFloorIdx = std::min<int>(LoweredFloorIdx, 412 LoopPOT.getIndexOf(*BlockLoopExit)); 413 } 414 } else { 415 // Acyclic successor case 416 for (const auto *SuccBlock : successors(Block)) { 417 CausedJoin |= visitEdge(*SuccBlock, *Label); 418 LoweredFloorIdx = 419 std::min<int>(LoweredFloorIdx, LoopPOT.getIndexOf(*SuccBlock)); 420 } 421 } 422 423 // Floor update 424 if (CausedJoin) { 425 // 1. Different labels pushed to successors 426 FloorIdx = LoweredFloorIdx; 427 } else if (FloorLabel != Label) { 428 // 2. No join caused BUT we pushed a label that is different than the 429 // last pushed label 430 FloorIdx = LoweredFloorIdx; 431 FloorLabel = Label; 432 } 433 } 434 435 LLVM_DEBUG(dbgs() << "SDA::joins. After propagation:\n"; printDefs(dbgs())); 436 437 return std::move(DivDesc); 438 } 439 }; 440 441 #ifndef NDEBUG 442 static void printBlockSet(ConstBlockSet &Blocks, raw_ostream &Out) { 443 Out << "["; 444 ListSeparator LS; 445 for (const auto *BB : Blocks) 446 Out << LS << BB->getName(); 447 Out << "]"; 448 } 449 #endif 450 451 const ControlDivergenceDesc & 452 SyncDependenceAnalysis::getJoinBlocks(const Instruction &Term) { 453 // trivial case 454 if (Term.getNumSuccessors() <= 1) { 455 return EmptyDivergenceDesc; 456 } 457 458 // already available in cache? 459 auto ItCached = CachedControlDivDescs.find(&Term); 460 if (ItCached != CachedControlDivDescs.end()) 461 return *ItCached->second; 462 463 // compute all join points 464 // Special handling of divergent loop exits is not needed for LCSSA 465 const auto &TermBlock = *Term.getParent(); 466 DivergencePropagator Propagator(LoopPO, DT, PDT, LI, TermBlock); 467 auto DivDesc = Propagator.computeJoinPoints(); 468 469 LLVM_DEBUG(dbgs() << "Result (" << Term.getParent()->getName() << "):\n"; 470 dbgs() << "JoinDivBlocks: "; 471 printBlockSet(DivDesc->JoinDivBlocks, dbgs()); 472 dbgs() << "\nLoopDivBlocks: "; 473 printBlockSet(DivDesc->LoopDivBlocks, dbgs()); dbgs() << "\n";); 474 475 auto ItInserted = CachedControlDivDescs.emplace(&Term, std::move(DivDesc)); 476 assert(ItInserted.second); 477 return *ItInserted.first->second; 478 } 479 480 } // namespace llvm 481