1 //===------- LoopBoundSplit.cpp - Split Loop Bound --------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "llvm/Transforms/Scalar/LoopBoundSplit.h" 10 #include "llvm/ADT/Sequence.h" 11 #include "llvm/Analysis/LoopAccessAnalysis.h" 12 #include "llvm/Analysis/LoopAnalysisManager.h" 13 #include "llvm/Analysis/LoopInfo.h" 14 #include "llvm/Analysis/LoopIterator.h" 15 #include "llvm/Analysis/LoopPass.h" 16 #include "llvm/Analysis/MemorySSA.h" 17 #include "llvm/Analysis/MemorySSAUpdater.h" 18 #include "llvm/Analysis/ScalarEvolution.h" 19 #include "llvm/Analysis/ScalarEvolutionExpressions.h" 20 #include "llvm/IR/PatternMatch.h" 21 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 22 #include "llvm/Transforms/Utils/Cloning.h" 23 #include "llvm/Transforms/Utils/LoopSimplify.h" 24 #include "llvm/Transforms/Utils/LoopUtils.h" 25 #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" 26 27 #define DEBUG_TYPE "loop-bound-split" 28 29 namespace llvm { 30 31 using namespace PatternMatch; 32 33 namespace { 34 struct ConditionInfo { 35 /// Branch instruction with this condition 36 BranchInst *BI; 37 /// ICmp instruction with this condition 38 ICmpInst *ICmp; 39 /// Preciate info 40 ICmpInst::Predicate Pred; 41 /// AddRec llvm value 42 Value *AddRecValue; 43 /// Bound llvm value 44 Value *BoundValue; 45 /// AddRec SCEV 46 const SCEV *AddRecSCEV; 47 /// Bound SCEV 48 const SCEV *BoundSCEV; 49 50 ConditionInfo() 51 : BI(nullptr), ICmp(nullptr), Pred(ICmpInst::BAD_ICMP_PREDICATE), 52 AddRecValue(nullptr), BoundValue(nullptr), AddRecSCEV(nullptr), 53 BoundSCEV(nullptr) {} 54 }; 55 } // namespace 56 57 static void analyzeICmp(ScalarEvolution &SE, ICmpInst *ICmp, 58 ConditionInfo &Cond) { 59 Cond.ICmp = ICmp; 60 if (match(ICmp, m_ICmp(Cond.Pred, m_Value(Cond.AddRecValue), 61 m_Value(Cond.BoundValue)))) { 62 Cond.AddRecSCEV = SE.getSCEV(Cond.AddRecValue); 63 Cond.BoundSCEV = SE.getSCEV(Cond.BoundValue); 64 // Locate AddRec in LHSSCEV and Bound in RHSSCEV. 65 if (isa<SCEVAddRecExpr>(Cond.BoundSCEV) && 66 !isa<SCEVAddRecExpr>(Cond.AddRecSCEV)) { 67 std::swap(Cond.AddRecValue, Cond.BoundValue); 68 std::swap(Cond.AddRecSCEV, Cond.BoundSCEV); 69 Cond.Pred = ICmpInst::getSwappedPredicate(Cond.Pred); 70 } 71 } 72 } 73 74 static bool calculateUpperBound(const Loop &L, ScalarEvolution &SE, 75 ConditionInfo &Cond, bool IsExitCond) { 76 if (IsExitCond) { 77 const SCEV *ExitCount = SE.getExitCount(&L, Cond.ICmp->getParent()); 78 if (isa<SCEVCouldNotCompute>(ExitCount)) 79 return false; 80 81 Cond.BoundSCEV = ExitCount; 82 return true; 83 } 84 85 // For non-exit condtion, if pred is LT, keep existing bound. 86 if (Cond.Pred == ICmpInst::ICMP_SLT || Cond.Pred == ICmpInst::ICMP_ULT) 87 return true; 88 89 // For non-exit condition, if pre is LE, try to convert it to LT. 90 // Range Range 91 // AddRec <= Bound --> AddRec < Bound + 1 92 if (Cond.Pred != ICmpInst::ICMP_ULE && Cond.Pred != ICmpInst::ICMP_SLE) 93 return false; 94 95 if (IntegerType *BoundSCEVIntType = 96 dyn_cast<IntegerType>(Cond.BoundSCEV->getType())) { 97 unsigned BitWidth = BoundSCEVIntType->getBitWidth(); 98 APInt Max = ICmpInst::isSigned(Cond.Pred) 99 ? APInt::getSignedMaxValue(BitWidth) 100 : APInt::getMaxValue(BitWidth); 101 const SCEV *MaxSCEV = SE.getConstant(Max); 102 // Check Bound < INT_MAX 103 ICmpInst::Predicate Pred = 104 ICmpInst::isSigned(Cond.Pred) ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT; 105 if (SE.isKnownPredicate(Pred, Cond.BoundSCEV, MaxSCEV)) { 106 const SCEV *BoundPlusOneSCEV = 107 SE.getAddExpr(Cond.BoundSCEV, SE.getOne(BoundSCEVIntType)); 108 Cond.BoundSCEV = BoundPlusOneSCEV; 109 Cond.Pred = Pred; 110 return true; 111 } 112 } 113 114 // ToDo: Support ICMP_NE/EQ. 115 116 return false; 117 } 118 119 static bool hasProcessableCondition(const Loop &L, ScalarEvolution &SE, 120 ICmpInst *ICmp, ConditionInfo &Cond, 121 bool IsExitCond) { 122 analyzeICmp(SE, ICmp, Cond); 123 124 // The BoundSCEV should be evaluated at loop entry. 125 if (!SE.isAvailableAtLoopEntry(Cond.BoundSCEV, &L)) 126 return false; 127 128 const SCEVAddRecExpr *AddRecSCEV = dyn_cast<SCEVAddRecExpr>(Cond.AddRecSCEV); 129 // Allowed AddRec as induction variable. 130 if (!AddRecSCEV) 131 return false; 132 133 if (!AddRecSCEV->isAffine()) 134 return false; 135 136 const SCEV *StepRecSCEV = AddRecSCEV->getStepRecurrence(SE); 137 // Allowed constant step. 138 if (!isa<SCEVConstant>(StepRecSCEV)) 139 return false; 140 141 ConstantInt *StepCI = cast<SCEVConstant>(StepRecSCEV)->getValue(); 142 // Allowed positive step for now. 143 // TODO: Support negative step. 144 if (StepCI->isNegative() || StepCI->isZero()) 145 return false; 146 147 // Calculate upper bound. 148 if (!calculateUpperBound(L, SE, Cond, IsExitCond)) 149 return false; 150 151 return true; 152 } 153 154 static bool isProcessableCondBI(const ScalarEvolution &SE, 155 const BranchInst *BI) { 156 BasicBlock *TrueSucc = nullptr; 157 BasicBlock *FalseSucc = nullptr; 158 ICmpInst::Predicate Pred; 159 Value *LHS, *RHS; 160 if (!match(BI, m_Br(m_ICmp(Pred, m_Value(LHS), m_Value(RHS)), 161 m_BasicBlock(TrueSucc), m_BasicBlock(FalseSucc)))) 162 return false; 163 164 if (!SE.isSCEVable(LHS->getType())) 165 return false; 166 assert(SE.isSCEVable(RHS->getType()) && "Expected RHS's type is SCEVable"); 167 168 if (TrueSucc == FalseSucc) 169 return false; 170 171 return true; 172 } 173 174 static bool canSplitLoopBound(const Loop &L, const DominatorTree &DT, 175 ScalarEvolution &SE, ConditionInfo &Cond) { 176 // Skip function with optsize. 177 if (L.getHeader()->getParent()->hasOptSize()) 178 return false; 179 180 // Split only innermost loop. 181 if (!L.isInnermost()) 182 return false; 183 184 // Check loop is in simplified form. 185 if (!L.isLoopSimplifyForm()) 186 return false; 187 188 // Check loop is in LCSSA form. 189 if (!L.isLCSSAForm(DT)) 190 return false; 191 192 // Skip loop that cannot be cloned. 193 if (!L.isSafeToClone()) 194 return false; 195 196 BasicBlock *ExitingBB = L.getExitingBlock(); 197 // Assumed only one exiting block. 198 if (!ExitingBB) 199 return false; 200 201 BranchInst *ExitingBI = dyn_cast<BranchInst>(ExitingBB->getTerminator()); 202 if (!ExitingBI) 203 return false; 204 205 // Allowed only conditional branch with ICmp. 206 if (!isProcessableCondBI(SE, ExitingBI)) 207 return false; 208 209 // Check the condition is processable. 210 ICmpInst *ICmp = cast<ICmpInst>(ExitingBI->getCondition()); 211 if (!hasProcessableCondition(L, SE, ICmp, Cond, /*IsExitCond*/ true)) 212 return false; 213 214 Cond.BI = ExitingBI; 215 return true; 216 } 217 218 static bool isProfitableToTransform(const Loop &L, const BranchInst *BI) { 219 // If the conditional branch splits a loop into two halves, we could 220 // generally say it is profitable. 221 // 222 // ToDo: Add more profitable cases here. 223 224 // Check this branch causes diamond CFG. 225 BasicBlock *Succ0 = BI->getSuccessor(0); 226 BasicBlock *Succ1 = BI->getSuccessor(1); 227 228 BasicBlock *Succ0Succ = Succ0->getSingleSuccessor(); 229 BasicBlock *Succ1Succ = Succ1->getSingleSuccessor(); 230 if (!Succ0Succ || !Succ1Succ || Succ0Succ != Succ1Succ) 231 return false; 232 233 // ToDo: Calculate each successor's instruction cost. 234 235 return true; 236 } 237 238 static BranchInst *findSplitCandidate(const Loop &L, ScalarEvolution &SE, 239 ConditionInfo &ExitingCond, 240 ConditionInfo &SplitCandidateCond) { 241 for (auto *BB : L.blocks()) { 242 // Skip condition of backedge. 243 if (L.getLoopLatch() == BB) 244 continue; 245 246 auto *BI = dyn_cast<BranchInst>(BB->getTerminator()); 247 if (!BI) 248 continue; 249 250 // Check conditional branch with ICmp. 251 if (!isProcessableCondBI(SE, BI)) 252 continue; 253 254 // Skip loop invariant condition. 255 if (L.isLoopInvariant(BI->getCondition())) 256 continue; 257 258 // Check the condition is processable. 259 ICmpInst *ICmp = cast<ICmpInst>(BI->getCondition()); 260 if (!hasProcessableCondition(L, SE, ICmp, SplitCandidateCond, 261 /*IsExitCond*/ false)) 262 continue; 263 264 if (ExitingCond.BoundSCEV->getType() != 265 SplitCandidateCond.BoundSCEV->getType()) 266 continue; 267 268 SplitCandidateCond.BI = BI; 269 return BI; 270 } 271 272 return nullptr; 273 } 274 275 static bool splitLoopBound(Loop &L, DominatorTree &DT, LoopInfo &LI, 276 ScalarEvolution &SE, LPMUpdater &U) { 277 ConditionInfo SplitCandidateCond; 278 ConditionInfo ExitingCond; 279 280 // Check we can split this loop's bound. 281 if (!canSplitLoopBound(L, DT, SE, ExitingCond)) 282 return false; 283 284 if (!findSplitCandidate(L, SE, ExitingCond, SplitCandidateCond)) 285 return false; 286 287 if (!isProfitableToTransform(L, SplitCandidateCond.BI)) 288 return false; 289 290 // Now, we have a split candidate. Let's build a form as below. 291 // +--------------------+ 292 // | preheader | 293 // | set up newbound | 294 // +--------------------+ 295 // | /----------------\ 296 // +--------v----v------+ | 297 // | header |---\ | 298 // | with true condition| | | 299 // +--------------------+ | | 300 // | | | 301 // +--------v-----------+ | | 302 // | if.then.BB | | | 303 // +--------------------+ | | 304 // | | | 305 // +--------v-----------<---/ | 306 // | latch >----------/ 307 // | with newbound | 308 // +--------------------+ 309 // | 310 // +--------v-----------+ 311 // | preheader2 |--------------\ 312 // | if (AddRec i != | | 313 // | org bound) | | 314 // +--------------------+ | 315 // | /----------------\ | 316 // +--------v----v------+ | | 317 // | header2 |---\ | | 318 // | conditional branch | | | | 319 // |with false condition| | | | 320 // +--------------------+ | | | 321 // | | | | 322 // +--------v-----------+ | | | 323 // | if.then.BB2 | | | | 324 // +--------------------+ | | | 325 // | | | | 326 // +--------v-----------<---/ | | 327 // | latch2 >----------/ | 328 // | with org bound | | 329 // +--------v-----------+ | 330 // | | 331 // | +---------------+ | 332 // +--> exit <-------/ 333 // +---------------+ 334 335 // Let's create post loop. 336 SmallVector<BasicBlock *, 8> PostLoopBlocks; 337 Loop *PostLoop; 338 ValueToValueMapTy VMap; 339 BasicBlock *PreHeader = L.getLoopPreheader(); 340 BasicBlock *SplitLoopPH = SplitEdge(PreHeader, L.getHeader(), &DT, &LI); 341 PostLoop = cloneLoopWithPreheader(L.getExitBlock(), SplitLoopPH, &L, VMap, 342 ".split", &LI, &DT, PostLoopBlocks); 343 remapInstructionsInBlocks(PostLoopBlocks, VMap); 344 345 // Add conditional branch to check we can skip post-loop in its preheader. 346 BasicBlock *PostLoopPreHeader = PostLoop->getLoopPreheader(); 347 IRBuilder<> Builder(PostLoopPreHeader); 348 Instruction *OrigBI = PostLoopPreHeader->getTerminator(); 349 ICmpInst::Predicate Pred = ICmpInst::ICMP_NE; 350 Value *Cond = 351 Builder.CreateICmp(Pred, ExitingCond.AddRecValue, ExitingCond.BoundValue); 352 Builder.CreateCondBr(Cond, PostLoop->getHeader(), PostLoop->getExitBlock()); 353 OrigBI->eraseFromParent(); 354 355 // Create new loop bound and add it into preheader of pre-loop. 356 const SCEV *NewBoundSCEV = ExitingCond.BoundSCEV; 357 const SCEV *SplitBoundSCEV = SplitCandidateCond.BoundSCEV; 358 NewBoundSCEV = ICmpInst::isSigned(ExitingCond.Pred) 359 ? SE.getSMinExpr(NewBoundSCEV, SplitBoundSCEV) 360 : SE.getUMinExpr(NewBoundSCEV, SplitBoundSCEV); 361 362 SCEVExpander Expander( 363 SE, L.getHeader()->getParent()->getParent()->getDataLayout(), "split"); 364 Instruction *InsertPt = SplitLoopPH->getTerminator(); 365 Value *NewBoundValue = 366 Expander.expandCodeFor(NewBoundSCEV, NewBoundSCEV->getType(), InsertPt); 367 NewBoundValue->setName("new.bound"); 368 369 // Replace exiting bound value of pre-loop NewBound. 370 ExitingCond.ICmp->setOperand(1, NewBoundValue); 371 372 // Replace IV's start value of post-loop by NewBound. 373 for (PHINode &PN : L.getHeader()->phis()) { 374 // Find PHI with exiting condition from pre-loop. 375 if (SE.isSCEVable(PN.getType()) && isa<SCEVAddRecExpr>(SE.getSCEV(&PN))) { 376 for (Value *Op : PN.incoming_values()) { 377 if (Op == ExitingCond.AddRecValue) { 378 // Find cloned PHI for post-loop. 379 PHINode *PostLoopPN = cast<PHINode>(VMap[&PN]); 380 PostLoopPN->setIncomingValueForBlock(PostLoopPreHeader, 381 NewBoundValue); 382 } 383 } 384 } 385 } 386 387 // Replace SplitCandidateCond.BI's condition of pre-loop by True. 388 LLVMContext &Context = PreHeader->getContext(); 389 SplitCandidateCond.BI->setCondition(ConstantInt::getTrue(Context)); 390 391 // Replace cloned SplitCandidateCond.BI's condition in post-loop by False. 392 BranchInst *ClonedSplitCandidateBI = 393 cast<BranchInst>(VMap[SplitCandidateCond.BI]); 394 ClonedSplitCandidateBI->setCondition(ConstantInt::getFalse(Context)); 395 396 // Replace exit branch target of pre-loop by post-loop's preheader. 397 if (L.getExitBlock() == ExitingCond.BI->getSuccessor(0)) 398 ExitingCond.BI->setSuccessor(0, PostLoopPreHeader); 399 else 400 ExitingCond.BI->setSuccessor(1, PostLoopPreHeader); 401 402 // Update phi node in exit block of post-loop. 403 for (PHINode &PN : PostLoop->getExitBlock()->phis()) { 404 for (auto i : seq<int>(0, PN.getNumOperands())) { 405 // Check incoming block is pre-loop's exiting block. 406 if (PN.getIncomingBlock(i) == L.getExitingBlock()) { 407 // Replace pre-loop's exiting block by post-loop's preheader. 408 PN.setIncomingBlock(i, PostLoopPreHeader); 409 // Add a new incoming value with post-loop's exiting block. 410 PN.addIncoming(VMap[PN.getIncomingValue(i)], 411 PostLoop->getExitingBlock()); 412 } 413 } 414 } 415 416 // Update dominator tree. 417 DT.changeImmediateDominator(PostLoopPreHeader, L.getExitingBlock()); 418 DT.changeImmediateDominator(PostLoop->getExitBlock(), PostLoopPreHeader); 419 420 // Invalidate cached SE information. 421 SE.forgetLoop(&L); 422 423 // Canonicalize loops. 424 // TODO: Try to update LCSSA information according to above change. 425 formLCSSA(L, DT, &LI, &SE); 426 simplifyLoop(&L, &DT, &LI, &SE, nullptr, nullptr, true); 427 formLCSSA(*PostLoop, DT, &LI, &SE); 428 simplifyLoop(PostLoop, &DT, &LI, &SE, nullptr, nullptr, true); 429 430 // Add new post-loop to loop pass manager. 431 U.addSiblingLoops(PostLoop); 432 433 return true; 434 } 435 436 PreservedAnalyses LoopBoundSplitPass::run(Loop &L, LoopAnalysisManager &AM, 437 LoopStandardAnalysisResults &AR, 438 LPMUpdater &U) { 439 Function &F = *L.getHeader()->getParent(); 440 (void)F; 441 442 LLVM_DEBUG(dbgs() << "Spliting bound of loop in " << F.getName() << ": " << L 443 << "\n"); 444 445 if (!splitLoopBound(L, AR.DT, AR.LI, AR.SE, U)) 446 return PreservedAnalyses::all(); 447 448 assert(AR.DT.verify(DominatorTree::VerificationLevel::Fast)); 449 AR.LI.verify(AR.DT); 450 451 return getLoopPassPreservedAnalyses(); 452 } 453 454 } // end namespace llvm 455