1*9d9cb274SAdam Nemet //===-------- LoopDataPrefetch.cpp - Loop Data Prefetching Pass -----------===//
2*9d9cb274SAdam Nemet //
3*9d9cb274SAdam Nemet //                     The LLVM Compiler Infrastructure
4*9d9cb274SAdam Nemet //
5*9d9cb274SAdam Nemet // This file is distributed under the University of Illinois Open Source
6*9d9cb274SAdam Nemet // License. See LICENSE.TXT for details.
7*9d9cb274SAdam Nemet //
8*9d9cb274SAdam Nemet //===----------------------------------------------------------------------===//
9*9d9cb274SAdam Nemet //
10*9d9cb274SAdam Nemet // This file implements a Loop Data Prefetching Pass.
11*9d9cb274SAdam Nemet //
12*9d9cb274SAdam Nemet //===----------------------------------------------------------------------===//
13*9d9cb274SAdam Nemet 
14*9d9cb274SAdam Nemet #define DEBUG_TYPE "loop-data-prefetch"
15*9d9cb274SAdam Nemet #include "llvm/Transforms/Scalar.h"
16*9d9cb274SAdam Nemet #include "llvm/ADT/DepthFirstIterator.h"
17*9d9cb274SAdam Nemet #include "llvm/ADT/Statistic.h"
18*9d9cb274SAdam Nemet #include "llvm/Analysis/AssumptionCache.h"
19*9d9cb274SAdam Nemet #include "llvm/Analysis/CodeMetrics.h"
20*9d9cb274SAdam Nemet #include "llvm/Analysis/InstructionSimplify.h"
21*9d9cb274SAdam Nemet #include "llvm/Analysis/LoopInfo.h"
22*9d9cb274SAdam Nemet #include "llvm/Analysis/ScalarEvolution.h"
23*9d9cb274SAdam Nemet #include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h"
24*9d9cb274SAdam Nemet #include "llvm/Analysis/ScalarEvolutionExpander.h"
25*9d9cb274SAdam Nemet #include "llvm/Analysis/ScalarEvolutionExpressions.h"
26*9d9cb274SAdam Nemet #include "llvm/Analysis/TargetTransformInfo.h"
27*9d9cb274SAdam Nemet #include "llvm/Analysis/ValueTracking.h"
28*9d9cb274SAdam Nemet #include "llvm/IR/CFG.h"
29*9d9cb274SAdam Nemet #include "llvm/IR/Dominators.h"
30*9d9cb274SAdam Nemet #include "llvm/IR/Function.h"
31*9d9cb274SAdam Nemet #include "llvm/IR/IntrinsicInst.h"
32*9d9cb274SAdam Nemet #include "llvm/IR/Module.h"
33*9d9cb274SAdam Nemet #include "llvm/Support/CommandLine.h"
34*9d9cb274SAdam Nemet #include "llvm/Support/Debug.h"
35*9d9cb274SAdam Nemet #include "llvm/Transforms/Utils/BasicBlockUtils.h"
36*9d9cb274SAdam Nemet #include "llvm/Transforms/Utils/Local.h"
37*9d9cb274SAdam Nemet #include "llvm/Transforms/Utils/ValueMapper.h"
38*9d9cb274SAdam Nemet using namespace llvm;
39*9d9cb274SAdam Nemet 
40*9d9cb274SAdam Nemet // By default, we limit this to creating 16 PHIs (which is a little over half
41*9d9cb274SAdam Nemet // of the allocatable register set).
42*9d9cb274SAdam Nemet static cl::opt<bool>
43*9d9cb274SAdam Nemet PrefetchWrites("loop-prefetch-writes", cl::Hidden, cl::init(false),
44*9d9cb274SAdam Nemet                cl::desc("Prefetch write addresses"));
45*9d9cb274SAdam Nemet 
46*9d9cb274SAdam Nemet namespace llvm {
47*9d9cb274SAdam Nemet   void initializeLoopDataPrefetchPass(PassRegistry&);
48*9d9cb274SAdam Nemet }
49*9d9cb274SAdam Nemet 
50*9d9cb274SAdam Nemet namespace {
51*9d9cb274SAdam Nemet 
52*9d9cb274SAdam Nemet   class LoopDataPrefetch : public FunctionPass {
53*9d9cb274SAdam Nemet   public:
54*9d9cb274SAdam Nemet     static char ID; // Pass ID, replacement for typeid
55*9d9cb274SAdam Nemet     LoopDataPrefetch() : FunctionPass(ID) {
56*9d9cb274SAdam Nemet       initializeLoopDataPrefetchPass(*PassRegistry::getPassRegistry());
57*9d9cb274SAdam Nemet     }
58*9d9cb274SAdam Nemet 
59*9d9cb274SAdam Nemet     void getAnalysisUsage(AnalysisUsage &AU) const override {
60*9d9cb274SAdam Nemet       AU.addRequired<AssumptionCacheTracker>();
61*9d9cb274SAdam Nemet       AU.addPreserved<DominatorTreeWrapperPass>();
62*9d9cb274SAdam Nemet       AU.addRequired<LoopInfoWrapperPass>();
63*9d9cb274SAdam Nemet       AU.addPreserved<LoopInfoWrapperPass>();
64*9d9cb274SAdam Nemet       AU.addRequired<ScalarEvolutionWrapperPass>();
65*9d9cb274SAdam Nemet       // FIXME: For some reason, preserving SE here breaks LSR (even if
66*9d9cb274SAdam Nemet       // this pass changes nothing).
67*9d9cb274SAdam Nemet       // AU.addPreserved<ScalarEvolutionWrapperPass>();
68*9d9cb274SAdam Nemet       AU.addRequired<TargetTransformInfoWrapperPass>();
69*9d9cb274SAdam Nemet     }
70*9d9cb274SAdam Nemet 
71*9d9cb274SAdam Nemet     bool runOnFunction(Function &F) override;
72*9d9cb274SAdam Nemet     bool runOnLoop(Loop *L);
73*9d9cb274SAdam Nemet 
74*9d9cb274SAdam Nemet   private:
75*9d9cb274SAdam Nemet     AssumptionCache *AC;
76*9d9cb274SAdam Nemet     LoopInfo *LI;
77*9d9cb274SAdam Nemet     ScalarEvolution *SE;
78*9d9cb274SAdam Nemet     const TargetTransformInfo *TTI;
79*9d9cb274SAdam Nemet     const DataLayout *DL;
80*9d9cb274SAdam Nemet   };
81*9d9cb274SAdam Nemet }
82*9d9cb274SAdam Nemet 
83*9d9cb274SAdam Nemet char LoopDataPrefetch::ID = 0;
84*9d9cb274SAdam Nemet INITIALIZE_PASS_BEGIN(LoopDataPrefetch, "loop-data-prefetch",
85*9d9cb274SAdam Nemet                       "Loop Data Prefetch", false, false)
86*9d9cb274SAdam Nemet INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
87*9d9cb274SAdam Nemet INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
88*9d9cb274SAdam Nemet INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
89*9d9cb274SAdam Nemet INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
90*9d9cb274SAdam Nemet INITIALIZE_PASS_END(LoopDataPrefetch, "loop-data-prefetch",
91*9d9cb274SAdam Nemet                     "Loop Data Prefetch", false, false)
92*9d9cb274SAdam Nemet 
93*9d9cb274SAdam Nemet FunctionPass *llvm::createLoopDataPrefetchPass() { return new LoopDataPrefetch(); }
94*9d9cb274SAdam Nemet 
95*9d9cb274SAdam Nemet bool LoopDataPrefetch::runOnFunction(Function &F) {
96*9d9cb274SAdam Nemet   LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
97*9d9cb274SAdam Nemet   SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
98*9d9cb274SAdam Nemet   DL = &F.getParent()->getDataLayout();
99*9d9cb274SAdam Nemet   AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
100*9d9cb274SAdam Nemet   TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
101*9d9cb274SAdam Nemet 
102*9d9cb274SAdam Nemet   assert(TTI->getCacheLineSize() && "Cache line size is not set for target");
103*9d9cb274SAdam Nemet   assert(TTI->getPrefetchDistance() &&
104*9d9cb274SAdam Nemet          "Prefetch distance is not set for target");
105*9d9cb274SAdam Nemet 
106*9d9cb274SAdam Nemet   bool MadeChange = false;
107*9d9cb274SAdam Nemet 
108*9d9cb274SAdam Nemet   for (auto I = LI->begin(), IE = LI->end(); I != IE; ++I)
109*9d9cb274SAdam Nemet     for (auto L = df_begin(*I), LE = df_end(*I); L != LE; ++L)
110*9d9cb274SAdam Nemet       MadeChange |= runOnLoop(*L);
111*9d9cb274SAdam Nemet 
112*9d9cb274SAdam Nemet   return MadeChange;
113*9d9cb274SAdam Nemet }
114*9d9cb274SAdam Nemet 
115*9d9cb274SAdam Nemet bool LoopDataPrefetch::runOnLoop(Loop *L) {
116*9d9cb274SAdam Nemet   bool MadeChange = false;
117*9d9cb274SAdam Nemet 
118*9d9cb274SAdam Nemet   // Only prefetch in the inner-most loop
119*9d9cb274SAdam Nemet   if (!L->empty())
120*9d9cb274SAdam Nemet     return MadeChange;
121*9d9cb274SAdam Nemet 
122*9d9cb274SAdam Nemet   SmallPtrSet<const Value *, 32> EphValues;
123*9d9cb274SAdam Nemet   CodeMetrics::collectEphemeralValues(L, AC, EphValues);
124*9d9cb274SAdam Nemet 
125*9d9cb274SAdam Nemet   // Calculate the number of iterations ahead to prefetch
126*9d9cb274SAdam Nemet   CodeMetrics Metrics;
127*9d9cb274SAdam Nemet   for (Loop::block_iterator I = L->block_begin(), IE = L->block_end();
128*9d9cb274SAdam Nemet        I != IE; ++I) {
129*9d9cb274SAdam Nemet 
130*9d9cb274SAdam Nemet     // If the loop already has prefetches, then assume that the user knows
131*9d9cb274SAdam Nemet     // what he or she is doing and don't add any more.
132*9d9cb274SAdam Nemet     for (BasicBlock::iterator J = (*I)->begin(), JE = (*I)->end();
133*9d9cb274SAdam Nemet          J != JE; ++J)
134*9d9cb274SAdam Nemet       if (CallInst *CI = dyn_cast<CallInst>(J))
135*9d9cb274SAdam Nemet         if (Function *F = CI->getCalledFunction())
136*9d9cb274SAdam Nemet           if (F->getIntrinsicID() == Intrinsic::prefetch)
137*9d9cb274SAdam Nemet             return MadeChange;
138*9d9cb274SAdam Nemet 
139*9d9cb274SAdam Nemet     Metrics.analyzeBasicBlock(*I, *TTI, EphValues);
140*9d9cb274SAdam Nemet   }
141*9d9cb274SAdam Nemet   unsigned LoopSize = Metrics.NumInsts;
142*9d9cb274SAdam Nemet   if (!LoopSize)
143*9d9cb274SAdam Nemet     LoopSize = 1;
144*9d9cb274SAdam Nemet 
145*9d9cb274SAdam Nemet   unsigned ItersAhead = TTI->getPrefetchDistance() / LoopSize;
146*9d9cb274SAdam Nemet   if (!ItersAhead)
147*9d9cb274SAdam Nemet     ItersAhead = 1;
148*9d9cb274SAdam Nemet 
149*9d9cb274SAdam Nemet   SmallVector<std::pair<Instruction *, const SCEVAddRecExpr *>, 16> PrefLoads;
150*9d9cb274SAdam Nemet   for (Loop::block_iterator I = L->block_begin(), IE = L->block_end();
151*9d9cb274SAdam Nemet        I != IE; ++I) {
152*9d9cb274SAdam Nemet     for (BasicBlock::iterator J = (*I)->begin(), JE = (*I)->end();
153*9d9cb274SAdam Nemet         J != JE; ++J) {
154*9d9cb274SAdam Nemet       Value *PtrValue;
155*9d9cb274SAdam Nemet       Instruction *MemI;
156*9d9cb274SAdam Nemet 
157*9d9cb274SAdam Nemet       if (LoadInst *LMemI = dyn_cast<LoadInst>(J)) {
158*9d9cb274SAdam Nemet         MemI = LMemI;
159*9d9cb274SAdam Nemet         PtrValue = LMemI->getPointerOperand();
160*9d9cb274SAdam Nemet       } else if (StoreInst *SMemI = dyn_cast<StoreInst>(J)) {
161*9d9cb274SAdam Nemet         if (!PrefetchWrites) continue;
162*9d9cb274SAdam Nemet         MemI = SMemI;
163*9d9cb274SAdam Nemet         PtrValue = SMemI->getPointerOperand();
164*9d9cb274SAdam Nemet       } else continue;
165*9d9cb274SAdam Nemet 
166*9d9cb274SAdam Nemet       unsigned PtrAddrSpace = PtrValue->getType()->getPointerAddressSpace();
167*9d9cb274SAdam Nemet       if (PtrAddrSpace)
168*9d9cb274SAdam Nemet         continue;
169*9d9cb274SAdam Nemet 
170*9d9cb274SAdam Nemet       if (L->isLoopInvariant(PtrValue))
171*9d9cb274SAdam Nemet         continue;
172*9d9cb274SAdam Nemet 
173*9d9cb274SAdam Nemet       const SCEV *LSCEV = SE->getSCEV(PtrValue);
174*9d9cb274SAdam Nemet       const SCEVAddRecExpr *LSCEVAddRec = dyn_cast<SCEVAddRecExpr>(LSCEV);
175*9d9cb274SAdam Nemet       if (!LSCEVAddRec)
176*9d9cb274SAdam Nemet         continue;
177*9d9cb274SAdam Nemet 
178*9d9cb274SAdam Nemet       // We don't want to double prefetch individual cache lines. If this load
179*9d9cb274SAdam Nemet       // is known to be within one cache line of some other load that has
180*9d9cb274SAdam Nemet       // already been prefetched, then don't prefetch this one as well.
181*9d9cb274SAdam Nemet       bool DupPref = false;
182*9d9cb274SAdam Nemet       for (SmallVector<std::pair<Instruction *, const SCEVAddRecExpr *>,
183*9d9cb274SAdam Nemet              16>::iterator K = PrefLoads.begin(), KE = PrefLoads.end();
184*9d9cb274SAdam Nemet            K != KE; ++K) {
185*9d9cb274SAdam Nemet         const SCEV *PtrDiff = SE->getMinusSCEV(LSCEVAddRec, K->second);
186*9d9cb274SAdam Nemet         if (const SCEVConstant *ConstPtrDiff =
187*9d9cb274SAdam Nemet             dyn_cast<SCEVConstant>(PtrDiff)) {
188*9d9cb274SAdam Nemet           int64_t PD = std::abs(ConstPtrDiff->getValue()->getSExtValue());
189*9d9cb274SAdam Nemet           if (PD < (int64_t) TTI->getCacheLineSize()) {
190*9d9cb274SAdam Nemet             DupPref = true;
191*9d9cb274SAdam Nemet             break;
192*9d9cb274SAdam Nemet           }
193*9d9cb274SAdam Nemet         }
194*9d9cb274SAdam Nemet       }
195*9d9cb274SAdam Nemet       if (DupPref)
196*9d9cb274SAdam Nemet         continue;
197*9d9cb274SAdam Nemet 
198*9d9cb274SAdam Nemet       const SCEV *NextLSCEV = SE->getAddExpr(LSCEVAddRec, SE->getMulExpr(
199*9d9cb274SAdam Nemet         SE->getConstant(LSCEVAddRec->getType(), ItersAhead),
200*9d9cb274SAdam Nemet         LSCEVAddRec->getStepRecurrence(*SE)));
201*9d9cb274SAdam Nemet       if (!isSafeToExpand(NextLSCEV, *SE))
202*9d9cb274SAdam Nemet         continue;
203*9d9cb274SAdam Nemet 
204*9d9cb274SAdam Nemet       PrefLoads.push_back(std::make_pair(MemI, LSCEVAddRec));
205*9d9cb274SAdam Nemet 
206*9d9cb274SAdam Nemet       Type *I8Ptr = Type::getInt8PtrTy((*I)->getContext(), PtrAddrSpace);
207*9d9cb274SAdam Nemet       SCEVExpander SCEVE(*SE, J->getModule()->getDataLayout(), "prefaddr");
208*9d9cb274SAdam Nemet       Value *PrefPtrValue = SCEVE.expandCodeFor(NextLSCEV, I8Ptr, MemI);
209*9d9cb274SAdam Nemet 
210*9d9cb274SAdam Nemet       IRBuilder<> Builder(MemI);
211*9d9cb274SAdam Nemet       Module *M = (*I)->getParent()->getParent();
212*9d9cb274SAdam Nemet       Type *I32 = Type::getInt32Ty((*I)->getContext());
213*9d9cb274SAdam Nemet       Value *PrefetchFunc = Intrinsic::getDeclaration(M, Intrinsic::prefetch);
214*9d9cb274SAdam Nemet       Builder.CreateCall(
215*9d9cb274SAdam Nemet           PrefetchFunc,
216*9d9cb274SAdam Nemet           {PrefPtrValue,
217*9d9cb274SAdam Nemet            ConstantInt::get(I32, MemI->mayReadFromMemory() ? 0 : 1),
218*9d9cb274SAdam Nemet            ConstantInt::get(I32, 3), ConstantInt::get(I32, 1)});
219*9d9cb274SAdam Nemet 
220*9d9cb274SAdam Nemet       MadeChange = true;
221*9d9cb274SAdam Nemet     }
222*9d9cb274SAdam Nemet   }
223*9d9cb274SAdam Nemet 
224*9d9cb274SAdam Nemet   return MadeChange;
225*9d9cb274SAdam Nemet }
226*9d9cb274SAdam Nemet 
227