1 //===- TLSVariableHoist.cpp -------- Remove Redundant TLS Loads ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This pass identifies/eliminate Redundant TLS Loads if related option is set. 10 // The example: Please refer to the comment at the head of TLSVariableHoist.h. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "llvm/ADT/SmallVector.h" 15 #include "llvm/IR/BasicBlock.h" 16 #include "llvm/IR/Dominators.h" 17 #include "llvm/IR/Function.h" 18 #include "llvm/IR/InstrTypes.h" 19 #include "llvm/IR/Instruction.h" 20 #include "llvm/IR/Instructions.h" 21 #include "llvm/IR/IntrinsicInst.h" 22 #include "llvm/IR/Module.h" 23 #include "llvm/IR/Value.h" 24 #include "llvm/InitializePasses.h" 25 #include "llvm/Pass.h" 26 #include "llvm/Support/Casting.h" 27 #include "llvm/Support/Debug.h" 28 #include "llvm/Support/raw_ostream.h" 29 #include "llvm/Transforms/Scalar.h" 30 #include "llvm/Transforms/Scalar/TLSVariableHoist.h" 31 #include <algorithm> 32 #include <cassert> 33 #include <cstdint> 34 #include <iterator> 35 #include <tuple> 36 #include <utility> 37 38 using namespace llvm; 39 using namespace tlshoist; 40 41 #define DEBUG_TYPE "tlshoist" 42 43 // TODO: Support "strict" model if we need to strictly load TLS address, 44 // because "non-optimize" may also do some optimization in other passes. 45 static cl::opt<std::string> TLSLoadHoist( 46 "tls-load-hoist", 47 cl::desc( 48 "hoist the TLS loads in PIC model: " 49 "tls-load-hoist=optimize: Eleminate redundant TLS load(s)." 50 "tls-load-hoist=strict: Strictly load TLS address before every use." 51 "tls-load-hoist=non-optimize: Generally load TLS before use(s)."), 52 cl::init("non-optimize"), cl::Hidden); 53 54 namespace { 55 56 /// The TLS Variable hoist pass. 57 class TLSVariableHoistLegacyPass : public FunctionPass { 58 public: 59 static char ID; // Pass identification, replacement for typeid 60 61 TLSVariableHoistLegacyPass() : FunctionPass(ID) { 62 initializeTLSVariableHoistLegacyPassPass(*PassRegistry::getPassRegistry()); 63 } 64 65 bool runOnFunction(Function &Fn) override; 66 67 StringRef getPassName() const override { return "TLS Variable Hoist"; } 68 69 void getAnalysisUsage(AnalysisUsage &AU) const override { 70 AU.setPreservesCFG(); 71 AU.addRequired<DominatorTreeWrapperPass>(); 72 AU.addRequired<LoopInfoWrapperPass>(); 73 } 74 75 private: 76 TLSVariableHoistPass Impl; 77 }; 78 79 } // end anonymous namespace 80 81 char TLSVariableHoistLegacyPass::ID = 0; 82 83 INITIALIZE_PASS_BEGIN(TLSVariableHoistLegacyPass, "tlshoist", 84 "TLS Variable Hoist", false, false) 85 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) 86 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) 87 INITIALIZE_PASS_END(TLSVariableHoistLegacyPass, "tlshoist", 88 "TLS Variable Hoist", false, false) 89 90 FunctionPass *llvm::createTLSVariableHoistPass() { 91 return new TLSVariableHoistLegacyPass(); 92 } 93 94 /// Perform the TLS Variable Hoist optimization for the given function. 95 bool TLSVariableHoistLegacyPass::runOnFunction(Function &Fn) { 96 if (skipFunction(Fn)) 97 return false; 98 99 LLVM_DEBUG(dbgs() << "********** Begin TLS Variable Hoist **********\n"); 100 LLVM_DEBUG(dbgs() << "********** Function: " << Fn.getName() << '\n'); 101 102 bool MadeChange = 103 Impl.runImpl(Fn, getAnalysis<DominatorTreeWrapperPass>().getDomTree(), 104 getAnalysis<LoopInfoWrapperPass>().getLoopInfo()); 105 106 if (MadeChange) { 107 LLVM_DEBUG(dbgs() << "********** Function after TLS Variable Hoist: " 108 << Fn.getName() << '\n'); 109 LLVM_DEBUG(dbgs() << Fn); 110 } 111 LLVM_DEBUG(dbgs() << "********** End TLS Variable Hoist **********\n"); 112 113 return MadeChange; 114 } 115 116 void TLSVariableHoistPass::collectTLSCandidate(Instruction *Inst) { 117 // Skip all cast instructions. They are visited indirectly later on. 118 if (Inst->isCast()) 119 return; 120 121 // Scan all operands. 122 for (unsigned Idx = 0, E = Inst->getNumOperands(); Idx != E; ++Idx) { 123 auto *GV = dyn_cast<GlobalVariable>(Inst->getOperand(Idx)); 124 if (!GV || !GV->isThreadLocal()) 125 continue; 126 127 // Add Candidate to TLSCandMap (GV --> Candidate). 128 TLSCandMap[GV].addUser(Inst, Idx); 129 } 130 } 131 132 void TLSVariableHoistPass::collectTLSCandidates(Function &Fn) { 133 // First, quickly check if there is TLS Variable. 134 Module *M = Fn.getParent(); 135 136 bool HasTLS = llvm::any_of( 137 M->globals(), [](GlobalVariable &GV) { return GV.isThreadLocal(); }); 138 139 // If non, directly return. 140 if (!HasTLS) 141 return; 142 143 TLSCandMap.clear(); 144 145 // Then, collect TLS Variable info. 146 for (BasicBlock &BB : Fn) { 147 // Ignore unreachable basic blocks. 148 if (!DT->isReachableFromEntry(&BB)) 149 continue; 150 151 for (Instruction &Inst : BB) 152 collectTLSCandidate(&Inst); 153 } 154 } 155 156 static bool oneUseOutsideLoop(tlshoist::TLSCandidate &Cand, LoopInfo *LI) { 157 if (Cand.Users.size() != 1) 158 return false; 159 160 BasicBlock *BB = Cand.Users[0].Inst->getParent(); 161 if (LI->getLoopFor(BB)) 162 return false; 163 164 return true; 165 } 166 167 Instruction *TLSVariableHoistPass::getNearestLoopDomInst(BasicBlock *BB, 168 Loop *L) { 169 assert(L && "Unexcepted Loop status!"); 170 171 // Get the outermost loop. 172 while (Loop *Parent = L->getParentLoop()) 173 L = Parent; 174 175 BasicBlock *PreHeader = L->getLoopPreheader(); 176 177 // There is unique predecessor outside the loop. 178 if (PreHeader) 179 return PreHeader->getTerminator(); 180 181 BasicBlock *Header = L->getHeader(); 182 BasicBlock *Dom = Header; 183 for (BasicBlock *PredBB : predecessors(Header)) 184 Dom = DT->findNearestCommonDominator(Dom, PredBB); 185 186 assert(Dom && "Not find dominator BB!"); 187 Instruction *Term = Dom->getTerminator(); 188 189 return Term; 190 } 191 192 Instruction *TLSVariableHoistPass::getDomInst(Instruction *I1, 193 Instruction *I2) { 194 if (!I1) 195 return I2; 196 if (DT->dominates(I1, I2)) 197 return I1; 198 if (DT->dominates(I2, I1)) 199 return I2; 200 201 // If there is no dominance relation, use common dominator. 202 BasicBlock *DomBB = 203 DT->findNearestCommonDominator(I1->getParent(), I2->getParent()); 204 205 Instruction *Dom = DomBB->getTerminator(); 206 assert(Dom && "Common dominator not found!"); 207 208 return Dom; 209 } 210 211 BasicBlock::iterator TLSVariableHoistPass::findInsertPos(Function &Fn, 212 GlobalVariable *GV, 213 BasicBlock *&PosBB) { 214 tlshoist::TLSCandidate &Cand = TLSCandMap[GV]; 215 216 // We should hoist the TLS use out of loop, so choose its nearest instruction 217 // which dominate the loop and the outside loops (if exist). 218 Instruction *LastPos = nullptr; 219 for (auto &User : Cand.Users) { 220 BasicBlock *BB = User.Inst->getParent(); 221 Instruction *Pos = User.Inst; 222 if (Loop *L = LI->getLoopFor(BB)) { 223 Pos = getNearestLoopDomInst(BB, L); 224 assert(Pos && "Not find insert position out of loop!"); 225 } 226 Pos = getDomInst(LastPos, Pos); 227 LastPos = Pos; 228 } 229 230 assert(LastPos && "Unexpected insert position!"); 231 BasicBlock *Parent = LastPos->getParent(); 232 PosBB = Parent; 233 return LastPos->getIterator(); 234 } 235 236 // Generate a bitcast (no type change) to replace the uses of TLS Candidate. 237 Instruction *TLSVariableHoistPass::genBitCastInst(Function &Fn, 238 GlobalVariable *GV) { 239 BasicBlock *PosBB = &Fn.getEntryBlock(); 240 BasicBlock::iterator Iter = findInsertPos(Fn, GV, PosBB); 241 Type *Ty = GV->getType(); 242 auto *CastInst = new BitCastInst(GV, Ty, "tls_bitcast"); 243 PosBB->getInstList().insert(Iter, CastInst); 244 return CastInst; 245 } 246 247 bool TLSVariableHoistPass::tryReplaceTLSCandidate(Function &Fn, 248 GlobalVariable *GV) { 249 250 tlshoist::TLSCandidate &Cand = TLSCandMap[GV]; 251 252 // If only used 1 time and not in loops, we no need to replace it. 253 if (oneUseOutsideLoop(Cand, LI)) 254 return false; 255 256 // Generate a bitcast (no type change) 257 auto *CastInst = genBitCastInst(Fn, GV); 258 259 // to replace the uses of TLS Candidate 260 for (auto &User : Cand.Users) 261 User.Inst->setOperand(User.OpndIdx, CastInst); 262 263 return true; 264 } 265 266 bool TLSVariableHoistPass::tryReplaceTLSCandidates(Function &Fn) { 267 if (TLSCandMap.empty()) 268 return false; 269 270 bool Replaced = false; 271 for (auto &GV2Cand : TLSCandMap) { 272 GlobalVariable *GV = GV2Cand.first; 273 Replaced |= tryReplaceTLSCandidate(Fn, GV); 274 } 275 276 return Replaced; 277 } 278 279 /// Optimize expensive TLS variables in the given function. 280 bool TLSVariableHoistPass::runImpl(Function &Fn, DominatorTree &DT, 281 LoopInfo &LI) { 282 if (Fn.hasOptNone()) 283 return false; 284 285 if (TLSLoadHoist != "optimize" && 286 !Fn.getAttributes().hasFnAttr("tls-load-hoist")) 287 return false; 288 289 this->LI = &LI; 290 this->DT = &DT; 291 assert(this->LI && this->DT && "Unexcepted requirement!"); 292 293 // Collect all TLS variable candidates. 294 collectTLSCandidates(Fn); 295 296 bool MadeChange = tryReplaceTLSCandidates(Fn); 297 298 return MadeChange; 299 } 300 301 PreservedAnalyses TLSVariableHoistPass::run(Function &F, 302 FunctionAnalysisManager &AM) { 303 304 auto &LI = AM.getResult<LoopAnalysis>(F); 305 auto &DT = AM.getResult<DominatorTreeAnalysis>(F); 306 307 if (!runImpl(F, DT, LI)) 308 return PreservedAnalyses::all(); 309 310 PreservedAnalyses PA; 311 PA.preserveSet<CFGAnalyses>(); 312 return PA; 313 } 314