1 //===------- LoopBoundSplit.cpp - Split Loop Bound --------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "llvm/Transforms/Scalar/LoopBoundSplit.h" 10 #include "llvm/ADT/Sequence.h" 11 #include "llvm/Analysis/LoopAnalysisManager.h" 12 #include "llvm/Analysis/LoopInfo.h" 13 #include "llvm/Analysis/MemorySSAUpdater.h" 14 #include "llvm/Analysis/ScalarEvolution.h" 15 #include "llvm/Analysis/ScalarEvolutionExpressions.h" 16 #include "llvm/IR/PatternMatch.h" 17 #include "llvm/Transforms/Scalar/LoopPassManager.h" 18 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 19 #include "llvm/Transforms/Utils/Cloning.h" 20 #include "llvm/Transforms/Utils/LoopSimplify.h" 21 #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" 22 23 #define DEBUG_TYPE "loop-bound-split" 24 25 namespace llvm { 26 27 using namespace PatternMatch; 28 29 namespace { 30 struct ConditionInfo { 31 /// Branch instruction with this condition 32 BranchInst *BI = nullptr; 33 /// ICmp instruction with this condition 34 ICmpInst *ICmp = nullptr; 35 /// Preciate info 36 ICmpInst::Predicate Pred = ICmpInst::BAD_ICMP_PREDICATE; 37 /// AddRec llvm value 38 Value *AddRecValue = nullptr; 39 /// Non PHI AddRec llvm value 40 Value *NonPHIAddRecValue; 41 /// Bound llvm value 42 Value *BoundValue = nullptr; 43 /// AddRec SCEV 44 const SCEVAddRecExpr *AddRecSCEV = nullptr; 45 /// Bound SCEV 46 const SCEV *BoundSCEV = nullptr; 47 48 ConditionInfo() = default; 49 }; 50 } // namespace 51 52 static void analyzeICmp(ScalarEvolution &SE, ICmpInst *ICmp, 53 ConditionInfo &Cond, const Loop &L) { 54 Cond.ICmp = ICmp; 55 if (match(ICmp, m_ICmp(Cond.Pred, m_Value(Cond.AddRecValue), 56 m_Value(Cond.BoundValue)))) { 57 const SCEV *AddRecSCEV = SE.getSCEV(Cond.AddRecValue); 58 const SCEV *BoundSCEV = SE.getSCEV(Cond.BoundValue); 59 const SCEVAddRecExpr *LHSAddRecSCEV = dyn_cast<SCEVAddRecExpr>(AddRecSCEV); 60 const SCEVAddRecExpr *RHSAddRecSCEV = dyn_cast<SCEVAddRecExpr>(BoundSCEV); 61 // Locate AddRec in LHSSCEV and Bound in RHSSCEV. 62 if (!LHSAddRecSCEV && RHSAddRecSCEV) { 63 std::swap(Cond.AddRecValue, Cond.BoundValue); 64 std::swap(AddRecSCEV, BoundSCEV); 65 Cond.Pred = ICmpInst::getSwappedPredicate(Cond.Pred); 66 } 67 68 Cond.AddRecSCEV = dyn_cast<SCEVAddRecExpr>(AddRecSCEV); 69 Cond.BoundSCEV = BoundSCEV; 70 Cond.NonPHIAddRecValue = Cond.AddRecValue; 71 72 // If the Cond.AddRecValue is PHI node, update Cond.NonPHIAddRecValue with 73 // value from backedge. 74 if (Cond.AddRecSCEV && isa<PHINode>(Cond.AddRecValue)) { 75 PHINode *PN = cast<PHINode>(Cond.AddRecValue); 76 Cond.NonPHIAddRecValue = PN->getIncomingValueForBlock(L.getLoopLatch()); 77 } 78 } 79 } 80 81 static bool calculateUpperBound(const Loop &L, ScalarEvolution &SE, 82 ConditionInfo &Cond, bool IsExitCond) { 83 if (IsExitCond) { 84 const SCEV *ExitCount = SE.getExitCount(&L, Cond.ICmp->getParent()); 85 if (isa<SCEVCouldNotCompute>(ExitCount)) 86 return false; 87 88 Cond.BoundSCEV = ExitCount; 89 return true; 90 } 91 92 // For non-exit condtion, if pred is LT, keep existing bound. 93 if (Cond.Pred == ICmpInst::ICMP_SLT || Cond.Pred == ICmpInst::ICMP_ULT) 94 return true; 95 96 // For non-exit condition, if pre is LE, try to convert it to LT. 97 // Range Range 98 // AddRec <= Bound --> AddRec < Bound + 1 99 if (Cond.Pred != ICmpInst::ICMP_ULE && Cond.Pred != ICmpInst::ICMP_SLE) 100 return false; 101 102 if (IntegerType *BoundSCEVIntType = 103 dyn_cast<IntegerType>(Cond.BoundSCEV->getType())) { 104 unsigned BitWidth = BoundSCEVIntType->getBitWidth(); 105 APInt Max = ICmpInst::isSigned(Cond.Pred) 106 ? APInt::getSignedMaxValue(BitWidth) 107 : APInt::getMaxValue(BitWidth); 108 const SCEV *MaxSCEV = SE.getConstant(Max); 109 // Check Bound < INT_MAX 110 ICmpInst::Predicate Pred = 111 ICmpInst::isSigned(Cond.Pred) ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT; 112 if (SE.isKnownPredicate(Pred, Cond.BoundSCEV, MaxSCEV)) { 113 const SCEV *BoundPlusOneSCEV = 114 SE.getAddExpr(Cond.BoundSCEV, SE.getOne(BoundSCEVIntType)); 115 Cond.BoundSCEV = BoundPlusOneSCEV; 116 Cond.Pred = Pred; 117 return true; 118 } 119 } 120 121 // ToDo: Support ICMP_NE/EQ. 122 123 return false; 124 } 125 126 static bool hasProcessableCondition(const Loop &L, ScalarEvolution &SE, 127 ICmpInst *ICmp, ConditionInfo &Cond, 128 bool IsExitCond) { 129 analyzeICmp(SE, ICmp, Cond, L); 130 131 // The BoundSCEV should be evaluated at loop entry. 132 if (!SE.isAvailableAtLoopEntry(Cond.BoundSCEV, &L)) 133 return false; 134 135 // Allowed AddRec as induction variable. 136 if (!Cond.AddRecSCEV) 137 return false; 138 139 if (!Cond.AddRecSCEV->isAffine()) 140 return false; 141 142 const SCEV *StepRecSCEV = Cond.AddRecSCEV->getStepRecurrence(SE); 143 // Allowed constant step. 144 if (!isa<SCEVConstant>(StepRecSCEV)) 145 return false; 146 147 ConstantInt *StepCI = cast<SCEVConstant>(StepRecSCEV)->getValue(); 148 // Allowed positive step for now. 149 // TODO: Support negative step. 150 if (StepCI->isNegative() || StepCI->isZero()) 151 return false; 152 153 // Calculate upper bound. 154 if (!calculateUpperBound(L, SE, Cond, IsExitCond)) 155 return false; 156 157 return true; 158 } 159 160 static bool isProcessableCondBI(const ScalarEvolution &SE, 161 const BranchInst *BI) { 162 BasicBlock *TrueSucc = nullptr; 163 BasicBlock *FalseSucc = nullptr; 164 ICmpInst::Predicate Pred; 165 Value *LHS, *RHS; 166 if (!match(BI, m_Br(m_ICmp(Pred, m_Value(LHS), m_Value(RHS)), 167 m_BasicBlock(TrueSucc), m_BasicBlock(FalseSucc)))) 168 return false; 169 170 if (!SE.isSCEVable(LHS->getType())) 171 return false; 172 assert(SE.isSCEVable(RHS->getType()) && "Expected RHS's type is SCEVable"); 173 174 if (TrueSucc == FalseSucc) 175 return false; 176 177 return true; 178 } 179 180 static bool canSplitLoopBound(const Loop &L, const DominatorTree &DT, 181 ScalarEvolution &SE, ConditionInfo &Cond) { 182 // Skip function with optsize. 183 if (L.getHeader()->getParent()->hasOptSize()) 184 return false; 185 186 // Split only innermost loop. 187 if (!L.isInnermost()) 188 return false; 189 190 // Check loop is in simplified form. 191 if (!L.isLoopSimplifyForm()) 192 return false; 193 194 // Check loop is in LCSSA form. 195 if (!L.isLCSSAForm(DT)) 196 return false; 197 198 // Skip loop that cannot be cloned. 199 if (!L.isSafeToClone()) 200 return false; 201 202 BasicBlock *ExitingBB = L.getExitingBlock(); 203 // Assumed only one exiting block. 204 if (!ExitingBB) 205 return false; 206 207 BranchInst *ExitingBI = dyn_cast<BranchInst>(ExitingBB->getTerminator()); 208 if (!ExitingBI) 209 return false; 210 211 // Allowed only conditional branch with ICmp. 212 if (!isProcessableCondBI(SE, ExitingBI)) 213 return false; 214 215 // Check the condition is processable. 216 ICmpInst *ICmp = cast<ICmpInst>(ExitingBI->getCondition()); 217 if (!hasProcessableCondition(L, SE, ICmp, Cond, /*IsExitCond*/ true)) 218 return false; 219 220 Cond.BI = ExitingBI; 221 return true; 222 } 223 224 static bool isProfitableToTransform(const Loop &L, const BranchInst *BI) { 225 // If the conditional branch splits a loop into two halves, we could 226 // generally say it is profitable. 227 // 228 // ToDo: Add more profitable cases here. 229 230 // Check this branch causes diamond CFG. 231 BasicBlock *Succ0 = BI->getSuccessor(0); 232 BasicBlock *Succ1 = BI->getSuccessor(1); 233 234 BasicBlock *Succ0Succ = Succ0->getSingleSuccessor(); 235 BasicBlock *Succ1Succ = Succ1->getSingleSuccessor(); 236 if (!Succ0Succ || !Succ1Succ || Succ0Succ != Succ1Succ) 237 return false; 238 239 // ToDo: Calculate each successor's instruction cost. 240 241 return true; 242 } 243 244 static BranchInst *findSplitCandidate(const Loop &L, ScalarEvolution &SE, 245 ConditionInfo &ExitingCond, 246 ConditionInfo &SplitCandidateCond) { 247 for (auto *BB : L.blocks()) { 248 // Skip condition of backedge. 249 if (L.getLoopLatch() == BB) 250 continue; 251 252 auto *BI = dyn_cast<BranchInst>(BB->getTerminator()); 253 if (!BI) 254 continue; 255 256 // Check conditional branch with ICmp. 257 if (!isProcessableCondBI(SE, BI)) 258 continue; 259 260 // Skip loop invariant condition. 261 if (L.isLoopInvariant(BI->getCondition())) 262 continue; 263 264 // Check the condition is processable. 265 ICmpInst *ICmp = cast<ICmpInst>(BI->getCondition()); 266 if (!hasProcessableCondition(L, SE, ICmp, SplitCandidateCond, 267 /*IsExitCond*/ false)) 268 continue; 269 270 if (ExitingCond.BoundSCEV->getType() != 271 SplitCandidateCond.BoundSCEV->getType()) 272 continue; 273 274 // After transformation, we assume the split condition of the pre-loop is 275 // always true. In order to guarantee it, we need to check the start value 276 // of the split cond AddRec satisfies the split condition. 277 if (!SE.isLoopEntryGuardedByCond(&L, SplitCandidateCond.Pred, 278 SplitCandidateCond.AddRecSCEV->getStart(), 279 SplitCandidateCond.BoundSCEV)) 280 continue; 281 282 SplitCandidateCond.BI = BI; 283 return BI; 284 } 285 286 return nullptr; 287 } 288 289 static bool splitLoopBound(Loop &L, DominatorTree &DT, LoopInfo &LI, 290 ScalarEvolution &SE, LPMUpdater &U) { 291 ConditionInfo SplitCandidateCond; 292 ConditionInfo ExitingCond; 293 294 // Check we can split this loop's bound. 295 if (!canSplitLoopBound(L, DT, SE, ExitingCond)) 296 return false; 297 298 if (!findSplitCandidate(L, SE, ExitingCond, SplitCandidateCond)) 299 return false; 300 301 if (!isProfitableToTransform(L, SplitCandidateCond.BI)) 302 return false; 303 304 // Now, we have a split candidate. Let's build a form as below. 305 // +--------------------+ 306 // | preheader | 307 // | set up newbound | 308 // +--------------------+ 309 // | /----------------\ 310 // +--------v----v------+ | 311 // | header |---\ | 312 // | with true condition| | | 313 // +--------------------+ | | 314 // | | | 315 // +--------v-----------+ | | 316 // | if.then.BB | | | 317 // +--------------------+ | | 318 // | | | 319 // +--------v-----------<---/ | 320 // | latch >----------/ 321 // | with newbound | 322 // +--------------------+ 323 // | 324 // +--------v-----------+ 325 // | preheader2 |--------------\ 326 // | if (AddRec i != | | 327 // | org bound) | | 328 // +--------------------+ | 329 // | /----------------\ | 330 // +--------v----v------+ | | 331 // | header2 |---\ | | 332 // | conditional branch | | | | 333 // |with false condition| | | | 334 // +--------------------+ | | | 335 // | | | | 336 // +--------v-----------+ | | | 337 // | if.then.BB2 | | | | 338 // +--------------------+ | | | 339 // | | | | 340 // +--------v-----------<---/ | | 341 // | latch2 >----------/ | 342 // | with org bound | | 343 // +--------v-----------+ | 344 // | | 345 // | +---------------+ | 346 // +--> exit <-------/ 347 // +---------------+ 348 349 // Let's create post loop. 350 SmallVector<BasicBlock *, 8> PostLoopBlocks; 351 Loop *PostLoop; 352 ValueToValueMapTy VMap; 353 BasicBlock *PreHeader = L.getLoopPreheader(); 354 BasicBlock *SplitLoopPH = SplitEdge(PreHeader, L.getHeader(), &DT, &LI); 355 PostLoop = cloneLoopWithPreheader(L.getExitBlock(), SplitLoopPH, &L, VMap, 356 ".split", &LI, &DT, PostLoopBlocks); 357 remapInstructionsInBlocks(PostLoopBlocks, VMap); 358 359 BasicBlock *PostLoopPreHeader = PostLoop->getLoopPreheader(); 360 IRBuilder<> Builder(&PostLoopPreHeader->front()); 361 362 // Update phi nodes in header of post-loop. 363 bool isExitingLatch = 364 (L.getExitingBlock() == L.getLoopLatch()) ? true : false; 365 Value *ExitingCondLCSSAPhi = nullptr; 366 for (PHINode &PN : L.getHeader()->phis()) { 367 // Create LCSSA phi node in preheader of post-loop. 368 PHINode *LCSSAPhi = 369 Builder.CreatePHI(PN.getType(), 1, PN.getName() + ".lcssa"); 370 LCSSAPhi->setDebugLoc(PN.getDebugLoc()); 371 // If the exiting block is loop latch, the phi does not have the update at 372 // last iteration. In this case, update lcssa phi with value from backedge. 373 LCSSAPhi->addIncoming( 374 isExitingLatch ? PN.getIncomingValueForBlock(L.getLoopLatch()) : &PN, 375 L.getExitingBlock()); 376 377 // Update the start value of phi node in post-loop with the LCSSA phi node. 378 PHINode *PostLoopPN = cast<PHINode>(VMap[&PN]); 379 PostLoopPN->setIncomingValueForBlock(PostLoopPreHeader, LCSSAPhi); 380 381 // Find PHI with exiting condition from pre-loop. The PHI should be 382 // SCEVAddRecExpr and have same incoming value from backedge with 383 // ExitingCond. 384 if (!SE.isSCEVable(PN.getType())) 385 continue; 386 387 const SCEVAddRecExpr *PhiSCEV = dyn_cast<SCEVAddRecExpr>(SE.getSCEV(&PN)); 388 if (PhiSCEV && ExitingCond.NonPHIAddRecValue == 389 PN.getIncomingValueForBlock(L.getLoopLatch())) 390 ExitingCondLCSSAPhi = LCSSAPhi; 391 } 392 393 // Add conditional branch to check we can skip post-loop in its preheader. 394 Instruction *OrigBI = PostLoopPreHeader->getTerminator(); 395 ICmpInst::Predicate Pred = ICmpInst::ICMP_NE; 396 Value *Cond = 397 Builder.CreateICmp(Pred, ExitingCondLCSSAPhi, ExitingCond.BoundValue); 398 Builder.CreateCondBr(Cond, PostLoop->getHeader(), PostLoop->getExitBlock()); 399 OrigBI->eraseFromParent(); 400 401 // Create new loop bound and add it into preheader of pre-loop. 402 const SCEV *NewBoundSCEV = ExitingCond.BoundSCEV; 403 const SCEV *SplitBoundSCEV = SplitCandidateCond.BoundSCEV; 404 NewBoundSCEV = ICmpInst::isSigned(ExitingCond.Pred) 405 ? SE.getSMinExpr(NewBoundSCEV, SplitBoundSCEV) 406 : SE.getUMinExpr(NewBoundSCEV, SplitBoundSCEV); 407 408 SCEVExpander Expander( 409 SE, L.getHeader()->getParent()->getParent()->getDataLayout(), "split"); 410 Instruction *InsertPt = SplitLoopPH->getTerminator(); 411 Value *NewBoundValue = 412 Expander.expandCodeFor(NewBoundSCEV, NewBoundSCEV->getType(), InsertPt); 413 NewBoundValue->setName("new.bound"); 414 415 // Replace exiting bound value of pre-loop NewBound. 416 ExitingCond.ICmp->setOperand(1, NewBoundValue); 417 418 // Replace SplitCandidateCond.BI's condition of pre-loop by True. 419 LLVMContext &Context = PreHeader->getContext(); 420 SplitCandidateCond.BI->setCondition(ConstantInt::getTrue(Context)); 421 422 // Replace cloned SplitCandidateCond.BI's condition in post-loop by False. 423 BranchInst *ClonedSplitCandidateBI = 424 cast<BranchInst>(VMap[SplitCandidateCond.BI]); 425 ClonedSplitCandidateBI->setCondition(ConstantInt::getFalse(Context)); 426 427 // Replace exit branch target of pre-loop by post-loop's preheader. 428 if (L.getExitBlock() == ExitingCond.BI->getSuccessor(0)) 429 ExitingCond.BI->setSuccessor(0, PostLoopPreHeader); 430 else 431 ExitingCond.BI->setSuccessor(1, PostLoopPreHeader); 432 433 // Update phi node in exit block of post-loop. 434 Builder.SetInsertPoint(&PostLoopPreHeader->front()); 435 for (PHINode &PN : PostLoop->getExitBlock()->phis()) { 436 for (auto i : seq<int>(0, PN.getNumOperands())) { 437 // Check incoming block is pre-loop's exiting block. 438 if (PN.getIncomingBlock(i) == L.getExitingBlock()) { 439 Value *IncomingValue = PN.getIncomingValue(i); 440 441 // Create LCSSA phi node for incoming value. 442 PHINode *LCSSAPhi = 443 Builder.CreatePHI(PN.getType(), 1, PN.getName() + ".lcssa"); 444 LCSSAPhi->setDebugLoc(PN.getDebugLoc()); 445 LCSSAPhi->addIncoming(IncomingValue, PN.getIncomingBlock(i)); 446 447 // Replace pre-loop's exiting block by post-loop's preheader. 448 PN.setIncomingBlock(i, PostLoopPreHeader); 449 // Replace incoming value by LCSSAPhi. 450 PN.setIncomingValue(i, LCSSAPhi); 451 // Add a new incoming value with post-loop's exiting block. 452 PN.addIncoming(VMap[IncomingValue], PostLoop->getExitingBlock()); 453 } 454 } 455 } 456 457 // Update dominator tree. 458 DT.changeImmediateDominator(PostLoopPreHeader, L.getExitingBlock()); 459 DT.changeImmediateDominator(PostLoop->getExitBlock(), PostLoopPreHeader); 460 461 // Invalidate cached SE information. 462 SE.forgetLoop(&L); 463 464 // Canonicalize loops. 465 simplifyLoop(&L, &DT, &LI, &SE, nullptr, nullptr, true); 466 simplifyLoop(PostLoop, &DT, &LI, &SE, nullptr, nullptr, true); 467 468 // Add new post-loop to loop pass manager. 469 U.addSiblingLoops(PostLoop); 470 471 return true; 472 } 473 474 PreservedAnalyses LoopBoundSplitPass::run(Loop &L, LoopAnalysisManager &AM, 475 LoopStandardAnalysisResults &AR, 476 LPMUpdater &U) { 477 Function &F = *L.getHeader()->getParent(); 478 (void)F; 479 480 LLVM_DEBUG(dbgs() << "Spliting bound of loop in " << F.getName() << ": " << L 481 << "\n"); 482 483 if (!splitLoopBound(L, AR.DT, AR.LI, AR.SE, U)) 484 return PreservedAnalyses::all(); 485 486 assert(AR.DT.verify(DominatorTree::VerificationLevel::Fast)); 487 AR.LI.verify(AR.DT); 488 489 return getLoopPassPreservedAnalyses(); 490 } 491 492 } // end namespace llvm 493