1 //===- Scalarizer.cpp - Scalarize vector operations -----------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This pass converts vector operations into scalar operations, in order
11 // to expose optimization opportunities on the individual scalar operations.
12 // It is mainly intended for targets that do not have vector units, but it
13 // may also be useful for revectorizing code to different vector widths.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "llvm/ADT/PostOrderIterator.h"
18 #include "llvm/ADT/SmallVector.h"
19 #include "llvm/ADT/Twine.h"
20 #include "llvm/Analysis/VectorUtils.h"
21 #include "llvm/IR/Argument.h"
22 #include "llvm/IR/BasicBlock.h"
23 #include "llvm/IR/Constants.h"
24 #include "llvm/IR/DataLayout.h"
25 #include "llvm/IR/DerivedTypes.h"
26 #include "llvm/IR/Function.h"
27 #include "llvm/IR/IRBuilder.h"
28 #include "llvm/IR/InstVisitor.h"
29 #include "llvm/IR/InstrTypes.h"
30 #include "llvm/IR/Instruction.h"
31 #include "llvm/IR/Instructions.h"
32 #include "llvm/IR/Intrinsics.h"
33 #include "llvm/IR/LLVMContext.h"
34 #include "llvm/IR/Module.h"
35 #include "llvm/IR/Type.h"
36 #include "llvm/IR/Value.h"
37 #include "llvm/Pass.h"
38 #include "llvm/Support/Casting.h"
39 #include "llvm/Support/MathExtras.h"
40 #include "llvm/Support/Options.h"
41 #include "llvm/Transforms/Scalar.h"
42 #include <cassert>
43 #include <cstdint>
44 #include <iterator>
45 #include <map>
46 #include <utility>
47 
48 using namespace llvm;
49 
50 #define DEBUG_TYPE "scalarizer"
51 
52 namespace {
53 
54 // Used to store the scattered form of a vector.
55 using ValueVector = SmallVector<Value *, 8>;
56 
57 // Used to map a vector Value to its scattered form.  We use std::map
58 // because we want iterators to persist across insertion and because the
59 // values are relatively large.
60 using ScatterMap = std::map<Value *, ValueVector>;
61 
62 // Lists Instructions that have been replaced with scalar implementations,
63 // along with a pointer to their scattered forms.
64 using GatherList = SmallVector<std::pair<Instruction *, ValueVector *>, 16>;
65 
66 // Provides a very limited vector-like interface for lazily accessing one
67 // component of a scattered vector or vector pointer.
68 class Scatterer {
69 public:
70   Scatterer() = default;
71 
72   // Scatter V into Size components.  If new instructions are needed,
73   // insert them before BBI in BB.  If Cache is nonnull, use it to cache
74   // the results.
75   Scatterer(BasicBlock *bb, BasicBlock::iterator bbi, Value *v,
76             ValueVector *cachePtr = nullptr);
77 
78   // Return component I, creating a new Value for it if necessary.
79   Value *operator[](unsigned I);
80 
81   // Return the number of components.
82   unsigned size() const { return Size; }
83 
84 private:
85   BasicBlock *BB;
86   BasicBlock::iterator BBI;
87   Value *V;
88   ValueVector *CachePtr;
89   PointerType *PtrTy;
90   ValueVector Tmp;
91   unsigned Size;
92 };
93 
94 // FCmpSpliiter(FCI)(Builder, X, Y, Name) uses Builder to create an FCmp
95 // called Name that compares X and Y in the same way as FCI.
96 struct FCmpSplitter {
97   FCmpSplitter(FCmpInst &fci) : FCI(fci) {}
98 
99   Value *operator()(IRBuilder<> &Builder, Value *Op0, Value *Op1,
100                     const Twine &Name) const {
101     return Builder.CreateFCmp(FCI.getPredicate(), Op0, Op1, Name);
102   }
103 
104   FCmpInst &FCI;
105 };
106 
107 // ICmpSpliiter(ICI)(Builder, X, Y, Name) uses Builder to create an ICmp
108 // called Name that compares X and Y in the same way as ICI.
109 struct ICmpSplitter {
110   ICmpSplitter(ICmpInst &ici) : ICI(ici) {}
111 
112   Value *operator()(IRBuilder<> &Builder, Value *Op0, Value *Op1,
113                     const Twine &Name) const {
114     return Builder.CreateICmp(ICI.getPredicate(), Op0, Op1, Name);
115   }
116 
117   ICmpInst &ICI;
118 };
119 
120 // BinarySpliiter(BO)(Builder, X, Y, Name) uses Builder to create
121 // a binary operator like BO called Name with operands X and Y.
122 struct BinarySplitter {
123   BinarySplitter(BinaryOperator &bo) : BO(bo) {}
124 
125   Value *operator()(IRBuilder<> &Builder, Value *Op0, Value *Op1,
126                     const Twine &Name) const {
127     return Builder.CreateBinOp(BO.getOpcode(), Op0, Op1, Name);
128   }
129 
130   BinaryOperator &BO;
131 };
132 
133 // Information about a load or store that we're scalarizing.
134 struct VectorLayout {
135   VectorLayout() = default;
136 
137   // Return the alignment of element I.
138   uint64_t getElemAlign(unsigned I) {
139     return MinAlign(VecAlign, I * ElemSize);
140   }
141 
142   // The type of the vector.
143   VectorType *VecTy = nullptr;
144 
145   // The type of each element.
146   Type *ElemTy = nullptr;
147 
148   // The alignment of the vector.
149   uint64_t VecAlign = 0;
150 
151   // The size of each element.
152   uint64_t ElemSize = 0;
153 };
154 
155 class Scalarizer : public FunctionPass,
156                    public InstVisitor<Scalarizer, bool> {
157 public:
158   static char ID;
159 
160   Scalarizer() : FunctionPass(ID) {
161     initializeScalarizerPass(*PassRegistry::getPassRegistry());
162   }
163 
164   bool doInitialization(Module &M) override;
165   bool runOnFunction(Function &F) override;
166 
167   // InstVisitor methods.  They return true if the instruction was scalarized,
168   // false if nothing changed.
169   bool visitInstruction(Instruction &I) { return false; }
170   bool visitSelectInst(SelectInst &SI);
171   bool visitICmpInst(ICmpInst &ICI);
172   bool visitFCmpInst(FCmpInst &FCI);
173   bool visitBinaryOperator(BinaryOperator &BO);
174   bool visitGetElementPtrInst(GetElementPtrInst &GEPI);
175   bool visitCastInst(CastInst &CI);
176   bool visitBitCastInst(BitCastInst &BCI);
177   bool visitShuffleVectorInst(ShuffleVectorInst &SVI);
178   bool visitPHINode(PHINode &PHI);
179   bool visitLoadInst(LoadInst &LI);
180   bool visitStoreInst(StoreInst &SI);
181   bool visitCallInst(CallInst &ICI);
182 
183   static void registerOptions() {
184     // This is disabled by default because having separate loads and stores
185     // makes it more likely that the -combiner-alias-analysis limits will be
186     // reached.
187     OptionRegistry::registerOption<bool, Scalarizer,
188                                  &Scalarizer::ScalarizeLoadStore>(
189         "scalarize-load-store",
190         "Allow the scalarizer pass to scalarize loads and store", false);
191   }
192 
193 private:
194   Scatterer scatter(Instruction *Point, Value *V);
195   void gather(Instruction *Op, const ValueVector &CV);
196   bool canTransferMetadata(unsigned Kind);
197   void transferMetadata(Instruction *Op, const ValueVector &CV);
198   bool getVectorLayout(Type *Ty, unsigned Alignment, VectorLayout &Layout,
199                        const DataLayout &DL);
200   bool finish();
201 
202   template<typename T> bool splitBinary(Instruction &, const T &);
203 
204   bool splitCall(CallInst &CI);
205 
206   ScatterMap Scattered;
207   GatherList Gathered;
208   unsigned ParallelLoopAccessMDKind;
209   bool ScalarizeLoadStore;
210 };
211 
212 } // end anonymous namespace
213 
214 char Scalarizer::ID = 0;
215 
216 INITIALIZE_PASS_WITH_OPTIONS(Scalarizer, "scalarizer",
217                              "Scalarize vector operations", false, false)
218 
219 Scatterer::Scatterer(BasicBlock *bb, BasicBlock::iterator bbi, Value *v,
220                      ValueVector *cachePtr)
221   : BB(bb), BBI(bbi), V(v), CachePtr(cachePtr) {
222   Type *Ty = V->getType();
223   PtrTy = dyn_cast<PointerType>(Ty);
224   if (PtrTy)
225     Ty = PtrTy->getElementType();
226   Size = Ty->getVectorNumElements();
227   if (!CachePtr)
228     Tmp.resize(Size, nullptr);
229   else if (CachePtr->empty())
230     CachePtr->resize(Size, nullptr);
231   else
232     assert(Size == CachePtr->size() && "Inconsistent vector sizes");
233 }
234 
235 // Return component I, creating a new Value for it if necessary.
236 Value *Scatterer::operator[](unsigned I) {
237   ValueVector &CV = (CachePtr ? *CachePtr : Tmp);
238   // Try to reuse a previous value.
239   if (CV[I])
240     return CV[I];
241   IRBuilder<> Builder(BB, BBI);
242   if (PtrTy) {
243     if (!CV[0]) {
244       Type *Ty =
245         PointerType::get(PtrTy->getElementType()->getVectorElementType(),
246                          PtrTy->getAddressSpace());
247       CV[0] = Builder.CreateBitCast(V, Ty, V->getName() + ".i0");
248     }
249     if (I != 0)
250       CV[I] = Builder.CreateConstGEP1_32(nullptr, CV[0], I,
251                                          V->getName() + ".i" + Twine(I));
252   } else {
253     // Search through a chain of InsertElementInsts looking for element I.
254     // Record other elements in the cache.  The new V is still suitable
255     // for all uncached indices.
256     while (true) {
257       InsertElementInst *Insert = dyn_cast<InsertElementInst>(V);
258       if (!Insert)
259         break;
260       ConstantInt *Idx = dyn_cast<ConstantInt>(Insert->getOperand(2));
261       if (!Idx)
262         break;
263       unsigned J = Idx->getZExtValue();
264       V = Insert->getOperand(0);
265       if (I == J) {
266         CV[J] = Insert->getOperand(1);
267         return CV[J];
268       } else if (!CV[J]) {
269         // Only cache the first entry we find for each index we're not actively
270         // searching for. This prevents us from going too far up the chain and
271         // caching incorrect entries.
272         CV[J] = Insert->getOperand(1);
273       }
274     }
275     CV[I] = Builder.CreateExtractElement(V, Builder.getInt32(I),
276                                          V->getName() + ".i" + Twine(I));
277   }
278   return CV[I];
279 }
280 
281 bool Scalarizer::doInitialization(Module &M) {
282   ParallelLoopAccessMDKind =
283       M.getContext().getMDKindID("llvm.mem.parallel_loop_access");
284   ScalarizeLoadStore =
285       M.getContext().getOption<bool, Scalarizer, &Scalarizer::ScalarizeLoadStore>();
286   return false;
287 }
288 
289 bool Scalarizer::runOnFunction(Function &F) {
290   if (skipFunction(F))
291     return false;
292   assert(Gathered.empty() && Scattered.empty());
293 
294   // To ensure we replace gathered components correctly we need to do an ordered
295   // traversal of the basic blocks in the function.
296   ReversePostOrderTraversal<BasicBlock *> RPOT(&F.getEntryBlock());
297   for (BasicBlock *BB : RPOT) {
298     for (BasicBlock::iterator II = BB->begin(), IE = BB->end(); II != IE;) {
299       Instruction *I = &*II;
300       bool Done = visit(I);
301       ++II;
302       if (Done && I->getType()->isVoidTy())
303         I->eraseFromParent();
304     }
305   }
306   return finish();
307 }
308 
309 // Return a scattered form of V that can be accessed by Point.  V must be a
310 // vector or a pointer to a vector.
311 Scatterer Scalarizer::scatter(Instruction *Point, Value *V) {
312   if (Argument *VArg = dyn_cast<Argument>(V)) {
313     // Put the scattered form of arguments in the entry block,
314     // so that it can be used everywhere.
315     Function *F = VArg->getParent();
316     BasicBlock *BB = &F->getEntryBlock();
317     return Scatterer(BB, BB->begin(), V, &Scattered[V]);
318   }
319   if (Instruction *VOp = dyn_cast<Instruction>(V)) {
320     // Put the scattered form of an instruction directly after the
321     // instruction.
322     BasicBlock *BB = VOp->getParent();
323     return Scatterer(BB, std::next(BasicBlock::iterator(VOp)),
324                      V, &Scattered[V]);
325   }
326   // In the fallback case, just put the scattered before Point and
327   // keep the result local to Point.
328   return Scatterer(Point->getParent(), Point->getIterator(), V);
329 }
330 
331 // Replace Op with the gathered form of the components in CV.  Defer the
332 // deletion of Op and creation of the gathered form to the end of the pass,
333 // so that we can avoid creating the gathered form if all uses of Op are
334 // replaced with uses of CV.
335 void Scalarizer::gather(Instruction *Op, const ValueVector &CV) {
336   // Since we're not deleting Op yet, stub out its operands, so that it
337   // doesn't make anything live unnecessarily.
338   for (unsigned I = 0, E = Op->getNumOperands(); I != E; ++I)
339     Op->setOperand(I, UndefValue::get(Op->getOperand(I)->getType()));
340 
341   transferMetadata(Op, CV);
342 
343   // If we already have a scattered form of Op (created from ExtractElements
344   // of Op itself), replace them with the new form.
345   ValueVector &SV = Scattered[Op];
346   if (!SV.empty()) {
347     for (unsigned I = 0, E = SV.size(); I != E; ++I) {
348       Value *V = SV[I];
349       if (V == nullptr)
350         continue;
351 
352       Instruction *Old = cast<Instruction>(V);
353       CV[I]->takeName(Old);
354       Old->replaceAllUsesWith(CV[I]);
355       Old->eraseFromParent();
356     }
357   }
358   SV = CV;
359   Gathered.push_back(GatherList::value_type(Op, &SV));
360 }
361 
362 // Return true if it is safe to transfer the given metadata tag from
363 // vector to scalar instructions.
364 bool Scalarizer::canTransferMetadata(unsigned Tag) {
365   return (Tag == LLVMContext::MD_tbaa
366           || Tag == LLVMContext::MD_fpmath
367           || Tag == LLVMContext::MD_tbaa_struct
368           || Tag == LLVMContext::MD_invariant_load
369           || Tag == LLVMContext::MD_alias_scope
370           || Tag == LLVMContext::MD_noalias
371           || Tag == ParallelLoopAccessMDKind);
372 }
373 
374 // Transfer metadata from Op to the instructions in CV if it is known
375 // to be safe to do so.
376 void Scalarizer::transferMetadata(Instruction *Op, const ValueVector &CV) {
377   SmallVector<std::pair<unsigned, MDNode *>, 4> MDs;
378   Op->getAllMetadataOtherThanDebugLoc(MDs);
379   for (unsigned I = 0, E = CV.size(); I != E; ++I) {
380     if (Instruction *New = dyn_cast<Instruction>(CV[I])) {
381       for (const auto &MD : MDs)
382         if (canTransferMetadata(MD.first))
383           New->setMetadata(MD.first, MD.second);
384       if (Op->getDebugLoc() && !New->getDebugLoc())
385         New->setDebugLoc(Op->getDebugLoc());
386     }
387   }
388 }
389 
390 // Try to fill in Layout from Ty, returning true on success.  Alignment is
391 // the alignment of the vector, or 0 if the ABI default should be used.
392 bool Scalarizer::getVectorLayout(Type *Ty, unsigned Alignment,
393                                  VectorLayout &Layout, const DataLayout &DL) {
394   // Make sure we're dealing with a vector.
395   Layout.VecTy = dyn_cast<VectorType>(Ty);
396   if (!Layout.VecTy)
397     return false;
398 
399   // Check that we're dealing with full-byte elements.
400   Layout.ElemTy = Layout.VecTy->getElementType();
401   if (DL.getTypeSizeInBits(Layout.ElemTy) !=
402       DL.getTypeStoreSizeInBits(Layout.ElemTy))
403     return false;
404 
405   if (Alignment)
406     Layout.VecAlign = Alignment;
407   else
408     Layout.VecAlign = DL.getABITypeAlignment(Layout.VecTy);
409   Layout.ElemSize = DL.getTypeStoreSize(Layout.ElemTy);
410   return true;
411 }
412 
413 // Scalarize two-operand instruction I, using Split(Builder, X, Y, Name)
414 // to create an instruction like I with operands X and Y and name Name.
415 template<typename Splitter>
416 bool Scalarizer::splitBinary(Instruction &I, const Splitter &Split) {
417   VectorType *VT = dyn_cast<VectorType>(I.getType());
418   if (!VT)
419     return false;
420 
421   unsigned NumElems = VT->getNumElements();
422   IRBuilder<> Builder(&I);
423   Scatterer Op0 = scatter(&I, I.getOperand(0));
424   Scatterer Op1 = scatter(&I, I.getOperand(1));
425   assert(Op0.size() == NumElems && "Mismatched binary operation");
426   assert(Op1.size() == NumElems && "Mismatched binary operation");
427   ValueVector Res;
428   Res.resize(NumElems);
429   for (unsigned Elem = 0; Elem < NumElems; ++Elem)
430     Res[Elem] = Split(Builder, Op0[Elem], Op1[Elem],
431                       I.getName() + ".i" + Twine(Elem));
432   gather(&I, Res);
433   return true;
434 }
435 
436 static bool isTriviallyScalariable(Intrinsic::ID ID) {
437   return isTriviallyVectorizable(ID);
438 }
439 
440 // All of the current scalarizable intrinsics only have one mangled type.
441 static Function *getScalarIntrinsicDeclaration(Module *M,
442                                                Intrinsic::ID ID,
443                                                VectorType *Ty) {
444   return Intrinsic::getDeclaration(M, ID, { Ty->getScalarType() });
445 }
446 
447 /// If a call to a vector typed intrinsic function, split into a scalar call per
448 /// element if possible for the intrinsic.
449 bool Scalarizer::splitCall(CallInst &CI) {
450   VectorType *VT = dyn_cast<VectorType>(CI.getType());
451   if (!VT)
452     return false;
453 
454   Function *F = CI.getCalledFunction();
455   if (!F)
456     return false;
457 
458   Intrinsic::ID ID = F->getIntrinsicID();
459   if (ID == Intrinsic::not_intrinsic || !isTriviallyScalariable(ID))
460     return false;
461 
462   unsigned NumElems = VT->getNumElements();
463   unsigned NumArgs = CI.getNumArgOperands();
464 
465   ValueVector ScalarOperands(NumArgs);
466   SmallVector<Scatterer, 8> Scattered(NumArgs);
467 
468   Scattered.resize(NumArgs);
469 
470   // Assumes that any vector type has the same number of elements as the return
471   // vector type, which is true for all current intrinsics.
472   for (unsigned I = 0; I != NumArgs; ++I) {
473     Value *OpI = CI.getOperand(I);
474     if (OpI->getType()->isVectorTy()) {
475       Scattered[I] = scatter(&CI, OpI);
476       assert(Scattered[I].size() == NumElems && "mismatched call operands");
477     } else {
478       ScalarOperands[I] = OpI;
479     }
480   }
481 
482   ValueVector Res(NumElems);
483   ValueVector ScalarCallOps(NumArgs);
484 
485   Function *NewIntrin = getScalarIntrinsicDeclaration(F->getParent(), ID, VT);
486   IRBuilder<> Builder(&CI);
487 
488   // Perform actual scalarization, taking care to preserve any scalar operands.
489   for (unsigned Elem = 0; Elem < NumElems; ++Elem) {
490     ScalarCallOps.clear();
491 
492     for (unsigned J = 0; J != NumArgs; ++J) {
493       if (hasVectorInstrinsicScalarOpd(ID, J))
494         ScalarCallOps.push_back(ScalarOperands[J]);
495       else
496         ScalarCallOps.push_back(Scattered[J][Elem]);
497     }
498 
499     Res[Elem] = Builder.CreateCall(NewIntrin, ScalarCallOps,
500                                    CI.getName() + ".i" + Twine(Elem));
501   }
502 
503   gather(&CI, Res);
504   return true;
505 }
506 
507 bool Scalarizer::visitSelectInst(SelectInst &SI) {
508   VectorType *VT = dyn_cast<VectorType>(SI.getType());
509   if (!VT)
510     return false;
511 
512   unsigned NumElems = VT->getNumElements();
513   IRBuilder<> Builder(&SI);
514   Scatterer Op1 = scatter(&SI, SI.getOperand(1));
515   Scatterer Op2 = scatter(&SI, SI.getOperand(2));
516   assert(Op1.size() == NumElems && "Mismatched select");
517   assert(Op2.size() == NumElems && "Mismatched select");
518   ValueVector Res;
519   Res.resize(NumElems);
520 
521   if (SI.getOperand(0)->getType()->isVectorTy()) {
522     Scatterer Op0 = scatter(&SI, SI.getOperand(0));
523     assert(Op0.size() == NumElems && "Mismatched select");
524     for (unsigned I = 0; I < NumElems; ++I)
525       Res[I] = Builder.CreateSelect(Op0[I], Op1[I], Op2[I],
526                                     SI.getName() + ".i" + Twine(I));
527   } else {
528     Value *Op0 = SI.getOperand(0);
529     for (unsigned I = 0; I < NumElems; ++I)
530       Res[I] = Builder.CreateSelect(Op0, Op1[I], Op2[I],
531                                     SI.getName() + ".i" + Twine(I));
532   }
533   gather(&SI, Res);
534   return true;
535 }
536 
537 bool Scalarizer::visitICmpInst(ICmpInst &ICI) {
538   return splitBinary(ICI, ICmpSplitter(ICI));
539 }
540 
541 bool Scalarizer::visitFCmpInst(FCmpInst &FCI) {
542   return splitBinary(FCI, FCmpSplitter(FCI));
543 }
544 
545 bool Scalarizer::visitBinaryOperator(BinaryOperator &BO) {
546   return splitBinary(BO, BinarySplitter(BO));
547 }
548 
549 bool Scalarizer::visitGetElementPtrInst(GetElementPtrInst &GEPI) {
550   VectorType *VT = dyn_cast<VectorType>(GEPI.getType());
551   if (!VT)
552     return false;
553 
554   IRBuilder<> Builder(&GEPI);
555   unsigned NumElems = VT->getNumElements();
556   unsigned NumIndices = GEPI.getNumIndices();
557 
558   // The base pointer might be scalar even if it's a vector GEP. In those cases,
559   // splat the pointer into a vector value, and scatter that vector.
560   Value *Op0 = GEPI.getOperand(0);
561   if (!Op0->getType()->isVectorTy())
562     Op0 = Builder.CreateVectorSplat(NumElems, Op0);
563   Scatterer Base = scatter(&GEPI, Op0);
564 
565   SmallVector<Scatterer, 8> Ops;
566   Ops.resize(NumIndices);
567   for (unsigned I = 0; I < NumIndices; ++I) {
568     Value *Op = GEPI.getOperand(I + 1);
569 
570     // The indices might be scalars even if it's a vector GEP. In those cases,
571     // splat the scalar into a vector value, and scatter that vector.
572     if (!Op->getType()->isVectorTy())
573       Op = Builder.CreateVectorSplat(NumElems, Op);
574 
575     Ops[I] = scatter(&GEPI, Op);
576   }
577 
578   ValueVector Res;
579   Res.resize(NumElems);
580   for (unsigned I = 0; I < NumElems; ++I) {
581     SmallVector<Value *, 8> Indices;
582     Indices.resize(NumIndices);
583     for (unsigned J = 0; J < NumIndices; ++J)
584       Indices[J] = Ops[J][I];
585     Res[I] = Builder.CreateGEP(GEPI.getSourceElementType(), Base[I], Indices,
586                                GEPI.getName() + ".i" + Twine(I));
587     if (GEPI.isInBounds())
588       if (GetElementPtrInst *NewGEPI = dyn_cast<GetElementPtrInst>(Res[I]))
589         NewGEPI->setIsInBounds();
590   }
591   gather(&GEPI, Res);
592   return true;
593 }
594 
595 bool Scalarizer::visitCastInst(CastInst &CI) {
596   VectorType *VT = dyn_cast<VectorType>(CI.getDestTy());
597   if (!VT)
598     return false;
599 
600   unsigned NumElems = VT->getNumElements();
601   IRBuilder<> Builder(&CI);
602   Scatterer Op0 = scatter(&CI, CI.getOperand(0));
603   assert(Op0.size() == NumElems && "Mismatched cast");
604   ValueVector Res;
605   Res.resize(NumElems);
606   for (unsigned I = 0; I < NumElems; ++I)
607     Res[I] = Builder.CreateCast(CI.getOpcode(), Op0[I], VT->getElementType(),
608                                 CI.getName() + ".i" + Twine(I));
609   gather(&CI, Res);
610   return true;
611 }
612 
613 bool Scalarizer::visitBitCastInst(BitCastInst &BCI) {
614   VectorType *DstVT = dyn_cast<VectorType>(BCI.getDestTy());
615   VectorType *SrcVT = dyn_cast<VectorType>(BCI.getSrcTy());
616   if (!DstVT || !SrcVT)
617     return false;
618 
619   unsigned DstNumElems = DstVT->getNumElements();
620   unsigned SrcNumElems = SrcVT->getNumElements();
621   IRBuilder<> Builder(&BCI);
622   Scatterer Op0 = scatter(&BCI, BCI.getOperand(0));
623   ValueVector Res;
624   Res.resize(DstNumElems);
625 
626   if (DstNumElems == SrcNumElems) {
627     for (unsigned I = 0; I < DstNumElems; ++I)
628       Res[I] = Builder.CreateBitCast(Op0[I], DstVT->getElementType(),
629                                      BCI.getName() + ".i" + Twine(I));
630   } else if (DstNumElems > SrcNumElems) {
631     // <M x t1> -> <N*M x t2>.  Convert each t1 to <N x t2> and copy the
632     // individual elements to the destination.
633     unsigned FanOut = DstNumElems / SrcNumElems;
634     Type *MidTy = VectorType::get(DstVT->getElementType(), FanOut);
635     unsigned ResI = 0;
636     for (unsigned Op0I = 0; Op0I < SrcNumElems; ++Op0I) {
637       Value *V = Op0[Op0I];
638       Instruction *VI;
639       // Look through any existing bitcasts before converting to <N x t2>.
640       // In the best case, the resulting conversion might be a no-op.
641       while ((VI = dyn_cast<Instruction>(V)) &&
642              VI->getOpcode() == Instruction::BitCast)
643         V = VI->getOperand(0);
644       V = Builder.CreateBitCast(V, MidTy, V->getName() + ".cast");
645       Scatterer Mid = scatter(&BCI, V);
646       for (unsigned MidI = 0; MidI < FanOut; ++MidI)
647         Res[ResI++] = Mid[MidI];
648     }
649   } else {
650     // <N*M x t1> -> <M x t2>.  Convert each group of <N x t1> into a t2.
651     unsigned FanIn = SrcNumElems / DstNumElems;
652     Type *MidTy = VectorType::get(SrcVT->getElementType(), FanIn);
653     unsigned Op0I = 0;
654     for (unsigned ResI = 0; ResI < DstNumElems; ++ResI) {
655       Value *V = UndefValue::get(MidTy);
656       for (unsigned MidI = 0; MidI < FanIn; ++MidI)
657         V = Builder.CreateInsertElement(V, Op0[Op0I++], Builder.getInt32(MidI),
658                                         BCI.getName() + ".i" + Twine(ResI)
659                                         + ".upto" + Twine(MidI));
660       Res[ResI] = Builder.CreateBitCast(V, DstVT->getElementType(),
661                                         BCI.getName() + ".i" + Twine(ResI));
662     }
663   }
664   gather(&BCI, Res);
665   return true;
666 }
667 
668 bool Scalarizer::visitShuffleVectorInst(ShuffleVectorInst &SVI) {
669   VectorType *VT = dyn_cast<VectorType>(SVI.getType());
670   if (!VT)
671     return false;
672 
673   unsigned NumElems = VT->getNumElements();
674   Scatterer Op0 = scatter(&SVI, SVI.getOperand(0));
675   Scatterer Op1 = scatter(&SVI, SVI.getOperand(1));
676   ValueVector Res;
677   Res.resize(NumElems);
678 
679   for (unsigned I = 0; I < NumElems; ++I) {
680     int Selector = SVI.getMaskValue(I);
681     if (Selector < 0)
682       Res[I] = UndefValue::get(VT->getElementType());
683     else if (unsigned(Selector) < Op0.size())
684       Res[I] = Op0[Selector];
685     else
686       Res[I] = Op1[Selector - Op0.size()];
687   }
688   gather(&SVI, Res);
689   return true;
690 }
691 
692 bool Scalarizer::visitPHINode(PHINode &PHI) {
693   VectorType *VT = dyn_cast<VectorType>(PHI.getType());
694   if (!VT)
695     return false;
696 
697   unsigned NumElems = VT->getNumElements();
698   IRBuilder<> Builder(&PHI);
699   ValueVector Res;
700   Res.resize(NumElems);
701 
702   unsigned NumOps = PHI.getNumOperands();
703   for (unsigned I = 0; I < NumElems; ++I)
704     Res[I] = Builder.CreatePHI(VT->getElementType(), NumOps,
705                                PHI.getName() + ".i" + Twine(I));
706 
707   for (unsigned I = 0; I < NumOps; ++I) {
708     Scatterer Op = scatter(&PHI, PHI.getIncomingValue(I));
709     BasicBlock *IncomingBlock = PHI.getIncomingBlock(I);
710     for (unsigned J = 0; J < NumElems; ++J)
711       cast<PHINode>(Res[J])->addIncoming(Op[J], IncomingBlock);
712   }
713   gather(&PHI, Res);
714   return true;
715 }
716 
717 bool Scalarizer::visitLoadInst(LoadInst &LI) {
718   if (!ScalarizeLoadStore)
719     return false;
720   if (!LI.isSimple())
721     return false;
722 
723   VectorLayout Layout;
724   if (!getVectorLayout(LI.getType(), LI.getAlignment(), Layout,
725                        LI.getModule()->getDataLayout()))
726     return false;
727 
728   unsigned NumElems = Layout.VecTy->getNumElements();
729   IRBuilder<> Builder(&LI);
730   Scatterer Ptr = scatter(&LI, LI.getPointerOperand());
731   ValueVector Res;
732   Res.resize(NumElems);
733 
734   for (unsigned I = 0; I < NumElems; ++I)
735     Res[I] = Builder.CreateAlignedLoad(Ptr[I], Layout.getElemAlign(I),
736                                        LI.getName() + ".i" + Twine(I));
737   gather(&LI, Res);
738   return true;
739 }
740 
741 bool Scalarizer::visitStoreInst(StoreInst &SI) {
742   if (!ScalarizeLoadStore)
743     return false;
744   if (!SI.isSimple())
745     return false;
746 
747   VectorLayout Layout;
748   Value *FullValue = SI.getValueOperand();
749   if (!getVectorLayout(FullValue->getType(), SI.getAlignment(), Layout,
750                        SI.getModule()->getDataLayout()))
751     return false;
752 
753   unsigned NumElems = Layout.VecTy->getNumElements();
754   IRBuilder<> Builder(&SI);
755   Scatterer Ptr = scatter(&SI, SI.getPointerOperand());
756   Scatterer Val = scatter(&SI, FullValue);
757 
758   ValueVector Stores;
759   Stores.resize(NumElems);
760   for (unsigned I = 0; I < NumElems; ++I) {
761     unsigned Align = Layout.getElemAlign(I);
762     Stores[I] = Builder.CreateAlignedStore(Val[I], Ptr[I], Align);
763   }
764   transferMetadata(&SI, Stores);
765   return true;
766 }
767 
768 bool Scalarizer::visitCallInst(CallInst &CI) {
769   return splitCall(CI);
770 }
771 
772 // Delete the instructions that we scalarized.  If a full vector result
773 // is still needed, recreate it using InsertElements.
774 bool Scalarizer::finish() {
775   // The presence of data in Gathered or Scattered indicates changes
776   // made to the Function.
777   if (Gathered.empty() && Scattered.empty())
778     return false;
779   for (const auto &GMI : Gathered) {
780     Instruction *Op = GMI.first;
781     ValueVector &CV = *GMI.second;
782     if (!Op->use_empty()) {
783       // The value is still needed, so recreate it using a series of
784       // InsertElements.
785       Type *Ty = Op->getType();
786       Value *Res = UndefValue::get(Ty);
787       BasicBlock *BB = Op->getParent();
788       unsigned Count = Ty->getVectorNumElements();
789       IRBuilder<> Builder(Op);
790       if (isa<PHINode>(Op))
791         Builder.SetInsertPoint(BB, BB->getFirstInsertionPt());
792       for (unsigned I = 0; I < Count; ++I)
793         Res = Builder.CreateInsertElement(Res, CV[I], Builder.getInt32(I),
794                                           Op->getName() + ".upto" + Twine(I));
795       Res->takeName(Op);
796       Op->replaceAllUsesWith(Res);
797     }
798     Op->eraseFromParent();
799   }
800   Gathered.clear();
801   Scattered.clear();
802   return true;
803 }
804 
805 FunctionPass *llvm::createScalarizerPass() {
806   return new Scalarizer();
807 }
808