1 //===------- VectorCombine.cpp - Optimize partial vector operations -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass optimizes scalar/vector interactions using target cost models. The
10 // transforms implemented here may not fit in traditional loop-based or SLP
11 // vectorization passes.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "llvm/Transforms/Vectorize/VectorCombine.h"
16 #include "llvm/ADT/Statistic.h"
17 #include "llvm/Analysis/BasicAliasAnalysis.h"
18 #include "llvm/Analysis/GlobalsModRef.h"
19 #include "llvm/Analysis/TargetTransformInfo.h"
20 #include "llvm/Analysis/ValueTracking.h"
21 #include "llvm/Analysis/VectorUtils.h"
22 #include "llvm/IR/Dominators.h"
23 #include "llvm/IR/Function.h"
24 #include "llvm/IR/IRBuilder.h"
25 #include "llvm/IR/PatternMatch.h"
26 #include "llvm/InitializePasses.h"
27 #include "llvm/Pass.h"
28 #include "llvm/Support/CommandLine.h"
29 #include "llvm/Transforms/Utils/Local.h"
30 #include "llvm/Transforms/Vectorize.h"
31 
32 using namespace llvm;
33 using namespace llvm::PatternMatch;
34 
35 #define DEBUG_TYPE "vector-combine"
36 STATISTIC(NumVecCmp, "Number of vector compares formed");
37 STATISTIC(NumVecBO, "Number of vector binops formed");
38 STATISTIC(NumVecCmpBO, "Number of vector compare + binop formed");
39 STATISTIC(NumShufOfBitcast, "Number of shuffles moved after bitcast");
40 STATISTIC(NumScalarBO, "Number of scalar binops formed");
41 STATISTIC(NumScalarCmp, "Number of scalar compares formed");
42 
43 static cl::opt<bool> DisableVectorCombine(
44     "disable-vector-combine", cl::init(false), cl::Hidden,
45     cl::desc("Disable all vector combine transforms"));
46 
47 static cl::opt<bool> DisableBinopExtractShuffle(
48     "disable-binop-extract-shuffle", cl::init(false), cl::Hidden,
49     cl::desc("Disable binop extract to shuffle transforms"));
50 
51 static const unsigned InvalidIndex = std::numeric_limits<unsigned>::max();
52 
53 class VectorCombine {
54 public:
55   VectorCombine(Function &F, const TargetTransformInfo &TTI,
56                 const DominatorTree &DT)
57       : F(F), Builder(F.getContext()), TTI(TTI), DT(DT) {}
58 
59   bool run();
60 
61 private:
62   Function &F;
63   IRBuilder<> Builder;
64   const TargetTransformInfo &TTI;
65   const DominatorTree &DT;
66 
67   ExtractElementInst *getShuffleExtract(ExtractElementInst *Ext0,
68                                         ExtractElementInst *Ext1,
69                                         unsigned PreferredExtractIndex) const;
70   bool isExtractExtractCheap(ExtractElementInst *Ext0, ExtractElementInst *Ext1,
71                              unsigned Opcode,
72                              ExtractElementInst *&ConvertToShuffle,
73                              unsigned PreferredExtractIndex);
74   void foldExtExtCmp(ExtractElementInst *Ext0, ExtractElementInst *Ext1,
75                      Instruction &I);
76   void foldExtExtBinop(ExtractElementInst *Ext0, ExtractElementInst *Ext1,
77                        Instruction &I);
78   bool foldExtractExtract(Instruction &I);
79   bool foldBitcastShuf(Instruction &I);
80   bool scalarizeBinopOrCmp(Instruction &I);
81   bool foldExtractedCmps(Instruction &I);
82 };
83 
84 static void replaceValue(Value &Old, Value &New) {
85   Old.replaceAllUsesWith(&New);
86   New.takeName(&Old);
87 }
88 
89 /// Determine which, if any, of the inputs should be replaced by a shuffle
90 /// followed by extract from a different index.
91 ExtractElementInst *VectorCombine::getShuffleExtract(
92     ExtractElementInst *Ext0, ExtractElementInst *Ext1,
93     unsigned PreferredExtractIndex = InvalidIndex) const {
94   assert(isa<ConstantInt>(Ext0->getIndexOperand()) &&
95          isa<ConstantInt>(Ext1->getIndexOperand()) &&
96          "Expected constant extract indexes");
97 
98   unsigned Index0 = cast<ConstantInt>(Ext0->getIndexOperand())->getZExtValue();
99   unsigned Index1 = cast<ConstantInt>(Ext1->getIndexOperand())->getZExtValue();
100 
101   // If the extract indexes are identical, no shuffle is needed.
102   if (Index0 == Index1)
103     return nullptr;
104 
105   Type *VecTy = Ext0->getVectorOperand()->getType();
106   assert(VecTy == Ext1->getVectorOperand()->getType() && "Need matching types");
107   int Cost0 = TTI.getVectorInstrCost(Ext0->getOpcode(), VecTy, Index0);
108   int Cost1 = TTI.getVectorInstrCost(Ext1->getOpcode(), VecTy, Index1);
109 
110   // We are extracting from 2 different indexes, so one operand must be shuffled
111   // before performing a vector operation and/or extract. The more expensive
112   // extract will be replaced by a shuffle.
113   if (Cost0 > Cost1)
114     return Ext0;
115   if (Cost1 > Cost0)
116     return Ext1;
117 
118   // If the costs are equal and there is a preferred extract index, shuffle the
119   // opposite operand.
120   if (PreferredExtractIndex == Index0)
121     return Ext1;
122   if (PreferredExtractIndex == Index1)
123     return Ext0;
124 
125   // Otherwise, replace the extract with the higher index.
126   return Index0 > Index1 ? Ext0 : Ext1;
127 }
128 
129 /// Compare the relative costs of 2 extracts followed by scalar operation vs.
130 /// vector operation(s) followed by extract. Return true if the existing
131 /// instructions are cheaper than a vector alternative. Otherwise, return false
132 /// and if one of the extracts should be transformed to a shufflevector, set
133 /// \p ConvertToShuffle to that extract instruction.
134 bool VectorCombine::isExtractExtractCheap(ExtractElementInst *Ext0,
135                                           ExtractElementInst *Ext1,
136                                           unsigned Opcode,
137                                           ExtractElementInst *&ConvertToShuffle,
138                                           unsigned PreferredExtractIndex) {
139   assert(isa<ConstantInt>(Ext0->getOperand(1)) &&
140          isa<ConstantInt>(Ext1->getOperand(1)) &&
141          "Expected constant extract indexes");
142   Type *ScalarTy = Ext0->getType();
143   auto *VecTy = cast<VectorType>(Ext0->getOperand(0)->getType());
144   int ScalarOpCost, VectorOpCost;
145 
146   // Get cost estimates for scalar and vector versions of the operation.
147   bool IsBinOp = Instruction::isBinaryOp(Opcode);
148   if (IsBinOp) {
149     ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, ScalarTy);
150     VectorOpCost = TTI.getArithmeticInstrCost(Opcode, VecTy);
151   } else {
152     assert((Opcode == Instruction::ICmp || Opcode == Instruction::FCmp) &&
153            "Expected a compare");
154     ScalarOpCost = TTI.getCmpSelInstrCost(Opcode, ScalarTy,
155                                           CmpInst::makeCmpResultType(ScalarTy));
156     VectorOpCost = TTI.getCmpSelInstrCost(Opcode, VecTy,
157                                           CmpInst::makeCmpResultType(VecTy));
158   }
159 
160   // Get cost estimates for the extract elements. These costs will factor into
161   // both sequences.
162   unsigned Ext0Index = cast<ConstantInt>(Ext0->getOperand(1))->getZExtValue();
163   unsigned Ext1Index = cast<ConstantInt>(Ext1->getOperand(1))->getZExtValue();
164 
165   int Extract0Cost =
166       TTI.getVectorInstrCost(Instruction::ExtractElement, VecTy, Ext0Index);
167   int Extract1Cost =
168       TTI.getVectorInstrCost(Instruction::ExtractElement, VecTy, Ext1Index);
169 
170   // A more expensive extract will always be replaced by a splat shuffle.
171   // For example, if Ext0 is more expensive:
172   // opcode (extelt V0, Ext0), (ext V1, Ext1) -->
173   // extelt (opcode (splat V0, Ext0), V1), Ext1
174   // TODO: Evaluate whether that always results in lowest cost. Alternatively,
175   //       check the cost of creating a broadcast shuffle and shuffling both
176   //       operands to element 0.
177   int CheapExtractCost = std::min(Extract0Cost, Extract1Cost);
178 
179   // Extra uses of the extracts mean that we include those costs in the
180   // vector total because those instructions will not be eliminated.
181   int OldCost, NewCost;
182   if (Ext0->getOperand(0) == Ext1->getOperand(0) && Ext0Index == Ext1Index) {
183     // Handle a special case. If the 2 extracts are identical, adjust the
184     // formulas to account for that. The extra use charge allows for either the
185     // CSE'd pattern or an unoptimized form with identical values:
186     // opcode (extelt V, C), (extelt V, C) --> extelt (opcode V, V), C
187     bool HasUseTax = Ext0 == Ext1 ? !Ext0->hasNUses(2)
188                                   : !Ext0->hasOneUse() || !Ext1->hasOneUse();
189     OldCost = CheapExtractCost + ScalarOpCost;
190     NewCost = VectorOpCost + CheapExtractCost + HasUseTax * CheapExtractCost;
191   } else {
192     // Handle the general case. Each extract is actually a different value:
193     // opcode (extelt V0, C0), (extelt V1, C1) --> extelt (opcode V0, V1), C
194     OldCost = Extract0Cost + Extract1Cost + ScalarOpCost;
195     NewCost = VectorOpCost + CheapExtractCost +
196               !Ext0->hasOneUse() * Extract0Cost +
197               !Ext1->hasOneUse() * Extract1Cost;
198   }
199 
200   ConvertToShuffle = getShuffleExtract(Ext0, Ext1, PreferredExtractIndex);
201   if (ConvertToShuffle) {
202     if (IsBinOp && DisableBinopExtractShuffle)
203       return true;
204 
205     // If we are extracting from 2 different indexes, then one operand must be
206     // shuffled before performing the vector operation. The shuffle mask is
207     // undefined except for 1 lane that is being translated to the remaining
208     // extraction lane. Therefore, it is a splat shuffle. Ex:
209     // ShufMask = { undef, undef, 0, undef }
210     // TODO: The cost model has an option for a "broadcast" shuffle
211     //       (splat-from-element-0), but no option for a more general splat.
212     NewCost +=
213         TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, VecTy);
214   }
215 
216   // Aggressively form a vector op if the cost is equal because the transform
217   // may enable further optimization.
218   // Codegen can reverse this transform (scalarize) if it was not profitable.
219   return OldCost < NewCost;
220 }
221 
222 /// Create a shuffle that translates (shifts) 1 element from the input vector
223 /// to a new element location.
224 static Value *createShiftShuffle(Value *Vec, unsigned OldIndex,
225                                  unsigned NewIndex, IRBuilder<> &Builder) {
226   // The shuffle mask is undefined except for 1 lane that is being translated
227   // to the new element index. Example for OldIndex == 2 and NewIndex == 0:
228   // ShufMask = { 2, undef, undef, undef }
229   auto *VecTy = cast<FixedVectorType>(Vec->getType());
230   SmallVector<int, 32> ShufMask(VecTy->getNumElements(), UndefMaskElem);
231   ShufMask[NewIndex] = OldIndex;
232   Value *Undef = UndefValue::get(VecTy);
233   return Builder.CreateShuffleVector(Vec, Undef, ShufMask, "shift");
234 }
235 
236 /// Given an extract element instruction with constant index operand, shuffle
237 /// the source vector (shift the scalar element) to a NewIndex for extraction.
238 /// Return null if the input can be constant folded, so that we are not creating
239 /// unnecessary instructions.
240 static ExtractElementInst *translateExtract(ExtractElementInst *ExtElt,
241                                             unsigned NewIndex,
242                                             IRBuilder<> &Builder) {
243   // If the extract can be constant-folded, this code is unsimplified. Defer
244   // to other passes to handle that.
245   Value *X = ExtElt->getVectorOperand();
246   Value *C = ExtElt->getIndexOperand();
247   assert(isa<ConstantInt>(C) && "Expected a constant index operand");
248   if (isa<Constant>(X))
249     return nullptr;
250 
251   Value *Shuf = createShiftShuffle(X, cast<ConstantInt>(C)->getZExtValue(),
252                                    NewIndex, Builder);
253   return cast<ExtractElementInst>(Builder.CreateExtractElement(Shuf, NewIndex));
254 }
255 
256 /// Try to reduce extract element costs by converting scalar compares to vector
257 /// compares followed by extract.
258 /// cmp (ext0 V0, C), (ext1 V1, C)
259 void VectorCombine::foldExtExtCmp(ExtractElementInst *Ext0,
260                                   ExtractElementInst *Ext1, Instruction &I) {
261   assert(isa<CmpInst>(&I) && "Expected a compare");
262   assert(cast<ConstantInt>(Ext0->getIndexOperand())->getZExtValue() ==
263              cast<ConstantInt>(Ext1->getIndexOperand())->getZExtValue() &&
264          "Expected matching constant extract indexes");
265 
266   // cmp Pred (extelt V0, C), (extelt V1, C) --> extelt (cmp Pred V0, V1), C
267   ++NumVecCmp;
268   CmpInst::Predicate Pred = cast<CmpInst>(&I)->getPredicate();
269   Value *V0 = Ext0->getVectorOperand(), *V1 = Ext1->getVectorOperand();
270   Value *VecCmp = Builder.CreateCmp(Pred, V0, V1);
271   Value *NewExt = Builder.CreateExtractElement(VecCmp, Ext0->getIndexOperand());
272   replaceValue(I, *NewExt);
273 }
274 
275 /// Try to reduce extract element costs by converting scalar binops to vector
276 /// binops followed by extract.
277 /// bo (ext0 V0, C), (ext1 V1, C)
278 void VectorCombine::foldExtExtBinop(ExtractElementInst *Ext0,
279                                     ExtractElementInst *Ext1, Instruction &I) {
280   assert(isa<BinaryOperator>(&I) && "Expected a binary operator");
281   assert(cast<ConstantInt>(Ext0->getIndexOperand())->getZExtValue() ==
282              cast<ConstantInt>(Ext1->getIndexOperand())->getZExtValue() &&
283          "Expected matching constant extract indexes");
284 
285   // bo (extelt V0, C), (extelt V1, C) --> extelt (bo V0, V1), C
286   ++NumVecBO;
287   Value *V0 = Ext0->getVectorOperand(), *V1 = Ext1->getVectorOperand();
288   Value *VecBO =
289       Builder.CreateBinOp(cast<BinaryOperator>(&I)->getOpcode(), V0, V1);
290 
291   // All IR flags are safe to back-propagate because any potential poison
292   // created in unused vector elements is discarded by the extract.
293   if (auto *VecBOInst = dyn_cast<Instruction>(VecBO))
294     VecBOInst->copyIRFlags(&I);
295 
296   Value *NewExt = Builder.CreateExtractElement(VecBO, Ext0->getIndexOperand());
297   replaceValue(I, *NewExt);
298 }
299 
300 /// Match an instruction with extracted vector operands.
301 bool VectorCombine::foldExtractExtract(Instruction &I) {
302   // It is not safe to transform things like div, urem, etc. because we may
303   // create undefined behavior when executing those on unknown vector elements.
304   if (!isSafeToSpeculativelyExecute(&I))
305     return false;
306 
307   Instruction *I0, *I1;
308   CmpInst::Predicate Pred = CmpInst::BAD_ICMP_PREDICATE;
309   if (!match(&I, m_Cmp(Pred, m_Instruction(I0), m_Instruction(I1))) &&
310       !match(&I, m_BinOp(m_Instruction(I0), m_Instruction(I1))))
311     return false;
312 
313   Value *V0, *V1;
314   uint64_t C0, C1;
315   if (!match(I0, m_ExtractElt(m_Value(V0), m_ConstantInt(C0))) ||
316       !match(I1, m_ExtractElt(m_Value(V1), m_ConstantInt(C1))) ||
317       V0->getType() != V1->getType())
318     return false;
319 
320   // If the scalar value 'I' is going to be re-inserted into a vector, then try
321   // to create an extract to that same element. The extract/insert can be
322   // reduced to a "select shuffle".
323   // TODO: If we add a larger pattern match that starts from an insert, this
324   //       probably becomes unnecessary.
325   auto *Ext0 = cast<ExtractElementInst>(I0);
326   auto *Ext1 = cast<ExtractElementInst>(I1);
327   uint64_t InsertIndex = InvalidIndex;
328   if (I.hasOneUse())
329     match(I.user_back(),
330           m_InsertElt(m_Value(), m_Value(), m_ConstantInt(InsertIndex)));
331 
332   ExtractElementInst *ExtractToChange;
333   if (isExtractExtractCheap(Ext0, Ext1, I.getOpcode(), ExtractToChange,
334                             InsertIndex))
335     return false;
336 
337   if (ExtractToChange) {
338     unsigned CheapExtractIdx = ExtractToChange == Ext0 ? C1 : C0;
339     ExtractElementInst *NewExtract =
340         translateExtract(ExtractToChange, CheapExtractIdx, Builder);
341     if (!NewExtract)
342       return false;
343     if (ExtractToChange == Ext0)
344       Ext0 = NewExtract;
345     else
346       Ext1 = NewExtract;
347   }
348 
349   if (Pred != CmpInst::BAD_ICMP_PREDICATE)
350     foldExtExtCmp(Ext0, Ext1, I);
351   else
352     foldExtExtBinop(Ext0, Ext1, I);
353 
354   return true;
355 }
356 
357 /// If this is a bitcast of a shuffle, try to bitcast the source vector to the
358 /// destination type followed by shuffle. This can enable further transforms by
359 /// moving bitcasts or shuffles together.
360 bool VectorCombine::foldBitcastShuf(Instruction &I) {
361   Value *V;
362   ArrayRef<int> Mask;
363   if (!match(&I, m_BitCast(
364                      m_OneUse(m_Shuffle(m_Value(V), m_Undef(), m_Mask(Mask))))))
365     return false;
366 
367   // Disallow non-vector casts and length-changing shuffles.
368   // TODO: We could allow any shuffle.
369   auto *DestTy = dyn_cast<VectorType>(I.getType());
370   auto *SrcTy = cast<VectorType>(V->getType());
371   if (!DestTy || I.getOperand(0)->getType() != SrcTy)
372     return false;
373 
374   // The new shuffle must not cost more than the old shuffle. The bitcast is
375   // moved ahead of the shuffle, so assume that it has the same cost as before.
376   if (TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, DestTy) >
377       TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, SrcTy))
378     return false;
379 
380   unsigned DestNumElts = DestTy->getNumElements();
381   unsigned SrcNumElts = SrcTy->getNumElements();
382   SmallVector<int, 16> NewMask;
383   if (SrcNumElts <= DestNumElts) {
384     // The bitcast is from wide to narrow/equal elements. The shuffle mask can
385     // always be expanded to the equivalent form choosing narrower elements.
386     assert(DestNumElts % SrcNumElts == 0 && "Unexpected shuffle mask");
387     unsigned ScaleFactor = DestNumElts / SrcNumElts;
388     narrowShuffleMaskElts(ScaleFactor, Mask, NewMask);
389   } else {
390     // The bitcast is from narrow elements to wide elements. The shuffle mask
391     // must choose consecutive elements to allow casting first.
392     assert(SrcNumElts % DestNumElts == 0 && "Unexpected shuffle mask");
393     unsigned ScaleFactor = SrcNumElts / DestNumElts;
394     if (!widenShuffleMaskElts(ScaleFactor, Mask, NewMask))
395       return false;
396   }
397   // bitcast (shuf V, MaskC) --> shuf (bitcast V), MaskC'
398   ++NumShufOfBitcast;
399   Value *CastV = Builder.CreateBitCast(V, DestTy);
400   Value *Shuf =
401       Builder.CreateShuffleVector(CastV, UndefValue::get(DestTy), NewMask);
402   replaceValue(I, *Shuf);
403   return true;
404 }
405 
406 /// Match a vector binop or compare instruction with at least one inserted
407 /// scalar operand and convert to scalar binop/cmp followed by insertelement.
408 bool VectorCombine::scalarizeBinopOrCmp(Instruction &I) {
409   CmpInst::Predicate Pred = CmpInst::BAD_ICMP_PREDICATE;
410   Value *Ins0, *Ins1;
411   if (!match(&I, m_BinOp(m_Value(Ins0), m_Value(Ins1))) &&
412       !match(&I, m_Cmp(Pred, m_Value(Ins0), m_Value(Ins1))))
413     return false;
414 
415   // Do not convert the vector condition of a vector select into a scalar
416   // condition. That may cause problems for codegen because of differences in
417   // boolean formats and register-file transfers.
418   // TODO: Can we account for that in the cost model?
419   bool IsCmp = Pred != CmpInst::Predicate::BAD_ICMP_PREDICATE;
420   if (IsCmp)
421     for (User *U : I.users())
422       if (match(U, m_Select(m_Specific(&I), m_Value(), m_Value())))
423         return false;
424 
425   // Match against one or both scalar values being inserted into constant
426   // vectors:
427   // vec_op VecC0, (inselt VecC1, V1, Index)
428   // vec_op (inselt VecC0, V0, Index), VecC1
429   // vec_op (inselt VecC0, V0, Index), (inselt VecC1, V1, Index)
430   // TODO: Deal with mismatched index constants and variable indexes?
431   Constant *VecC0 = nullptr, *VecC1 = nullptr;
432   Value *V0 = nullptr, *V1 = nullptr;
433   uint64_t Index0 = 0, Index1 = 0;
434   if (!match(Ins0, m_InsertElt(m_Constant(VecC0), m_Value(V0),
435                                m_ConstantInt(Index0))) &&
436       !match(Ins0, m_Constant(VecC0)))
437     return false;
438   if (!match(Ins1, m_InsertElt(m_Constant(VecC1), m_Value(V1),
439                                m_ConstantInt(Index1))) &&
440       !match(Ins1, m_Constant(VecC1)))
441     return false;
442 
443   bool IsConst0 = !V0;
444   bool IsConst1 = !V1;
445   if (IsConst0 && IsConst1)
446     return false;
447   if (!IsConst0 && !IsConst1 && Index0 != Index1)
448     return false;
449 
450   // Bail for single insertion if it is a load.
451   // TODO: Handle this once getVectorInstrCost can cost for load/stores.
452   auto *I0 = dyn_cast_or_null<Instruction>(V0);
453   auto *I1 = dyn_cast_or_null<Instruction>(V1);
454   if ((IsConst0 && I1 && I1->mayReadFromMemory()) ||
455       (IsConst1 && I0 && I0->mayReadFromMemory()))
456     return false;
457 
458   uint64_t Index = IsConst0 ? Index1 : Index0;
459   Type *ScalarTy = IsConst0 ? V1->getType() : V0->getType();
460   Type *VecTy = I.getType();
461   assert(VecTy->isVectorTy() &&
462          (IsConst0 || IsConst1 || V0->getType() == V1->getType()) &&
463          (ScalarTy->isIntegerTy() || ScalarTy->isFloatingPointTy() ||
464           ScalarTy->isPointerTy()) &&
465          "Unexpected types for insert element into binop or cmp");
466 
467   unsigned Opcode = I.getOpcode();
468   int ScalarOpCost, VectorOpCost;
469   if (IsCmp) {
470     ScalarOpCost = TTI.getCmpSelInstrCost(Opcode, ScalarTy);
471     VectorOpCost = TTI.getCmpSelInstrCost(Opcode, VecTy);
472   } else {
473     ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, ScalarTy);
474     VectorOpCost = TTI.getArithmeticInstrCost(Opcode, VecTy);
475   }
476 
477   // Get cost estimate for the insert element. This cost will factor into
478   // both sequences.
479   int InsertCost =
480       TTI.getVectorInstrCost(Instruction::InsertElement, VecTy, Index);
481   int OldCost = (IsConst0 ? 0 : InsertCost) + (IsConst1 ? 0 : InsertCost) +
482                 VectorOpCost;
483   int NewCost = ScalarOpCost + InsertCost +
484                 (IsConst0 ? 0 : !Ins0->hasOneUse() * InsertCost) +
485                 (IsConst1 ? 0 : !Ins1->hasOneUse() * InsertCost);
486 
487   // We want to scalarize unless the vector variant actually has lower cost.
488   if (OldCost < NewCost)
489     return false;
490 
491   // vec_op (inselt VecC0, V0, Index), (inselt VecC1, V1, Index) -->
492   // inselt NewVecC, (scalar_op V0, V1), Index
493   if (IsCmp)
494     ++NumScalarCmp;
495   else
496     ++NumScalarBO;
497 
498   // For constant cases, extract the scalar element, this should constant fold.
499   if (IsConst0)
500     V0 = ConstantExpr::getExtractElement(VecC0, Builder.getInt64(Index));
501   if (IsConst1)
502     V1 = ConstantExpr::getExtractElement(VecC1, Builder.getInt64(Index));
503 
504   Value *Scalar =
505       IsCmp ? Builder.CreateCmp(Pred, V0, V1)
506             : Builder.CreateBinOp((Instruction::BinaryOps)Opcode, V0, V1);
507 
508   Scalar->setName(I.getName() + ".scalar");
509 
510   // All IR flags are safe to back-propagate. There is no potential for extra
511   // poison to be created by the scalar instruction.
512   if (auto *ScalarInst = dyn_cast<Instruction>(Scalar))
513     ScalarInst->copyIRFlags(&I);
514 
515   // Fold the vector constants in the original vectors into a new base vector.
516   Constant *NewVecC = IsCmp ? ConstantExpr::getCompare(Pred, VecC0, VecC1)
517                             : ConstantExpr::get(Opcode, VecC0, VecC1);
518   Value *Insert = Builder.CreateInsertElement(NewVecC, Scalar, Index);
519   replaceValue(I, *Insert);
520   return true;
521 }
522 
523 /// Try to combine a scalar binop + 2 scalar compares of extracted elements of
524 /// a vector into vector operations followed by extract. Note: The SLP pass
525 /// may miss this pattern because of implementation problems.
526 bool VectorCombine::foldExtractedCmps(Instruction &I) {
527   // We are looking for a scalar binop of booleans.
528   // binop i1 (cmp Pred I0, C0), (cmp Pred I1, C1)
529   if (!I.isBinaryOp() || !I.getType()->isIntegerTy(1))
530     return false;
531 
532   // The compare predicates should match, and each compare should have a
533   // constant operand.
534   // TODO: Relax the one-use constraints.
535   Value *B0 = I.getOperand(0), *B1 = I.getOperand(1);
536   Instruction *I0, *I1;
537   Constant *C0, *C1;
538   CmpInst::Predicate P0, P1;
539   if (!match(B0, m_OneUse(m_Cmp(P0, m_Instruction(I0), m_Constant(C0)))) ||
540       !match(B1, m_OneUse(m_Cmp(P1, m_Instruction(I1), m_Constant(C1)))) ||
541       P0 != P1)
542     return false;
543 
544   // The compare operands must be extracts of the same vector with constant
545   // extract indexes.
546   // TODO: Relax the one-use constraints.
547   Value *X;
548   uint64_t Index0, Index1;
549   if (!match(I0, m_OneUse(m_ExtractElt(m_Value(X), m_ConstantInt(Index0)))) ||
550       !match(I1, m_OneUse(m_ExtractElt(m_Specific(X), m_ConstantInt(Index1)))))
551     return false;
552 
553   auto *Ext0 = cast<ExtractElementInst>(I0);
554   auto *Ext1 = cast<ExtractElementInst>(I1);
555   ExtractElementInst *ConvertToShuf = getShuffleExtract(Ext0, Ext1);
556   if (!ConvertToShuf)
557     return false;
558 
559   // The original scalar pattern is:
560   // binop i1 (cmp Pred (ext X, Index0), C0), (cmp Pred (ext X, Index1), C1)
561   CmpInst::Predicate Pred = P0;
562   unsigned CmpOpcode = CmpInst::isFPPredicate(Pred) ? Instruction::FCmp
563                                                     : Instruction::ICmp;
564   auto *VecTy = dyn_cast<FixedVectorType>(X->getType());
565   if (!VecTy)
566     return false;
567 
568   int OldCost = TTI.getVectorInstrCost(Ext0->getOpcode(), VecTy, Index0);
569   OldCost += TTI.getVectorInstrCost(Ext1->getOpcode(), VecTy, Index1);
570   OldCost += TTI.getCmpSelInstrCost(CmpOpcode, I0->getType()) * 2;
571   OldCost += TTI.getArithmeticInstrCost(I.getOpcode(), I.getType());
572 
573   // The proposed vector pattern is:
574   // vcmp = cmp Pred X, VecC
575   // ext (binop vNi1 vcmp, (shuffle vcmp, Index1)), Index0
576   int CheapIndex = ConvertToShuf == Ext0 ? Index1 : Index0;
577   int ExpensiveIndex = ConvertToShuf == Ext0 ? Index0 : Index1;
578   auto *CmpTy = cast<FixedVectorType>(CmpInst::makeCmpResultType(X->getType()));
579   int NewCost = TTI.getCmpSelInstrCost(CmpOpcode, X->getType());
580   NewCost +=
581       TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, CmpTy);
582   NewCost += TTI.getArithmeticInstrCost(I.getOpcode(), CmpTy);
583   NewCost += TTI.getVectorInstrCost(Ext0->getOpcode(), CmpTy, CheapIndex);
584 
585   // Aggressively form vector ops if the cost is equal because the transform
586   // may enable further optimization.
587   // Codegen can reverse this transform (scalarize) if it was not profitable.
588   if (OldCost < NewCost)
589     return false;
590 
591   // Create a vector constant from the 2 scalar constants.
592   SmallVector<Constant *, 32> CmpC(VecTy->getNumElements(),
593                                    UndefValue::get(VecTy->getElementType()));
594   CmpC[Index0] = C0;
595   CmpC[Index1] = C1;
596   Value *VCmp = Builder.CreateCmp(Pred, X, ConstantVector::get(CmpC));
597 
598   Value *Shuf = createShiftShuffle(VCmp, ExpensiveIndex, CheapIndex, Builder);
599   Value *VecLogic = Builder.CreateBinOp(cast<BinaryOperator>(I).getOpcode(),
600                                         VCmp, Shuf);
601   Value *NewExt = Builder.CreateExtractElement(VecLogic, CheapIndex);
602   replaceValue(I, *NewExt);
603   ++NumVecCmpBO;
604   return true;
605 }
606 
607 /// This is the entry point for all transforms. Pass manager differences are
608 /// handled in the callers of this function.
609 bool VectorCombine::run() {
610   if (DisableVectorCombine)
611     return false;
612 
613   bool MadeChange = false;
614   for (BasicBlock &BB : F) {
615     // Ignore unreachable basic blocks.
616     if (!DT.isReachableFromEntry(&BB))
617       continue;
618     // Do not delete instructions under here and invalidate the iterator.
619     // Walk the block forwards to enable simple iterative chains of transforms.
620     // TODO: It could be more efficient to remove dead instructions
621     //       iteratively in this loop rather than waiting until the end.
622     for (Instruction &I : BB) {
623       if (isa<DbgInfoIntrinsic>(I))
624         continue;
625       Builder.SetInsertPoint(&I);
626       MadeChange |= foldExtractExtract(I);
627       MadeChange |= foldBitcastShuf(I);
628       MadeChange |= scalarizeBinopOrCmp(I);
629       MadeChange |= foldExtractedCmps(I);
630     }
631   }
632 
633   // We're done with transforms, so remove dead instructions.
634   if (MadeChange)
635     for (BasicBlock &BB : F)
636       SimplifyInstructionsInBlock(&BB);
637 
638   return MadeChange;
639 }
640 
641 // Pass manager boilerplate below here.
642 
643 namespace {
644 class VectorCombineLegacyPass : public FunctionPass {
645 public:
646   static char ID;
647   VectorCombineLegacyPass() : FunctionPass(ID) {
648     initializeVectorCombineLegacyPassPass(*PassRegistry::getPassRegistry());
649   }
650 
651   void getAnalysisUsage(AnalysisUsage &AU) const override {
652     AU.addRequired<DominatorTreeWrapperPass>();
653     AU.addRequired<TargetTransformInfoWrapperPass>();
654     AU.setPreservesCFG();
655     AU.addPreserved<DominatorTreeWrapperPass>();
656     AU.addPreserved<GlobalsAAWrapperPass>();
657     AU.addPreserved<AAResultsWrapperPass>();
658     AU.addPreserved<BasicAAWrapperPass>();
659     FunctionPass::getAnalysisUsage(AU);
660   }
661 
662   bool runOnFunction(Function &F) override {
663     if (skipFunction(F))
664       return false;
665     auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
666     auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
667     VectorCombine Combiner(F, TTI, DT);
668     return Combiner.run();
669   }
670 };
671 } // namespace
672 
673 char VectorCombineLegacyPass::ID = 0;
674 INITIALIZE_PASS_BEGIN(VectorCombineLegacyPass, "vector-combine",
675                       "Optimize scalar/vector ops", false,
676                       false)
677 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
678 INITIALIZE_PASS_END(VectorCombineLegacyPass, "vector-combine",
679                     "Optimize scalar/vector ops", false, false)
680 Pass *llvm::createVectorCombinePass() {
681   return new VectorCombineLegacyPass();
682 }
683 
684 PreservedAnalyses VectorCombinePass::run(Function &F,
685                                          FunctionAnalysisManager &FAM) {
686   TargetTransformInfo &TTI = FAM.getResult<TargetIRAnalysis>(F);
687   DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F);
688   VectorCombine Combiner(F, TTI, DT);
689   if (!Combiner.run())
690     return PreservedAnalyses::all();
691   PreservedAnalyses PA;
692   PA.preserveSet<CFGAnalyses>();
693   PA.preserve<GlobalsAA>();
694   PA.preserve<AAManager>();
695   PA.preserve<BasicAA>();
696   return PA;
697 }
698