1 //===--- SyncDependenceAnalysis.cpp - Compute Control Divergence Effects --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements an algorithm that returns for a divergent branch
10 // the set of basic blocks whose phi nodes become divergent due to divergent
11 // control. These are the blocks that are reachable by two disjoint paths from
12 // the branch or loop exits that have a reaching path that is disjoint from a
13 // path to the loop latch.
14 //
15 // The SyncDependenceAnalysis is used in the DivergenceAnalysis to model
16 // control-induced divergence in phi nodes.
17 //
18 //
19 // -- Reference --
20 // The algorithm is presented in Section 5 of
21 //
22 // An abstract interpretation for SPMD divergence
23 // on reducible control flow graphs.
24 // Julian Rosemann, Simon Moll and Sebastian Hack
25 // POPL '21
26 //
27 //
28 // -- Sync dependence --
29 // Sync dependence characterizes the control flow aspect of the
30 // propagation of branch divergence. For example,
31 //
32 // %cond = icmp slt i32 %tid, 10
33 // br i1 %cond, label %then, label %else
34 // then:
35 // br label %merge
36 // else:
37 // br label %merge
38 // merge:
39 // %a = phi i32 [ 0, %then ], [ 1, %else ]
40 //
41 // Suppose %tid holds the thread ID. Although %a is not data dependent on %tid
42 // because %tid is not on its use-def chains, %a is sync dependent on %tid
43 // because the branch "br i1 %cond" depends on %tid and affects which value %a
44 // is assigned to.
45 //
46 //
47 // -- Reduction to SSA construction --
48 // There are two disjoint paths from A to X, if a certain variant of SSA
49 // construction places a phi node in X under the following set-up scheme.
50 //
51 // This variant of SSA construction ignores incoming undef values.
52 // That is paths from the entry without a definition do not result in
53 // phi nodes.
54 //
55 // entry
56 // / \
57 // A \
58 // / \ Y
59 // B C /
60 // \ / \ /
61 // D E
62 // \ /
63 // F
64 //
65 // Assume that A contains a divergent branch. We are interested
66 // in the set of all blocks where each block is reachable from A
67 // via two disjoint paths. This would be the set {D, F} in this
68 // case.
69 // To generally reduce this query to SSA construction we introduce
70 // a virtual variable x and assign to x different values in each
71 // successor block of A.
72 //
73 // entry
74 // / \
75 // A \
76 // / \ Y
77 // x = 0 x = 1 /
78 // \ / \ /
79 // D E
80 // \ /
81 // F
82 //
83 // Our flavor of SSA construction for x will construct the following
84 //
85 // entry
86 // / \
87 // A \
88 // / \ Y
89 // x0 = 0 x1 = 1 /
90 // \ / \ /
91 // x2 = phi E
92 // \ /
93 // x3 = phi
94 //
95 // The blocks D and F contain phi nodes and are thus each reachable
96 // by two disjoins paths from A.
97 //
98 // -- Remarks --
99 // * In case of loop exits we need to check the disjoint path criterion for loops.
100 // To this end, we check whether the definition of x differs between the
101 // loop exit and the loop header (_after_ SSA construction).
102 //
103 // -- Known Limitations & Future Work --
104 // * The algorithm requires reducible loops because the implementation
105 // implicitly performs a single iteration of the underlying data flow analysis.
106 // This was done for pragmatism, simplicity and speed.
107 //
108 // Relevant related work for extending the algorithm to irreducible control:
109 // A simple algorithm for global data flow analysis problems.
110 // Matthew S. Hecht and Jeffrey D. Ullman.
111 // SIAM Journal on Computing, 4(4):519–532, December 1975.
112 //
113 // * Another reason for requiring reducible loops is that points of
114 // synchronization in irreducible loops aren't 'obvious' - there is no unique
115 // header where threads 'should' synchronize when entering or coming back
116 // around from the latch.
117 //
118 //===----------------------------------------------------------------------===//
119
120 #include "llvm/Analysis/SyncDependenceAnalysis.h"
121 #include "llvm/ADT/SmallPtrSet.h"
122 #include "llvm/Analysis/LoopInfo.h"
123 #include "llvm/IR/BasicBlock.h"
124 #include "llvm/IR/CFG.h"
125 #include "llvm/IR/Dominators.h"
126 #include "llvm/IR/Function.h"
127
128 #include <functional>
129
130 #define DEBUG_TYPE "sync-dependence"
131
132 // The SDA algorithm operates on a modified CFG - we modify the edges leaving
133 // loop headers as follows:
134 //
135 // * We remove all edges leaving all loop headers.
136 // * We add additional edges from the loop headers to their exit blocks.
137 //
138 // The modification is virtual, that is whenever we visit a loop header we
139 // pretend it had different successors.
140 namespace {
141 using namespace llvm;
142
143 // Custom Post-Order Traveral
144 //
145 // We cannot use the vanilla (R)PO computation of LLVM because:
146 // * We (virtually) modify the CFG.
147 // * We want a loop-compact block enumeration, that is the numbers assigned to
148 // blocks of a loop form an interval
149 //
150 using POCB = std::function<void(const BasicBlock &)>;
151 using VisitedSet = std::set<const BasicBlock *>;
152 using BlockStack = std::vector<const BasicBlock *>;
153
154 // forward
155 static void computeLoopPO(const LoopInfo &LI, Loop &Loop, POCB CallBack,
156 VisitedSet &Finalized);
157
158 // for a nested region (top-level loop or nested loop)
computeStackPO(BlockStack & Stack,const LoopInfo & LI,Loop * Loop,POCB CallBack,VisitedSet & Finalized)159 static void computeStackPO(BlockStack &Stack, const LoopInfo &LI, Loop *Loop,
160 POCB CallBack, VisitedSet &Finalized) {
161 const auto *LoopHeader = Loop ? Loop->getHeader() : nullptr;
162 while (!Stack.empty()) {
163 const auto *NextBB = Stack.back();
164
165 auto *NestedLoop = LI.getLoopFor(NextBB);
166 bool IsNestedLoop = NestedLoop != Loop;
167
168 // Treat the loop as a node
169 if (IsNestedLoop) {
170 SmallVector<BasicBlock *, 3> NestedExits;
171 NestedLoop->getUniqueExitBlocks(NestedExits);
172 bool PushedNodes = false;
173 for (const auto *NestedExitBB : NestedExits) {
174 if (NestedExitBB == LoopHeader)
175 continue;
176 if (Loop && !Loop->contains(NestedExitBB))
177 continue;
178 if (Finalized.count(NestedExitBB))
179 continue;
180 PushedNodes = true;
181 Stack.push_back(NestedExitBB);
182 }
183 if (!PushedNodes) {
184 // All loop exits finalized -> finish this node
185 Stack.pop_back();
186 computeLoopPO(LI, *NestedLoop, CallBack, Finalized);
187 }
188 continue;
189 }
190
191 // DAG-style
192 bool PushedNodes = false;
193 for (const auto *SuccBB : successors(NextBB)) {
194 if (SuccBB == LoopHeader)
195 continue;
196 if (Loop && !Loop->contains(SuccBB))
197 continue;
198 if (Finalized.count(SuccBB))
199 continue;
200 PushedNodes = true;
201 Stack.push_back(SuccBB);
202 }
203 if (!PushedNodes) {
204 // Never push nodes twice
205 Stack.pop_back();
206 if (!Finalized.insert(NextBB).second)
207 continue;
208 CallBack(*NextBB);
209 }
210 }
211 }
212
computeTopLevelPO(Function & F,const LoopInfo & LI,POCB CallBack)213 static void computeTopLevelPO(Function &F, const LoopInfo &LI, POCB CallBack) {
214 VisitedSet Finalized;
215 BlockStack Stack;
216 Stack.reserve(24); // FIXME made-up number
217 Stack.push_back(&F.getEntryBlock());
218 computeStackPO(Stack, LI, nullptr, CallBack, Finalized);
219 }
220
computeLoopPO(const LoopInfo & LI,Loop & Loop,POCB CallBack,VisitedSet & Finalized)221 static void computeLoopPO(const LoopInfo &LI, Loop &Loop, POCB CallBack,
222 VisitedSet &Finalized) {
223 /// Call CallBack on all loop blocks.
224 std::vector<const BasicBlock *> Stack;
225 const auto *LoopHeader = Loop.getHeader();
226
227 // Visit the header last
228 Finalized.insert(LoopHeader);
229 CallBack(*LoopHeader);
230
231 // Initialize with immediate successors
232 for (const auto *BB : successors(LoopHeader)) {
233 if (!Loop.contains(BB))
234 continue;
235 if (BB == LoopHeader)
236 continue;
237 Stack.push_back(BB);
238 }
239
240 // Compute PO inside region
241 computeStackPO(Stack, LI, &Loop, CallBack, Finalized);
242 }
243
244 } // namespace
245
246 namespace llvm {
247
248 ControlDivergenceDesc SyncDependenceAnalysis::EmptyDivergenceDesc;
249
SyncDependenceAnalysis(const DominatorTree & DT,const PostDominatorTree & PDT,const LoopInfo & LI)250 SyncDependenceAnalysis::SyncDependenceAnalysis(const DominatorTree &DT,
251 const PostDominatorTree &PDT,
252 const LoopInfo &LI)
253 : DT(DT), PDT(PDT), LI(LI) {
254 computeTopLevelPO(*DT.getRoot()->getParent(), LI,
255 [&](const BasicBlock &BB) { LoopPO.appendBlock(BB); });
256 }
257
258 SyncDependenceAnalysis::~SyncDependenceAnalysis() = default;
259
260 // divergence propagator for reducible CFGs
261 struct DivergencePropagator {
262 const ModifiedPO &LoopPOT;
263 const DominatorTree &DT;
264 const PostDominatorTree &PDT;
265 const LoopInfo &LI;
266 const BasicBlock &DivTermBlock;
267
268 // * if BlockLabels[IndexOf(B)] == C then C is the dominating definition at
269 // block B
270 // * if BlockLabels[IndexOf(B)] ~ undef then we haven't seen B yet
271 // * if BlockLabels[IndexOf(B)] == B then B is a join point of disjoint paths
272 // from X or B is an immediate successor of X (initial value).
273 using BlockLabelVec = std::vector<const BasicBlock *>;
274 BlockLabelVec BlockLabels;
275 // divergent join and loop exit descriptor.
276 std::unique_ptr<ControlDivergenceDesc> DivDesc;
277
DivergencePropagatorllvm::DivergencePropagator278 DivergencePropagator(const ModifiedPO &LoopPOT, const DominatorTree &DT,
279 const PostDominatorTree &PDT, const LoopInfo &LI,
280 const BasicBlock &DivTermBlock)
281 : LoopPOT(LoopPOT), DT(DT), PDT(PDT), LI(LI), DivTermBlock(DivTermBlock),
282 BlockLabels(LoopPOT.size(), nullptr),
283 DivDesc(new ControlDivergenceDesc) {}
284
printDefsllvm::DivergencePropagator285 void printDefs(raw_ostream &Out) {
286 Out << "Propagator::BlockLabels {\n";
287 for (int BlockIdx = (int)BlockLabels.size() - 1; BlockIdx > 0; --BlockIdx) {
288 const auto *Label = BlockLabels[BlockIdx];
289 Out << LoopPOT.getBlockAt(BlockIdx)->getName().str() << "(" << BlockIdx
290 << ") : ";
291 if (!Label) {
292 Out << "<null>\n";
293 } else {
294 Out << Label->getName() << "\n";
295 }
296 }
297 Out << "}\n";
298 }
299
300 // Push a definition (\p PushedLabel) to \p SuccBlock and return whether this
301 // causes a divergent join.
computeJoinllvm::DivergencePropagator302 bool computeJoin(const BasicBlock &SuccBlock, const BasicBlock &PushedLabel) {
303 auto SuccIdx = LoopPOT.getIndexOf(SuccBlock);
304
305 // unset or same reaching label
306 const auto *OldLabel = BlockLabels[SuccIdx];
307 if (!OldLabel || (OldLabel == &PushedLabel)) {
308 BlockLabels[SuccIdx] = &PushedLabel;
309 return false;
310 }
311
312 // Update the definition
313 BlockLabels[SuccIdx] = &SuccBlock;
314 return true;
315 }
316
317 // visiting a virtual loop exit edge from the loop header --> temporal
318 // divergence on join
visitLoopExitEdgellvm::DivergencePropagator319 bool visitLoopExitEdge(const BasicBlock &ExitBlock,
320 const BasicBlock &DefBlock, bool FromParentLoop) {
321 // Pushing from a non-parent loop cannot cause temporal divergence.
322 if (!FromParentLoop)
323 return visitEdge(ExitBlock, DefBlock);
324
325 if (!computeJoin(ExitBlock, DefBlock))
326 return false;
327
328 // Identified a divergent loop exit
329 DivDesc->LoopDivBlocks.insert(&ExitBlock);
330 LLVM_DEBUG(dbgs() << "\tDivergent loop exit: " << ExitBlock.getName()
331 << "\n");
332 return true;
333 }
334
335 // process \p SuccBlock with reaching definition \p DefBlock
visitEdgellvm::DivergencePropagator336 bool visitEdge(const BasicBlock &SuccBlock, const BasicBlock &DefBlock) {
337 if (!computeJoin(SuccBlock, DefBlock))
338 return false;
339
340 // Divergent, disjoint paths join.
341 DivDesc->JoinDivBlocks.insert(&SuccBlock);
342 LLVM_DEBUG(dbgs() << "\tDivergent join: " << SuccBlock.getName());
343 return true;
344 }
345
computeJoinPointsllvm::DivergencePropagator346 std::unique_ptr<ControlDivergenceDesc> computeJoinPoints() {
347 assert(DivDesc);
348
349 LLVM_DEBUG(dbgs() << "SDA:computeJoinPoints: " << DivTermBlock.getName()
350 << "\n");
351
352 const auto *DivBlockLoop = LI.getLoopFor(&DivTermBlock);
353
354 // Early stopping criterion
355 int FloorIdx = LoopPOT.size() - 1;
356 const BasicBlock *FloorLabel = nullptr;
357
358 // bootstrap with branch targets
359 int BlockIdx = 0;
360
361 for (const auto *SuccBlock : successors(&DivTermBlock)) {
362 auto SuccIdx = LoopPOT.getIndexOf(*SuccBlock);
363 BlockLabels[SuccIdx] = SuccBlock;
364
365 // Find the successor with the highest index to start with
366 BlockIdx = std::max<int>(BlockIdx, SuccIdx);
367 FloorIdx = std::min<int>(FloorIdx, SuccIdx);
368
369 // Identify immediate divergent loop exits
370 if (!DivBlockLoop)
371 continue;
372
373 const auto *BlockLoop = LI.getLoopFor(SuccBlock);
374 if (BlockLoop && DivBlockLoop->contains(BlockLoop))
375 continue;
376 DivDesc->LoopDivBlocks.insert(SuccBlock);
377 LLVM_DEBUG(dbgs() << "\tImmediate divergent loop exit: "
378 << SuccBlock->getName() << "\n");
379 }
380
381 // propagate definitions at the immediate successors of the node in RPO
382 for (; BlockIdx >= FloorIdx; --BlockIdx) {
383 LLVM_DEBUG(dbgs() << "Before next visit:\n"; printDefs(dbgs()));
384
385 // Any label available here
386 const auto *Label = BlockLabels[BlockIdx];
387 if (!Label)
388 continue;
389
390 // Ok. Get the block
391 const auto *Block = LoopPOT.getBlockAt(BlockIdx);
392 LLVM_DEBUG(dbgs() << "SDA::joins. visiting " << Block->getName() << "\n");
393
394 auto *BlockLoop = LI.getLoopFor(Block);
395 bool IsLoopHeader = BlockLoop && BlockLoop->getHeader() == Block;
396 bool CausedJoin = false;
397 int LoweredFloorIdx = FloorIdx;
398 if (IsLoopHeader) {
399 // Disconnect from immediate successors and propagate directly to loop
400 // exits.
401 SmallVector<BasicBlock *, 4> BlockLoopExits;
402 BlockLoop->getExitBlocks(BlockLoopExits);
403
404 bool IsParentLoop = BlockLoop->contains(&DivTermBlock);
405 for (const auto *BlockLoopExit : BlockLoopExits) {
406 CausedJoin |= visitLoopExitEdge(*BlockLoopExit, *Label, IsParentLoop);
407 LoweredFloorIdx = std::min<int>(LoweredFloorIdx,
408 LoopPOT.getIndexOf(*BlockLoopExit));
409 }
410 } else {
411 // Acyclic successor case
412 for (const auto *SuccBlock : successors(Block)) {
413 CausedJoin |= visitEdge(*SuccBlock, *Label);
414 LoweredFloorIdx =
415 std::min<int>(LoweredFloorIdx, LoopPOT.getIndexOf(*SuccBlock));
416 }
417 }
418
419 // Floor update
420 if (CausedJoin) {
421 // 1. Different labels pushed to successors
422 FloorIdx = LoweredFloorIdx;
423 } else if (FloorLabel != Label) {
424 // 2. No join caused BUT we pushed a label that is different than the
425 // last pushed label
426 FloorIdx = LoweredFloorIdx;
427 FloorLabel = Label;
428 }
429 }
430
431 LLVM_DEBUG(dbgs() << "SDA::joins. After propagation:\n"; printDefs(dbgs()));
432
433 return std::move(DivDesc);
434 }
435 };
436
437 #ifndef NDEBUG
printBlockSet(ConstBlockSet & Blocks,raw_ostream & Out)438 static void printBlockSet(ConstBlockSet &Blocks, raw_ostream &Out) {
439 Out << "[";
440 ListSeparator LS;
441 for (const auto *BB : Blocks)
442 Out << LS << BB->getName();
443 Out << "]";
444 }
445 #endif
446
447 const ControlDivergenceDesc &
getJoinBlocks(const Instruction & Term)448 SyncDependenceAnalysis::getJoinBlocks(const Instruction &Term) {
449 // trivial case
450 if (Term.getNumSuccessors() <= 1) {
451 return EmptyDivergenceDesc;
452 }
453
454 // already available in cache?
455 auto ItCached = CachedControlDivDescs.find(&Term);
456 if (ItCached != CachedControlDivDescs.end())
457 return *ItCached->second;
458
459 // compute all join points
460 // Special handling of divergent loop exits is not needed for LCSSA
461 const auto &TermBlock = *Term.getParent();
462 DivergencePropagator Propagator(LoopPO, DT, PDT, LI, TermBlock);
463 auto DivDesc = Propagator.computeJoinPoints();
464
465 LLVM_DEBUG(dbgs() << "Result (" << Term.getParent()->getName() << "):\n";
466 dbgs() << "JoinDivBlocks: ";
467 printBlockSet(DivDesc->JoinDivBlocks, dbgs());
468 dbgs() << "\nLoopDivBlocks: ";
469 printBlockSet(DivDesc->LoopDivBlocks, dbgs()); dbgs() << "\n";);
470
471 auto ItInserted = CachedControlDivDescs.emplace(&Term, std::move(DivDesc));
472 assert(ItInserted.second);
473 return *ItInserted.first->second;
474 }
475
476 } // namespace llvm
477