1 //===--- SyncDependenceAnalysis.cpp - Compute Control Divergence Effects --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements an algorithm that returns for a divergent branch
10 // the set of basic blocks whose phi nodes become divergent due to divergent
11 // control. These are the blocks that are reachable by two disjoint paths from
12 // the branch or loop exits that have a reaching path that is disjoint from a
13 // path to the loop latch.
14 //
15 // The SyncDependenceAnalysis is used in the DivergenceAnalysis to model
16 // control-induced divergence in phi nodes.
17 //
18 //
19 // -- Reference --
20 // The algorithm is presented in Section 5 of
21 //
22 //   An abstract interpretation for SPMD divergence
23 //       on reducible control flow graphs.
24 //   Julian Rosemann, Simon Moll and Sebastian Hack
25 //   POPL '21
26 //
27 //
28 // -- Sync dependence --
29 // Sync dependence characterizes the control flow aspect of the
30 // propagation of branch divergence. For example,
31 //
32 //   %cond = icmp slt i32 %tid, 10
33 //   br i1 %cond, label %then, label %else
34 // then:
35 //   br label %merge
36 // else:
37 //   br label %merge
38 // merge:
39 //   %a = phi i32 [ 0, %then ], [ 1, %else ]
40 //
41 // Suppose %tid holds the thread ID. Although %a is not data dependent on %tid
42 // because %tid is not on its use-def chains, %a is sync dependent on %tid
43 // because the branch "br i1 %cond" depends on %tid and affects which value %a
44 // is assigned to.
45 //
46 //
47 // -- Reduction to SSA construction --
48 // There are two disjoint paths from A to X, if a certain variant of SSA
49 // construction places a phi node in X under the following set-up scheme.
50 //
51 // This variant of SSA construction ignores incoming undef values.
52 // That is paths from the entry without a definition do not result in
53 // phi nodes.
54 //
55 //       entry
56 //     /      \
57 //    A        \
58 //  /   \       Y
59 // B     C     /
60 //  \   /  \  /
61 //    D     E
62 //     \   /
63 //       F
64 //
65 // Assume that A contains a divergent branch. We are interested
66 // in the set of all blocks where each block is reachable from A
67 // via two disjoint paths. This would be the set {D, F} in this
68 // case.
69 // To generally reduce this query to SSA construction we introduce
70 // a virtual variable x and assign to x different values in each
71 // successor block of A.
72 //
73 //           entry
74 //         /      \
75 //        A        \
76 //      /   \       Y
77 // x = 0   x = 1   /
78 //      \  /   \  /
79 //        D     E
80 //         \   /
81 //           F
82 //
83 // Our flavor of SSA construction for x will construct the following
84 //
85 //            entry
86 //          /      \
87 //         A        \
88 //       /   \       Y
89 // x0 = 0   x1 = 1  /
90 //       \   /   \ /
91 //     x2 = phi   E
92 //         \     /
93 //         x3 = phi
94 //
95 // The blocks D and F contain phi nodes and are thus each reachable
96 // by two disjoins paths from A.
97 //
98 // -- Remarks --
99 // * In case of loop exits we need to check the disjoint path criterion for loops.
100 //   To this end, we check whether the definition of x differs between the
101 //   loop exit and the loop header (_after_ SSA construction).
102 //
103 // -- Known Limitations & Future Work --
104 // * The algorithm requires reducible loops because the implementation
105 //   implicitly performs a single iteration of the underlying data flow analysis.
106 //   This was done for pragmatism, simplicity and speed.
107 //
108 //   Relevant related work for extending the algorithm to irreducible control:
109 //     A simple algorithm for global data flow analysis problems.
110 //     Matthew S. Hecht and Jeffrey D. Ullman.
111 //     SIAM Journal on Computing, 4(4):519–532, December 1975.
112 //
113 // * Another reason for requiring reducible loops is that points of
114 //   synchronization in irreducible loops aren't 'obvious' - there is no unique
115 //   header where threads 'should' synchronize when entering or coming back
116 //   around from the latch.
117 //
118 //===----------------------------------------------------------------------===//
119 #include "llvm/Analysis/SyncDependenceAnalysis.h"
120 #include "llvm/ADT/PostOrderIterator.h"
121 #include "llvm/ADT/SmallPtrSet.h"
122 #include "llvm/Analysis/PostDominators.h"
123 #include "llvm/IR/BasicBlock.h"
124 #include "llvm/IR/CFG.h"
125 #include "llvm/IR/Dominators.h"
126 #include "llvm/IR/Function.h"
127 
128 #include <functional>
129 #include <stack>
130 #include <unordered_set>
131 
132 #define DEBUG_TYPE "sync-dependence"
133 
134 // The SDA algorithm operates on a modified CFG - we modify the edges leaving
135 // loop headers as follows:
136 //
137 // * We remove all edges leaving all loop headers.
138 // * We add additional edges from the loop headers to their exit blocks.
139 //
140 // The modification is virtual, that is whenever we visit a loop header we
141 // pretend it had different successors.
142 namespace {
143 using namespace llvm;
144 
145 // Custom Post-Order Traveral
146 //
147 // We cannot use the vanilla (R)PO computation of LLVM because:
148 // * We (virtually) modify the CFG.
149 // * We want a loop-compact block enumeration, that is the numbers assigned to
150 //   blocks of a loop form an interval
151 //
152 using POCB = std::function<void(const BasicBlock &)>;
153 using VisitedSet = std::set<const BasicBlock *>;
154 using BlockStack = std::vector<const BasicBlock *>;
155 
156 // forward
157 static void computeLoopPO(const LoopInfo &LI, Loop &Loop, POCB CallBack,
158                           VisitedSet &Finalized);
159 
160 // for a nested region (top-level loop or nested loop)
161 static void computeStackPO(BlockStack &Stack, const LoopInfo &LI, Loop *Loop,
162                            POCB CallBack, VisitedSet &Finalized) {
163   const auto *LoopHeader = Loop ? Loop->getHeader() : nullptr;
164   while (!Stack.empty()) {
165     const auto *NextBB = Stack.back();
166 
167     auto *NestedLoop = LI.getLoopFor(NextBB);
168     bool IsNestedLoop = NestedLoop != Loop;
169 
170     // Treat the loop as a node
171     if (IsNestedLoop) {
172       SmallVector<BasicBlock *, 3> NestedExits;
173       NestedLoop->getUniqueExitBlocks(NestedExits);
174       bool PushedNodes = false;
175       for (const auto *NestedExitBB : NestedExits) {
176         if (NestedExitBB == LoopHeader)
177           continue;
178         if (Loop && !Loop->contains(NestedExitBB))
179           continue;
180         if (Finalized.count(NestedExitBB))
181           continue;
182         PushedNodes = true;
183         Stack.push_back(NestedExitBB);
184       }
185       if (!PushedNodes) {
186         // All loop exits finalized -> finish this node
187         Stack.pop_back();
188         computeLoopPO(LI, *NestedLoop, CallBack, Finalized);
189       }
190       continue;
191     }
192 
193     // DAG-style
194     bool PushedNodes = false;
195     for (const auto *SuccBB : successors(NextBB)) {
196       if (SuccBB == LoopHeader)
197         continue;
198       if (Loop && !Loop->contains(SuccBB))
199         continue;
200       if (Finalized.count(SuccBB))
201         continue;
202       PushedNodes = true;
203       Stack.push_back(SuccBB);
204     }
205     if (!PushedNodes) {
206       // Never push nodes twice
207       Stack.pop_back();
208       if (!Finalized.insert(NextBB).second)
209         continue;
210       CallBack(*NextBB);
211     }
212   }
213 }
214 
215 static void computeTopLevelPO(Function &F, const LoopInfo &LI, POCB CallBack) {
216   VisitedSet Finalized;
217   BlockStack Stack;
218   Stack.reserve(24); // FIXME made-up number
219   Stack.push_back(&F.getEntryBlock());
220   computeStackPO(Stack, LI, nullptr, CallBack, Finalized);
221 }
222 
223 static void computeLoopPO(const LoopInfo &LI, Loop &Loop, POCB CallBack,
224                           VisitedSet &Finalized) {
225   /// Call CallBack on all loop blocks.
226   std::vector<const BasicBlock *> Stack;
227   const auto *LoopHeader = Loop.getHeader();
228 
229   // Visit the header last
230   Finalized.insert(LoopHeader);
231   CallBack(*LoopHeader);
232 
233   // Initialize with immediate successors
234   for (const auto *BB : successors(LoopHeader)) {
235     if (!Loop.contains(BB))
236       continue;
237     if (BB == LoopHeader)
238       continue;
239     Stack.push_back(BB);
240   }
241 
242   // Compute PO inside region
243   computeStackPO(Stack, LI, &Loop, CallBack, Finalized);
244 }
245 
246 } // namespace
247 
248 namespace llvm {
249 
250 ControlDivergenceDesc SyncDependenceAnalysis::EmptyDivergenceDesc;
251 
252 SyncDependenceAnalysis::SyncDependenceAnalysis(const DominatorTree &DT,
253                                                const PostDominatorTree &PDT,
254                                                const LoopInfo &LI)
255     : DT(DT), PDT(PDT), LI(LI) {
256   computeTopLevelPO(*DT.getRoot()->getParent(), LI,
257                     [&](const BasicBlock &BB) { LoopPO.appendBlock(BB); });
258 }
259 
260 SyncDependenceAnalysis::~SyncDependenceAnalysis() {}
261 
262 // divergence propagator for reducible CFGs
263 struct DivergencePropagator {
264   const ModifiedPO &LoopPOT;
265   const DominatorTree &DT;
266   const PostDominatorTree &PDT;
267   const LoopInfo &LI;
268   const BasicBlock &DivTermBlock;
269 
270   // * if BlockLabels[IndexOf(B)] == C then C is the dominating definition at
271   //   block B
272   // * if BlockLabels[IndexOf(B)] ~ undef then we haven't seen B yet
273   // * if BlockLabels[IndexOf(B)] == B then B is a join point of disjoint paths
274   // from X or B is an immediate successor of X (initial value).
275   using BlockLabelVec = std::vector<const BasicBlock *>;
276   BlockLabelVec BlockLabels;
277   // divergent join and loop exit descriptor.
278   std::unique_ptr<ControlDivergenceDesc> DivDesc;
279 
280   DivergencePropagator(const ModifiedPO &LoopPOT, const DominatorTree &DT,
281                        const PostDominatorTree &PDT, const LoopInfo &LI,
282                        const BasicBlock &DivTermBlock)
283       : LoopPOT(LoopPOT), DT(DT), PDT(PDT), LI(LI), DivTermBlock(DivTermBlock),
284         BlockLabels(LoopPOT.size(), nullptr),
285         DivDesc(new ControlDivergenceDesc) {}
286 
287   void printDefs(raw_ostream &Out) {
288     Out << "Propagator::BlockLabels {\n";
289     for (int BlockIdx = (int)BlockLabels.size() - 1; BlockIdx > 0; --BlockIdx) {
290       const auto *Label = BlockLabels[BlockIdx];
291       Out << LoopPOT.getBlockAt(BlockIdx)->getName().str() << "(" << BlockIdx
292           << ") : ";
293       if (!Label) {
294         Out << "<null>\n";
295       } else {
296         Out << Label->getName() << "\n";
297       }
298     }
299     Out << "}\n";
300   }
301 
302   // Push a definition (\p PushedLabel) to \p SuccBlock and return whether this
303   // causes a divergent join.
304   bool computeJoin(const BasicBlock &SuccBlock, const BasicBlock &PushedLabel) {
305     auto SuccIdx = LoopPOT.getIndexOf(SuccBlock);
306 
307     // unset or same reaching label
308     const auto *OldLabel = BlockLabels[SuccIdx];
309     if (!OldLabel || (OldLabel == &PushedLabel)) {
310       BlockLabels[SuccIdx] = &PushedLabel;
311       return false;
312     }
313 
314     // Update the definition
315     BlockLabels[SuccIdx] = &SuccBlock;
316     return true;
317   }
318 
319   // visiting a virtual loop exit edge from the loop header --> temporal
320   // divergence on join
321   bool visitLoopExitEdge(const BasicBlock &ExitBlock,
322                          const BasicBlock &DefBlock, bool FromParentLoop) {
323     // Pushing from a non-parent loop cannot cause temporal divergence.
324     if (!FromParentLoop)
325       return visitEdge(ExitBlock, DefBlock);
326 
327     if (!computeJoin(ExitBlock, DefBlock))
328       return false;
329 
330     // Identified a divergent loop exit
331     DivDesc->LoopDivBlocks.insert(&ExitBlock);
332     LLVM_DEBUG(dbgs() << "\tDivergent loop exit: " << ExitBlock.getName()
333                       << "\n");
334     return true;
335   }
336 
337   // process \p SuccBlock with reaching definition \p DefBlock
338   bool visitEdge(const BasicBlock &SuccBlock, const BasicBlock &DefBlock) {
339     if (!computeJoin(SuccBlock, DefBlock))
340       return false;
341 
342     // Divergent, disjoint paths join.
343     DivDesc->JoinDivBlocks.insert(&SuccBlock);
344     LLVM_DEBUG(dbgs() << "\tDivergent join: " << SuccBlock.getName());
345     return true;
346   }
347 
348   std::unique_ptr<ControlDivergenceDesc> computeJoinPoints() {
349     assert(DivDesc);
350 
351     LLVM_DEBUG(dbgs() << "SDA:computeJoinPoints: " << DivTermBlock.getName()
352                       << "\n");
353 
354     const auto *DivBlockLoop = LI.getLoopFor(&DivTermBlock);
355 
356     // Early stopping criterion
357     int FloorIdx = LoopPOT.size() - 1;
358     const BasicBlock *FloorLabel = nullptr;
359 
360     // bootstrap with branch targets
361     int BlockIdx = 0;
362 
363     for (const auto *SuccBlock : successors(&DivTermBlock)) {
364       auto SuccIdx = LoopPOT.getIndexOf(*SuccBlock);
365       BlockLabels[SuccIdx] = SuccBlock;
366 
367       // Find the successor with the highest index to start with
368       BlockIdx = std::max<int>(BlockIdx, SuccIdx);
369       FloorIdx = std::min<int>(FloorIdx, SuccIdx);
370 
371       // Identify immediate divergent loop exits
372       if (!DivBlockLoop)
373         continue;
374 
375       const auto *BlockLoop = LI.getLoopFor(SuccBlock);
376       if (BlockLoop && DivBlockLoop->contains(BlockLoop))
377         continue;
378       DivDesc->LoopDivBlocks.insert(SuccBlock);
379       LLVM_DEBUG(dbgs() << "\tImmediate divergent loop exit: "
380                         << SuccBlock->getName() << "\n");
381     }
382 
383     // propagate definitions at the immediate successors of the node in RPO
384     for (; BlockIdx >= FloorIdx; --BlockIdx) {
385       LLVM_DEBUG(dbgs() << "Before next visit:\n"; printDefs(dbgs()));
386 
387       // Any label available here
388       const auto *Label = BlockLabels[BlockIdx];
389       if (!Label)
390         continue;
391 
392       // Ok. Get the block
393       const auto *Block = LoopPOT.getBlockAt(BlockIdx);
394       LLVM_DEBUG(dbgs() << "SDA::joins. visiting " << Block->getName() << "\n");
395 
396       auto *BlockLoop = LI.getLoopFor(Block);
397       bool IsLoopHeader = BlockLoop && BlockLoop->getHeader() == Block;
398       bool CausedJoin = false;
399       int LoweredFloorIdx = FloorIdx;
400       if (IsLoopHeader) {
401         // Disconnect from immediate successors and propagate directly to loop
402         // exits.
403         SmallVector<BasicBlock *, 4> BlockLoopExits;
404         BlockLoop->getExitBlocks(BlockLoopExits);
405 
406         bool IsParentLoop = BlockLoop->contains(&DivTermBlock);
407         for (const auto *BlockLoopExit : BlockLoopExits) {
408           CausedJoin |= visitLoopExitEdge(*BlockLoopExit, *Label, IsParentLoop);
409           LoweredFloorIdx = std::min<int>(LoweredFloorIdx,
410                                           LoopPOT.getIndexOf(*BlockLoopExit));
411         }
412       } else {
413         // Acyclic successor case
414         for (const auto *SuccBlock : successors(Block)) {
415           CausedJoin |= visitEdge(*SuccBlock, *Label);
416           LoweredFloorIdx =
417               std::min<int>(LoweredFloorIdx, LoopPOT.getIndexOf(*SuccBlock));
418         }
419       }
420 
421       // Floor update
422       if (CausedJoin) {
423         // 1. Different labels pushed to successors
424         FloorIdx = LoweredFloorIdx;
425       } else if (FloorLabel != Label) {
426         // 2. No join caused BUT we pushed a label that is different than the
427         // last pushed label
428         FloorIdx = LoweredFloorIdx;
429         FloorLabel = Label;
430       }
431     }
432 
433     LLVM_DEBUG(dbgs() << "SDA::joins. After propagation:\n"; printDefs(dbgs()));
434 
435     return std::move(DivDesc);
436   }
437 };
438 
439 #ifndef NDEBUG
440 static void printBlockSet(ConstBlockSet &Blocks, raw_ostream &Out) {
441   Out << "[";
442   ListSeparator LS;
443   for (const auto *BB : Blocks)
444     Out << LS << BB->getName();
445   Out << "]";
446 }
447 #endif
448 
449 const ControlDivergenceDesc &
450 SyncDependenceAnalysis::getJoinBlocks(const Instruction &Term) {
451   // trivial case
452   if (Term.getNumSuccessors() <= 1) {
453     return EmptyDivergenceDesc;
454   }
455 
456   // already available in cache?
457   auto ItCached = CachedControlDivDescs.find(&Term);
458   if (ItCached != CachedControlDivDescs.end())
459     return *ItCached->second;
460 
461   // compute all join points
462   // Special handling of divergent loop exits is not needed for LCSSA
463   const auto &TermBlock = *Term.getParent();
464   DivergencePropagator Propagator(LoopPO, DT, PDT, LI, TermBlock);
465   auto DivDesc = Propagator.computeJoinPoints();
466 
467   LLVM_DEBUG(dbgs() << "Result (" << Term.getParent()->getName() << "):\n";
468              dbgs() << "JoinDivBlocks: ";
469              printBlockSet(DivDesc->JoinDivBlocks, dbgs());
470              dbgs() << "\nLoopDivBlocks: ";
471              printBlockSet(DivDesc->LoopDivBlocks, dbgs()); dbgs() << "\n";);
472 
473   auto ItInserted = CachedControlDivDescs.emplace(&Term, std::move(DivDesc));
474   assert(ItInserted.second);
475   return *ItInserted.first->second;
476 }
477 
478 } // namespace llvm
479