19d9cb274SAdam Nemet //===-------- LoopDataPrefetch.cpp - Loop Data Prefetching Pass -----------===//
29d9cb274SAdam Nemet //
39d9cb274SAdam Nemet //                     The LLVM Compiler Infrastructure
49d9cb274SAdam Nemet //
59d9cb274SAdam Nemet // This file is distributed under the University of Illinois Open Source
69d9cb274SAdam Nemet // License. See LICENSE.TXT for details.
79d9cb274SAdam Nemet //
89d9cb274SAdam Nemet //===----------------------------------------------------------------------===//
99d9cb274SAdam Nemet //
109d9cb274SAdam Nemet // This file implements a Loop Data Prefetching Pass.
119d9cb274SAdam Nemet //
129d9cb274SAdam Nemet //===----------------------------------------------------------------------===//
139d9cb274SAdam Nemet 
149d9cb274SAdam Nemet #define DEBUG_TYPE "loop-data-prefetch"
159d9cb274SAdam Nemet #include "llvm/Transforms/Scalar.h"
169d9cb274SAdam Nemet #include "llvm/ADT/DepthFirstIterator.h"
179d9cb274SAdam Nemet #include "llvm/ADT/Statistic.h"
189d9cb274SAdam Nemet #include "llvm/Analysis/AssumptionCache.h"
199d9cb274SAdam Nemet #include "llvm/Analysis/CodeMetrics.h"
209d9cb274SAdam Nemet #include "llvm/Analysis/InstructionSimplify.h"
219d9cb274SAdam Nemet #include "llvm/Analysis/LoopInfo.h"
229d9cb274SAdam Nemet #include "llvm/Analysis/ScalarEvolution.h"
239d9cb274SAdam Nemet #include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h"
249d9cb274SAdam Nemet #include "llvm/Analysis/ScalarEvolutionExpander.h"
259d9cb274SAdam Nemet #include "llvm/Analysis/ScalarEvolutionExpressions.h"
269d9cb274SAdam Nemet #include "llvm/Analysis/TargetTransformInfo.h"
279d9cb274SAdam Nemet #include "llvm/Analysis/ValueTracking.h"
289d9cb274SAdam Nemet #include "llvm/IR/CFG.h"
29*3c5eabfcSAdam Nemet #include "llvm/IR/DiagnosticInfo.h"
309d9cb274SAdam Nemet #include "llvm/IR/Dominators.h"
319d9cb274SAdam Nemet #include "llvm/IR/Function.h"
329d9cb274SAdam Nemet #include "llvm/IR/IntrinsicInst.h"
339d9cb274SAdam Nemet #include "llvm/IR/Module.h"
349d9cb274SAdam Nemet #include "llvm/Support/CommandLine.h"
359d9cb274SAdam Nemet #include "llvm/Support/Debug.h"
369d9cb274SAdam Nemet #include "llvm/Transforms/Utils/BasicBlockUtils.h"
379d9cb274SAdam Nemet #include "llvm/Transforms/Utils/Local.h"
389d9cb274SAdam Nemet #include "llvm/Transforms/Utils/ValueMapper.h"
399d9cb274SAdam Nemet using namespace llvm;
409d9cb274SAdam Nemet 
419d9cb274SAdam Nemet // By default, we limit this to creating 16 PHIs (which is a little over half
429d9cb274SAdam Nemet // of the allocatable register set).
439d9cb274SAdam Nemet static cl::opt<bool>
449d9cb274SAdam Nemet PrefetchWrites("loop-prefetch-writes", cl::Hidden, cl::init(false),
459d9cb274SAdam Nemet                cl::desc("Prefetch write addresses"));
469d9cb274SAdam Nemet 
471428d41fSAdam Nemet static cl::opt<unsigned>
481428d41fSAdam Nemet     PrefetchDistance("prefetch-distance",
491428d41fSAdam Nemet                      cl::desc("Number of instructions to prefetch ahead"),
501428d41fSAdam Nemet                      cl::Hidden);
511428d41fSAdam Nemet 
521428d41fSAdam Nemet static cl::opt<unsigned>
531428d41fSAdam Nemet     MinPrefetchStride("min-prefetch-stride",
541428d41fSAdam Nemet                       cl::desc("Min stride to add prefetches"), cl::Hidden);
551428d41fSAdam Nemet 
561428d41fSAdam Nemet static cl::opt<unsigned> MaxPrefetchIterationsAhead(
571428d41fSAdam Nemet     "max-prefetch-iters-ahead",
581428d41fSAdam Nemet     cl::desc("Max number of iterations to prefetch ahead"), cl::Hidden);
591428d41fSAdam Nemet 
6034785ecfSAdam Nemet STATISTIC(NumPrefetches, "Number of prefetches inserted");
6134785ecfSAdam Nemet 
629d9cb274SAdam Nemet namespace llvm {
639d9cb274SAdam Nemet   void initializeLoopDataPrefetchPass(PassRegistry&);
649d9cb274SAdam Nemet }
659d9cb274SAdam Nemet 
669d9cb274SAdam Nemet namespace {
679d9cb274SAdam Nemet 
689d9cb274SAdam Nemet   class LoopDataPrefetch : public FunctionPass {
699d9cb274SAdam Nemet   public:
709d9cb274SAdam Nemet     static char ID; // Pass ID, replacement for typeid
719d9cb274SAdam Nemet     LoopDataPrefetch() : FunctionPass(ID) {
729d9cb274SAdam Nemet       initializeLoopDataPrefetchPass(*PassRegistry::getPassRegistry());
739d9cb274SAdam Nemet     }
749d9cb274SAdam Nemet 
759d9cb274SAdam Nemet     void getAnalysisUsage(AnalysisUsage &AU) const override {
769d9cb274SAdam Nemet       AU.addRequired<AssumptionCacheTracker>();
779d9cb274SAdam Nemet       AU.addPreserved<DominatorTreeWrapperPass>();
789d9cb274SAdam Nemet       AU.addRequired<LoopInfoWrapperPass>();
799d9cb274SAdam Nemet       AU.addPreserved<LoopInfoWrapperPass>();
809d9cb274SAdam Nemet       AU.addRequired<ScalarEvolutionWrapperPass>();
819d9cb274SAdam Nemet       // FIXME: For some reason, preserving SE here breaks LSR (even if
829d9cb274SAdam Nemet       // this pass changes nothing).
839d9cb274SAdam Nemet       // AU.addPreserved<ScalarEvolutionWrapperPass>();
849d9cb274SAdam Nemet       AU.addRequired<TargetTransformInfoWrapperPass>();
859d9cb274SAdam Nemet     }
869d9cb274SAdam Nemet 
879d9cb274SAdam Nemet     bool runOnFunction(Function &F) override;
8885fba393SAdam Nemet 
8985fba393SAdam Nemet   private:
909d9cb274SAdam Nemet     bool runOnLoop(Loop *L);
919d9cb274SAdam Nemet 
926d8beecaSAdam Nemet     /// \brief Check if the the stride of the accesses is large enough to
936d8beecaSAdam Nemet     /// warrant a prefetch.
946d8beecaSAdam Nemet     bool isStrideLargeEnough(const SCEVAddRecExpr *AR);
956d8beecaSAdam Nemet 
961428d41fSAdam Nemet     unsigned getMinPrefetchStride() {
971428d41fSAdam Nemet       if (MinPrefetchStride.getNumOccurrences() > 0)
981428d41fSAdam Nemet         return MinPrefetchStride;
991428d41fSAdam Nemet       return TTI->getMinPrefetchStride();
1001428d41fSAdam Nemet     }
1011428d41fSAdam Nemet 
1021428d41fSAdam Nemet     unsigned getPrefetchDistance() {
1031428d41fSAdam Nemet       if (PrefetchDistance.getNumOccurrences() > 0)
1041428d41fSAdam Nemet         return PrefetchDistance;
1051428d41fSAdam Nemet       return TTI->getPrefetchDistance();
1061428d41fSAdam Nemet     }
1071428d41fSAdam Nemet 
1081428d41fSAdam Nemet     unsigned getMaxPrefetchIterationsAhead() {
1091428d41fSAdam Nemet       if (MaxPrefetchIterationsAhead.getNumOccurrences() > 0)
1101428d41fSAdam Nemet         return MaxPrefetchIterationsAhead;
1111428d41fSAdam Nemet       return TTI->getMaxPrefetchIterationsAhead();
1121428d41fSAdam Nemet     }
1131428d41fSAdam Nemet 
1149d9cb274SAdam Nemet     AssumptionCache *AC;
1159d9cb274SAdam Nemet     LoopInfo *LI;
1169d9cb274SAdam Nemet     ScalarEvolution *SE;
1179d9cb274SAdam Nemet     const TargetTransformInfo *TTI;
1189d9cb274SAdam Nemet     const DataLayout *DL;
1199d9cb274SAdam Nemet   };
1209d9cb274SAdam Nemet }
1219d9cb274SAdam Nemet 
1229d9cb274SAdam Nemet char LoopDataPrefetch::ID = 0;
1239d9cb274SAdam Nemet INITIALIZE_PASS_BEGIN(LoopDataPrefetch, "loop-data-prefetch",
1249d9cb274SAdam Nemet                       "Loop Data Prefetch", false, false)
1259d9cb274SAdam Nemet INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
1269d9cb274SAdam Nemet INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
1279d9cb274SAdam Nemet INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
1289d9cb274SAdam Nemet INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
1299d9cb274SAdam Nemet INITIALIZE_PASS_END(LoopDataPrefetch, "loop-data-prefetch",
1309d9cb274SAdam Nemet                     "Loop Data Prefetch", false, false)
1319d9cb274SAdam Nemet 
1329d9cb274SAdam Nemet FunctionPass *llvm::createLoopDataPrefetchPass() { return new LoopDataPrefetch(); }
1339d9cb274SAdam Nemet 
1346d8beecaSAdam Nemet bool LoopDataPrefetch::isStrideLargeEnough(const SCEVAddRecExpr *AR) {
1351428d41fSAdam Nemet   unsigned TargetMinStride = getMinPrefetchStride();
1366d8beecaSAdam Nemet   // No need to check if any stride goes.
1376d8beecaSAdam Nemet   if (TargetMinStride <= 1)
1386d8beecaSAdam Nemet     return true;
1396d8beecaSAdam Nemet 
1406d8beecaSAdam Nemet   const auto *ConstStride = dyn_cast<SCEVConstant>(AR->getStepRecurrence(*SE));
1416d8beecaSAdam Nemet   // If MinStride is set, don't prefetch unless we can ensure that stride is
1426d8beecaSAdam Nemet   // larger.
1436d8beecaSAdam Nemet   if (!ConstStride)
1446d8beecaSAdam Nemet     return false;
1456d8beecaSAdam Nemet 
1466d8beecaSAdam Nemet   unsigned AbsStride = std::abs(ConstStride->getAPInt().getSExtValue());
1476d8beecaSAdam Nemet   return TargetMinStride <= AbsStride;
1486d8beecaSAdam Nemet }
1496d8beecaSAdam Nemet 
1509d9cb274SAdam Nemet bool LoopDataPrefetch::runOnFunction(Function &F) {
15150271f78SAndrew Kaylor   if (skipFunction(F))
15250271f78SAndrew Kaylor     return false;
15350271f78SAndrew Kaylor 
1549d9cb274SAdam Nemet   LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
1559d9cb274SAdam Nemet   SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
1569d9cb274SAdam Nemet   DL = &F.getParent()->getDataLayout();
1579d9cb274SAdam Nemet   AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
1589d9cb274SAdam Nemet   TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
1599d9cb274SAdam Nemet 
160bb3680bdSAdam Nemet   // If PrefetchDistance is not set, don't run the pass.  This gives an
161bb3680bdSAdam Nemet   // opportunity for targets to run this pass for selected subtargets only
162bb3680bdSAdam Nemet   // (whose TTI sets PrefetchDistance).
1631428d41fSAdam Nemet   if (getPrefetchDistance() == 0)
164bb3680bdSAdam Nemet     return false;
1659d9cb274SAdam Nemet   assert(TTI->getCacheLineSize() && "Cache line size is not set for target");
1669d9cb274SAdam Nemet 
1679d9cb274SAdam Nemet   bool MadeChange = false;
1689d9cb274SAdam Nemet 
1699d9cb274SAdam Nemet   for (auto I = LI->begin(), IE = LI->end(); I != IE; ++I)
1709d9cb274SAdam Nemet     for (auto L = df_begin(*I), LE = df_end(*I); L != LE; ++L)
1719d9cb274SAdam Nemet       MadeChange |= runOnLoop(*L);
1729d9cb274SAdam Nemet 
1739d9cb274SAdam Nemet   return MadeChange;
1749d9cb274SAdam Nemet }
1759d9cb274SAdam Nemet 
1769d9cb274SAdam Nemet bool LoopDataPrefetch::runOnLoop(Loop *L) {
1779d9cb274SAdam Nemet   bool MadeChange = false;
1789d9cb274SAdam Nemet 
1799d9cb274SAdam Nemet   // Only prefetch in the inner-most loop
1809d9cb274SAdam Nemet   if (!L->empty())
1819d9cb274SAdam Nemet     return MadeChange;
1829d9cb274SAdam Nemet 
1839d9cb274SAdam Nemet   SmallPtrSet<const Value *, 32> EphValues;
1849d9cb274SAdam Nemet   CodeMetrics::collectEphemeralValues(L, AC, EphValues);
1859d9cb274SAdam Nemet 
1869d9cb274SAdam Nemet   // Calculate the number of iterations ahead to prefetch
1879d9cb274SAdam Nemet   CodeMetrics Metrics;
1889d9cb274SAdam Nemet   for (Loop::block_iterator I = L->block_begin(), IE = L->block_end();
1899d9cb274SAdam Nemet        I != IE; ++I) {
1909d9cb274SAdam Nemet 
1919d9cb274SAdam Nemet     // If the loop already has prefetches, then assume that the user knows
1929d9cb274SAdam Nemet     // what he or she is doing and don't add any more.
1939d9cb274SAdam Nemet     for (BasicBlock::iterator J = (*I)->begin(), JE = (*I)->end();
1949d9cb274SAdam Nemet          J != JE; ++J)
1959d9cb274SAdam Nemet       if (CallInst *CI = dyn_cast<CallInst>(J))
1969d9cb274SAdam Nemet         if (Function *F = CI->getCalledFunction())
1979d9cb274SAdam Nemet           if (F->getIntrinsicID() == Intrinsic::prefetch)
1989d9cb274SAdam Nemet             return MadeChange;
1999d9cb274SAdam Nemet 
2009d9cb274SAdam Nemet     Metrics.analyzeBasicBlock(*I, *TTI, EphValues);
2019d9cb274SAdam Nemet   }
2029d9cb274SAdam Nemet   unsigned LoopSize = Metrics.NumInsts;
2039d9cb274SAdam Nemet   if (!LoopSize)
2049d9cb274SAdam Nemet     LoopSize = 1;
2059d9cb274SAdam Nemet 
2061428d41fSAdam Nemet   unsigned ItersAhead = getPrefetchDistance() / LoopSize;
2079d9cb274SAdam Nemet   if (!ItersAhead)
2089d9cb274SAdam Nemet     ItersAhead = 1;
2099d9cb274SAdam Nemet 
2101428d41fSAdam Nemet   if (ItersAhead > getMaxPrefetchIterationsAhead())
211709e3046SAdam Nemet     return MadeChange;
212709e3046SAdam Nemet 
213*3c5eabfcSAdam Nemet   Function *F = L->getHeader()->getParent();
21434785ecfSAdam Nemet   DEBUG(dbgs() << "Prefetching " << ItersAhead
21534785ecfSAdam Nemet                << " iterations ahead (loop size: " << LoopSize << ") in "
216*3c5eabfcSAdam Nemet                << F->getName() << ": " << *L);
21734785ecfSAdam Nemet 
2189d9cb274SAdam Nemet   SmallVector<std::pair<Instruction *, const SCEVAddRecExpr *>, 16> PrefLoads;
2199d9cb274SAdam Nemet   for (Loop::block_iterator I = L->block_begin(), IE = L->block_end();
2209d9cb274SAdam Nemet        I != IE; ++I) {
2219d9cb274SAdam Nemet     for (BasicBlock::iterator J = (*I)->begin(), JE = (*I)->end();
2229d9cb274SAdam Nemet         J != JE; ++J) {
2239d9cb274SAdam Nemet       Value *PtrValue;
2249d9cb274SAdam Nemet       Instruction *MemI;
2259d9cb274SAdam Nemet 
2269d9cb274SAdam Nemet       if (LoadInst *LMemI = dyn_cast<LoadInst>(J)) {
2279d9cb274SAdam Nemet         MemI = LMemI;
2289d9cb274SAdam Nemet         PtrValue = LMemI->getPointerOperand();
2299d9cb274SAdam Nemet       } else if (StoreInst *SMemI = dyn_cast<StoreInst>(J)) {
2309d9cb274SAdam Nemet         if (!PrefetchWrites) continue;
2319d9cb274SAdam Nemet         MemI = SMemI;
2329d9cb274SAdam Nemet         PtrValue = SMemI->getPointerOperand();
2339d9cb274SAdam Nemet       } else continue;
2349d9cb274SAdam Nemet 
2359d9cb274SAdam Nemet       unsigned PtrAddrSpace = PtrValue->getType()->getPointerAddressSpace();
2369d9cb274SAdam Nemet       if (PtrAddrSpace)
2379d9cb274SAdam Nemet         continue;
2389d9cb274SAdam Nemet 
2399d9cb274SAdam Nemet       if (L->isLoopInvariant(PtrValue))
2409d9cb274SAdam Nemet         continue;
2419d9cb274SAdam Nemet 
2429d9cb274SAdam Nemet       const SCEV *LSCEV = SE->getSCEV(PtrValue);
2439d9cb274SAdam Nemet       const SCEVAddRecExpr *LSCEVAddRec = dyn_cast<SCEVAddRecExpr>(LSCEV);
2449d9cb274SAdam Nemet       if (!LSCEVAddRec)
2459d9cb274SAdam Nemet         continue;
2469d9cb274SAdam Nemet 
2476d8beecaSAdam Nemet       // Check if the the stride of the accesses is large enough to warrant a
2486d8beecaSAdam Nemet       // prefetch.
2496d8beecaSAdam Nemet       if (!isStrideLargeEnough(LSCEVAddRec))
2506d8beecaSAdam Nemet         continue;
2516d8beecaSAdam Nemet 
2529d9cb274SAdam Nemet       // We don't want to double prefetch individual cache lines. If this load
2539d9cb274SAdam Nemet       // is known to be within one cache line of some other load that has
2549d9cb274SAdam Nemet       // already been prefetched, then don't prefetch this one as well.
2559d9cb274SAdam Nemet       bool DupPref = false;
2569d9cb274SAdam Nemet       for (SmallVector<std::pair<Instruction *, const SCEVAddRecExpr *>,
2579d9cb274SAdam Nemet              16>::iterator K = PrefLoads.begin(), KE = PrefLoads.end();
2589d9cb274SAdam Nemet            K != KE; ++K) {
2599d9cb274SAdam Nemet         const SCEV *PtrDiff = SE->getMinusSCEV(LSCEVAddRec, K->second);
2609d9cb274SAdam Nemet         if (const SCEVConstant *ConstPtrDiff =
2619d9cb274SAdam Nemet             dyn_cast<SCEVConstant>(PtrDiff)) {
2629d9cb274SAdam Nemet           int64_t PD = std::abs(ConstPtrDiff->getValue()->getSExtValue());
2639d9cb274SAdam Nemet           if (PD < (int64_t) TTI->getCacheLineSize()) {
2649d9cb274SAdam Nemet             DupPref = true;
2659d9cb274SAdam Nemet             break;
2669d9cb274SAdam Nemet           }
2679d9cb274SAdam Nemet         }
2689d9cb274SAdam Nemet       }
2699d9cb274SAdam Nemet       if (DupPref)
2709d9cb274SAdam Nemet         continue;
2719d9cb274SAdam Nemet 
2729d9cb274SAdam Nemet       const SCEV *NextLSCEV = SE->getAddExpr(LSCEVAddRec, SE->getMulExpr(
2739d9cb274SAdam Nemet         SE->getConstant(LSCEVAddRec->getType(), ItersAhead),
2749d9cb274SAdam Nemet         LSCEVAddRec->getStepRecurrence(*SE)));
2759d9cb274SAdam Nemet       if (!isSafeToExpand(NextLSCEV, *SE))
2769d9cb274SAdam Nemet         continue;
2779d9cb274SAdam Nemet 
2789d9cb274SAdam Nemet       PrefLoads.push_back(std::make_pair(MemI, LSCEVAddRec));
2799d9cb274SAdam Nemet 
2809d9cb274SAdam Nemet       Type *I8Ptr = Type::getInt8PtrTy((*I)->getContext(), PtrAddrSpace);
2819d9cb274SAdam Nemet       SCEVExpander SCEVE(*SE, J->getModule()->getDataLayout(), "prefaddr");
2829d9cb274SAdam Nemet       Value *PrefPtrValue = SCEVE.expandCodeFor(NextLSCEV, I8Ptr, MemI);
2839d9cb274SAdam Nemet 
2849d9cb274SAdam Nemet       IRBuilder<> Builder(MemI);
2859d9cb274SAdam Nemet       Module *M = (*I)->getParent()->getParent();
2869d9cb274SAdam Nemet       Type *I32 = Type::getInt32Ty((*I)->getContext());
2879d9cb274SAdam Nemet       Value *PrefetchFunc = Intrinsic::getDeclaration(M, Intrinsic::prefetch);
2889d9cb274SAdam Nemet       Builder.CreateCall(
2899d9cb274SAdam Nemet           PrefetchFunc,
2909d9cb274SAdam Nemet           {PrefPtrValue,
2919d9cb274SAdam Nemet            ConstantInt::get(I32, MemI->mayReadFromMemory() ? 0 : 1),
2929d9cb274SAdam Nemet            ConstantInt::get(I32, 3), ConstantInt::get(I32, 1)});
29334785ecfSAdam Nemet       ++NumPrefetches;
29434785ecfSAdam Nemet       DEBUG(dbgs() << "  Access: " << *PtrValue << ", SCEV: " << *LSCEV
29534785ecfSAdam Nemet                    << "\n");
296*3c5eabfcSAdam Nemet       emitOptimizationRemark(F->getContext(), DEBUG_TYPE, *F,
297*3c5eabfcSAdam Nemet                              MemI->getDebugLoc(), "prefetched memory access");
298*3c5eabfcSAdam Nemet 
2999d9cb274SAdam Nemet 
3009d9cb274SAdam Nemet       MadeChange = true;
3019d9cb274SAdam Nemet     }
3029d9cb274SAdam Nemet   }
3039d9cb274SAdam Nemet 
3049d9cb274SAdam Nemet   return MadeChange;
3059d9cb274SAdam Nemet }
3069d9cb274SAdam Nemet 
307