1 //===- LoadStoreVectorizer.cpp - GPU Load & Store Vectorizer --------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 
10 #include "llvm/ADT/APInt.h"
11 #include "llvm/ADT/ArrayRef.h"
12 #include "llvm/ADT/MapVector.h"
13 #include "llvm/ADT/PostOrderIterator.h"
14 #include "llvm/ADT/STLExtras.h"
15 #include "llvm/ADT/SmallPtrSet.h"
16 #include "llvm/ADT/SmallVector.h"
17 #include "llvm/ADT/Statistic.h"
18 #include "llvm/ADT/iterator_range.h"
19 #include "llvm/Analysis/AliasAnalysis.h"
20 #include "llvm/Analysis/MemoryLocation.h"
21 #include "llvm/Analysis/OrderedBasicBlock.h"
22 #include "llvm/Analysis/ScalarEvolution.h"
23 #include "llvm/Analysis/TargetTransformInfo.h"
24 #include "llvm/Analysis/ValueTracking.h"
25 #include "llvm/Analysis/VectorUtils.h"
26 #include "llvm/IR/Attributes.h"
27 #include "llvm/IR/BasicBlock.h"
28 #include "llvm/IR/Constants.h"
29 #include "llvm/IR/DataLayout.h"
30 #include "llvm/IR/DerivedTypes.h"
31 #include "llvm/IR/Dominators.h"
32 #include "llvm/IR/Function.h"
33 #include "llvm/IR/IRBuilder.h"
34 #include "llvm/IR/InstrTypes.h"
35 #include "llvm/IR/Instruction.h"
36 #include "llvm/IR/Instructions.h"
37 #include "llvm/IR/IntrinsicInst.h"
38 #include "llvm/IR/Module.h"
39 #include "llvm/IR/Type.h"
40 #include "llvm/IR/User.h"
41 #include "llvm/IR/Value.h"
42 #include "llvm/Pass.h"
43 #include "llvm/Support/Casting.h"
44 #include "llvm/Support/Debug.h"
45 #include "llvm/Support/KnownBits.h"
46 #include "llvm/Support/MathExtras.h"
47 #include "llvm/Support/raw_ostream.h"
48 #include "llvm/Transforms/Utils/Local.h"
49 #include "llvm/Transforms/Vectorize.h"
50 #include <algorithm>
51 #include <cassert>
52 #include <cstdlib>
53 #include <tuple>
54 #include <utility>
55 
56 using namespace llvm;
57 
58 #define DEBUG_TYPE "load-store-vectorizer"
59 
60 STATISTIC(NumVectorInstructions, "Number of vector accesses generated");
61 STATISTIC(NumScalarsVectorized, "Number of scalar accesses vectorized");
62 
63 // FIXME: Assuming stack alignment of 4 is always good enough
64 static const unsigned StackAdjustedAlignment = 4;
65 
66 namespace {
67 
68 using InstrList = SmallVector<Instruction *, 8>;
69 using InstrListMap = MapVector<Value *, InstrList>;
70 
71 class Vectorizer {
72   Function &F;
73   AliasAnalysis &AA;
74   DominatorTree &DT;
75   ScalarEvolution &SE;
76   TargetTransformInfo &TTI;
77   const DataLayout &DL;
78   IRBuilder<> Builder;
79 
80 public:
81   Vectorizer(Function &F, AliasAnalysis &AA, DominatorTree &DT,
82              ScalarEvolution &SE, TargetTransformInfo &TTI)
83       : F(F), AA(AA), DT(DT), SE(SE), TTI(TTI),
84         DL(F.getParent()->getDataLayout()), Builder(SE.getContext()) {}
85 
86   bool run();
87 
88 private:
89   Value *getPointerOperand(Value *I) const;
90 
91   GetElementPtrInst *getSourceGEP(Value *Src) const;
92 
93   unsigned getPointerAddressSpace(Value *I);
94 
95   unsigned getAlignment(LoadInst *LI) const {
96     unsigned Align = LI->getAlignment();
97     if (Align != 0)
98       return Align;
99 
100     return DL.getABITypeAlignment(LI->getType());
101   }
102 
103   unsigned getAlignment(StoreInst *SI) const {
104     unsigned Align = SI->getAlignment();
105     if (Align != 0)
106       return Align;
107 
108     return DL.getABITypeAlignment(SI->getValueOperand()->getType());
109   }
110 
111   bool isConsecutiveAccess(Value *A, Value *B);
112 
113   /// After vectorization, reorder the instructions that I depends on
114   /// (the instructions defining its operands), to ensure they dominate I.
115   void reorder(Instruction *I);
116 
117   /// Returns the first and the last instructions in Chain.
118   std::pair<BasicBlock::iterator, BasicBlock::iterator>
119   getBoundaryInstrs(ArrayRef<Instruction *> Chain);
120 
121   /// Erases the original instructions after vectorizing.
122   void eraseInstructions(ArrayRef<Instruction *> Chain);
123 
124   /// "Legalize" the vector type that would be produced by combining \p
125   /// ElementSizeBits elements in \p Chain. Break into two pieces such that the
126   /// total size of each piece is 1, 2 or a multiple of 4 bytes. \p Chain is
127   /// expected to have more than 4 elements.
128   std::pair<ArrayRef<Instruction *>, ArrayRef<Instruction *>>
129   splitOddVectorElts(ArrayRef<Instruction *> Chain, unsigned ElementSizeBits);
130 
131   /// Finds the largest prefix of Chain that's vectorizable, checking for
132   /// intervening instructions which may affect the memory accessed by the
133   /// instructions within Chain.
134   ///
135   /// The elements of \p Chain must be all loads or all stores and must be in
136   /// address order.
137   ArrayRef<Instruction *> getVectorizablePrefix(ArrayRef<Instruction *> Chain);
138 
139   /// Collects load and store instructions to vectorize.
140   std::pair<InstrListMap, InstrListMap> collectInstructions(BasicBlock *BB);
141 
142   /// Processes the collected instructions, the \p Map. The values of \p Map
143   /// should be all loads or all stores.
144   bool vectorizeChains(InstrListMap &Map);
145 
146   /// Finds the load/stores to consecutive memory addresses and vectorizes them.
147   bool vectorizeInstructions(ArrayRef<Instruction *> Instrs);
148 
149   /// Vectorizes the load instructions in Chain.
150   bool
151   vectorizeLoadChain(ArrayRef<Instruction *> Chain,
152                      SmallPtrSet<Instruction *, 16> *InstructionsProcessed);
153 
154   /// Vectorizes the store instructions in Chain.
155   bool
156   vectorizeStoreChain(ArrayRef<Instruction *> Chain,
157                       SmallPtrSet<Instruction *, 16> *InstructionsProcessed);
158 
159   /// Check if this load/store access is misaligned accesses.
160   bool accessIsMisaligned(unsigned SzInBytes, unsigned AddressSpace,
161                           unsigned Alignment);
162 };
163 
164 class LoadStoreVectorizer : public FunctionPass {
165 public:
166   static char ID;
167 
168   LoadStoreVectorizer() : FunctionPass(ID) {
169     initializeLoadStoreVectorizerPass(*PassRegistry::getPassRegistry());
170   }
171 
172   bool runOnFunction(Function &F) override;
173 
174   StringRef getPassName() const override {
175     return "GPU Load and Store Vectorizer";
176   }
177 
178   void getAnalysisUsage(AnalysisUsage &AU) const override {
179     AU.addRequired<AAResultsWrapperPass>();
180     AU.addRequired<ScalarEvolutionWrapperPass>();
181     AU.addRequired<DominatorTreeWrapperPass>();
182     AU.addRequired<TargetTransformInfoWrapperPass>();
183     AU.setPreservesCFG();
184   }
185 };
186 
187 } // end anonymous namespace
188 
189 char LoadStoreVectorizer::ID = 0;
190 
191 INITIALIZE_PASS_BEGIN(LoadStoreVectorizer, DEBUG_TYPE,
192                       "Vectorize load and Store instructions", false, false)
193 INITIALIZE_PASS_DEPENDENCY(SCEVAAWrapperPass)
194 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
195 INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass)
196 INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass)
197 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
198 INITIALIZE_PASS_END(LoadStoreVectorizer, DEBUG_TYPE,
199                     "Vectorize load and store instructions", false, false)
200 
201 Pass *llvm::createLoadStoreVectorizerPass() {
202   return new LoadStoreVectorizer();
203 }
204 
205 // The real propagateMetadata expects a SmallVector<Value*>, but we deal in
206 // vectors of Instructions.
207 static void propagateMetadata(Instruction *I, ArrayRef<Instruction *> IL) {
208   SmallVector<Value *, 8> VL(IL.begin(), IL.end());
209   propagateMetadata(I, VL);
210 }
211 
212 bool LoadStoreVectorizer::runOnFunction(Function &F) {
213   // Don't vectorize when the attribute NoImplicitFloat is used.
214   if (skipFunction(F) || F.hasFnAttribute(Attribute::NoImplicitFloat))
215     return false;
216 
217   AliasAnalysis &AA = getAnalysis<AAResultsWrapperPass>().getAAResults();
218   DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
219   ScalarEvolution &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE();
220   TargetTransformInfo &TTI =
221       getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
222 
223   Vectorizer V(F, AA, DT, SE, TTI);
224   return V.run();
225 }
226 
227 // Vectorizer Implementation
228 bool Vectorizer::run() {
229   bool Changed = false;
230 
231   // Scan the blocks in the function in post order.
232   for (BasicBlock *BB : post_order(&F)) {
233     InstrListMap LoadRefs, StoreRefs;
234     std::tie(LoadRefs, StoreRefs) = collectInstructions(BB);
235     Changed |= vectorizeChains(LoadRefs);
236     Changed |= vectorizeChains(StoreRefs);
237   }
238 
239   return Changed;
240 }
241 
242 Value *Vectorizer::getPointerOperand(Value *I) const {
243   if (LoadInst *LI = dyn_cast<LoadInst>(I))
244     return LI->getPointerOperand();
245   if (StoreInst *SI = dyn_cast<StoreInst>(I))
246     return SI->getPointerOperand();
247   return nullptr;
248 }
249 
250 unsigned Vectorizer::getPointerAddressSpace(Value *I) {
251   if (LoadInst *L = dyn_cast<LoadInst>(I))
252     return L->getPointerAddressSpace();
253   if (StoreInst *S = dyn_cast<StoreInst>(I))
254     return S->getPointerAddressSpace();
255   return -1;
256 }
257 
258 GetElementPtrInst *Vectorizer::getSourceGEP(Value *Src) const {
259   // First strip pointer bitcasts. Make sure pointee size is the same with
260   // and without casts.
261   // TODO: a stride set by the add instruction below can match the difference
262   // in pointee type size here. Currently it will not be vectorized.
263   Value *SrcPtr = getPointerOperand(Src);
264   Value *SrcBase = SrcPtr->stripPointerCasts();
265   if (DL.getTypeStoreSize(SrcPtr->getType()->getPointerElementType()) ==
266       DL.getTypeStoreSize(SrcBase->getType()->getPointerElementType()))
267     SrcPtr = SrcBase;
268   return dyn_cast<GetElementPtrInst>(SrcPtr);
269 }
270 
271 // FIXME: Merge with llvm::isConsecutiveAccess
272 bool Vectorizer::isConsecutiveAccess(Value *A, Value *B) {
273   Value *PtrA = getPointerOperand(A);
274   Value *PtrB = getPointerOperand(B);
275   unsigned ASA = getPointerAddressSpace(A);
276   unsigned ASB = getPointerAddressSpace(B);
277 
278   // Check that the address spaces match and that the pointers are valid.
279   if (!PtrA || !PtrB || (ASA != ASB))
280     return false;
281 
282   // Make sure that A and B are different pointers of the same size type.
283   unsigned PtrBitWidth = DL.getPointerSizeInBits(ASA);
284   Type *PtrATy = PtrA->getType()->getPointerElementType();
285   Type *PtrBTy = PtrB->getType()->getPointerElementType();
286   if (PtrA == PtrB ||
287       DL.getTypeStoreSize(PtrATy) != DL.getTypeStoreSize(PtrBTy) ||
288       DL.getTypeStoreSize(PtrATy->getScalarType()) !=
289           DL.getTypeStoreSize(PtrBTy->getScalarType()))
290     return false;
291 
292   APInt Size(PtrBitWidth, DL.getTypeStoreSize(PtrATy));
293 
294   APInt OffsetA(PtrBitWidth, 0), OffsetB(PtrBitWidth, 0);
295   PtrA = PtrA->stripAndAccumulateInBoundsConstantOffsets(DL, OffsetA);
296   PtrB = PtrB->stripAndAccumulateInBoundsConstantOffsets(DL, OffsetB);
297 
298   APInt OffsetDelta = OffsetB - OffsetA;
299 
300   // Check if they are based on the same pointer. That makes the offsets
301   // sufficient.
302   if (PtrA == PtrB)
303     return OffsetDelta == Size;
304 
305   // Compute the necessary base pointer delta to have the necessary final delta
306   // equal to the size.
307   APInt BaseDelta = Size - OffsetDelta;
308 
309   // Compute the distance with SCEV between the base pointers.
310   const SCEV *PtrSCEVA = SE.getSCEV(PtrA);
311   const SCEV *PtrSCEVB = SE.getSCEV(PtrB);
312   const SCEV *C = SE.getConstant(BaseDelta);
313   const SCEV *X = SE.getAddExpr(PtrSCEVA, C);
314   if (X == PtrSCEVB)
315     return true;
316 
317   // Sometimes even this doesn't work, because SCEV can't always see through
318   // patterns that look like (gep (ext (add (shl X, C1), C2))). Try checking
319   // things the hard way.
320 
321   // Look through GEPs after checking they're the same except for the last
322   // index.
323   GetElementPtrInst *GEPA = getSourceGEP(A);
324   GetElementPtrInst *GEPB = getSourceGEP(B);
325   if (!GEPA || !GEPB || GEPA->getNumOperands() != GEPB->getNumOperands())
326     return false;
327   unsigned FinalIndex = GEPA->getNumOperands() - 1;
328   for (unsigned i = 0; i < FinalIndex; i++)
329     if (GEPA->getOperand(i) != GEPB->getOperand(i))
330       return false;
331 
332   Instruction *OpA = dyn_cast<Instruction>(GEPA->getOperand(FinalIndex));
333   Instruction *OpB = dyn_cast<Instruction>(GEPB->getOperand(FinalIndex));
334   if (!OpA || !OpB || OpA->getOpcode() != OpB->getOpcode() ||
335       OpA->getType() != OpB->getType())
336     return false;
337 
338   // Only look through a ZExt/SExt.
339   if (!isa<SExtInst>(OpA) && !isa<ZExtInst>(OpA))
340     return false;
341 
342   bool Signed = isa<SExtInst>(OpA);
343 
344   OpA = dyn_cast<Instruction>(OpA->getOperand(0));
345   OpB = dyn_cast<Instruction>(OpB->getOperand(0));
346   if (!OpA || !OpB || OpA->getType() != OpB->getType())
347     return false;
348 
349   // Now we need to prove that adding 1 to OpA won't overflow.
350   bool Safe = false;
351   // First attempt: if OpB is an add with NSW/NUW, and OpB is 1 added to OpA,
352   // we're okay.
353   if (OpB->getOpcode() == Instruction::Add &&
354       isa<ConstantInt>(OpB->getOperand(1)) &&
355       cast<ConstantInt>(OpB->getOperand(1))->getSExtValue() > 0) {
356     if (Signed)
357       Safe = cast<BinaryOperator>(OpB)->hasNoSignedWrap();
358     else
359       Safe = cast<BinaryOperator>(OpB)->hasNoUnsignedWrap();
360   }
361 
362   unsigned BitWidth = OpA->getType()->getScalarSizeInBits();
363 
364   // Second attempt:
365   // If any bits are known to be zero other than the sign bit in OpA, we can
366   // add 1 to it while guaranteeing no overflow of any sort.
367   if (!Safe) {
368     KnownBits Known(BitWidth);
369     computeKnownBits(OpA, Known, DL, 0, nullptr, OpA, &DT);
370     if (Known.countMaxTrailingOnes() < (BitWidth - 1))
371       Safe = true;
372   }
373 
374   if (!Safe)
375     return false;
376 
377   const SCEV *OffsetSCEVA = SE.getSCEV(OpA);
378   const SCEV *OffsetSCEVB = SE.getSCEV(OpB);
379   const SCEV *One = SE.getConstant(APInt(BitWidth, 1));
380   const SCEV *X2 = SE.getAddExpr(OffsetSCEVA, One);
381   return X2 == OffsetSCEVB;
382 }
383 
384 void Vectorizer::reorder(Instruction *I) {
385   OrderedBasicBlock OBB(I->getParent());
386   SmallPtrSet<Instruction *, 16> InstructionsToMove;
387   SmallVector<Instruction *, 16> Worklist;
388 
389   Worklist.push_back(I);
390   while (!Worklist.empty()) {
391     Instruction *IW = Worklist.pop_back_val();
392     int NumOperands = IW->getNumOperands();
393     for (int i = 0; i < NumOperands; i++) {
394       Instruction *IM = dyn_cast<Instruction>(IW->getOperand(i));
395       if (!IM || IM->getOpcode() == Instruction::PHI)
396         continue;
397 
398       // If IM is in another BB, no need to move it, because this pass only
399       // vectorizes instructions within one BB.
400       if (IM->getParent() != I->getParent())
401         continue;
402 
403       if (!OBB.dominates(IM, I)) {
404         InstructionsToMove.insert(IM);
405         Worklist.push_back(IM);
406       }
407     }
408   }
409 
410   // All instructions to move should follow I. Start from I, not from begin().
411   for (auto BBI = I->getIterator(), E = I->getParent()->end(); BBI != E;
412        ++BBI) {
413     if (!InstructionsToMove.count(&*BBI))
414       continue;
415     Instruction *IM = &*BBI;
416     --BBI;
417     IM->removeFromParent();
418     IM->insertBefore(I);
419   }
420 }
421 
422 std::pair<BasicBlock::iterator, BasicBlock::iterator>
423 Vectorizer::getBoundaryInstrs(ArrayRef<Instruction *> Chain) {
424   Instruction *C0 = Chain[0];
425   BasicBlock::iterator FirstInstr = C0->getIterator();
426   BasicBlock::iterator LastInstr = C0->getIterator();
427 
428   BasicBlock *BB = C0->getParent();
429   unsigned NumFound = 0;
430   for (Instruction &I : *BB) {
431     if (!is_contained(Chain, &I))
432       continue;
433 
434     ++NumFound;
435     if (NumFound == 1) {
436       FirstInstr = I.getIterator();
437     }
438     if (NumFound == Chain.size()) {
439       LastInstr = I.getIterator();
440       break;
441     }
442   }
443 
444   // Range is [first, last).
445   return std::make_pair(FirstInstr, ++LastInstr);
446 }
447 
448 void Vectorizer::eraseInstructions(ArrayRef<Instruction *> Chain) {
449   SmallVector<Instruction *, 16> Instrs;
450   for (Instruction *I : Chain) {
451     Value *PtrOperand = getPointerOperand(I);
452     assert(PtrOperand && "Instruction must have a pointer operand.");
453     Instrs.push_back(I);
454     if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(PtrOperand))
455       Instrs.push_back(GEP);
456   }
457 
458   // Erase instructions.
459   for (Instruction *I : Instrs)
460     if (I->use_empty())
461       I->eraseFromParent();
462 }
463 
464 std::pair<ArrayRef<Instruction *>, ArrayRef<Instruction *>>
465 Vectorizer::splitOddVectorElts(ArrayRef<Instruction *> Chain,
466                                unsigned ElementSizeBits) {
467   unsigned ElementSizeBytes = ElementSizeBits / 8;
468   unsigned SizeBytes = ElementSizeBytes * Chain.size();
469   unsigned NumLeft = (SizeBytes - (SizeBytes % 4)) / ElementSizeBytes;
470   if (NumLeft == Chain.size()) {
471     if ((NumLeft & 1) == 0)
472       NumLeft /= 2; // Split even in half
473     else
474       --NumLeft;    // Split off last element
475   } else if (NumLeft == 0)
476     NumLeft = 1;
477   return std::make_pair(Chain.slice(0, NumLeft), Chain.slice(NumLeft));
478 }
479 
480 ArrayRef<Instruction *>
481 Vectorizer::getVectorizablePrefix(ArrayRef<Instruction *> Chain) {
482   // These are in BB order, unlike Chain, which is in address order.
483   SmallVector<Instruction *, 16> MemoryInstrs;
484   SmallVector<Instruction *, 16> ChainInstrs;
485 
486   bool IsLoadChain = isa<LoadInst>(Chain[0]);
487   DEBUG({
488     for (Instruction *I : Chain) {
489       if (IsLoadChain)
490         assert(isa<LoadInst>(I) &&
491                "All elements of Chain must be loads, or all must be stores.");
492       else
493         assert(isa<StoreInst>(I) &&
494                "All elements of Chain must be loads, or all must be stores.");
495     }
496   });
497 
498   for (Instruction &I : make_range(getBoundaryInstrs(Chain))) {
499     if (isa<LoadInst>(I) || isa<StoreInst>(I)) {
500       if (!is_contained(Chain, &I))
501         MemoryInstrs.push_back(&I);
502       else
503         ChainInstrs.push_back(&I);
504     } else if (isa<IntrinsicInst>(&I) &&
505                cast<IntrinsicInst>(&I)->getIntrinsicID() ==
506                    Intrinsic::sideeffect) {
507       // Ignore llvm.sideeffect calls.
508     } else if (IsLoadChain && (I.mayWriteToMemory() || I.mayThrow())) {
509       DEBUG(dbgs() << "LSV: Found may-write/throw operation: " << I << '\n');
510       break;
511     } else if (!IsLoadChain && (I.mayReadOrWriteMemory() || I.mayThrow())) {
512       DEBUG(dbgs() << "LSV: Found may-read/write/throw operation: " << I
513                    << '\n');
514       break;
515     }
516   }
517 
518   OrderedBasicBlock OBB(Chain[0]->getParent());
519 
520   // Loop until we find an instruction in ChainInstrs that we can't vectorize.
521   unsigned ChainInstrIdx = 0;
522   Instruction *BarrierMemoryInstr = nullptr;
523 
524   for (unsigned E = ChainInstrs.size(); ChainInstrIdx < E; ++ChainInstrIdx) {
525     Instruction *ChainInstr = ChainInstrs[ChainInstrIdx];
526 
527     // If a barrier memory instruction was found, chain instructions that follow
528     // will not be added to the valid prefix.
529     if (BarrierMemoryInstr && OBB.dominates(BarrierMemoryInstr, ChainInstr))
530       break;
531 
532     // Check (in BB order) if any instruction prevents ChainInstr from being
533     // vectorized. Find and store the first such "conflicting" instruction.
534     for (Instruction *MemInstr : MemoryInstrs) {
535       // If a barrier memory instruction was found, do not check past it.
536       if (BarrierMemoryInstr && OBB.dominates(BarrierMemoryInstr, MemInstr))
537         break;
538 
539       if (isa<LoadInst>(MemInstr) && isa<LoadInst>(ChainInstr))
540         continue;
541 
542       // We can ignore the alias as long as the load comes before the store,
543       // because that means we won't be moving the load past the store to
544       // vectorize it (the vectorized load is inserted at the location of the
545       // first load in the chain).
546       if (isa<StoreInst>(MemInstr) && isa<LoadInst>(ChainInstr) &&
547           OBB.dominates(ChainInstr, MemInstr))
548         continue;
549 
550       // Same case, but in reverse.
551       if (isa<LoadInst>(MemInstr) && isa<StoreInst>(ChainInstr) &&
552           OBB.dominates(MemInstr, ChainInstr))
553         continue;
554 
555       if (!AA.isNoAlias(MemoryLocation::get(MemInstr),
556                         MemoryLocation::get(ChainInstr))) {
557         DEBUG({
558           dbgs() << "LSV: Found alias:\n"
559                     "  Aliasing instruction and pointer:\n"
560                  << "  " << *MemInstr << '\n'
561                  << "  " << *getPointerOperand(MemInstr) << '\n'
562                  << "  Aliased instruction and pointer:\n"
563                  << "  " << *ChainInstr << '\n'
564                  << "  " << *getPointerOperand(ChainInstr) << '\n';
565         });
566         // Save this aliasing memory instruction as a barrier, but allow other
567         // instructions that precede the barrier to be vectorized with this one.
568         BarrierMemoryInstr = MemInstr;
569         break;
570       }
571     }
572     // Continue the search only for store chains, since vectorizing stores that
573     // precede an aliasing load is valid. Conversely, vectorizing loads is valid
574     // up to an aliasing store, but should not pull loads from further down in
575     // the basic block.
576     if (IsLoadChain && BarrierMemoryInstr) {
577       // The BarrierMemoryInstr is a store that precedes ChainInstr.
578       assert(OBB.dominates(BarrierMemoryInstr, ChainInstr));
579       break;
580     }
581   }
582 
583   // Find the largest prefix of Chain whose elements are all in
584   // ChainInstrs[0, ChainInstrIdx).  This is the largest vectorizable prefix of
585   // Chain.  (Recall that Chain is in address order, but ChainInstrs is in BB
586   // order.)
587   SmallPtrSet<Instruction *, 8> VectorizableChainInstrs(
588       ChainInstrs.begin(), ChainInstrs.begin() + ChainInstrIdx);
589   unsigned ChainIdx = 0;
590   for (unsigned ChainLen = Chain.size(); ChainIdx < ChainLen; ++ChainIdx) {
591     if (!VectorizableChainInstrs.count(Chain[ChainIdx]))
592       break;
593   }
594   return Chain.slice(0, ChainIdx);
595 }
596 
597 std::pair<InstrListMap, InstrListMap>
598 Vectorizer::collectInstructions(BasicBlock *BB) {
599   InstrListMap LoadRefs;
600   InstrListMap StoreRefs;
601 
602   for (Instruction &I : *BB) {
603     if (!I.mayReadOrWriteMemory())
604       continue;
605 
606     if (LoadInst *LI = dyn_cast<LoadInst>(&I)) {
607       if (!LI->isSimple())
608         continue;
609 
610       // Skip if it's not legal.
611       if (!TTI.isLegalToVectorizeLoad(LI))
612         continue;
613 
614       Type *Ty = LI->getType();
615       if (!VectorType::isValidElementType(Ty->getScalarType()))
616         continue;
617 
618       // Skip weird non-byte sizes. They probably aren't worth the effort of
619       // handling correctly.
620       unsigned TySize = DL.getTypeSizeInBits(Ty);
621       if ((TySize % 8) != 0)
622         continue;
623 
624       // Skip vectors of pointers. The vectorizeLoadChain/vectorizeStoreChain
625       // functions are currently using an integer type for the vectorized
626       // load/store, and does not support casting between the integer type and a
627       // vector of pointers (e.g. i64 to <2 x i16*>)
628       if (Ty->isVectorTy() && Ty->isPtrOrPtrVectorTy())
629         continue;
630 
631       Value *Ptr = LI->getPointerOperand();
632       unsigned AS = Ptr->getType()->getPointerAddressSpace();
633       unsigned VecRegSize = TTI.getLoadStoreVecRegBitWidth(AS);
634 
635       // No point in looking at these if they're too big to vectorize.
636       if (TySize > VecRegSize / 2)
637         continue;
638 
639       // Make sure all the users of a vector are constant-index extracts.
640       if (isa<VectorType>(Ty) && !llvm::all_of(LI->users(), [](const User *U) {
641             const ExtractElementInst *EEI = dyn_cast<ExtractElementInst>(U);
642             return EEI && isa<ConstantInt>(EEI->getOperand(1));
643           }))
644         continue;
645 
646       // Save the load locations.
647       Value *ObjPtr = GetUnderlyingObject(Ptr, DL);
648       LoadRefs[ObjPtr].push_back(LI);
649     } else if (StoreInst *SI = dyn_cast<StoreInst>(&I)) {
650       if (!SI->isSimple())
651         continue;
652 
653       // Skip if it's not legal.
654       if (!TTI.isLegalToVectorizeStore(SI))
655         continue;
656 
657       Type *Ty = SI->getValueOperand()->getType();
658       if (!VectorType::isValidElementType(Ty->getScalarType()))
659         continue;
660 
661       // Skip vectors of pointers. The vectorizeLoadChain/vectorizeStoreChain
662       // functions are currently using an integer type for the vectorized
663       // load/store, and does not support casting between the integer type and a
664       // vector of pointers (e.g. i64 to <2 x i16*>)
665       if (Ty->isVectorTy() && Ty->isPtrOrPtrVectorTy())
666         continue;
667 
668       // Skip weird non-byte sizes. They probably aren't worth the effort of
669       // handling correctly.
670       unsigned TySize = DL.getTypeSizeInBits(Ty);
671       if ((TySize % 8) != 0)
672         continue;
673 
674       Value *Ptr = SI->getPointerOperand();
675       unsigned AS = Ptr->getType()->getPointerAddressSpace();
676       unsigned VecRegSize = TTI.getLoadStoreVecRegBitWidth(AS);
677 
678       // No point in looking at these if they're too big to vectorize.
679       if (TySize > VecRegSize / 2)
680         continue;
681 
682       if (isa<VectorType>(Ty) && !llvm::all_of(SI->users(), [](const User *U) {
683             const ExtractElementInst *EEI = dyn_cast<ExtractElementInst>(U);
684             return EEI && isa<ConstantInt>(EEI->getOperand(1));
685           }))
686         continue;
687 
688       // Save store location.
689       Value *ObjPtr = GetUnderlyingObject(Ptr, DL);
690       StoreRefs[ObjPtr].push_back(SI);
691     }
692   }
693 
694   return {LoadRefs, StoreRefs};
695 }
696 
697 bool Vectorizer::vectorizeChains(InstrListMap &Map) {
698   bool Changed = false;
699 
700   for (const std::pair<Value *, InstrList> &Chain : Map) {
701     unsigned Size = Chain.second.size();
702     if (Size < 2)
703       continue;
704 
705     DEBUG(dbgs() << "LSV: Analyzing a chain of length " << Size << ".\n");
706 
707     // Process the stores in chunks of 64.
708     for (unsigned CI = 0, CE = Size; CI < CE; CI += 64) {
709       unsigned Len = std::min<unsigned>(CE - CI, 64);
710       ArrayRef<Instruction *> Chunk(&Chain.second[CI], Len);
711       Changed |= vectorizeInstructions(Chunk);
712     }
713   }
714 
715   return Changed;
716 }
717 
718 bool Vectorizer::vectorizeInstructions(ArrayRef<Instruction *> Instrs) {
719   DEBUG(dbgs() << "LSV: Vectorizing " << Instrs.size() << " instructions.\n");
720   SmallVector<int, 16> Heads, Tails;
721   int ConsecutiveChain[64];
722 
723   // Do a quadratic search on all of the given loads/stores and find all of the
724   // pairs of loads/stores that follow each other.
725   for (int i = 0, e = Instrs.size(); i < e; ++i) {
726     ConsecutiveChain[i] = -1;
727     for (int j = e - 1; j >= 0; --j) {
728       if (i == j)
729         continue;
730 
731       if (isConsecutiveAccess(Instrs[i], Instrs[j])) {
732         if (ConsecutiveChain[i] != -1) {
733           int CurDistance = std::abs(ConsecutiveChain[i] - i);
734           int NewDistance = std::abs(ConsecutiveChain[i] - j);
735           if (j < i || NewDistance > CurDistance)
736             continue; // Should not insert.
737         }
738 
739         Tails.push_back(j);
740         Heads.push_back(i);
741         ConsecutiveChain[i] = j;
742       }
743     }
744   }
745 
746   bool Changed = false;
747   SmallPtrSet<Instruction *, 16> InstructionsProcessed;
748 
749   for (int Head : Heads) {
750     if (InstructionsProcessed.count(Instrs[Head]))
751       continue;
752     bool LongerChainExists = false;
753     for (unsigned TIt = 0; TIt < Tails.size(); TIt++)
754       if (Head == Tails[TIt] &&
755           !InstructionsProcessed.count(Instrs[Heads[TIt]])) {
756         LongerChainExists = true;
757         break;
758       }
759     if (LongerChainExists)
760       continue;
761 
762     // We found an instr that starts a chain. Now follow the chain and try to
763     // vectorize it.
764     SmallVector<Instruction *, 16> Operands;
765     int I = Head;
766     while (I != -1 && (is_contained(Tails, I) || is_contained(Heads, I))) {
767       if (InstructionsProcessed.count(Instrs[I]))
768         break;
769 
770       Operands.push_back(Instrs[I]);
771       I = ConsecutiveChain[I];
772     }
773 
774     bool Vectorized = false;
775     if (isa<LoadInst>(*Operands.begin()))
776       Vectorized = vectorizeLoadChain(Operands, &InstructionsProcessed);
777     else
778       Vectorized = vectorizeStoreChain(Operands, &InstructionsProcessed);
779 
780     Changed |= Vectorized;
781   }
782 
783   return Changed;
784 }
785 
786 bool Vectorizer::vectorizeStoreChain(
787     ArrayRef<Instruction *> Chain,
788     SmallPtrSet<Instruction *, 16> *InstructionsProcessed) {
789   StoreInst *S0 = cast<StoreInst>(Chain[0]);
790 
791   // If the vector has an int element, default to int for the whole store.
792   Type *StoreTy;
793   for (Instruction *I : Chain) {
794     StoreTy = cast<StoreInst>(I)->getValueOperand()->getType();
795     if (StoreTy->isIntOrIntVectorTy())
796       break;
797 
798     if (StoreTy->isPtrOrPtrVectorTy()) {
799       StoreTy = Type::getIntNTy(F.getParent()->getContext(),
800                                 DL.getTypeSizeInBits(StoreTy));
801       break;
802     }
803   }
804 
805   unsigned Sz = DL.getTypeSizeInBits(StoreTy);
806   unsigned AS = S0->getPointerAddressSpace();
807   unsigned VecRegSize = TTI.getLoadStoreVecRegBitWidth(AS);
808   unsigned VF = VecRegSize / Sz;
809   unsigned ChainSize = Chain.size();
810   unsigned Alignment = getAlignment(S0);
811 
812   if (!isPowerOf2_32(Sz) || VF < 2 || ChainSize < 2) {
813     InstructionsProcessed->insert(Chain.begin(), Chain.end());
814     return false;
815   }
816 
817   ArrayRef<Instruction *> NewChain = getVectorizablePrefix(Chain);
818   if (NewChain.empty()) {
819     // No vectorization possible.
820     InstructionsProcessed->insert(Chain.begin(), Chain.end());
821     return false;
822   }
823   if (NewChain.size() == 1) {
824     // Failed after the first instruction. Discard it and try the smaller chain.
825     InstructionsProcessed->insert(NewChain.front());
826     return false;
827   }
828 
829   // Update Chain to the valid vectorizable subchain.
830   Chain = NewChain;
831   ChainSize = Chain.size();
832 
833   // Check if it's legal to vectorize this chain. If not, split the chain and
834   // try again.
835   unsigned EltSzInBytes = Sz / 8;
836   unsigned SzInBytes = EltSzInBytes * ChainSize;
837   if (!TTI.isLegalToVectorizeStoreChain(SzInBytes, Alignment, AS)) {
838     auto Chains = splitOddVectorElts(Chain, Sz);
839     return vectorizeStoreChain(Chains.first, InstructionsProcessed) |
840            vectorizeStoreChain(Chains.second, InstructionsProcessed);
841   }
842 
843   VectorType *VecTy;
844   VectorType *VecStoreTy = dyn_cast<VectorType>(StoreTy);
845   if (VecStoreTy)
846     VecTy = VectorType::get(StoreTy->getScalarType(),
847                             Chain.size() * VecStoreTy->getNumElements());
848   else
849     VecTy = VectorType::get(StoreTy, Chain.size());
850 
851   // If it's more than the max vector size or the target has a better
852   // vector factor, break it into two pieces.
853   unsigned TargetVF = TTI.getStoreVectorFactor(VF, Sz, SzInBytes, VecTy);
854   if (ChainSize > VF || (VF != TargetVF && TargetVF < ChainSize)) {
855     DEBUG(dbgs() << "LSV: Chain doesn't match with the vector factor."
856                     " Creating two separate arrays.\n");
857     return vectorizeStoreChain(Chain.slice(0, TargetVF),
858                                InstructionsProcessed) |
859            vectorizeStoreChain(Chain.slice(TargetVF), InstructionsProcessed);
860   }
861 
862   DEBUG({
863     dbgs() << "LSV: Stores to vectorize:\n";
864     for (Instruction *I : Chain)
865       dbgs() << "  " << *I << "\n";
866   });
867 
868   // We won't try again to vectorize the elements of the chain, regardless of
869   // whether we succeed below.
870   InstructionsProcessed->insert(Chain.begin(), Chain.end());
871 
872   // If the store is going to be misaligned, don't vectorize it.
873   if (accessIsMisaligned(SzInBytes, AS, Alignment)) {
874     if (S0->getPointerAddressSpace() != 0)
875       return false;
876 
877     unsigned NewAlign = getOrEnforceKnownAlignment(S0->getPointerOperand(),
878                                                    StackAdjustedAlignment,
879                                                    DL, S0, nullptr, &DT);
880     if (NewAlign < StackAdjustedAlignment)
881       return false;
882   }
883 
884   BasicBlock::iterator First, Last;
885   std::tie(First, Last) = getBoundaryInstrs(Chain);
886   Builder.SetInsertPoint(&*Last);
887 
888   Value *Vec = UndefValue::get(VecTy);
889 
890   if (VecStoreTy) {
891     unsigned VecWidth = VecStoreTy->getNumElements();
892     for (unsigned I = 0, E = Chain.size(); I != E; ++I) {
893       StoreInst *Store = cast<StoreInst>(Chain[I]);
894       for (unsigned J = 0, NE = VecStoreTy->getNumElements(); J != NE; ++J) {
895         unsigned NewIdx = J + I * VecWidth;
896         Value *Extract = Builder.CreateExtractElement(Store->getValueOperand(),
897                                                       Builder.getInt32(J));
898         if (Extract->getType() != StoreTy->getScalarType())
899           Extract = Builder.CreateBitCast(Extract, StoreTy->getScalarType());
900 
901         Value *Insert =
902             Builder.CreateInsertElement(Vec, Extract, Builder.getInt32(NewIdx));
903         Vec = Insert;
904       }
905     }
906   } else {
907     for (unsigned I = 0, E = Chain.size(); I != E; ++I) {
908       StoreInst *Store = cast<StoreInst>(Chain[I]);
909       Value *Extract = Store->getValueOperand();
910       if (Extract->getType() != StoreTy->getScalarType())
911         Extract =
912             Builder.CreateBitOrPointerCast(Extract, StoreTy->getScalarType());
913 
914       Value *Insert =
915           Builder.CreateInsertElement(Vec, Extract, Builder.getInt32(I));
916       Vec = Insert;
917     }
918   }
919 
920   // This cast is safe because Builder.CreateStore() always creates a bona fide
921   // StoreInst.
922   StoreInst *SI = cast<StoreInst>(
923       Builder.CreateStore(Vec, Builder.CreateBitCast(S0->getPointerOperand(),
924                                                      VecTy->getPointerTo(AS))));
925   propagateMetadata(SI, Chain);
926   SI->setAlignment(Alignment);
927 
928   eraseInstructions(Chain);
929   ++NumVectorInstructions;
930   NumScalarsVectorized += Chain.size();
931   return true;
932 }
933 
934 bool Vectorizer::vectorizeLoadChain(
935     ArrayRef<Instruction *> Chain,
936     SmallPtrSet<Instruction *, 16> *InstructionsProcessed) {
937   LoadInst *L0 = cast<LoadInst>(Chain[0]);
938 
939   // If the vector has an int element, default to int for the whole load.
940   Type *LoadTy;
941   for (const auto &V : Chain) {
942     LoadTy = cast<LoadInst>(V)->getType();
943     if (LoadTy->isIntOrIntVectorTy())
944       break;
945 
946     if (LoadTy->isPtrOrPtrVectorTy()) {
947       LoadTy = Type::getIntNTy(F.getParent()->getContext(),
948                                DL.getTypeSizeInBits(LoadTy));
949       break;
950     }
951   }
952 
953   unsigned Sz = DL.getTypeSizeInBits(LoadTy);
954   unsigned AS = L0->getPointerAddressSpace();
955   unsigned VecRegSize = TTI.getLoadStoreVecRegBitWidth(AS);
956   unsigned VF = VecRegSize / Sz;
957   unsigned ChainSize = Chain.size();
958   unsigned Alignment = getAlignment(L0);
959 
960   if (!isPowerOf2_32(Sz) || VF < 2 || ChainSize < 2) {
961     InstructionsProcessed->insert(Chain.begin(), Chain.end());
962     return false;
963   }
964 
965   ArrayRef<Instruction *> NewChain = getVectorizablePrefix(Chain);
966   if (NewChain.empty()) {
967     // No vectorization possible.
968     InstructionsProcessed->insert(Chain.begin(), Chain.end());
969     return false;
970   }
971   if (NewChain.size() == 1) {
972     // Failed after the first instruction. Discard it and try the smaller chain.
973     InstructionsProcessed->insert(NewChain.front());
974     return false;
975   }
976 
977   // Update Chain to the valid vectorizable subchain.
978   Chain = NewChain;
979   ChainSize = Chain.size();
980 
981   // Check if it's legal to vectorize this chain. If not, split the chain and
982   // try again.
983   unsigned EltSzInBytes = Sz / 8;
984   unsigned SzInBytes = EltSzInBytes * ChainSize;
985   if (!TTI.isLegalToVectorizeLoadChain(SzInBytes, Alignment, AS)) {
986     auto Chains = splitOddVectorElts(Chain, Sz);
987     return vectorizeLoadChain(Chains.first, InstructionsProcessed) |
988            vectorizeLoadChain(Chains.second, InstructionsProcessed);
989   }
990 
991   VectorType *VecTy;
992   VectorType *VecLoadTy = dyn_cast<VectorType>(LoadTy);
993   if (VecLoadTy)
994     VecTy = VectorType::get(LoadTy->getScalarType(),
995                             Chain.size() * VecLoadTy->getNumElements());
996   else
997     VecTy = VectorType::get(LoadTy, Chain.size());
998 
999   // If it's more than the max vector size or the target has a better
1000   // vector factor, break it into two pieces.
1001   unsigned TargetVF = TTI.getLoadVectorFactor(VF, Sz, SzInBytes, VecTy);
1002   if (ChainSize > VF || (VF != TargetVF && TargetVF < ChainSize)) {
1003     DEBUG(dbgs() << "LSV: Chain doesn't match with the vector factor."
1004                     " Creating two separate arrays.\n");
1005     return vectorizeLoadChain(Chain.slice(0, TargetVF), InstructionsProcessed) |
1006            vectorizeLoadChain(Chain.slice(TargetVF), InstructionsProcessed);
1007   }
1008 
1009   // We won't try again to vectorize the elements of the chain, regardless of
1010   // whether we succeed below.
1011   InstructionsProcessed->insert(Chain.begin(), Chain.end());
1012 
1013   // If the load is going to be misaligned, don't vectorize it.
1014   if (accessIsMisaligned(SzInBytes, AS, Alignment)) {
1015     if (L0->getPointerAddressSpace() != 0)
1016       return false;
1017 
1018     unsigned NewAlign = getOrEnforceKnownAlignment(L0->getPointerOperand(),
1019                                                    StackAdjustedAlignment,
1020                                                    DL, L0, nullptr, &DT);
1021     if (NewAlign < StackAdjustedAlignment)
1022       return false;
1023 
1024     Alignment = NewAlign;
1025   }
1026 
1027   DEBUG({
1028     dbgs() << "LSV: Loads to vectorize:\n";
1029     for (Instruction *I : Chain)
1030       I->dump();
1031   });
1032 
1033   // getVectorizablePrefix already computed getBoundaryInstrs.  The value of
1034   // Last may have changed since then, but the value of First won't have.  If it
1035   // matters, we could compute getBoundaryInstrs only once and reuse it here.
1036   BasicBlock::iterator First, Last;
1037   std::tie(First, Last) = getBoundaryInstrs(Chain);
1038   Builder.SetInsertPoint(&*First);
1039 
1040   Value *Bitcast =
1041       Builder.CreateBitCast(L0->getPointerOperand(), VecTy->getPointerTo(AS));
1042   // This cast is safe because Builder.CreateLoad always creates a bona fide
1043   // LoadInst.
1044   LoadInst *LI = cast<LoadInst>(Builder.CreateLoad(Bitcast));
1045   propagateMetadata(LI, Chain);
1046   LI->setAlignment(Alignment);
1047 
1048   if (VecLoadTy) {
1049     SmallVector<Instruction *, 16> InstrsToErase;
1050 
1051     unsigned VecWidth = VecLoadTy->getNumElements();
1052     for (unsigned I = 0, E = Chain.size(); I != E; ++I) {
1053       for (auto Use : Chain[I]->users()) {
1054         // All users of vector loads are ExtractElement instructions with
1055         // constant indices, otherwise we would have bailed before now.
1056         Instruction *UI = cast<Instruction>(Use);
1057         unsigned Idx = cast<ConstantInt>(UI->getOperand(1))->getZExtValue();
1058         unsigned NewIdx = Idx + I * VecWidth;
1059         Value *V = Builder.CreateExtractElement(LI, Builder.getInt32(NewIdx),
1060                                                 UI->getName());
1061         if (V->getType() != UI->getType())
1062           V = Builder.CreateBitCast(V, UI->getType());
1063 
1064         // Replace the old instruction.
1065         UI->replaceAllUsesWith(V);
1066         InstrsToErase.push_back(UI);
1067       }
1068     }
1069 
1070     // Bitcast might not be an Instruction, if the value being loaded is a
1071     // constant.  In that case, no need to reorder anything.
1072     if (Instruction *BitcastInst = dyn_cast<Instruction>(Bitcast))
1073       reorder(BitcastInst);
1074 
1075     for (auto I : InstrsToErase)
1076       I->eraseFromParent();
1077   } else {
1078     for (unsigned I = 0, E = Chain.size(); I != E; ++I) {
1079       Value *CV = Chain[I];
1080       Value *V =
1081           Builder.CreateExtractElement(LI, Builder.getInt32(I), CV->getName());
1082       if (V->getType() != CV->getType()) {
1083         V = Builder.CreateBitOrPointerCast(V, CV->getType());
1084       }
1085 
1086       // Replace the old instruction.
1087       CV->replaceAllUsesWith(V);
1088     }
1089 
1090     if (Instruction *BitcastInst = dyn_cast<Instruction>(Bitcast))
1091       reorder(BitcastInst);
1092   }
1093 
1094   eraseInstructions(Chain);
1095 
1096   ++NumVectorInstructions;
1097   NumScalarsVectorized += Chain.size();
1098   return true;
1099 }
1100 
1101 bool Vectorizer::accessIsMisaligned(unsigned SzInBytes, unsigned AddressSpace,
1102                                     unsigned Alignment) {
1103   if (Alignment % SzInBytes == 0)
1104     return false;
1105 
1106   bool Fast = false;
1107   bool Allows = TTI.allowsMisalignedMemoryAccesses(F.getParent()->getContext(),
1108                                                    SzInBytes * 8, AddressSpace,
1109                                                    Alignment, &Fast);
1110   DEBUG(dbgs() << "LSV: Target said misaligned is allowed? " << Allows
1111                << " and fast? " << Fast << "\n";);
1112   return !Allows || !Fast;
1113 }
1114