1 //===----- LoadStoreVectorizer.cpp - GPU Load & Store Vectorizer ----------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 //===----------------------------------------------------------------------===//
11 
12 #include "llvm/ADT/MapVector.h"
13 #include "llvm/ADT/PostOrderIterator.h"
14 #include "llvm/ADT/SetVector.h"
15 #include "llvm/ADT/Statistic.h"
16 #include "llvm/ADT/Triple.h"
17 #include "llvm/Analysis/AliasAnalysis.h"
18 #include "llvm/Analysis/ScalarEvolution.h"
19 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
20 #include "llvm/Analysis/TargetTransformInfo.h"
21 #include "llvm/Analysis/ValueTracking.h"
22 #include "llvm/Analysis/VectorUtils.h"
23 #include "llvm/IR/DataLayout.h"
24 #include "llvm/IR/Dominators.h"
25 #include "llvm/IR/IRBuilder.h"
26 #include "llvm/IR/Instructions.h"
27 #include "llvm/IR/Module.h"
28 #include "llvm/IR/Type.h"
29 #include "llvm/IR/Value.h"
30 #include "llvm/Support/CommandLine.h"
31 #include "llvm/Support/Debug.h"
32 #include "llvm/Support/raw_ostream.h"
33 #include "llvm/Transforms/Vectorize.h"
34 
35 using namespace llvm;
36 
37 #define DEBUG_TYPE "load-store-vectorizer"
38 STATISTIC(NumVectorInstructions, "Number of vector accesses generated");
39 STATISTIC(NumScalarsVectorized, "Number of scalar accesses vectorized");
40 
41 namespace {
42 
43 // TODO: Remove this
44 static const unsigned TargetBaseAlign = 4;
45 
46 typedef SmallVector<Value *, 8> ValueList;
47 typedef MapVector<Value *, ValueList> ValueListMap;
48 
49 class Vectorizer {
50   Function &F;
51   AliasAnalysis &AA;
52   DominatorTree &DT;
53   ScalarEvolution &SE;
54   TargetTransformInfo &TTI;
55   const DataLayout &DL;
56   IRBuilder<> Builder;
57 
58 public:
59   Vectorizer(Function &F, AliasAnalysis &AA, DominatorTree &DT,
60              ScalarEvolution &SE, TargetTransformInfo &TTI)
61       : F(F), AA(AA), DT(DT), SE(SE), TTI(TTI),
62         DL(F.getParent()->getDataLayout()), Builder(SE.getContext()) {}
63 
64   bool run();
65 
66 private:
67   Value *getPointerOperand(Value *I);
68 
69   unsigned getPointerAddressSpace(Value *I);
70 
71   unsigned getAlignment(LoadInst *LI) const {
72     unsigned Align = LI->getAlignment();
73     if (Align != 0)
74       return Align;
75 
76     return DL.getABITypeAlignment(LI->getType());
77   }
78 
79   unsigned getAlignment(StoreInst *SI) const {
80     unsigned Align = SI->getAlignment();
81     if (Align != 0)
82       return Align;
83 
84     return DL.getABITypeAlignment(SI->getValueOperand()->getType());
85   }
86 
87   bool isConsecutiveAccess(Value *A, Value *B);
88 
89   /// After vectorization, reorder the instructions that I depends on
90   /// (the instructions defining its operands), to ensure they dominate I.
91   void reorder(Instruction *I);
92 
93   /// Returns the first and the last instructions in Chain.
94   std::pair<BasicBlock::iterator, BasicBlock::iterator>
95   getBoundaryInstrs(ArrayRef<Value *> Chain);
96 
97   /// Erases the original instructions after vectorizing.
98   void eraseInstructions(ArrayRef<Value *> Chain);
99 
100   /// "Legalize" the vector type that would be produced by combining \p
101   /// ElementSizeBits elements in \p Chain. Break into two pieces such that the
102   /// total size of each piece is 1, 2 or a multiple of 4 bytes. \p Chain is
103   /// expected to have more than 4 elements.
104   std::pair<ArrayRef<Value *>, ArrayRef<Value *>>
105   splitOddVectorElts(ArrayRef<Value *> Chain, unsigned ElementSizeBits);
106 
107   /// Finds the largest prefix of Chain that's vectorizable, checking for
108   /// intervening instructions which may affect the memory accessed by the
109   /// instructions within Chain.
110   ///
111   /// The elements of \p Chain must be all loads or all stores and must be in
112   /// address order.
113   ArrayRef<Value *> getVectorizablePrefix(ArrayRef<Value *> Chain);
114 
115   /// Collects load and store instructions to vectorize.
116   std::pair<ValueListMap, ValueListMap> collectInstructions(BasicBlock *BB);
117 
118   /// Processes the collected instructions, the \p Map. The elements of \p Map
119   /// should be all loads or all stores.
120   bool vectorizeChains(ValueListMap &Map);
121 
122   /// Finds the load/stores to consecutive memory addresses and vectorizes them.
123   bool vectorizeInstructions(ArrayRef<Value *> Instrs);
124 
125   /// Vectorizes the load instructions in Chain.
126   bool vectorizeLoadChain(ArrayRef<Value *> Chain,
127                           SmallPtrSet<Value *, 16> *InstructionsProcessed);
128 
129   /// Vectorizes the store instructions in Chain.
130   bool vectorizeStoreChain(ArrayRef<Value *> Chain,
131                            SmallPtrSet<Value *, 16> *InstructionsProcessed);
132 
133   /// Check if this load/store access is misaligned accesses
134   bool accessIsMisaligned(unsigned SzInBytes, unsigned AddressSpace,
135                           unsigned Alignment);
136 };
137 
138 class LoadStoreVectorizer : public FunctionPass {
139 public:
140   static char ID;
141 
142   LoadStoreVectorizer() : FunctionPass(ID) {
143     initializeLoadStoreVectorizerPass(*PassRegistry::getPassRegistry());
144   }
145 
146   bool runOnFunction(Function &F) override;
147 
148   const char *getPassName() const override {
149     return "GPU Load and Store Vectorizer";
150   }
151 
152   void getAnalysisUsage(AnalysisUsage &AU) const override {
153     AU.addRequired<AAResultsWrapperPass>();
154     AU.addRequired<ScalarEvolutionWrapperPass>();
155     AU.addRequired<DominatorTreeWrapperPass>();
156     AU.addRequired<TargetTransformInfoWrapperPass>();
157     AU.setPreservesCFG();
158   }
159 };
160 }
161 
162 INITIALIZE_PASS_BEGIN(LoadStoreVectorizer, DEBUG_TYPE,
163                       "Vectorize load and Store instructions", false, false)
164 INITIALIZE_PASS_DEPENDENCY(SCEVAAWrapperPass)
165 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
166 INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass)
167 INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass)
168 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
169 INITIALIZE_PASS_END(LoadStoreVectorizer, DEBUG_TYPE,
170                     "Vectorize load and store instructions", false, false)
171 
172 char LoadStoreVectorizer::ID = 0;
173 
174 Pass *llvm::createLoadStoreVectorizerPass() {
175   return new LoadStoreVectorizer();
176 }
177 
178 bool LoadStoreVectorizer::runOnFunction(Function &F) {
179   // Don't vectorize when the attribute NoImplicitFloat is used.
180   if (skipFunction(F) || F.hasFnAttribute(Attribute::NoImplicitFloat))
181     return false;
182 
183   AliasAnalysis &AA = getAnalysis<AAResultsWrapperPass>().getAAResults();
184   DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
185   ScalarEvolution &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE();
186   TargetTransformInfo &TTI =
187       getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
188 
189   Vectorizer V(F, AA, DT, SE, TTI);
190   return V.run();
191 }
192 
193 // Vectorizer Implementation
194 bool Vectorizer::run() {
195   bool Changed = false;
196 
197   // Scan the blocks in the function in post order.
198   for (BasicBlock *BB : post_order(&F)) {
199     ValueListMap LoadRefs, StoreRefs;
200     std::tie(LoadRefs, StoreRefs) = collectInstructions(BB);
201     Changed |= vectorizeChains(LoadRefs);
202     Changed |= vectorizeChains(StoreRefs);
203   }
204 
205   return Changed;
206 }
207 
208 Value *Vectorizer::getPointerOperand(Value *I) {
209   if (LoadInst *LI = dyn_cast<LoadInst>(I))
210     return LI->getPointerOperand();
211   if (StoreInst *SI = dyn_cast<StoreInst>(I))
212     return SI->getPointerOperand();
213   return nullptr;
214 }
215 
216 unsigned Vectorizer::getPointerAddressSpace(Value *I) {
217   if (LoadInst *L = dyn_cast<LoadInst>(I))
218     return L->getPointerAddressSpace();
219   if (StoreInst *S = dyn_cast<StoreInst>(I))
220     return S->getPointerAddressSpace();
221   return -1;
222 }
223 
224 // FIXME: Merge with llvm::isConsecutiveAccess
225 bool Vectorizer::isConsecutiveAccess(Value *A, Value *B) {
226   Value *PtrA = getPointerOperand(A);
227   Value *PtrB = getPointerOperand(B);
228   unsigned ASA = getPointerAddressSpace(A);
229   unsigned ASB = getPointerAddressSpace(B);
230 
231   // Check that the address spaces match and that the pointers are valid.
232   if (!PtrA || !PtrB || (ASA != ASB))
233     return false;
234 
235   // Make sure that A and B are different pointers of the same size type.
236   unsigned PtrBitWidth = DL.getPointerSizeInBits(ASA);
237   Type *PtrATy = PtrA->getType()->getPointerElementType();
238   Type *PtrBTy = PtrB->getType()->getPointerElementType();
239   if (PtrA == PtrB ||
240       DL.getTypeStoreSize(PtrATy) != DL.getTypeStoreSize(PtrBTy) ||
241       DL.getTypeStoreSize(PtrATy->getScalarType()) !=
242           DL.getTypeStoreSize(PtrBTy->getScalarType()))
243     return false;
244 
245   APInt Size(PtrBitWidth, DL.getTypeStoreSize(PtrATy));
246 
247   APInt OffsetA(PtrBitWidth, 0), OffsetB(PtrBitWidth, 0);
248   PtrA = PtrA->stripAndAccumulateInBoundsConstantOffsets(DL, OffsetA);
249   PtrB = PtrB->stripAndAccumulateInBoundsConstantOffsets(DL, OffsetB);
250 
251   APInt OffsetDelta = OffsetB - OffsetA;
252 
253   // Check if they are based on the same pointer. That makes the offsets
254   // sufficient.
255   if (PtrA == PtrB)
256     return OffsetDelta == Size;
257 
258   // Compute the necessary base pointer delta to have the necessary final delta
259   // equal to the size.
260   APInt BaseDelta = Size - OffsetDelta;
261 
262   // Compute the distance with SCEV between the base pointers.
263   const SCEV *PtrSCEVA = SE.getSCEV(PtrA);
264   const SCEV *PtrSCEVB = SE.getSCEV(PtrB);
265   const SCEV *C = SE.getConstant(BaseDelta);
266   const SCEV *X = SE.getAddExpr(PtrSCEVA, C);
267   if (X == PtrSCEVB)
268     return true;
269 
270   // Sometimes even this doesn't work, because SCEV can't always see through
271   // patterns that look like (gep (ext (add (shl X, C1), C2))). Try checking
272   // things the hard way.
273 
274   // Look through GEPs after checking they're the same except for the last
275   // index.
276   GetElementPtrInst *GEPA = dyn_cast<GetElementPtrInst>(getPointerOperand(A));
277   GetElementPtrInst *GEPB = dyn_cast<GetElementPtrInst>(getPointerOperand(B));
278   if (!GEPA || !GEPB || GEPA->getNumOperands() != GEPB->getNumOperands())
279     return false;
280   unsigned FinalIndex = GEPA->getNumOperands() - 1;
281   for (unsigned i = 0; i < FinalIndex; i++)
282     if (GEPA->getOperand(i) != GEPB->getOperand(i))
283       return false;
284 
285   Instruction *OpA = dyn_cast<Instruction>(GEPA->getOperand(FinalIndex));
286   Instruction *OpB = dyn_cast<Instruction>(GEPB->getOperand(FinalIndex));
287   if (!OpA || !OpB || OpA->getOpcode() != OpB->getOpcode() ||
288       OpA->getType() != OpB->getType())
289     return false;
290 
291   // Only look through a ZExt/SExt.
292   if (!isa<SExtInst>(OpA) && !isa<ZExtInst>(OpA))
293     return false;
294 
295   bool Signed = isa<SExtInst>(OpA);
296 
297   OpA = dyn_cast<Instruction>(OpA->getOperand(0));
298   OpB = dyn_cast<Instruction>(OpB->getOperand(0));
299   if (!OpA || !OpB || OpA->getType() != OpB->getType())
300     return false;
301 
302   // Now we need to prove that adding 1 to OpA won't overflow.
303   bool Safe = false;
304   // First attempt: if OpB is an add with NSW/NUW, and OpB is 1 added to OpA,
305   // we're okay.
306   if (OpB->getOpcode() == Instruction::Add &&
307       isa<ConstantInt>(OpB->getOperand(1)) &&
308       cast<ConstantInt>(OpB->getOperand(1))->getSExtValue() > 0) {
309     if (Signed)
310       Safe = cast<BinaryOperator>(OpB)->hasNoSignedWrap();
311     else
312       Safe = cast<BinaryOperator>(OpB)->hasNoUnsignedWrap();
313   }
314 
315   unsigned BitWidth = OpA->getType()->getScalarSizeInBits();
316 
317   // Second attempt:
318   // If any bits are known to be zero other than the sign bit in OpA, we can
319   // add 1 to it while guaranteeing no overflow of any sort.
320   if (!Safe) {
321     APInt KnownZero(BitWidth, 0);
322     APInt KnownOne(BitWidth, 0);
323     computeKnownBits(OpA, KnownZero, KnownOne, DL, 0, nullptr, OpA, &DT);
324     KnownZero &= ~APInt::getHighBitsSet(BitWidth, 1);
325     if (KnownZero != 0)
326       Safe = true;
327   }
328 
329   if (!Safe)
330     return false;
331 
332   const SCEV *OffsetSCEVA = SE.getSCEV(OpA);
333   const SCEV *OffsetSCEVB = SE.getSCEV(OpB);
334   const SCEV *One = SE.getConstant(APInt(BitWidth, 1));
335   const SCEV *X2 = SE.getAddExpr(OffsetSCEVA, One);
336   return X2 == OffsetSCEVB;
337 }
338 
339 void Vectorizer::reorder(Instruction *I) {
340   SmallPtrSet<Instruction *, 16> InstructionsToMove;
341   SmallVector<Instruction *, 16> Worklist;
342 
343   Worklist.push_back(I);
344   while (!Worklist.empty()) {
345     Instruction *IW = Worklist.pop_back_val();
346     int NumOperands = IW->getNumOperands();
347     for (int i = 0; i < NumOperands; i++) {
348       Instruction *IM = dyn_cast<Instruction>(IW->getOperand(i));
349       if (!IM || IM->getOpcode() == Instruction::PHI)
350         continue;
351 
352       if (!DT.dominates(IM, I)) {
353         InstructionsToMove.insert(IM);
354         Worklist.push_back(IM);
355         assert(IM->getParent() == IW->getParent() &&
356                "Instructions to move should be in the same basic block");
357       }
358     }
359   }
360 
361   // All instructions to move should follow I. Start from I, not from begin().
362   for (auto BBI = I->getIterator(), E = I->getParent()->end(); BBI != E;
363        ++BBI) {
364     if (!is_contained(InstructionsToMove, &*BBI))
365       continue;
366     Instruction *IM = &*BBI;
367     --BBI;
368     IM->removeFromParent();
369     IM->insertBefore(I);
370   }
371 }
372 
373 std::pair<BasicBlock::iterator, BasicBlock::iterator>
374 Vectorizer::getBoundaryInstrs(ArrayRef<Value *> Chain) {
375   Instruction *C0 = cast<Instruction>(Chain[0]);
376   BasicBlock::iterator FirstInstr = C0->getIterator();
377   BasicBlock::iterator LastInstr = C0->getIterator();
378 
379   BasicBlock *BB = C0->getParent();
380   unsigned NumFound = 0;
381   for (Instruction &I : *BB) {
382     if (!is_contained(Chain, &I))
383       continue;
384 
385     ++NumFound;
386     if (NumFound == 1) {
387       FirstInstr = I.getIterator();
388     }
389     if (NumFound == Chain.size()) {
390       LastInstr = I.getIterator();
391       break;
392     }
393   }
394 
395   // Range is [first, last).
396   return std::make_pair(FirstInstr, ++LastInstr);
397 }
398 
399 void Vectorizer::eraseInstructions(ArrayRef<Value *> Chain) {
400   SmallVector<Instruction *, 16> Instrs;
401   for (Value *V : Chain) {
402     Value *PtrOperand = getPointerOperand(V);
403     assert(PtrOperand && "Instruction must have a pointer operand.");
404     Instrs.push_back(cast<Instruction>(V));
405     if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(PtrOperand))
406       Instrs.push_back(GEP);
407   }
408 
409   // Erase instructions.
410   for (Value *V : Instrs) {
411     Instruction *Instr = cast<Instruction>(V);
412     if (Instr->use_empty())
413       Instr->eraseFromParent();
414   }
415 }
416 
417 std::pair<ArrayRef<Value *>, ArrayRef<Value *>>
418 Vectorizer::splitOddVectorElts(ArrayRef<Value *> Chain,
419                                unsigned ElementSizeBits) {
420   unsigned ElemSizeInBytes = ElementSizeBits / 8;
421   unsigned SizeInBytes = ElemSizeInBytes * Chain.size();
422   unsigned NumRight = (SizeInBytes % 4) / ElemSizeInBytes;
423   unsigned NumLeft = Chain.size() - NumRight;
424   return std::make_pair(Chain.slice(0, NumLeft), Chain.slice(NumLeft));
425 }
426 
427 ArrayRef<Value *> Vectorizer::getVectorizablePrefix(ArrayRef<Value *> Chain) {
428   // These are in BB order, unlike Chain, which is in address order.
429   SmallVector<std::pair<Value *, unsigned>, 16> MemoryInstrs;
430   SmallVector<std::pair<Value *, unsigned>, 16> ChainInstrs;
431 
432   bool IsLoadChain = isa<LoadInst>(Chain[0]);
433   DEBUG({
434     for (Value *V : Chain) {
435       if (IsLoadChain)
436         assert(isa<LoadInst>(V) &&
437                "All elements of Chain must be loads, or all must be stores.");
438       else
439         assert(isa<StoreInst>(V) &&
440                "All elements of Chain must be loads, or all must be stores.");
441     }
442   });
443 
444   unsigned InstrIdx = 0;
445   for (Instruction &I : make_range(getBoundaryInstrs(Chain))) {
446     ++InstrIdx;
447     if (isa<LoadInst>(I) || isa<StoreInst>(I)) {
448       if (!is_contained(Chain, &I))
449         MemoryInstrs.push_back({&I, InstrIdx});
450       else
451         ChainInstrs.push_back({&I, InstrIdx});
452     } else if (IsLoadChain && (I.mayWriteToMemory() || I.mayThrow())) {
453       DEBUG(dbgs() << "LSV: Found may-write/throw operation: " << I << '\n');
454       break;
455     } else if (!IsLoadChain && (I.mayReadOrWriteMemory() || I.mayThrow())) {
456       DEBUG(dbgs() << "LSV: Found may-read/write/throw operation: " << I
457                    << '\n');
458       break;
459     }
460   }
461 
462   // Loop until we find an instruction in ChainInstrs that we can't vectorize.
463   unsigned ChainInstrIdx, ChainInstrsLen;
464   for (ChainInstrIdx = 0, ChainInstrsLen = ChainInstrs.size();
465        ChainInstrIdx < ChainInstrsLen; ++ChainInstrIdx) {
466     Value *ChainInstr = ChainInstrs[ChainInstrIdx].first;
467     unsigned ChainInstrLoc = ChainInstrs[ChainInstrIdx].second;
468     bool AliasFound = false;
469     for (auto EntryMem : MemoryInstrs) {
470       Value *MemInstr = EntryMem.first;
471       unsigned MemInstrLoc = EntryMem.second;
472       if (isa<LoadInst>(MemInstr) && isa<LoadInst>(ChainInstr))
473         continue;
474 
475       // We can ignore the alias as long as the load comes before the store,
476       // because that means we won't be moving the load past the store to
477       // vectorize it (the vectorized load is inserted at the location of the
478       // first load in the chain).
479       if (isa<StoreInst>(MemInstr) && isa<LoadInst>(ChainInstr) &&
480           ChainInstrLoc < MemInstrLoc)
481         continue;
482 
483       // Same case, but in reverse.
484       if (isa<LoadInst>(MemInstr) && isa<StoreInst>(ChainInstr) &&
485           ChainInstrLoc > MemInstrLoc)
486         continue;
487 
488       Instruction *M0 = cast<Instruction>(MemInstr);
489       Instruction *M1 = cast<Instruction>(ChainInstr);
490 
491       if (!AA.isNoAlias(MemoryLocation::get(M0), MemoryLocation::get(M1))) {
492         DEBUG({
493           Value *Ptr0 = getPointerOperand(M0);
494           Value *Ptr1 = getPointerOperand(M1);
495           dbgs() << "LSV: Found alias:\n"
496                     "  Aliasing instruction and pointer:\n"
497                  << "  " << *MemInstr << '\n'
498                  << "  " << *Ptr0 << '\n'
499                  << "  Aliased instruction and pointer:\n"
500                  << "  " << *ChainInstr << '\n'
501                  << "  " << *Ptr1 << '\n';
502         });
503         AliasFound = true;
504         break;
505       }
506     }
507     if (AliasFound)
508       break;
509   }
510 
511   // Find the largest prefix of Chain whose elements are all in
512   // ChainInstrs[0, ChainInstrIdx).  This is the largest vectorizable prefix of
513   // Chain.  (Recall that Chain is in address order, but ChainInstrs is in BB
514   // order.)
515   auto VectorizableChainInstrs =
516       makeArrayRef(ChainInstrs.data(), ChainInstrIdx);
517   unsigned ChainIdx, ChainLen;
518   for (ChainIdx = 0, ChainLen = Chain.size(); ChainIdx < ChainLen; ++ChainIdx) {
519     Value *V = Chain[ChainIdx];
520     if (!any_of(VectorizableChainInstrs,
521                 [V](std::pair<Value *, unsigned> CI) { return V == CI.first; }))
522       break;
523   }
524   return Chain.slice(0, ChainIdx);
525 }
526 
527 std::pair<ValueListMap, ValueListMap>
528 Vectorizer::collectInstructions(BasicBlock *BB) {
529   ValueListMap LoadRefs;
530   ValueListMap StoreRefs;
531 
532   for (Instruction &I : *BB) {
533     if (!I.mayReadOrWriteMemory())
534       continue;
535 
536     if (LoadInst *LI = dyn_cast<LoadInst>(&I)) {
537       if (!LI->isSimple())
538         continue;
539 
540       Type *Ty = LI->getType();
541       if (!VectorType::isValidElementType(Ty->getScalarType()))
542         continue;
543 
544       // Skip weird non-byte sizes. They probably aren't worth the effort of
545       // handling correctly.
546       unsigned TySize = DL.getTypeSizeInBits(Ty);
547       if (TySize < 8)
548         continue;
549 
550       Value *Ptr = LI->getPointerOperand();
551       unsigned AS = Ptr->getType()->getPointerAddressSpace();
552       unsigned VecRegSize = TTI.getLoadStoreVecRegBitWidth(AS);
553 
554       // No point in looking at these if they're too big to vectorize.
555       if (TySize > VecRegSize / 2)
556         continue;
557 
558       // Make sure all the users of a vector are constant-index extracts.
559       if (isa<VectorType>(Ty) && !all_of(LI->users(), [LI](const User *U) {
560             const Instruction *UI = cast<Instruction>(U);
561             return isa<ExtractElementInst>(UI) &&
562                    isa<ConstantInt>(UI->getOperand(1));
563           }))
564         continue;
565 
566       // TODO: Target hook to filter types.
567 
568       // Save the load locations.
569       Value *ObjPtr = GetUnderlyingObject(Ptr, DL);
570       LoadRefs[ObjPtr].push_back(LI);
571 
572     } else if (StoreInst *SI = dyn_cast<StoreInst>(&I)) {
573       if (!SI->isSimple())
574         continue;
575 
576       Type *Ty = SI->getValueOperand()->getType();
577       if (!VectorType::isValidElementType(Ty->getScalarType()))
578         continue;
579 
580       // Skip weird non-byte sizes. They probably aren't worth the effort of
581       // handling correctly.
582       unsigned TySize = DL.getTypeSizeInBits(Ty);
583       if (TySize < 8)
584         continue;
585 
586       Value *Ptr = SI->getPointerOperand();
587       unsigned AS = Ptr->getType()->getPointerAddressSpace();
588       unsigned VecRegSize = TTI.getLoadStoreVecRegBitWidth(AS);
589       if (TySize > VecRegSize / 2)
590         continue;
591 
592       if (isa<VectorType>(Ty) && !all_of(SI->users(), [SI](const User *U) {
593             const Instruction *UI = cast<Instruction>(U);
594             return isa<ExtractElementInst>(UI) &&
595                    isa<ConstantInt>(UI->getOperand(1));
596           }))
597         continue;
598 
599       // Save store location.
600       Value *ObjPtr = GetUnderlyingObject(Ptr, DL);
601       StoreRefs[ObjPtr].push_back(SI);
602     }
603   }
604 
605   return {LoadRefs, StoreRefs};
606 }
607 
608 bool Vectorizer::vectorizeChains(ValueListMap &Map) {
609   bool Changed = false;
610 
611   for (const std::pair<Value *, ValueList> &Chain : Map) {
612     unsigned Size = Chain.second.size();
613     if (Size < 2)
614       continue;
615 
616     DEBUG(dbgs() << "LSV: Analyzing a chain of length " << Size << ".\n");
617 
618     // Process the stores in chunks of 64.
619     for (unsigned CI = 0, CE = Size; CI < CE; CI += 64) {
620       unsigned Len = std::min<unsigned>(CE - CI, 64);
621       ArrayRef<Value *> Chunk(&Chain.second[CI], Len);
622       Changed |= vectorizeInstructions(Chunk);
623     }
624   }
625 
626   return Changed;
627 }
628 
629 bool Vectorizer::vectorizeInstructions(ArrayRef<Value *> Instrs) {
630   DEBUG(dbgs() << "LSV: Vectorizing " << Instrs.size() << " instructions.\n");
631   SmallSetVector<int, 16> Heads, Tails;
632   int ConsecutiveChain[64];
633 
634   // Do a quadratic search on all of the given stores and find all of the pairs
635   // of stores that follow each other.
636   for (int i = 0, e = Instrs.size(); i < e; ++i) {
637     ConsecutiveChain[i] = -1;
638     for (int j = e - 1; j >= 0; --j) {
639       if (i == j)
640         continue;
641 
642       if (isConsecutiveAccess(Instrs[i], Instrs[j])) {
643         if (ConsecutiveChain[i] != -1) {
644           int CurDistance = std::abs(ConsecutiveChain[i] - i);
645           int NewDistance = std::abs(ConsecutiveChain[i] - j);
646           if (j < i || NewDistance > CurDistance)
647             continue; // Should not insert.
648         }
649 
650         Tails.insert(j);
651         Heads.insert(i);
652         ConsecutiveChain[i] = j;
653       }
654     }
655   }
656 
657   bool Changed = false;
658   SmallPtrSet<Value *, 16> InstructionsProcessed;
659 
660   for (int Head : Heads) {
661     if (InstructionsProcessed.count(Instrs[Head]))
662       continue;
663     bool longerChainExists = false;
664     for (unsigned TIt = 0; TIt < Tails.size(); TIt++)
665       if (Head == Tails[TIt] &&
666           !InstructionsProcessed.count(Instrs[Heads[TIt]])) {
667         longerChainExists = true;
668         break;
669       }
670     if (longerChainExists)
671       continue;
672 
673     // We found an instr that starts a chain. Now follow the chain and try to
674     // vectorize it.
675     SmallVector<Value *, 16> Operands;
676     int I = Head;
677     while (I != -1 && (Tails.count(I) || Heads.count(I))) {
678       if (InstructionsProcessed.count(Instrs[I]))
679         break;
680 
681       Operands.push_back(Instrs[I]);
682       I = ConsecutiveChain[I];
683     }
684 
685     bool Vectorized = false;
686     if (isa<LoadInst>(*Operands.begin()))
687       Vectorized = vectorizeLoadChain(Operands, &InstructionsProcessed);
688     else
689       Vectorized = vectorizeStoreChain(Operands, &InstructionsProcessed);
690 
691     Changed |= Vectorized;
692   }
693 
694   return Changed;
695 }
696 
697 bool Vectorizer::vectorizeStoreChain(
698     ArrayRef<Value *> Chain, SmallPtrSet<Value *, 16> *InstructionsProcessed) {
699   StoreInst *S0 = cast<StoreInst>(Chain[0]);
700 
701   // If the vector has an int element, default to int for the whole load.
702   Type *StoreTy;
703   for (const auto &V : Chain) {
704     StoreTy = cast<StoreInst>(V)->getValueOperand()->getType();
705     if (StoreTy->isIntOrIntVectorTy())
706       break;
707 
708     if (StoreTy->isPtrOrPtrVectorTy()) {
709       StoreTy = Type::getIntNTy(F.getParent()->getContext(),
710                                 DL.getTypeSizeInBits(StoreTy));
711       break;
712     }
713   }
714 
715   unsigned Sz = DL.getTypeSizeInBits(StoreTy);
716   unsigned AS = S0->getPointerAddressSpace();
717   unsigned VecRegSize = TTI.getLoadStoreVecRegBitWidth(AS);
718   unsigned VF = VecRegSize / Sz;
719   unsigned ChainSize = Chain.size();
720 
721   if (!isPowerOf2_32(Sz) || VF < 2 || ChainSize < 2) {
722     InstructionsProcessed->insert(Chain.begin(), Chain.end());
723     return false;
724   }
725 
726   ArrayRef<Value *> NewChain = getVectorizablePrefix(Chain);
727   if (NewChain.empty()) {
728     // No vectorization possible.
729     InstructionsProcessed->insert(Chain.begin(), Chain.end());
730     return false;
731   }
732   if (NewChain.size() == 1) {
733     // Failed after the first instruction. Discard it and try the smaller chain.
734     InstructionsProcessed->insert(NewChain.front());
735     return false;
736   }
737 
738   // Update Chain to the valid vectorizable subchain.
739   Chain = NewChain;
740   ChainSize = Chain.size();
741 
742   // Store size should be 1B, 2B or multiple of 4B.
743   // TODO: Target hook for size constraint?
744   unsigned SzInBytes = (Sz / 8) * ChainSize;
745   if (SzInBytes > 2 && SzInBytes % 4 != 0) {
746     DEBUG(dbgs() << "LSV: Size should be 1B, 2B "
747                     "or multiple of 4B. Splitting.\n");
748     if (SzInBytes == 3)
749       return vectorizeStoreChain(Chain.slice(0, ChainSize - 1),
750                                  InstructionsProcessed);
751 
752     auto Chains = splitOddVectorElts(Chain, Sz);
753     return vectorizeStoreChain(Chains.first, InstructionsProcessed) |
754            vectorizeStoreChain(Chains.second, InstructionsProcessed);
755   }
756 
757   VectorType *VecTy;
758   VectorType *VecStoreTy = dyn_cast<VectorType>(StoreTy);
759   if (VecStoreTy)
760     VecTy = VectorType::get(StoreTy->getScalarType(),
761                             Chain.size() * VecStoreTy->getNumElements());
762   else
763     VecTy = VectorType::get(StoreTy, Chain.size());
764 
765   // If it's more than the max vector size, break it into two pieces.
766   // TODO: Target hook to control types to split to.
767   if (ChainSize > VF) {
768     DEBUG(dbgs() << "LSV: Vector factor is too big."
769                     " Creating two separate arrays.\n");
770     return vectorizeStoreChain(Chain.slice(0, VF), InstructionsProcessed) |
771            vectorizeStoreChain(Chain.slice(VF), InstructionsProcessed);
772   }
773 
774   DEBUG({
775     dbgs() << "LSV: Stores to vectorize:\n";
776     for (Value *V : Chain)
777       dbgs() << "  " << *V << "\n";
778   });
779 
780   // We won't try again to vectorize the elements of the chain, regardless of
781   // whether we succeed below.
782   InstructionsProcessed->insert(Chain.begin(), Chain.end());
783 
784   // Check alignment restrictions.
785   unsigned Alignment = getAlignment(S0);
786 
787   // If the store is going to be misaligned, don't vectorize it.
788   if (accessIsMisaligned(SzInBytes, AS, Alignment)) {
789     if (S0->getPointerAddressSpace() != 0)
790       return false;
791 
792     // If we're storing to an object on the stack, we control its alignment,
793     // so we can cheat and change it!
794     Value *V = GetUnderlyingObject(S0->getPointerOperand(), DL);
795     if (AllocaInst *AI = dyn_cast_or_null<AllocaInst>(V)) {
796       AI->setAlignment(TargetBaseAlign);
797       Alignment = TargetBaseAlign;
798     } else {
799       return false;
800     }
801   }
802 
803   BasicBlock::iterator First, Last;
804   std::tie(First, Last) = getBoundaryInstrs(Chain);
805   Builder.SetInsertPoint(&*Last);
806 
807   Value *Vec = UndefValue::get(VecTy);
808 
809   if (VecStoreTy) {
810     unsigned VecWidth = VecStoreTy->getNumElements();
811     for (unsigned I = 0, E = Chain.size(); I != E; ++I) {
812       StoreInst *Store = cast<StoreInst>(Chain[I]);
813       for (unsigned J = 0, NE = VecStoreTy->getNumElements(); J != NE; ++J) {
814         unsigned NewIdx = J + I * VecWidth;
815         Value *Extract = Builder.CreateExtractElement(Store->getValueOperand(),
816                                                       Builder.getInt32(J));
817         if (Extract->getType() != StoreTy->getScalarType())
818           Extract = Builder.CreateBitCast(Extract, StoreTy->getScalarType());
819 
820         Value *Insert =
821             Builder.CreateInsertElement(Vec, Extract, Builder.getInt32(NewIdx));
822         Vec = Insert;
823       }
824     }
825   } else {
826     for (unsigned I = 0, E = Chain.size(); I != E; ++I) {
827       StoreInst *Store = cast<StoreInst>(Chain[I]);
828       Value *Extract = Store->getValueOperand();
829       if (Extract->getType() != StoreTy->getScalarType())
830         Extract =
831             Builder.CreateBitOrPointerCast(Extract, StoreTy->getScalarType());
832 
833       Value *Insert =
834           Builder.CreateInsertElement(Vec, Extract, Builder.getInt32(I));
835       Vec = Insert;
836     }
837   }
838 
839   Value *Bitcast =
840       Builder.CreateBitCast(S0->getPointerOperand(), VecTy->getPointerTo(AS));
841   StoreInst *SI = cast<StoreInst>(Builder.CreateStore(Vec, Bitcast));
842   propagateMetadata(SI, Chain);
843   SI->setAlignment(Alignment);
844 
845   eraseInstructions(Chain);
846   ++NumVectorInstructions;
847   NumScalarsVectorized += Chain.size();
848   return true;
849 }
850 
851 bool Vectorizer::vectorizeLoadChain(
852     ArrayRef<Value *> Chain, SmallPtrSet<Value *, 16> *InstructionsProcessed) {
853   LoadInst *L0 = cast<LoadInst>(Chain[0]);
854 
855   // If the vector has an int element, default to int for the whole load.
856   Type *LoadTy;
857   for (const auto &V : Chain) {
858     LoadTy = cast<LoadInst>(V)->getType();
859     if (LoadTy->isIntOrIntVectorTy())
860       break;
861 
862     if (LoadTy->isPtrOrPtrVectorTy()) {
863       LoadTy = Type::getIntNTy(F.getParent()->getContext(),
864                                DL.getTypeSizeInBits(LoadTy));
865       break;
866     }
867   }
868 
869   unsigned Sz = DL.getTypeSizeInBits(LoadTy);
870   unsigned AS = L0->getPointerAddressSpace();
871   unsigned VecRegSize = TTI.getLoadStoreVecRegBitWidth(AS);
872   unsigned VF = VecRegSize / Sz;
873   unsigned ChainSize = Chain.size();
874 
875   if (!isPowerOf2_32(Sz) || VF < 2 || ChainSize < 2) {
876     InstructionsProcessed->insert(Chain.begin(), Chain.end());
877     return false;
878   }
879 
880   ArrayRef<Value *> NewChain = getVectorizablePrefix(Chain);
881   if (NewChain.empty()) {
882     // No vectorization possible.
883     InstructionsProcessed->insert(Chain.begin(), Chain.end());
884     return false;
885   }
886   if (NewChain.size() == 1) {
887     // Failed after the first instruction. Discard it and try the smaller chain.
888     InstructionsProcessed->insert(NewChain.front());
889     return false;
890   }
891 
892   // Update Chain to the valid vectorizable subchain.
893   Chain = NewChain;
894   ChainSize = Chain.size();
895 
896   // Load size should be 1B, 2B or multiple of 4B.
897   // TODO: Should size constraint be a target hook?
898   unsigned SzInBytes = (Sz / 8) * ChainSize;
899   if (SzInBytes > 2 && SzInBytes % 4 != 0) {
900     DEBUG(dbgs() << "LSV: Size should be 1B, 2B "
901                     "or multiple of 4B. Splitting.\n");
902     if (SzInBytes == 3)
903       return vectorizeLoadChain(Chain.slice(0, ChainSize - 1),
904                                 InstructionsProcessed);
905     auto Chains = splitOddVectorElts(Chain, Sz);
906     return vectorizeLoadChain(Chains.first, InstructionsProcessed) |
907            vectorizeLoadChain(Chains.second, InstructionsProcessed);
908   }
909 
910   VectorType *VecTy;
911   VectorType *VecLoadTy = dyn_cast<VectorType>(LoadTy);
912   if (VecLoadTy)
913     VecTy = VectorType::get(LoadTy->getScalarType(),
914                             Chain.size() * VecLoadTy->getNumElements());
915   else
916     VecTy = VectorType::get(LoadTy, Chain.size());
917 
918   // If it's more than the max vector size, break it into two pieces.
919   // TODO: Target hook to control types to split to.
920   if (ChainSize > VF) {
921     DEBUG(dbgs() << "LSV: Vector factor is too big. "
922                     "Creating two separate arrays.\n");
923     return vectorizeLoadChain(Chain.slice(0, VF), InstructionsProcessed) |
924            vectorizeLoadChain(Chain.slice(VF), InstructionsProcessed);
925   }
926 
927   // We won't try again to vectorize the elements of the chain, regardless of
928   // whether we succeed below.
929   InstructionsProcessed->insert(Chain.begin(), Chain.end());
930 
931   // Check alignment restrictions.
932   unsigned Alignment = getAlignment(L0);
933 
934   // If the load is going to be misaligned, don't vectorize it.
935   if (accessIsMisaligned(SzInBytes, AS, Alignment)) {
936     if (L0->getPointerAddressSpace() != 0)
937       return false;
938 
939     // If we're loading from an object on the stack, we control its alignment,
940     // so we can cheat and change it!
941     Value *V = GetUnderlyingObject(L0->getPointerOperand(), DL);
942     if (AllocaInst *AI = dyn_cast_or_null<AllocaInst>(V)) {
943       AI->setAlignment(TargetBaseAlign);
944       Alignment = TargetBaseAlign;
945     } else {
946       return false;
947     }
948   }
949 
950   DEBUG({
951     dbgs() << "LSV: Loads to vectorize:\n";
952     for (Value *V : Chain)
953       V->dump();
954   });
955 
956   // getVectorizablePrefix already computed getBoundaryInstrs.  The value of
957   // Last may have changed since then, but the value of First won't have.  If it
958   // matters, we could compute getBoundaryInstrs only once and reuse it here.
959   BasicBlock::iterator First, Last;
960   std::tie(First, Last) = getBoundaryInstrs(Chain);
961   Builder.SetInsertPoint(&*First);
962 
963   Value *Bitcast =
964       Builder.CreateBitCast(L0->getPointerOperand(), VecTy->getPointerTo(AS));
965 
966   LoadInst *LI = cast<LoadInst>(Builder.CreateLoad(Bitcast));
967   propagateMetadata(LI, Chain);
968   LI->setAlignment(Alignment);
969 
970   if (VecLoadTy) {
971     SmallVector<Instruction *, 16> InstrsToErase;
972     SmallVector<Instruction *, 16> InstrsToReorder;
973     InstrsToReorder.push_back(cast<Instruction>(Bitcast));
974 
975     unsigned VecWidth = VecLoadTy->getNumElements();
976     for (unsigned I = 0, E = Chain.size(); I != E; ++I) {
977       for (auto Use : Chain[I]->users()) {
978         Instruction *UI = cast<Instruction>(Use);
979         unsigned Idx = cast<ConstantInt>(UI->getOperand(1))->getZExtValue();
980         unsigned NewIdx = Idx + I * VecWidth;
981         Value *V = Builder.CreateExtractElement(LI, Builder.getInt32(NewIdx));
982         Instruction *Extracted = cast<Instruction>(V);
983         if (Extracted->getType() != UI->getType())
984           Extracted = cast<Instruction>(
985               Builder.CreateBitCast(Extracted, UI->getType()));
986 
987         // Replace the old instruction.
988         UI->replaceAllUsesWith(Extracted);
989         InstrsToErase.push_back(UI);
990       }
991     }
992 
993     for (Instruction *ModUser : InstrsToReorder)
994       reorder(ModUser);
995 
996     for (auto I : InstrsToErase)
997       I->eraseFromParent();
998   } else {
999     SmallVector<Instruction *, 16> InstrsToReorder;
1000     InstrsToReorder.push_back(cast<Instruction>(Bitcast));
1001 
1002     for (unsigned I = 0, E = Chain.size(); I != E; ++I) {
1003       Value *V = Builder.CreateExtractElement(LI, Builder.getInt32(I));
1004       Instruction *Extracted = cast<Instruction>(V);
1005       Instruction *UI = cast<Instruction>(Chain[I]);
1006       if (Extracted->getType() != UI->getType()) {
1007         Extracted = cast<Instruction>(
1008             Builder.CreateBitOrPointerCast(Extracted, UI->getType()));
1009       }
1010 
1011       // Replace the old instruction.
1012       UI->replaceAllUsesWith(Extracted);
1013     }
1014 
1015     for (Instruction *ModUser : InstrsToReorder)
1016       reorder(ModUser);
1017   }
1018 
1019   eraseInstructions(Chain);
1020 
1021   ++NumVectorInstructions;
1022   NumScalarsVectorized += Chain.size();
1023   return true;
1024 }
1025 
1026 bool Vectorizer::accessIsMisaligned(unsigned SzInBytes, unsigned AddressSpace,
1027                                     unsigned Alignment) {
1028   bool Fast = false;
1029   bool Allows = TTI.allowsMisalignedMemoryAccesses(SzInBytes * 8, AddressSpace,
1030                                                    Alignment, &Fast);
1031   // TODO: Remove TargetBaseAlign
1032   return !(Allows && Fast) && (Alignment % SzInBytes) != 0 &&
1033          (Alignment % TargetBaseAlign) != 0;
1034 }
1035