1 //===------- LoopBoundSplit.cpp - Split Loop Bound --------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "llvm/Transforms/Scalar/LoopBoundSplit.h" 10 #include "llvm/ADT/Sequence.h" 11 #include "llvm/Analysis/LoopAccessAnalysis.h" 12 #include "llvm/Analysis/LoopAnalysisManager.h" 13 #include "llvm/Analysis/LoopInfo.h" 14 #include "llvm/Analysis/LoopIterator.h" 15 #include "llvm/Analysis/LoopPass.h" 16 #include "llvm/Analysis/MemorySSA.h" 17 #include "llvm/Analysis/MemorySSAUpdater.h" 18 #include "llvm/Analysis/ScalarEvolution.h" 19 #include "llvm/Analysis/ScalarEvolutionExpressions.h" 20 #include "llvm/IR/PatternMatch.h" 21 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 22 #include "llvm/Transforms/Utils/Cloning.h" 23 #include "llvm/Transforms/Utils/LoopSimplify.h" 24 #include "llvm/Transforms/Utils/LoopUtils.h" 25 #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" 26 27 #define DEBUG_TYPE "loop-bound-split" 28 29 namespace llvm { 30 31 using namespace PatternMatch; 32 33 namespace { 34 struct ConditionInfo { 35 /// Branch instruction with this condition 36 BranchInst *BI = nullptr; 37 /// ICmp instruction with this condition 38 ICmpInst *ICmp = nullptr; 39 /// Preciate info 40 ICmpInst::Predicate Pred = ICmpInst::BAD_ICMP_PREDICATE; 41 /// AddRec llvm value 42 Value *AddRecValue = nullptr; 43 /// Non PHI AddRec llvm value 44 Value *NonPHIAddRecValue; 45 /// Bound llvm value 46 Value *BoundValue = nullptr; 47 /// AddRec SCEV 48 const SCEVAddRecExpr *AddRecSCEV = nullptr; 49 /// Bound SCEV 50 const SCEV *BoundSCEV = nullptr; 51 52 ConditionInfo() = default; 53 }; 54 } // namespace 55 56 static void analyzeICmp(ScalarEvolution &SE, ICmpInst *ICmp, 57 ConditionInfo &Cond, const Loop &L) { 58 Cond.ICmp = ICmp; 59 if (match(ICmp, m_ICmp(Cond.Pred, m_Value(Cond.AddRecValue), 60 m_Value(Cond.BoundValue)))) { 61 const SCEV *AddRecSCEV = SE.getSCEV(Cond.AddRecValue); 62 const SCEV *BoundSCEV = SE.getSCEV(Cond.BoundValue); 63 const SCEVAddRecExpr *LHSAddRecSCEV = dyn_cast<SCEVAddRecExpr>(AddRecSCEV); 64 const SCEVAddRecExpr *RHSAddRecSCEV = dyn_cast<SCEVAddRecExpr>(BoundSCEV); 65 // Locate AddRec in LHSSCEV and Bound in RHSSCEV. 66 if (!LHSAddRecSCEV && RHSAddRecSCEV) { 67 std::swap(Cond.AddRecValue, Cond.BoundValue); 68 std::swap(AddRecSCEV, BoundSCEV); 69 Cond.Pred = ICmpInst::getSwappedPredicate(Cond.Pred); 70 } 71 72 Cond.AddRecSCEV = dyn_cast<SCEVAddRecExpr>(AddRecSCEV); 73 Cond.BoundSCEV = BoundSCEV; 74 Cond.NonPHIAddRecValue = Cond.AddRecValue; 75 76 // If the Cond.AddRecValue is PHI node, update Cond.NonPHIAddRecValue with 77 // value from backedge. 78 if (Cond.AddRecSCEV && isa<PHINode>(Cond.AddRecValue)) { 79 PHINode *PN = cast<PHINode>(Cond.AddRecValue); 80 Cond.NonPHIAddRecValue = PN->getIncomingValueForBlock(L.getLoopLatch()); 81 } 82 } 83 } 84 85 static bool calculateUpperBound(const Loop &L, ScalarEvolution &SE, 86 ConditionInfo &Cond, bool IsExitCond) { 87 if (IsExitCond) { 88 const SCEV *ExitCount = SE.getExitCount(&L, Cond.ICmp->getParent()); 89 if (isa<SCEVCouldNotCompute>(ExitCount)) 90 return false; 91 92 Cond.BoundSCEV = ExitCount; 93 return true; 94 } 95 96 // For non-exit condtion, if pred is LT, keep existing bound. 97 if (Cond.Pred == ICmpInst::ICMP_SLT || Cond.Pred == ICmpInst::ICMP_ULT) 98 return true; 99 100 // For non-exit condition, if pre is LE, try to convert it to LT. 101 // Range Range 102 // AddRec <= Bound --> AddRec < Bound + 1 103 if (Cond.Pred != ICmpInst::ICMP_ULE && Cond.Pred != ICmpInst::ICMP_SLE) 104 return false; 105 106 if (IntegerType *BoundSCEVIntType = 107 dyn_cast<IntegerType>(Cond.BoundSCEV->getType())) { 108 unsigned BitWidth = BoundSCEVIntType->getBitWidth(); 109 APInt Max = ICmpInst::isSigned(Cond.Pred) 110 ? APInt::getSignedMaxValue(BitWidth) 111 : APInt::getMaxValue(BitWidth); 112 const SCEV *MaxSCEV = SE.getConstant(Max); 113 // Check Bound < INT_MAX 114 ICmpInst::Predicate Pred = 115 ICmpInst::isSigned(Cond.Pred) ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT; 116 if (SE.isKnownPredicate(Pred, Cond.BoundSCEV, MaxSCEV)) { 117 const SCEV *BoundPlusOneSCEV = 118 SE.getAddExpr(Cond.BoundSCEV, SE.getOne(BoundSCEVIntType)); 119 Cond.BoundSCEV = BoundPlusOneSCEV; 120 Cond.Pred = Pred; 121 return true; 122 } 123 } 124 125 // ToDo: Support ICMP_NE/EQ. 126 127 return false; 128 } 129 130 static bool hasProcessableCondition(const Loop &L, ScalarEvolution &SE, 131 ICmpInst *ICmp, ConditionInfo &Cond, 132 bool IsExitCond) { 133 analyzeICmp(SE, ICmp, Cond, L); 134 135 // The BoundSCEV should be evaluated at loop entry. 136 if (!SE.isAvailableAtLoopEntry(Cond.BoundSCEV, &L)) 137 return false; 138 139 // Allowed AddRec as induction variable. 140 if (!Cond.AddRecSCEV) 141 return false; 142 143 if (!Cond.AddRecSCEV->isAffine()) 144 return false; 145 146 const SCEV *StepRecSCEV = Cond.AddRecSCEV->getStepRecurrence(SE); 147 // Allowed constant step. 148 if (!isa<SCEVConstant>(StepRecSCEV)) 149 return false; 150 151 ConstantInt *StepCI = cast<SCEVConstant>(StepRecSCEV)->getValue(); 152 // Allowed positive step for now. 153 // TODO: Support negative step. 154 if (StepCI->isNegative() || StepCI->isZero()) 155 return false; 156 157 // Calculate upper bound. 158 if (!calculateUpperBound(L, SE, Cond, IsExitCond)) 159 return false; 160 161 return true; 162 } 163 164 static bool isProcessableCondBI(const ScalarEvolution &SE, 165 const BranchInst *BI) { 166 BasicBlock *TrueSucc = nullptr; 167 BasicBlock *FalseSucc = nullptr; 168 ICmpInst::Predicate Pred; 169 Value *LHS, *RHS; 170 if (!match(BI, m_Br(m_ICmp(Pred, m_Value(LHS), m_Value(RHS)), 171 m_BasicBlock(TrueSucc), m_BasicBlock(FalseSucc)))) 172 return false; 173 174 if (!SE.isSCEVable(LHS->getType())) 175 return false; 176 assert(SE.isSCEVable(RHS->getType()) && "Expected RHS's type is SCEVable"); 177 178 if (TrueSucc == FalseSucc) 179 return false; 180 181 return true; 182 } 183 184 static bool canSplitLoopBound(const Loop &L, const DominatorTree &DT, 185 ScalarEvolution &SE, ConditionInfo &Cond) { 186 // Skip function with optsize. 187 if (L.getHeader()->getParent()->hasOptSize()) 188 return false; 189 190 // Split only innermost loop. 191 if (!L.isInnermost()) 192 return false; 193 194 // Check loop is in simplified form. 195 if (!L.isLoopSimplifyForm()) 196 return false; 197 198 // Check loop is in LCSSA form. 199 if (!L.isLCSSAForm(DT)) 200 return false; 201 202 // Skip loop that cannot be cloned. 203 if (!L.isSafeToClone()) 204 return false; 205 206 BasicBlock *ExitingBB = L.getExitingBlock(); 207 // Assumed only one exiting block. 208 if (!ExitingBB) 209 return false; 210 211 BranchInst *ExitingBI = dyn_cast<BranchInst>(ExitingBB->getTerminator()); 212 if (!ExitingBI) 213 return false; 214 215 // Allowed only conditional branch with ICmp. 216 if (!isProcessableCondBI(SE, ExitingBI)) 217 return false; 218 219 // Check the condition is processable. 220 ICmpInst *ICmp = cast<ICmpInst>(ExitingBI->getCondition()); 221 if (!hasProcessableCondition(L, SE, ICmp, Cond, /*IsExitCond*/ true)) 222 return false; 223 224 Cond.BI = ExitingBI; 225 return true; 226 } 227 228 static bool isProfitableToTransform(const Loop &L, const BranchInst *BI) { 229 // If the conditional branch splits a loop into two halves, we could 230 // generally say it is profitable. 231 // 232 // ToDo: Add more profitable cases here. 233 234 // Check this branch causes diamond CFG. 235 BasicBlock *Succ0 = BI->getSuccessor(0); 236 BasicBlock *Succ1 = BI->getSuccessor(1); 237 238 BasicBlock *Succ0Succ = Succ0->getSingleSuccessor(); 239 BasicBlock *Succ1Succ = Succ1->getSingleSuccessor(); 240 if (!Succ0Succ || !Succ1Succ || Succ0Succ != Succ1Succ) 241 return false; 242 243 // ToDo: Calculate each successor's instruction cost. 244 245 return true; 246 } 247 248 static BranchInst *findSplitCandidate(const Loop &L, ScalarEvolution &SE, 249 ConditionInfo &ExitingCond, 250 ConditionInfo &SplitCandidateCond) { 251 for (auto *BB : L.blocks()) { 252 // Skip condition of backedge. 253 if (L.getLoopLatch() == BB) 254 continue; 255 256 auto *BI = dyn_cast<BranchInst>(BB->getTerminator()); 257 if (!BI) 258 continue; 259 260 // Check conditional branch with ICmp. 261 if (!isProcessableCondBI(SE, BI)) 262 continue; 263 264 // Skip loop invariant condition. 265 if (L.isLoopInvariant(BI->getCondition())) 266 continue; 267 268 // Check the condition is processable. 269 ICmpInst *ICmp = cast<ICmpInst>(BI->getCondition()); 270 if (!hasProcessableCondition(L, SE, ICmp, SplitCandidateCond, 271 /*IsExitCond*/ false)) 272 continue; 273 274 if (ExitingCond.BoundSCEV->getType() != 275 SplitCandidateCond.BoundSCEV->getType()) 276 continue; 277 278 // After transformation, we assume the split condition of the pre-loop is 279 // always true. In order to guarantee it, we need to check the start value 280 // of the split cond AddRec satisfies the split condition. 281 if (!SE.isLoopEntryGuardedByCond(&L, SplitCandidateCond.Pred, 282 SplitCandidateCond.AddRecSCEV->getStart(), 283 SplitCandidateCond.BoundSCEV)) 284 continue; 285 286 SplitCandidateCond.BI = BI; 287 return BI; 288 } 289 290 return nullptr; 291 } 292 293 static bool splitLoopBound(Loop &L, DominatorTree &DT, LoopInfo &LI, 294 ScalarEvolution &SE, LPMUpdater &U) { 295 ConditionInfo SplitCandidateCond; 296 ConditionInfo ExitingCond; 297 298 // Check we can split this loop's bound. 299 if (!canSplitLoopBound(L, DT, SE, ExitingCond)) 300 return false; 301 302 if (!findSplitCandidate(L, SE, ExitingCond, SplitCandidateCond)) 303 return false; 304 305 if (!isProfitableToTransform(L, SplitCandidateCond.BI)) 306 return false; 307 308 // Now, we have a split candidate. Let's build a form as below. 309 // +--------------------+ 310 // | preheader | 311 // | set up newbound | 312 // +--------------------+ 313 // | /----------------\ 314 // +--------v----v------+ | 315 // | header |---\ | 316 // | with true condition| | | 317 // +--------------------+ | | 318 // | | | 319 // +--------v-----------+ | | 320 // | if.then.BB | | | 321 // +--------------------+ | | 322 // | | | 323 // +--------v-----------<---/ | 324 // | latch >----------/ 325 // | with newbound | 326 // +--------------------+ 327 // | 328 // +--------v-----------+ 329 // | preheader2 |--------------\ 330 // | if (AddRec i != | | 331 // | org bound) | | 332 // +--------------------+ | 333 // | /----------------\ | 334 // +--------v----v------+ | | 335 // | header2 |---\ | | 336 // | conditional branch | | | | 337 // |with false condition| | | | 338 // +--------------------+ | | | 339 // | | | | 340 // +--------v-----------+ | | | 341 // | if.then.BB2 | | | | 342 // +--------------------+ | | | 343 // | | | | 344 // +--------v-----------<---/ | | 345 // | latch2 >----------/ | 346 // | with org bound | | 347 // +--------v-----------+ | 348 // | | 349 // | +---------------+ | 350 // +--> exit <-------/ 351 // +---------------+ 352 353 // Let's create post loop. 354 SmallVector<BasicBlock *, 8> PostLoopBlocks; 355 Loop *PostLoop; 356 ValueToValueMapTy VMap; 357 BasicBlock *PreHeader = L.getLoopPreheader(); 358 BasicBlock *SplitLoopPH = SplitEdge(PreHeader, L.getHeader(), &DT, &LI); 359 PostLoop = cloneLoopWithPreheader(L.getExitBlock(), SplitLoopPH, &L, VMap, 360 ".split", &LI, &DT, PostLoopBlocks); 361 remapInstructionsInBlocks(PostLoopBlocks, VMap); 362 363 BasicBlock *PostLoopPreHeader = PostLoop->getLoopPreheader(); 364 IRBuilder<> Builder(&PostLoopPreHeader->front()); 365 366 // Update phi nodes in header of post-loop. 367 bool isExitingLatch = 368 (L.getExitingBlock() == L.getLoopLatch()) ? true : false; 369 Value *ExitingCondLCSSAPhi = nullptr; 370 for (PHINode &PN : L.getHeader()->phis()) { 371 // Create LCSSA phi node in preheader of post-loop. 372 PHINode *LCSSAPhi = 373 Builder.CreatePHI(PN.getType(), 1, PN.getName() + ".lcssa"); 374 LCSSAPhi->setDebugLoc(PN.getDebugLoc()); 375 // If the exiting block is loop latch, the phi does not have the update at 376 // last iteration. In this case, update lcssa phi with value from backedge. 377 LCSSAPhi->addIncoming( 378 isExitingLatch ? PN.getIncomingValueForBlock(L.getLoopLatch()) : &PN, 379 L.getExitingBlock()); 380 381 // Update the start value of phi node in post-loop with the LCSSA phi node. 382 PHINode *PostLoopPN = cast<PHINode>(VMap[&PN]); 383 PostLoopPN->setIncomingValueForBlock(PostLoopPreHeader, LCSSAPhi); 384 385 // Find PHI with exiting condition from pre-loop. The PHI should be 386 // SCEVAddRecExpr and have same incoming value from backedge with 387 // ExitingCond. 388 if (!SE.isSCEVable(PN.getType())) 389 continue; 390 391 const SCEVAddRecExpr *PhiSCEV = dyn_cast<SCEVAddRecExpr>(SE.getSCEV(&PN)); 392 if (PhiSCEV && ExitingCond.NonPHIAddRecValue == 393 PN.getIncomingValueForBlock(L.getLoopLatch())) 394 ExitingCondLCSSAPhi = LCSSAPhi; 395 } 396 397 // Add conditional branch to check we can skip post-loop in its preheader. 398 Instruction *OrigBI = PostLoopPreHeader->getTerminator(); 399 ICmpInst::Predicate Pred = ICmpInst::ICMP_NE; 400 Value *Cond = 401 Builder.CreateICmp(Pred, ExitingCondLCSSAPhi, ExitingCond.BoundValue); 402 Builder.CreateCondBr(Cond, PostLoop->getHeader(), PostLoop->getExitBlock()); 403 OrigBI->eraseFromParent(); 404 405 // Create new loop bound and add it into preheader of pre-loop. 406 const SCEV *NewBoundSCEV = ExitingCond.BoundSCEV; 407 const SCEV *SplitBoundSCEV = SplitCandidateCond.BoundSCEV; 408 NewBoundSCEV = ICmpInst::isSigned(ExitingCond.Pred) 409 ? SE.getSMinExpr(NewBoundSCEV, SplitBoundSCEV) 410 : SE.getUMinExpr(NewBoundSCEV, SplitBoundSCEV); 411 412 SCEVExpander Expander( 413 SE, L.getHeader()->getParent()->getParent()->getDataLayout(), "split"); 414 Instruction *InsertPt = SplitLoopPH->getTerminator(); 415 Value *NewBoundValue = 416 Expander.expandCodeFor(NewBoundSCEV, NewBoundSCEV->getType(), InsertPt); 417 NewBoundValue->setName("new.bound"); 418 419 // Replace exiting bound value of pre-loop NewBound. 420 ExitingCond.ICmp->setOperand(1, NewBoundValue); 421 422 // Replace SplitCandidateCond.BI's condition of pre-loop by True. 423 LLVMContext &Context = PreHeader->getContext(); 424 SplitCandidateCond.BI->setCondition(ConstantInt::getTrue(Context)); 425 426 // Replace cloned SplitCandidateCond.BI's condition in post-loop by False. 427 BranchInst *ClonedSplitCandidateBI = 428 cast<BranchInst>(VMap[SplitCandidateCond.BI]); 429 ClonedSplitCandidateBI->setCondition(ConstantInt::getFalse(Context)); 430 431 // Replace exit branch target of pre-loop by post-loop's preheader. 432 if (L.getExitBlock() == ExitingCond.BI->getSuccessor(0)) 433 ExitingCond.BI->setSuccessor(0, PostLoopPreHeader); 434 else 435 ExitingCond.BI->setSuccessor(1, PostLoopPreHeader); 436 437 // Update phi node in exit block of post-loop. 438 Builder.SetInsertPoint(&PostLoopPreHeader->front()); 439 for (PHINode &PN : PostLoop->getExitBlock()->phis()) { 440 for (auto i : seq<int>(0, PN.getNumOperands())) { 441 // Check incoming block is pre-loop's exiting block. 442 if (PN.getIncomingBlock(i) == L.getExitingBlock()) { 443 Value *IncomingValue = PN.getIncomingValue(i); 444 445 // Create LCSSA phi node for incoming value. 446 PHINode *LCSSAPhi = 447 Builder.CreatePHI(PN.getType(), 1, PN.getName() + ".lcssa"); 448 LCSSAPhi->setDebugLoc(PN.getDebugLoc()); 449 LCSSAPhi->addIncoming(IncomingValue, PN.getIncomingBlock(i)); 450 451 // Replace pre-loop's exiting block by post-loop's preheader. 452 PN.setIncomingBlock(i, PostLoopPreHeader); 453 // Replace incoming value by LCSSAPhi. 454 PN.setIncomingValue(i, LCSSAPhi); 455 // Add a new incoming value with post-loop's exiting block. 456 PN.addIncoming(VMap[IncomingValue], PostLoop->getExitingBlock()); 457 } 458 } 459 } 460 461 // Update dominator tree. 462 DT.changeImmediateDominator(PostLoopPreHeader, L.getExitingBlock()); 463 DT.changeImmediateDominator(PostLoop->getExitBlock(), PostLoopPreHeader); 464 465 // Invalidate cached SE information. 466 SE.forgetLoop(&L); 467 468 // Canonicalize loops. 469 simplifyLoop(&L, &DT, &LI, &SE, nullptr, nullptr, true); 470 simplifyLoop(PostLoop, &DT, &LI, &SE, nullptr, nullptr, true); 471 472 // Add new post-loop to loop pass manager. 473 U.addSiblingLoops(PostLoop); 474 475 return true; 476 } 477 478 PreservedAnalyses LoopBoundSplitPass::run(Loop &L, LoopAnalysisManager &AM, 479 LoopStandardAnalysisResults &AR, 480 LPMUpdater &U) { 481 Function &F = *L.getHeader()->getParent(); 482 (void)F; 483 484 LLVM_DEBUG(dbgs() << "Spliting bound of loop in " << F.getName() << ": " << L 485 << "\n"); 486 487 if (!splitLoopBound(L, AR.DT, AR.LI, AR.SE, U)) 488 return PreservedAnalyses::all(); 489 490 assert(AR.DT.verify(DominatorTree::VerificationLevel::Fast)); 491 AR.LI.verify(AR.DT); 492 493 return getLoopPassPreservedAnalyses(); 494 } 495 496 } // end namespace llvm 497