1 //===---- DivergenceAnalysis.cpp --- Divergence Analysis Implementation ----==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a general divergence analysis for loop vectorization
10 // and GPU programs. It determines which branches and values in a loop or GPU
11 // program are divergent. It can help branch optimizations such as jump
12 // threading and loop unswitching to make better decisions.
13 //
14 // GPU programs typically use the SIMD execution model, where multiple threads
15 // in the same execution group have to execute in lock-step. Therefore, if the
16 // code contains divergent branches (i.e., threads in a group do not agree on
17 // which path of the branch to take), the group of threads has to execute all
18 // the paths from that branch with different subsets of threads enabled until
19 // they re-converge.
20 //
21 // Due to this execution model, some optimizations such as jump
22 // threading and loop unswitching can interfere with thread re-convergence.
23 // Therefore, an analysis that computes which branches in a GPU program are
24 // divergent can help the compiler to selectively run these optimizations.
25 //
26 // This implementation is derived from the Vectorization Analysis of the
27 // Region Vectorizer (RV). That implementation in turn is based on the approach
28 // described in
29 //
30 //   Improving Performance of OpenCL on CPUs
31 //   Ralf Karrenberg and Sebastian Hack
32 //   CC '12
33 //
34 // This DivergenceAnalysis implementation is generic in the sense that it does
35 // not itself identify original sources of divergence.
36 // Instead specialized adapter classes, (LoopDivergenceAnalysis) for loops and
37 // (GPUDivergenceAnalysis) for GPU programs, identify the sources of divergence
38 // (e.g., special variables that hold the thread ID or the iteration variable).
39 //
40 // The generic implementation propagates divergence to variables that are data
41 // or sync dependent on a source of divergence.
42 //
43 // While data dependency is a well-known concept, the notion of sync dependency
44 // is worth more explanation. Sync dependence characterizes the control flow
45 // aspect of the propagation of branch divergence. For example,
46 //
47 //   %cond = icmp slt i32 %tid, 10
48 //   br i1 %cond, label %then, label %else
49 // then:
50 //   br label %merge
51 // else:
52 //   br label %merge
53 // merge:
54 //   %a = phi i32 [ 0, %then ], [ 1, %else ]
55 //
56 // Suppose %tid holds the thread ID. Although %a is not data dependent on %tid
57 // because %tid is not on its use-def chains, %a is sync dependent on %tid
58 // because the branch "br i1 %cond" depends on %tid and affects which value %a
59 // is assigned to.
60 //
61 // The sync dependence detection (which branch induces divergence in which join
62 // points) is implemented in the SyncDependenceAnalysis.
63 //
64 // The current DivergenceAnalysis implementation has the following limitations:
65 // 1. intra-procedural. It conservatively considers the arguments of a
66 //    non-kernel-entry function and the return value of a function call as
67 //    divergent.
68 // 2. memory as black box. It conservatively considers values loaded from
69 //    generic or local address as divergent. This can be improved by leveraging
70 //    pointer analysis and/or by modelling non-escaping memory objects in SSA
71 //    as done in RV.
72 //
73 //===----------------------------------------------------------------------===//
74 
75 #include "llvm/Analysis/DivergenceAnalysis.h"
76 #include "llvm/Analysis/LoopInfo.h"
77 #include "llvm/Analysis/Passes.h"
78 #include "llvm/Analysis/PostDominators.h"
79 #include "llvm/Analysis/TargetTransformInfo.h"
80 #include "llvm/IR/Dominators.h"
81 #include "llvm/IR/InstIterator.h"
82 #include "llvm/IR/Instructions.h"
83 #include "llvm/IR/IntrinsicInst.h"
84 #include "llvm/IR/Value.h"
85 #include "llvm/Support/Debug.h"
86 #include "llvm/Support/raw_ostream.h"
87 #include <vector>
88 
89 using namespace llvm;
90 
91 #define DEBUG_TYPE "divergence-analysis"
92 
93 // class DivergenceAnalysis
94 DivergenceAnalysis::DivergenceAnalysis(
95     const Function &F, const Loop *RegionLoop, const DominatorTree &DT,
96     const LoopInfo &LI, SyncDependenceAnalysis &SDA, bool IsLCSSAForm)
97     : F(F), RegionLoop(RegionLoop), DT(DT), LI(LI), SDA(SDA),
98       IsLCSSAForm(IsLCSSAForm) {}
99 
100 bool DivergenceAnalysis::markDivergent(const Value &DivVal) {
101   if (isAlwaysUniform(DivVal))
102     return false;
103   assert(isa<Instruction>(DivVal) || isa<Argument>(DivVal));
104   assert(!isAlwaysUniform(DivVal) && "cannot be a divergent");
105   return DivergentValues.insert(&DivVal).second;
106 }
107 
108 void DivergenceAnalysis::addUniformOverride(const Value &UniVal) {
109   UniformOverrides.insert(&UniVal);
110 }
111 
112 bool DivergenceAnalysis::isTemporalDivergent(const BasicBlock &ObservingBlock,
113                                              const Value &Val) const {
114   const auto *Inst = dyn_cast<const Instruction>(&Val);
115   if (!Inst)
116     return false;
117   // check whether any divergent loop carrying Val terminates before control
118   // proceeds to ObservingBlock
119   for (const auto *Loop = LI.getLoopFor(Inst->getParent());
120        Loop != RegionLoop && !Loop->contains(&ObservingBlock);
121        Loop = Loop->getParentLoop()) {
122     if (DivergentLoops.find(Loop) != DivergentLoops.end())
123       return true;
124   }
125 
126   return false;
127 }
128 
129 bool DivergenceAnalysis::inRegion(const Instruction &I) const {
130   return I.getParent() && inRegion(*I.getParent());
131 }
132 
133 bool DivergenceAnalysis::inRegion(const BasicBlock &BB) const {
134   return (!RegionLoop && BB.getParent() == &F) || RegionLoop->contains(&BB);
135 }
136 
137 void DivergenceAnalysis::pushUsers(const Value &V) {
138   const auto *I = dyn_cast<const Instruction>(&V);
139 
140   if (I && I->isTerminator()) {
141     analyzeControlDivergence(*I);
142     return;
143   }
144 
145   for (const auto *User : V.users()) {
146     const auto *UserInst = dyn_cast<const Instruction>(User);
147     if (!UserInst)
148       continue;
149 
150     // only compute divergent inside loop
151     if (!inRegion(*UserInst))
152       continue;
153 
154     // All users of divergent values are immediate divergent
155     if (markDivergent(*UserInst))
156       Worklist.push_back(UserInst);
157   }
158 }
159 
160 static const Instruction *getIfCarriedInstruction(const Use &U,
161                                                   const Loop &DivLoop) {
162   const auto *I = dyn_cast<const Instruction>(&U);
163   if (!I)
164     return nullptr;
165   if (!DivLoop.contains(I))
166     return nullptr;
167   return I;
168 }
169 
170 void DivergenceAnalysis::analyzeTemporalDivergence(const Instruction &I,
171                                                    const Loop &OuterDivLoop) {
172   if (isAlwaysUniform(I))
173     return;
174   if (isDivergent(I))
175     return;
176 
177   LLVM_DEBUG(dbgs() << "Analyze temporal divergence: " << I.getName() << "\n");
178   assert((isa<PHINode>(I) || !IsLCSSAForm) &&
179          "In LCSSA form all users of loop-exiting defs are Phi nodes.");
180   for (const Use &Op : I.operands()) {
181     const auto *OpInst = getIfCarriedInstruction(Op, OuterDivLoop);
182     if (!OpInst)
183       continue;
184     if (markDivergent(I))
185       pushUsers(I);
186     return;
187   }
188 }
189 
190 // marks all users of loop-carried values of the loop headed by LoopHeader as
191 // divergent
192 void DivergenceAnalysis::analyzeLoopExitDivergence(const BasicBlock &DivExit,
193                                                    const Loop &OuterDivLoop) {
194   // All users are in immediate exit blocks
195   if (IsLCSSAForm) {
196     for (const auto &Phi : DivExit.phis()) {
197       analyzeTemporalDivergence(Phi, OuterDivLoop);
198     }
199     return;
200   }
201 
202   // For non-LCSSA we have to follow all live out edges wherever they may lead.
203   const BasicBlock &LoopHeader = *OuterDivLoop.getHeader();
204   SmallVector<const BasicBlock *, 8> TaintStack;
205   TaintStack.push_back(&DivExit);
206 
207   // Otherwise potential users of loop-carried values could be anywhere in the
208   // dominance region of DivLoop (including its fringes for phi nodes)
209   DenseSet<const BasicBlock *> Visited;
210   Visited.insert(&DivExit);
211 
212   do {
213     auto *UserBlock = TaintStack.back();
214     TaintStack.pop_back();
215 
216     // don't spread divergence beyond the region
217     if (!inRegion(*UserBlock))
218       continue;
219 
220     assert(!OuterDivLoop.contains(UserBlock) &&
221            "irreducible control flow detected");
222 
223     // phi nodes at the fringes of the dominance region
224     if (!DT.dominates(&LoopHeader, UserBlock)) {
225       // all PHI nodes of UserBlock become divergent
226       for (auto &Phi : UserBlock->phis()) {
227         analyzeTemporalDivergence(Phi, OuterDivLoop);
228       }
229       continue;
230     }
231 
232     // Taint outside users of values carried by OuterDivLoop.
233     for (auto &I : *UserBlock) {
234       analyzeTemporalDivergence(I, OuterDivLoop);
235     }
236 
237     // visit all blocks in the dominance region
238     for (auto *SuccBlock : successors(UserBlock)) {
239       if (!Visited.insert(SuccBlock).second) {
240         continue;
241       }
242       TaintStack.push_back(SuccBlock);
243     }
244   } while (!TaintStack.empty());
245 }
246 
247 void DivergenceAnalysis::propagateLoopExitDivergence(const BasicBlock &DivExit,
248                                                      const Loop &InnerDivLoop) {
249   LLVM_DEBUG(dbgs() << "\tpropLoopExitDiv " << DivExit.getName() << "\n");
250 
251   // Find outer-most loop that does not contain \p DivExit
252   const Loop *DivLoop = &InnerDivLoop;
253   const Loop *OuterDivLoop = DivLoop;
254   const Loop *ExitLevelLoop = LI.getLoopFor(&DivExit);
255   const unsigned LoopExitDepth =
256       ExitLevelLoop ? ExitLevelLoop->getLoopDepth() : 0;
257   while (DivLoop && DivLoop->getLoopDepth() > LoopExitDepth) {
258     DivergentLoops.insert(DivLoop); // all crossed loops are divergent
259     OuterDivLoop = DivLoop;
260     DivLoop = DivLoop->getParentLoop();
261   }
262   LLVM_DEBUG(dbgs() << "\tOuter-most left loop: " << OuterDivLoop->getName()
263                     << "\n");
264 
265   analyzeLoopExitDivergence(DivExit, *OuterDivLoop);
266 }
267 
268 // this is a divergent join point - mark all phi nodes as divergent and push
269 // them onto the stack.
270 void DivergenceAnalysis::taintAndPushPhiNodes(const BasicBlock &JoinBlock) {
271   LLVM_DEBUG(dbgs() << "taintAndPushPhiNodes in " << JoinBlock.getName()
272                     << "\n");
273 
274   // ignore divergence outside the region
275   if (!inRegion(JoinBlock)) {
276     return;
277   }
278 
279   // push non-divergent phi nodes in JoinBlock to the worklist
280   for (const auto &Phi : JoinBlock.phis()) {
281     if (isDivergent(Phi))
282       continue;
283     // FIXME Theoretically ,the 'undef' value could be replaced by any other
284     // value causing spurious divergence.
285     if (Phi.hasConstantOrUndefValue())
286       continue;
287     if (markDivergent(Phi))
288       Worklist.push_back(&Phi);
289   }
290 }
291 
292 void DivergenceAnalysis::analyzeControlDivergence(const Instruction &Term) {
293   LLVM_DEBUG(dbgs() << "analyzeControlDiv " << Term.getParent()->getName()
294                     << "\n");
295 
296   // Don't propagate divergence from unreachable blocks.
297   if (!DT.isReachableFromEntry(Term.getParent()))
298     return;
299 
300   const auto *BranchLoop = LI.getLoopFor(Term.getParent());
301 
302   const auto &DivDesc = SDA.getJoinBlocks(Term);
303 
304   // Iterate over all blocks now reachable by a disjoint path join
305   for (const auto *JoinBlock : DivDesc.JoinDivBlocks) {
306     taintAndPushPhiNodes(*JoinBlock);
307   }
308 
309   assert(DivDesc.LoopDivBlocks.empty() || BranchLoop);
310   for (const auto *DivExitBlock : DivDesc.LoopDivBlocks) {
311     propagateLoopExitDivergence(*DivExitBlock, *BranchLoop);
312   }
313 }
314 
315 void DivergenceAnalysis::compute() {
316   // Initialize worklist.
317   auto DivValuesCopy = DivergentValues;
318   for (const auto *DivVal : DivValuesCopy) {
319     assert(isDivergent(*DivVal) && "Worklist invariant violated!");
320     pushUsers(*DivVal);
321   }
322 
323   // All values on the Worklist are divergent.
324   // Their users may not have been updated yed.
325   while (!Worklist.empty()) {
326     const Instruction &I = *Worklist.back();
327     Worklist.pop_back();
328 
329     // propagate value divergence to users
330     assert(isDivergent(I) && "Worklist invariant violated!");
331     pushUsers(I);
332   }
333 }
334 
335 bool DivergenceAnalysis::isAlwaysUniform(const Value &V) const {
336   return UniformOverrides.find(&V) != UniformOverrides.end();
337 }
338 
339 bool DivergenceAnalysis::isDivergent(const Value &V) const {
340   return DivergentValues.find(&V) != DivergentValues.end();
341 }
342 
343 bool DivergenceAnalysis::isDivergentUse(const Use &U) const {
344   Value &V = *U.get();
345   Instruction &I = *cast<Instruction>(U.getUser());
346   return isDivergent(V) || isTemporalDivergent(*I.getParent(), V);
347 }
348 
349 void DivergenceAnalysis::print(raw_ostream &OS, const Module *) const {
350   if (DivergentValues.empty())
351     return;
352   // iterate instructions using instructions() to ensure a deterministic order.
353   for (auto &I : instructions(F)) {
354     if (isDivergent(I))
355       OS << "DIVERGENT:" << I << '\n';
356   }
357 }
358 
359 // class GPUDivergenceAnalysis
360 GPUDivergenceAnalysis::GPUDivergenceAnalysis(Function &F,
361                                              const DominatorTree &DT,
362                                              const PostDominatorTree &PDT,
363                                              const LoopInfo &LI,
364                                              const TargetTransformInfo &TTI)
365     : SDA(DT, PDT, LI), DA(F, nullptr, DT, LI, SDA, /* LCSSA */ false) {
366   for (auto &I : instructions(F)) {
367     if (TTI.isSourceOfDivergence(&I)) {
368       DA.markDivergent(I);
369     } else if (TTI.isAlwaysUniform(&I)) {
370       DA.addUniformOverride(I);
371     }
372   }
373   for (auto &Arg : F.args()) {
374     if (TTI.isSourceOfDivergence(&Arg)) {
375       DA.markDivergent(Arg);
376     }
377   }
378 
379   DA.compute();
380 }
381 
382 bool GPUDivergenceAnalysis::isDivergent(const Value &val) const {
383   return DA.isDivergent(val);
384 }
385 
386 bool GPUDivergenceAnalysis::isDivergentUse(const Use &use) const {
387   return DA.isDivergentUse(use);
388 }
389 
390 void GPUDivergenceAnalysis::print(raw_ostream &OS, const Module *mod) const {
391   OS << "Divergence of kernel " << DA.getFunction().getName() << " {\n";
392   DA.print(OS, mod);
393   OS << "}\n";
394 }
395