18fb3d57eSArtur Pilipenko //===-- LoopPredication.cpp - Guard based loop predication pass -----------===//
28fb3d57eSArtur Pilipenko //
38fb3d57eSArtur Pilipenko //                     The LLVM Compiler Infrastructure
48fb3d57eSArtur Pilipenko //
58fb3d57eSArtur Pilipenko // This file is distributed under the University of Illinois Open Source
68fb3d57eSArtur Pilipenko // License. See LICENSE.TXT for details.
78fb3d57eSArtur Pilipenko //
88fb3d57eSArtur Pilipenko //===----------------------------------------------------------------------===//
98fb3d57eSArtur Pilipenko //
108fb3d57eSArtur Pilipenko // The LoopPredication pass tries to convert loop variant range checks to loop
118fb3d57eSArtur Pilipenko // invariant by widening checks across loop iterations. For example, it will
128fb3d57eSArtur Pilipenko // convert
138fb3d57eSArtur Pilipenko //
148fb3d57eSArtur Pilipenko //   for (i = 0; i < n; i++) {
158fb3d57eSArtur Pilipenko //     guard(i < len);
168fb3d57eSArtur Pilipenko //     ...
178fb3d57eSArtur Pilipenko //   }
188fb3d57eSArtur Pilipenko //
198fb3d57eSArtur Pilipenko // to
208fb3d57eSArtur Pilipenko //
218fb3d57eSArtur Pilipenko //   for (i = 0; i < n; i++) {
228fb3d57eSArtur Pilipenko //     guard(n - 1 < len);
238fb3d57eSArtur Pilipenko //     ...
248fb3d57eSArtur Pilipenko //   }
258fb3d57eSArtur Pilipenko //
268fb3d57eSArtur Pilipenko // After this transformation the condition of the guard is loop invariant, so
278fb3d57eSArtur Pilipenko // loop-unswitch can later unswitch the loop by this condition which basically
288fb3d57eSArtur Pilipenko // predicates the loop by the widened condition:
298fb3d57eSArtur Pilipenko //
308fb3d57eSArtur Pilipenko //   if (n - 1 < len)
318fb3d57eSArtur Pilipenko //     for (i = 0; i < n; i++) {
328fb3d57eSArtur Pilipenko //       ...
338fb3d57eSArtur Pilipenko //     }
348fb3d57eSArtur Pilipenko //   else
358fb3d57eSArtur Pilipenko //     deoptimize
368fb3d57eSArtur Pilipenko //
378fb3d57eSArtur Pilipenko //===----------------------------------------------------------------------===//
388fb3d57eSArtur Pilipenko 
398fb3d57eSArtur Pilipenko #include "llvm/Transforms/Scalar/LoopPredication.h"
408fb3d57eSArtur Pilipenko #include "llvm/Pass.h"
418fb3d57eSArtur Pilipenko #include "llvm/Analysis/LoopInfo.h"
428fb3d57eSArtur Pilipenko #include "llvm/Analysis/LoopPass.h"
438fb3d57eSArtur Pilipenko #include "llvm/Analysis/ScalarEvolution.h"
448fb3d57eSArtur Pilipenko #include "llvm/Analysis/ScalarEvolutionExpander.h"
458fb3d57eSArtur Pilipenko #include "llvm/Analysis/ScalarEvolutionExpressions.h"
468fb3d57eSArtur Pilipenko #include "llvm/IR/Function.h"
478fb3d57eSArtur Pilipenko #include "llvm/IR/GlobalValue.h"
488fb3d57eSArtur Pilipenko #include "llvm/IR/IntrinsicInst.h"
498fb3d57eSArtur Pilipenko #include "llvm/IR/Module.h"
508fb3d57eSArtur Pilipenko #include "llvm/IR/PatternMatch.h"
518fb3d57eSArtur Pilipenko #include "llvm/Support/Debug.h"
528fb3d57eSArtur Pilipenko #include "llvm/Transforms/Scalar.h"
538fb3d57eSArtur Pilipenko #include "llvm/Transforms/Utils/LoopUtils.h"
548fb3d57eSArtur Pilipenko 
558fb3d57eSArtur Pilipenko #define DEBUG_TYPE "loop-predication"
568fb3d57eSArtur Pilipenko 
578fb3d57eSArtur Pilipenko using namespace llvm;
588fb3d57eSArtur Pilipenko 
598fb3d57eSArtur Pilipenko namespace {
608fb3d57eSArtur Pilipenko class LoopPredication {
618fb3d57eSArtur Pilipenko   ScalarEvolution *SE;
628fb3d57eSArtur Pilipenko 
638fb3d57eSArtur Pilipenko   Loop *L;
648fb3d57eSArtur Pilipenko   const DataLayout *DL;
658fb3d57eSArtur Pilipenko   BasicBlock *Preheader;
668fb3d57eSArtur Pilipenko 
678fb3d57eSArtur Pilipenko   Optional<Value *> widenICmpRangeCheck(ICmpInst *ICI, SCEVExpander &Expander,
688fb3d57eSArtur Pilipenko                                         IRBuilder<> &Builder);
698fb3d57eSArtur Pilipenko   bool widenGuardConditions(IntrinsicInst *II, SCEVExpander &Expander);
708fb3d57eSArtur Pilipenko 
718fb3d57eSArtur Pilipenko public:
728fb3d57eSArtur Pilipenko   LoopPredication(ScalarEvolution *SE) : SE(SE){};
738fb3d57eSArtur Pilipenko   bool runOnLoop(Loop *L);
748fb3d57eSArtur Pilipenko };
758fb3d57eSArtur Pilipenko 
768fb3d57eSArtur Pilipenko class LoopPredicationLegacyPass : public LoopPass {
778fb3d57eSArtur Pilipenko public:
788fb3d57eSArtur Pilipenko   static char ID;
798fb3d57eSArtur Pilipenko   LoopPredicationLegacyPass() : LoopPass(ID) {
808fb3d57eSArtur Pilipenko     initializeLoopPredicationLegacyPassPass(*PassRegistry::getPassRegistry());
818fb3d57eSArtur Pilipenko   }
828fb3d57eSArtur Pilipenko 
838fb3d57eSArtur Pilipenko   void getAnalysisUsage(AnalysisUsage &AU) const override {
848fb3d57eSArtur Pilipenko     getLoopAnalysisUsage(AU);
858fb3d57eSArtur Pilipenko   }
868fb3d57eSArtur Pilipenko 
878fb3d57eSArtur Pilipenko   bool runOnLoop(Loop *L, LPPassManager &LPM) override {
888fb3d57eSArtur Pilipenko     if (skipLoop(L))
898fb3d57eSArtur Pilipenko       return false;
908fb3d57eSArtur Pilipenko     auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
918fb3d57eSArtur Pilipenko     LoopPredication LP(SE);
928fb3d57eSArtur Pilipenko     return LP.runOnLoop(L);
938fb3d57eSArtur Pilipenko   }
948fb3d57eSArtur Pilipenko };
958fb3d57eSArtur Pilipenko 
968fb3d57eSArtur Pilipenko char LoopPredicationLegacyPass::ID = 0;
978fb3d57eSArtur Pilipenko } // end namespace llvm
988fb3d57eSArtur Pilipenko 
998fb3d57eSArtur Pilipenko INITIALIZE_PASS_BEGIN(LoopPredicationLegacyPass, "loop-predication",
1008fb3d57eSArtur Pilipenko                       "Loop predication", false, false)
1018fb3d57eSArtur Pilipenko INITIALIZE_PASS_DEPENDENCY(LoopPass)
1028fb3d57eSArtur Pilipenko INITIALIZE_PASS_END(LoopPredicationLegacyPass, "loop-predication",
1038fb3d57eSArtur Pilipenko                     "Loop predication", false, false)
1048fb3d57eSArtur Pilipenko 
1058fb3d57eSArtur Pilipenko Pass *llvm::createLoopPredicationPass() {
1068fb3d57eSArtur Pilipenko   return new LoopPredicationLegacyPass();
1078fb3d57eSArtur Pilipenko }
1088fb3d57eSArtur Pilipenko 
1098fb3d57eSArtur Pilipenko PreservedAnalyses LoopPredicationPass::run(Loop &L, LoopAnalysisManager &AM,
1108fb3d57eSArtur Pilipenko                                            LoopStandardAnalysisResults &AR,
1118fb3d57eSArtur Pilipenko                                            LPMUpdater &U) {
1128fb3d57eSArtur Pilipenko   LoopPredication LP(&AR.SE);
1138fb3d57eSArtur Pilipenko   if (!LP.runOnLoop(&L))
1148fb3d57eSArtur Pilipenko     return PreservedAnalyses::all();
1158fb3d57eSArtur Pilipenko 
1168fb3d57eSArtur Pilipenko   return getLoopPassPreservedAnalyses();
1178fb3d57eSArtur Pilipenko }
1188fb3d57eSArtur Pilipenko 
1198fb3d57eSArtur Pilipenko /// If ICI can be widened to a loop invariant condition emits the loop
1208fb3d57eSArtur Pilipenko /// invariant condition in the loop preheader and return it, otherwise
1218fb3d57eSArtur Pilipenko /// returns None.
1228fb3d57eSArtur Pilipenko Optional<Value *> LoopPredication::widenICmpRangeCheck(ICmpInst *ICI,
1238fb3d57eSArtur Pilipenko                                                        SCEVExpander &Expander,
1248fb3d57eSArtur Pilipenko                                                        IRBuilder<> &Builder) {
1258fb3d57eSArtur Pilipenko   DEBUG(dbgs() << "Analyzing ICmpInst condition:\n");
1268fb3d57eSArtur Pilipenko   DEBUG(ICI->dump());
1278fb3d57eSArtur Pilipenko 
1288fb3d57eSArtur Pilipenko   ICmpInst::Predicate Pred = ICI->getPredicate();
1298fb3d57eSArtur Pilipenko   Value *LHS = ICI->getOperand(0);
1308fb3d57eSArtur Pilipenko   Value *RHS = ICI->getOperand(1);
1318fb3d57eSArtur Pilipenko   const SCEV *LHSS = SE->getSCEV(LHS);
1328fb3d57eSArtur Pilipenko   if (isa<SCEVCouldNotCompute>(LHSS))
1338fb3d57eSArtur Pilipenko     return None;
1348fb3d57eSArtur Pilipenko   const SCEV *RHSS = SE->getSCEV(RHS);
1358fb3d57eSArtur Pilipenko   if (isa<SCEVCouldNotCompute>(RHSS))
1368fb3d57eSArtur Pilipenko     return None;
1378fb3d57eSArtur Pilipenko 
1388fb3d57eSArtur Pilipenko   // Canonicalize RHS to be loop invariant bound, LHS - a loop computable index
1398fb3d57eSArtur Pilipenko   if (SE->isLoopInvariant(LHSS, L)) {
1408fb3d57eSArtur Pilipenko     std::swap(LHS, RHS);
1418fb3d57eSArtur Pilipenko     std::swap(LHSS, RHSS);
1428fb3d57eSArtur Pilipenko     Pred = ICmpInst::getSwappedPredicate(Pred);
1438fb3d57eSArtur Pilipenko   }
1448fb3d57eSArtur Pilipenko   if (!SE->isLoopInvariant(RHSS, L))
1458fb3d57eSArtur Pilipenko     return None;
1468fb3d57eSArtur Pilipenko 
1478fb3d57eSArtur Pilipenko   Value *Bound = RHS;
1488fb3d57eSArtur Pilipenko   const SCEVAddRecExpr *IndexAR = dyn_cast<SCEVAddRecExpr>(LHSS);
1498fb3d57eSArtur Pilipenko   if (!IndexAR || IndexAR->getLoop() != L)
1508fb3d57eSArtur Pilipenko     return None;
1518fb3d57eSArtur Pilipenko 
1528fb3d57eSArtur Pilipenko   DEBUG(dbgs() << "IndexAR: ");
1538fb3d57eSArtur Pilipenko   DEBUG(IndexAR->dump());
1548fb3d57eSArtur Pilipenko 
1558fb3d57eSArtur Pilipenko   bool IsIncreasing = false;
1568fb3d57eSArtur Pilipenko   if (!SE->isMonotonicPredicate(IndexAR, Pred, IsIncreasing))
1578fb3d57eSArtur Pilipenko     return None;
1588fb3d57eSArtur Pilipenko 
1598fb3d57eSArtur Pilipenko   // If the predicate is increasing the condition can change from false to true
1608fb3d57eSArtur Pilipenko   // as the loop progresses, in this case take the value on the first iteration
1618fb3d57eSArtur Pilipenko   // for the widened check. Otherwise the condition can change from true to
1628fb3d57eSArtur Pilipenko   // false as the loop progresses, so take the value on the last iteration.
1638fb3d57eSArtur Pilipenko   const SCEV *NewLHSS = IsIncreasing
1648fb3d57eSArtur Pilipenko                             ? IndexAR->getStart()
1658fb3d57eSArtur Pilipenko                             : SE->getSCEVAtScope(IndexAR, L->getParentLoop());
1668fb3d57eSArtur Pilipenko   if (NewLHSS == IndexAR) {
167*2cbaded5SArtur Pilipenko     DEBUG(dbgs() << "Can't compute NewLHSS!\n");
1688fb3d57eSArtur Pilipenko     return None;
1698fb3d57eSArtur Pilipenko   }
1708fb3d57eSArtur Pilipenko 
1718fb3d57eSArtur Pilipenko   DEBUG(dbgs() << "NewLHSS: ");
1728fb3d57eSArtur Pilipenko   DEBUG(NewLHSS->dump());
1738fb3d57eSArtur Pilipenko 
1748fb3d57eSArtur Pilipenko   if (!SE->isLoopInvariant(NewLHSS, L) || !isSafeToExpand(NewLHSS, *SE))
1758fb3d57eSArtur Pilipenko     return None;
1768fb3d57eSArtur Pilipenko 
1778fb3d57eSArtur Pilipenko   DEBUG(dbgs() << "NewLHSS is loop invariant and safe to expand. Expand!\n");
1788fb3d57eSArtur Pilipenko 
1798fb3d57eSArtur Pilipenko   Value *NewLHS = Expander.expandCodeFor(NewLHSS, Bound->getType(),
1808fb3d57eSArtur Pilipenko                                          Preheader->getTerminator());
1818fb3d57eSArtur Pilipenko   return Builder.CreateICmp(Pred, NewLHS, Bound);
1828fb3d57eSArtur Pilipenko }
1838fb3d57eSArtur Pilipenko 
1848fb3d57eSArtur Pilipenko bool LoopPredication::widenGuardConditions(IntrinsicInst *Guard,
1858fb3d57eSArtur Pilipenko                                            SCEVExpander &Expander) {
1868fb3d57eSArtur Pilipenko   DEBUG(dbgs() << "Processing guard:\n");
1878fb3d57eSArtur Pilipenko   DEBUG(Guard->dump());
1888fb3d57eSArtur Pilipenko 
1898fb3d57eSArtur Pilipenko   IRBuilder<> Builder(cast<Instruction>(Preheader->getTerminator()));
1908fb3d57eSArtur Pilipenko 
1918fb3d57eSArtur Pilipenko   // The guard condition is expected to be in form of:
1928fb3d57eSArtur Pilipenko   //   cond1 && cond2 && cond3 ...
1938fb3d57eSArtur Pilipenko   // Iterate over subconditions looking for for icmp conditions which can be
1948fb3d57eSArtur Pilipenko   // widened across loop iterations. Widening these conditions remember the
1958fb3d57eSArtur Pilipenko   // resulting list of subconditions in Checks vector.
1968fb3d57eSArtur Pilipenko   SmallVector<Value *, 4> Worklist(1, Guard->getOperand(0));
1978fb3d57eSArtur Pilipenko   SmallPtrSet<Value *, 4> Visited;
1988fb3d57eSArtur Pilipenko 
1998fb3d57eSArtur Pilipenko   SmallVector<Value *, 4> Checks;
2008fb3d57eSArtur Pilipenko 
2018fb3d57eSArtur Pilipenko   unsigned NumWidened = 0;
2028fb3d57eSArtur Pilipenko   do {
2038fb3d57eSArtur Pilipenko     Value *Condition = Worklist.pop_back_val();
2048fb3d57eSArtur Pilipenko     if (!Visited.insert(Condition).second)
2058fb3d57eSArtur Pilipenko       continue;
2068fb3d57eSArtur Pilipenko 
2078fb3d57eSArtur Pilipenko     Value *LHS, *RHS;
2088fb3d57eSArtur Pilipenko     using namespace llvm::PatternMatch;
2098fb3d57eSArtur Pilipenko     if (match(Condition, m_And(m_Value(LHS), m_Value(RHS)))) {
2108fb3d57eSArtur Pilipenko       Worklist.push_back(LHS);
2118fb3d57eSArtur Pilipenko       Worklist.push_back(RHS);
2128fb3d57eSArtur Pilipenko       continue;
2138fb3d57eSArtur Pilipenko     }
2148fb3d57eSArtur Pilipenko 
2158fb3d57eSArtur Pilipenko     if (ICmpInst *ICI = dyn_cast<ICmpInst>(Condition)) {
2168fb3d57eSArtur Pilipenko       if (auto NewRangeCheck = widenICmpRangeCheck(ICI, Expander, Builder)) {
2178fb3d57eSArtur Pilipenko         Checks.push_back(NewRangeCheck.getValue());
2188fb3d57eSArtur Pilipenko         NumWidened++;
2198fb3d57eSArtur Pilipenko         continue;
2208fb3d57eSArtur Pilipenko       }
2218fb3d57eSArtur Pilipenko     }
2228fb3d57eSArtur Pilipenko 
2238fb3d57eSArtur Pilipenko     // Save the condition as is if we can't widen it
2248fb3d57eSArtur Pilipenko     Checks.push_back(Condition);
2258fb3d57eSArtur Pilipenko   } while (Worklist.size() != 0);
2268fb3d57eSArtur Pilipenko 
2278fb3d57eSArtur Pilipenko   if (NumWidened == 0)
2288fb3d57eSArtur Pilipenko     return false;
2298fb3d57eSArtur Pilipenko 
2308fb3d57eSArtur Pilipenko   // Emit the new guard condition
2318fb3d57eSArtur Pilipenko   Builder.SetInsertPoint(Guard);
2328fb3d57eSArtur Pilipenko   Value *LastCheck = nullptr;
2338fb3d57eSArtur Pilipenko   for (auto *Check : Checks)
2348fb3d57eSArtur Pilipenko     if (!LastCheck)
2358fb3d57eSArtur Pilipenko       LastCheck = Check;
2368fb3d57eSArtur Pilipenko     else
2378fb3d57eSArtur Pilipenko       LastCheck = Builder.CreateAnd(LastCheck, Check);
2388fb3d57eSArtur Pilipenko   Guard->setOperand(0, LastCheck);
2398fb3d57eSArtur Pilipenko 
2408fb3d57eSArtur Pilipenko   DEBUG(dbgs() << "Widened checks = " << NumWidened << "\n");
2418fb3d57eSArtur Pilipenko   return true;
2428fb3d57eSArtur Pilipenko }
2438fb3d57eSArtur Pilipenko 
2448fb3d57eSArtur Pilipenko bool LoopPredication::runOnLoop(Loop *Loop) {
2458fb3d57eSArtur Pilipenko   L = Loop;
2468fb3d57eSArtur Pilipenko 
2478fb3d57eSArtur Pilipenko   DEBUG(dbgs() << "Analyzing ");
2488fb3d57eSArtur Pilipenko   DEBUG(L->dump());
2498fb3d57eSArtur Pilipenko 
2508fb3d57eSArtur Pilipenko   Module *M = L->getHeader()->getModule();
2518fb3d57eSArtur Pilipenko 
2528fb3d57eSArtur Pilipenko   // There is nothing to do if the module doesn't use guards
2538fb3d57eSArtur Pilipenko   auto *GuardDecl =
2548fb3d57eSArtur Pilipenko       M->getFunction(Intrinsic::getName(Intrinsic::experimental_guard));
2558fb3d57eSArtur Pilipenko   if (!GuardDecl || GuardDecl->use_empty())
2568fb3d57eSArtur Pilipenko     return false;
2578fb3d57eSArtur Pilipenko 
2588fb3d57eSArtur Pilipenko   DL = &M->getDataLayout();
2598fb3d57eSArtur Pilipenko 
2608fb3d57eSArtur Pilipenko   Preheader = L->getLoopPreheader();
2618fb3d57eSArtur Pilipenko   if (!Preheader)
2628fb3d57eSArtur Pilipenko     return false;
2638fb3d57eSArtur Pilipenko 
2648fb3d57eSArtur Pilipenko   // Collect all the guards into a vector and process later, so as not
2658fb3d57eSArtur Pilipenko   // to invalidate the instruction iterator.
2668fb3d57eSArtur Pilipenko   SmallVector<IntrinsicInst *, 4> Guards;
2678fb3d57eSArtur Pilipenko   for (const auto BB : L->blocks())
2688fb3d57eSArtur Pilipenko     for (auto &I : *BB)
2698fb3d57eSArtur Pilipenko       if (auto *II = dyn_cast<IntrinsicInst>(&I))
2708fb3d57eSArtur Pilipenko         if (II->getIntrinsicID() == Intrinsic::experimental_guard)
2718fb3d57eSArtur Pilipenko           Guards.push_back(II);
2728fb3d57eSArtur Pilipenko 
2738fb3d57eSArtur Pilipenko   SCEVExpander Expander(*SE, *DL, "loop-predication");
2748fb3d57eSArtur Pilipenko 
2758fb3d57eSArtur Pilipenko   bool Changed = false;
2768fb3d57eSArtur Pilipenko   for (auto *Guard : Guards)
2778fb3d57eSArtur Pilipenko     Changed |= widenGuardConditions(Guard, Expander);
2788fb3d57eSArtur Pilipenko 
2798fb3d57eSArtur Pilipenko   return Changed;
2808fb3d57eSArtur Pilipenko }
281