1 //===- TLSVariableHoist.cpp -------- Remove Redundant TLS Loads ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass identifies/eliminate Redundant TLS Loads if related option is set.
10 // The example: Please refer to the comment at the head of TLSVariableHoist.h.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/ADT/SmallVector.h"
15 #include "llvm/IR/BasicBlock.h"
16 #include "llvm/IR/Dominators.h"
17 #include "llvm/IR/Function.h"
18 #include "llvm/IR/InstrTypes.h"
19 #include "llvm/IR/Instruction.h"
20 #include "llvm/IR/Instructions.h"
21 #include "llvm/IR/IntrinsicInst.h"
22 #include "llvm/IR/Module.h"
23 #include "llvm/IR/Value.h"
24 #include "llvm/InitializePasses.h"
25 #include "llvm/Pass.h"
26 #include "llvm/Support/Casting.h"
27 #include "llvm/Support/Debug.h"
28 #include "llvm/Support/raw_ostream.h"
29 #include "llvm/Transforms/Scalar.h"
30 #include "llvm/Transforms/Scalar/TLSVariableHoist.h"
31 #include <algorithm>
32 #include <cassert>
33 #include <cstdint>
34 #include <iterator>
35 #include <tuple>
36 #include <utility>
37 
38 using namespace llvm;
39 using namespace tlshoist;
40 
41 #define DEBUG_TYPE "tlshoist"
42 
43 // TODO: Support "strict" model if we need to strictly load TLS address,
44 // because "non-optimize" may also do some optimization in other passes.
45 static cl::opt<std::string> TLSLoadHoist(
46     "tls-load-hoist",
47     cl::desc(
48         "hoist the TLS loads in PIC model: "
49         "tls-load-hoist=optimize: Eleminate redundant TLS load(s)."
50         "tls-load-hoist=strict: Strictly load TLS address before every use."
51         "tls-load-hoist=non-optimize: Generally load TLS before use(s)."),
52     cl::init("non-optimize"), cl::Hidden);
53 
54 namespace {
55 
56 /// The TLS Variable hoist pass.
57 class TLSVariableHoistLegacyPass : public FunctionPass {
58 public:
59   static char ID; // Pass identification, replacement for typeid
60 
61   TLSVariableHoistLegacyPass() : FunctionPass(ID) {
62     initializeTLSVariableHoistLegacyPassPass(*PassRegistry::getPassRegistry());
63   }
64 
65   bool runOnFunction(Function &Fn) override;
66 
67   StringRef getPassName() const override { return "TLS Variable Hoist"; }
68 
69   void getAnalysisUsage(AnalysisUsage &AU) const override {
70     AU.setPreservesCFG();
71     AU.addRequired<DominatorTreeWrapperPass>();
72     AU.addRequired<LoopInfoWrapperPass>();
73   }
74 
75 private:
76   TLSVariableHoistPass Impl;
77 };
78 
79 } // end anonymous namespace
80 
81 char TLSVariableHoistLegacyPass::ID = 0;
82 
83 INITIALIZE_PASS_BEGIN(TLSVariableHoistLegacyPass, "tlshoist",
84                       "TLS Variable Hoist", false, false)
85 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
86 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
87 INITIALIZE_PASS_END(TLSVariableHoistLegacyPass, "tlshoist",
88                     "TLS Variable Hoist", false, false)
89 
90 FunctionPass *llvm::createTLSVariableHoistPass() {
91   return new TLSVariableHoistLegacyPass();
92 }
93 
94 /// Perform the TLS Variable Hoist optimization for the given function.
95 bool TLSVariableHoistLegacyPass::runOnFunction(Function &Fn) {
96   if (skipFunction(Fn))
97     return false;
98 
99   LLVM_DEBUG(dbgs() << "********** Begin TLS Variable Hoist **********\n");
100   LLVM_DEBUG(dbgs() << "********** Function: " << Fn.getName() << '\n');
101 
102   bool MadeChange =
103       Impl.runImpl(Fn, getAnalysis<DominatorTreeWrapperPass>().getDomTree(),
104                    getAnalysis<LoopInfoWrapperPass>().getLoopInfo());
105 
106   if (MadeChange) {
107     LLVM_DEBUG(dbgs() << "********** Function after TLS Variable Hoist: "
108                       << Fn.getName() << '\n');
109     LLVM_DEBUG(dbgs() << Fn);
110   }
111   LLVM_DEBUG(dbgs() << "********** End TLS Variable Hoist **********\n");
112 
113   return MadeChange;
114 }
115 
116 void TLSVariableHoistPass::collectTLSCandidate(Instruction *Inst) {
117   // Skip all cast instructions. They are visited indirectly later on.
118   if (Inst->isCast())
119     return;
120 
121   // Scan all operands.
122   for (unsigned Idx = 0, E = Inst->getNumOperands(); Idx != E; ++Idx) {
123     auto *GV = dyn_cast<GlobalVariable>(Inst->getOperand(Idx));
124     if (!GV || !GV->isThreadLocal())
125       continue;
126 
127     // Add Candidate to TLSCandMap (GV --> Candidate).
128     TLSCandMap[GV].addUser(Inst, Idx);
129   }
130 }
131 
132 void TLSVariableHoistPass::collectTLSCandidates(Function &Fn) {
133   // First, quickly check if there is TLS Variable.
134   Module *M = Fn.getParent();
135 
136   bool HasTLS = llvm::any_of(
137       M->globals(), [](GlobalVariable &GV) { return GV.isThreadLocal(); });
138 
139   // If non, directly return.
140   if (!HasTLS)
141     return;
142 
143   TLSCandMap.clear();
144 
145   // Then, collect TLS Variable info.
146   for (BasicBlock &BB : Fn) {
147     // Ignore unreachable basic blocks.
148     if (!DT->isReachableFromEntry(&BB))
149       continue;
150 
151     for (Instruction &Inst : BB)
152       collectTLSCandidate(&Inst);
153   }
154 }
155 
156 static bool oneUseOutsideLoop(tlshoist::TLSCandidate &Cand, LoopInfo *LI) {
157   if (Cand.Users.size() != 1)
158     return false;
159 
160   BasicBlock *BB = Cand.Users[0].Inst->getParent();
161   if (LI->getLoopFor(BB))
162     return false;
163 
164   return true;
165 }
166 
167 Instruction *TLSVariableHoistPass::getNearestLoopDomInst(BasicBlock *BB,
168                                                          Loop *L) {
169   assert(L && "Unexcepted Loop status!");
170 
171   // Get the outermost loop.
172   while (Loop *Parent = L->getParentLoop())
173     L = Parent;
174 
175   BasicBlock *PreHeader = L->getLoopPreheader();
176 
177   // There is unique predecessor outside the loop.
178   if (PreHeader)
179     return PreHeader->getTerminator();
180 
181   BasicBlock *Header = L->getHeader();
182   BasicBlock *Dom = Header;
183   for (BasicBlock *PredBB : predecessors(Header))
184     Dom = DT->findNearestCommonDominator(Dom, PredBB);
185 
186   assert(Dom && "Not find dominator BB!");
187   Instruction *Term = Dom->getTerminator();
188 
189   return Term;
190 }
191 
192 Instruction *TLSVariableHoistPass::getDomInst(Instruction *I1,
193                                               Instruction *I2) {
194   if (!I1)
195     return I2;
196   if (DT->dominates(I1, I2))
197     return I1;
198   if (DT->dominates(I2, I1))
199     return I2;
200 
201   // If there is no dominance relation, use common dominator.
202   BasicBlock *DomBB =
203       DT->findNearestCommonDominator(I1->getParent(), I2->getParent());
204 
205   Instruction *Dom = DomBB->getTerminator();
206   assert(Dom && "Common dominator not found!");
207 
208   return Dom;
209 }
210 
211 BasicBlock::iterator TLSVariableHoistPass::findInsertPos(Function &Fn,
212                                                          GlobalVariable *GV,
213                                                          BasicBlock *&PosBB) {
214   tlshoist::TLSCandidate &Cand = TLSCandMap[GV];
215 
216   // We should hoist the TLS use out of loop, so choose its nearest instruction
217   // which dominate the loop and the outside loops (if exist).
218   Instruction *LastPos = nullptr;
219   for (auto &User : Cand.Users) {
220     BasicBlock *BB = User.Inst->getParent();
221     Instruction *Pos = User.Inst;
222     if (Loop *L = LI->getLoopFor(BB)) {
223       Pos = getNearestLoopDomInst(BB, L);
224       assert(Pos && "Not find insert position out of loop!");
225     }
226     Pos = getDomInst(LastPos, Pos);
227     LastPos = Pos;
228   }
229 
230   assert(LastPos && "Unexpected insert position!");
231   BasicBlock *Parent = LastPos->getParent();
232   PosBB = Parent;
233   return LastPos->getIterator();
234 }
235 
236 // Generate a bitcast (no type change) to replace the uses of TLS Candidate.
237 Instruction *TLSVariableHoistPass::genBitCastInst(Function &Fn,
238                                                   GlobalVariable *GV) {
239   BasicBlock *PosBB = &Fn.getEntryBlock();
240   BasicBlock::iterator Iter = findInsertPos(Fn, GV, PosBB);
241   Type *Ty = GV->getType();
242   auto *CastInst = new BitCastInst(GV, Ty, "tls_bitcast");
243   PosBB->getInstList().insert(Iter, CastInst);
244   return CastInst;
245 }
246 
247 bool TLSVariableHoistPass::tryReplaceTLSCandidate(Function &Fn,
248                                                   GlobalVariable *GV) {
249 
250   tlshoist::TLSCandidate &Cand = TLSCandMap[GV];
251 
252   // If only used 1 time and not in loops, we no need to replace it.
253   if (oneUseOutsideLoop(Cand, LI))
254     return false;
255 
256   // Generate a bitcast (no type change)
257   auto *CastInst = genBitCastInst(Fn, GV);
258 
259   // to replace the uses of TLS Candidate
260   for (auto &User : Cand.Users)
261     User.Inst->setOperand(User.OpndIdx, CastInst);
262 
263   return true;
264 }
265 
266 bool TLSVariableHoistPass::tryReplaceTLSCandidates(Function &Fn) {
267   if (TLSCandMap.empty())
268     return false;
269 
270   bool Replaced = false;
271   for (auto &GV2Cand : TLSCandMap) {
272     GlobalVariable *GV = GV2Cand.first;
273     Replaced |= tryReplaceTLSCandidate(Fn, GV);
274   }
275 
276   return Replaced;
277 }
278 
279 /// Optimize expensive TLS variables in the given function.
280 bool TLSVariableHoistPass::runImpl(Function &Fn, DominatorTree &DT,
281                                    LoopInfo &LI) {
282   if (Fn.hasOptNone())
283     return false;
284 
285   if (TLSLoadHoist != "optimize" &&
286       !Fn.getAttributes().hasFnAttr("tls-load-hoist"))
287     return false;
288 
289   this->LI = &LI;
290   this->DT = &DT;
291   assert(this->LI && this->DT && "Unexcepted requirement!");
292 
293   // Collect all TLS variable candidates.
294   collectTLSCandidates(Fn);
295 
296   bool MadeChange = tryReplaceTLSCandidates(Fn);
297 
298   return MadeChange;
299 }
300 
301 PreservedAnalyses TLSVariableHoistPass::run(Function &F,
302                                             FunctionAnalysisManager &AM) {
303 
304   auto &LI = AM.getResult<LoopAnalysis>(F);
305   auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
306 
307   if (!runImpl(F, DT, LI))
308     return PreservedAnalyses::all();
309 
310   PreservedAnalyses PA;
311   PA.preserveSet<CFGAnalyses>();
312   return PA;
313 }
314