1 //==- SyncDependenceAnalysis.cpp - Divergent Branch Dependence Calculation -==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements an algorithm that returns for a divergent branch
10 // the set of basic blocks whose phi nodes become divergent due to divergent
11 // control. These are the blocks that are reachable by two disjoint paths from
12 // the branch or loop exits that have a reaching path that is disjoint from a
13 // path to the loop latch.
14 //
15 // The SyncDependenceAnalysis is used in the DivergenceAnalysis to model
16 // control-induced divergence in phi nodes.
17 //
18 // -- Summary --
19 // The SyncDependenceAnalysis lazily computes sync dependences [3].
20 // The analysis evaluates the disjoint path criterion [2] by a reduction
21 // to SSA construction. The SSA construction algorithm is implemented as
22 // a simple data-flow analysis [1].
23 //
24 // [1] "A Simple, Fast Dominance Algorithm", SPI '01, Cooper, Harvey and Kennedy
25 // [2] "Efficiently Computing Static Single Assignment Form
26 //     and the Control Dependence Graph", TOPLAS '91,
27 //           Cytron, Ferrante, Rosen, Wegman and Zadeck
28 // [3] "Improving Performance of OpenCL on CPUs", CC '12, Karrenberg and Hack
29 // [4] "Divergence Analysis", TOPLAS '13, Sampaio, Souza, Collange and Pereira
30 //
31 // -- Sync dependence --
32 // Sync dependence [4] characterizes the control flow aspect of the
33 // propagation of branch divergence. For example,
34 //
35 //   %cond = icmp slt i32 %tid, 10
36 //   br i1 %cond, label %then, label %else
37 // then:
38 //   br label %merge
39 // else:
40 //   br label %merge
41 // merge:
42 //   %a = phi i32 [ 0, %then ], [ 1, %else ]
43 //
44 // Suppose %tid holds the thread ID. Although %a is not data dependent on %tid
45 // because %tid is not on its use-def chains, %a is sync dependent on %tid
46 // because the branch "br i1 %cond" depends on %tid and affects which value %a
47 // is assigned to.
48 //
49 // -- Reduction to SSA construction --
50 // There are two disjoint paths from A to X, if a certain variant of SSA
51 // construction places a phi node in X under the following set-up scheme [2].
52 //
53 // This variant of SSA construction ignores incoming undef values.
54 // That is paths from the entry without a definition do not result in
55 // phi nodes.
56 //
57 //       entry
58 //     /      \
59 //    A        \
60 //  /   \       Y
61 // B     C     /
62 //  \   /  \  /
63 //    D     E
64 //     \   /
65 //       F
66 // Assume that A contains a divergent branch. We are interested
67 // in the set of all blocks where each block is reachable from A
68 // via two disjoint paths. This would be the set {D, F} in this
69 // case.
70 // To generally reduce this query to SSA construction we introduce
71 // a virtual variable x and assign to x different values in each
72 // successor block of A.
73 //           entry
74 //         /      \
75 //        A        \
76 //      /   \       Y
77 // x = 0   x = 1   /
78 //      \  /   \  /
79 //        D     E
80 //         \   /
81 //           F
82 // Our flavor of SSA construction for x will construct the following
83 //            entry
84 //          /      \
85 //         A        \
86 //       /   \       Y
87 // x0 = 0   x1 = 1  /
88 //       \   /   \ /
89 //      x2=phi    E
90 //         \     /
91 //          x3=phi
92 // The blocks D and F contain phi nodes and are thus each reachable
93 // by two disjoins paths from A.
94 //
95 // -- Remarks --
96 // In case of loop exits we need to check the disjoint path criterion for loops
97 // [2]. To this end, we check whether the definition of x differs between the
98 // loop exit and the loop header (_after_ SSA construction).
99 //
100 //===----------------------------------------------------------------------===//
101 #include "llvm/Analysis/SyncDependenceAnalysis.h"
102 #include "llvm/ADT/PostOrderIterator.h"
103 #include "llvm/ADT/SmallPtrSet.h"
104 #include "llvm/Analysis/PostDominators.h"
105 #include "llvm/IR/BasicBlock.h"
106 #include "llvm/IR/CFG.h"
107 #include "llvm/IR/Dominators.h"
108 #include "llvm/IR/Function.h"
109 
110 #include <stack>
111 #include <unordered_set>
112 
113 #define DEBUG_TYPE "sync-dependence"
114 
115 namespace llvm {
116 
117 ConstBlockSet SyncDependenceAnalysis::EmptyBlockSet;
118 
119 SyncDependenceAnalysis::SyncDependenceAnalysis(const DominatorTree &DT,
120                                                const PostDominatorTree &PDT,
121                                                const LoopInfo &LI)
122     : FuncRPOT(DT.getRoot()->getParent()), DT(DT), PDT(PDT), LI(LI) {}
123 
124 SyncDependenceAnalysis::~SyncDependenceAnalysis() {}
125 
126 using FunctionRPOT = ReversePostOrderTraversal<const Function *>;
127 
128 // divergence propagator for reducible CFGs
129 struct DivergencePropagator {
130   const FunctionRPOT &FuncRPOT;
131   const DominatorTree &DT;
132   const PostDominatorTree &PDT;
133   const LoopInfo &LI;
134 
135   // identified join points
136   std::unique_ptr<ConstBlockSet> JoinBlocks;
137 
138   // reached loop exits (by a path disjoint to a path to the loop header)
139   SmallPtrSet<const BasicBlock *, 4> ReachedLoopExits;
140 
141   // if DefMap[B] == C then C is the dominating definition at block B
142   // if DefMap[B] ~ undef then we haven't seen B yet
143   // if DefMap[B] == B then B is a join point of disjoint paths from X or B is
144   // an immediate successor of X (initial value).
145   using DefiningBlockMap = std::map<const BasicBlock *, const BasicBlock *>;
146   DefiningBlockMap DefMap;
147 
148   // all blocks with pending visits
149   std::unordered_set<const BasicBlock *> PendingUpdates;
150 
151   DivergencePropagator(const FunctionRPOT &FuncRPOT, const DominatorTree &DT,
152                        const PostDominatorTree &PDT, const LoopInfo &LI)
153       : FuncRPOT(FuncRPOT), DT(DT), PDT(PDT), LI(LI),
154         JoinBlocks(new ConstBlockSet) {}
155 
156   // set the definition at @block and mark @block as pending for a visit
157   void addPending(const BasicBlock &Block, const BasicBlock &DefBlock) {
158     bool WasAdded = DefMap.emplace(&Block, &DefBlock).second;
159     if (WasAdded)
160       PendingUpdates.insert(&Block);
161   }
162 
163   void printDefs(raw_ostream &Out) {
164     Out << "Propagator::DefMap {\n";
165     for (const auto *Block : FuncRPOT) {
166       auto It = DefMap.find(Block);
167       Out << Block->getName() << " : ";
168       if (It == DefMap.end()) {
169         Out << "\n";
170       } else {
171         const auto *DefBlock = It->second;
172         Out << (DefBlock ? DefBlock->getName() : "<null>") << "\n";
173       }
174     }
175     Out << "}\n";
176   }
177 
178   // process @succBlock with reaching definition @defBlock
179   // the original divergent branch was in @parentLoop (if any)
180   void visitSuccessor(const BasicBlock &SuccBlock, const Loop *ParentLoop,
181                       const BasicBlock &DefBlock) {
182 
183     // @succBlock is a loop exit
184     if (ParentLoop && !ParentLoop->contains(&SuccBlock)) {
185       DefMap.emplace(&SuccBlock, &DefBlock);
186       ReachedLoopExits.insert(&SuccBlock);
187       return;
188     }
189 
190     // first reaching def?
191     auto ItLastDef = DefMap.find(&SuccBlock);
192     if (ItLastDef == DefMap.end()) {
193       addPending(SuccBlock, DefBlock);
194       return;
195     }
196 
197     // a join of at least two definitions
198     if (ItLastDef->second != &DefBlock) {
199       // do we know this join already?
200       if (!JoinBlocks->insert(&SuccBlock).second)
201         return;
202 
203       // update the definition
204       addPending(SuccBlock, SuccBlock);
205     }
206   }
207 
208   // find all blocks reachable by two disjoint paths from @rootTerm.
209   // This method works for both divergent terminators and loops with
210   // divergent exits.
211   // @rootBlock is either the block containing the branch or the header of the
212   // divergent loop.
213   // @nodeSuccessors is the set of successors of the node (Loop or Terminator)
214   // headed by @rootBlock.
215   // @parentLoop is the parent loop of the Loop or the loop that contains the
216   // Terminator.
217   template <typename SuccessorIterable>
218   std::unique_ptr<ConstBlockSet>
219   computeJoinPoints(const BasicBlock &RootBlock,
220                     SuccessorIterable NodeSuccessors, const Loop *ParentLoop) {
221     assert(JoinBlocks);
222 
223     LLVM_DEBUG(dbgs() << "SDA:computeJoinPoints. Parent loop: "
224                       << (ParentLoop ? ParentLoop->getName() : "<null>")
225                       << "\n");
226 
227     // bootstrap with branch targets
228     for (const auto *SuccBlock : NodeSuccessors) {
229       DefMap.emplace(SuccBlock, SuccBlock);
230 
231       if (ParentLoop && !ParentLoop->contains(SuccBlock)) {
232         // immediate loop exit from node.
233         ReachedLoopExits.insert(SuccBlock);
234       } else {
235         // regular successor
236         PendingUpdates.insert(SuccBlock);
237       }
238     }
239 
240     LLVM_DEBUG(dbgs() << "SDA: rpo order:\n"; for (const auto *RpoBlock
241                                                    : FuncRPOT) {
242       dbgs() << "- " << RpoBlock->getName() << "\n";
243     });
244 
245     auto ItBeginRPO = FuncRPOT.begin();
246     auto ItEndRPO = FuncRPOT.end();
247 
248     // skip until term (TODO RPOT won't let us start at @term directly)
249     for (; *ItBeginRPO != &RootBlock; ++ItBeginRPO) {
250       assert(ItBeginRPO != ItEndRPO && "Unable to find RootBlock");
251     }
252 
253     // propagate definitions at the immediate successors of the node in RPO
254     auto ItBlockRPO = ItBeginRPO;
255     while ((++ItBlockRPO != ItEndRPO) && !PendingUpdates.empty()) {
256       const auto *Block = *ItBlockRPO;
257       LLVM_DEBUG(dbgs() << "SDA::joins. visiting " << Block->getName() << "\n");
258 
259       // skip Block if not pending update
260       auto ItPending = PendingUpdates.find(Block);
261       if (ItPending == PendingUpdates.end())
262         continue;
263       PendingUpdates.erase(ItPending);
264 
265       // propagate definition at Block to its successors
266       auto ItDef = DefMap.find(Block);
267       const auto *DefBlock = ItDef->second;
268       assert(DefBlock);
269 
270       auto *BlockLoop = LI.getLoopFor(Block);
271       if (ParentLoop &&
272           (ParentLoop != BlockLoop && ParentLoop->contains(BlockLoop))) {
273         // if the successor is the header of a nested loop pretend its a
274         // single node with the loop's exits as successors
275         SmallVector<BasicBlock *, 4> BlockLoopExits;
276         BlockLoop->getExitBlocks(BlockLoopExits);
277         for (const auto *BlockLoopExit : BlockLoopExits) {
278           visitSuccessor(*BlockLoopExit, ParentLoop, *DefBlock);
279         }
280 
281       } else {
282         // the successors are either on the same loop level or loop exits
283         for (const auto *SuccBlock : successors(Block)) {
284           visitSuccessor(*SuccBlock, ParentLoop, *DefBlock);
285         }
286       }
287     }
288 
289     LLVM_DEBUG(dbgs() << "SDA::joins. After propagation:\n"; printDefs(dbgs()));
290 
291     // We need to know the definition at the parent loop header to decide
292     // whether the definition at the header is different from the definition at
293     // the loop exits, which would indicate a divergent loop exits.
294     //
295     // A // loop header
296     // |
297     // B // nested loop header
298     // |
299     // C -> X (exit from B loop) -..-> (A latch)
300     // |
301     // D -> back to B (B latch)
302     // |
303     // proper exit from both loops
304     //
305     // analyze reached loop exits
306     if (!ReachedLoopExits.empty()) {
307       const BasicBlock *ParentLoopHeader =
308           ParentLoop ? ParentLoop->getHeader() : nullptr;
309 
310       assert(ParentLoop);
311       auto ItHeaderDef = DefMap.find(ParentLoopHeader);
312       const auto *HeaderDefBlock =
313           (ItHeaderDef == DefMap.end()) ? nullptr : ItHeaderDef->second;
314 
315       LLVM_DEBUG(printDefs(dbgs()));
316       assert(HeaderDefBlock && "no definition at header of carrying loop");
317 
318       for (const auto *ExitBlock : ReachedLoopExits) {
319         auto ItExitDef = DefMap.find(ExitBlock);
320         assert((ItExitDef != DefMap.end()) &&
321                "no reaching def at reachable loop exit");
322         if (ItExitDef->second != HeaderDefBlock) {
323           JoinBlocks->insert(ExitBlock);
324         }
325       }
326     }
327 
328     return std::move(JoinBlocks);
329   }
330 };
331 
332 const ConstBlockSet &SyncDependenceAnalysis::join_blocks(const Loop &Loop) {
333   using LoopExitVec = SmallVector<BasicBlock *, 4>;
334   LoopExitVec LoopExits;
335   Loop.getExitBlocks(LoopExits);
336   if (LoopExits.size() < 1) {
337     return EmptyBlockSet;
338   }
339 
340   // already available in cache?
341   auto ItCached = CachedLoopExitJoins.find(&Loop);
342   if (ItCached != CachedLoopExitJoins.end()) {
343     return *ItCached->second;
344   }
345 
346   // compute all join points
347   DivergencePropagator Propagator{FuncRPOT, DT, PDT, LI};
348   auto JoinBlocks = Propagator.computeJoinPoints<const LoopExitVec &>(
349       *Loop.getHeader(), LoopExits, Loop.getParentLoop());
350 
351   auto ItInserted = CachedLoopExitJoins.emplace(&Loop, std::move(JoinBlocks));
352   assert(ItInserted.second);
353   return *ItInserted.first->second;
354 }
355 
356 const ConstBlockSet &
357 SyncDependenceAnalysis::join_blocks(const Instruction &Term) {
358   // trivial case
359   if (Term.getNumSuccessors() < 1) {
360     return EmptyBlockSet;
361   }
362 
363   // already available in cache?
364   auto ItCached = CachedBranchJoins.find(&Term);
365   if (ItCached != CachedBranchJoins.end())
366     return *ItCached->second;
367 
368   // compute all join points
369   DivergencePropagator Propagator{FuncRPOT, DT, PDT, LI};
370   const auto &TermBlock = *Term.getParent();
371   auto JoinBlocks = Propagator.computeJoinPoints<const_succ_range>(
372       TermBlock, successors(Term.getParent()), LI.getLoopFor(&TermBlock));
373 
374   auto ItInserted = CachedBranchJoins.emplace(&Term, std::move(JoinBlocks));
375   assert(ItInserted.second);
376   return *ItInserted.first->second;
377 }
378 
379 } // namespace llvm
380