1a17f03bdSSanjay Patel //===------- VectorCombine.cpp - Optimize partial vector operations -------===//
2a17f03bdSSanjay Patel //
3a17f03bdSSanjay Patel // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4a17f03bdSSanjay Patel // See https://llvm.org/LICENSE.txt for license information.
5a17f03bdSSanjay Patel // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6a17f03bdSSanjay Patel //
7a17f03bdSSanjay Patel //===----------------------------------------------------------------------===//
8a17f03bdSSanjay Patel //
9a17f03bdSSanjay Patel // This pass optimizes scalar/vector interactions using target cost models. The
10a17f03bdSSanjay Patel // transforms implemented here may not fit in traditional loop-based or SLP
11a17f03bdSSanjay Patel // vectorization passes.
12a17f03bdSSanjay Patel //
13a17f03bdSSanjay Patel //===----------------------------------------------------------------------===//
14a17f03bdSSanjay Patel 
15a17f03bdSSanjay Patel #include "llvm/Transforms/Vectorize/VectorCombine.h"
16a17f03bdSSanjay Patel #include "llvm/ADT/Statistic.h"
17a17f03bdSSanjay Patel #include "llvm/Analysis/GlobalsModRef.h"
18a17f03bdSSanjay Patel #include "llvm/Analysis/TargetTransformInfo.h"
1919b62b79SSanjay Patel #include "llvm/Analysis/ValueTracking.h"
20b6050ca1SSanjay Patel #include "llvm/Analysis/VectorUtils.h"
21a17f03bdSSanjay Patel #include "llvm/IR/Dominators.h"
22a17f03bdSSanjay Patel #include "llvm/IR/Function.h"
23a17f03bdSSanjay Patel #include "llvm/IR/IRBuilder.h"
24a17f03bdSSanjay Patel #include "llvm/IR/PatternMatch.h"
25a17f03bdSSanjay Patel #include "llvm/InitializePasses.h"
26a17f03bdSSanjay Patel #include "llvm/Pass.h"
2725c6544fSSanjay Patel #include "llvm/Support/CommandLine.h"
28a17f03bdSSanjay Patel #include "llvm/Transforms/Vectorize.h"
29a17f03bdSSanjay Patel #include "llvm/Transforms/Utils/Local.h"
30a17f03bdSSanjay Patel 
31a17f03bdSSanjay Patel using namespace llvm;
32a17f03bdSSanjay Patel using namespace llvm::PatternMatch;
33a17f03bdSSanjay Patel 
34a17f03bdSSanjay Patel #define DEBUG_TYPE "vector-combine"
35a17f03bdSSanjay Patel STATISTIC(NumVecCmp, "Number of vector compares formed");
3619b62b79SSanjay Patel STATISTIC(NumVecBO, "Number of vector binops formed");
370d2a0b44SSanjay Patel STATISTIC(NumScalarBO, "Number of scalar binops formed");
38a17f03bdSSanjay Patel 
3925c6544fSSanjay Patel static cl::opt<bool> DisableVectorCombine(
4025c6544fSSanjay Patel     "disable-vector-combine", cl::init(false), cl::Hidden,
4125c6544fSSanjay Patel     cl::desc("Disable all vector combine transforms"));
4225c6544fSSanjay Patel 
43a69158c1SSanjay Patel static cl::opt<bool> DisableBinopExtractShuffle(
44a69158c1SSanjay Patel     "disable-binop-extract-shuffle", cl::init(false), cl::Hidden,
45a69158c1SSanjay Patel     cl::desc("Disable binop extract to shuffle transforms"));
46a69158c1SSanjay Patel 
47a69158c1SSanjay Patel 
48a69158c1SSanjay Patel /// Compare the relative costs of 2 extracts followed by scalar operation vs.
49a69158c1SSanjay Patel /// vector operation(s) followed by extract. Return true if the existing
50a69158c1SSanjay Patel /// instructions are cheaper than a vector alternative. Otherwise, return false
51a69158c1SSanjay Patel /// and if one of the extracts should be transformed to a shufflevector, set
52a69158c1SSanjay Patel /// \p ConvertToShuffle to that extract instruction.
5334e34855SSanjay Patel static bool isExtractExtractCheap(Instruction *Ext0, Instruction *Ext1,
5434e34855SSanjay Patel                                   unsigned Opcode,
55a69158c1SSanjay Patel                                   const TargetTransformInfo &TTI,
56ce97ce3aSSanjay Patel                                   Instruction *&ConvertToShuffle,
57ce97ce3aSSanjay Patel                                   unsigned PreferredExtractIndex) {
584fa63fd4SAustin Kerbow   assert(isa<ConstantInt>(Ext0->getOperand(1)) &&
59a69158c1SSanjay Patel          isa<ConstantInt>(Ext1->getOperand(1)) &&
60a69158c1SSanjay Patel          "Expected constant extract indexes");
6134e34855SSanjay Patel   Type *ScalarTy = Ext0->getType();
62e3056ae9SSam Parker   auto *VecTy = cast<VectorType>(Ext0->getOperand(0)->getType());
6334e34855SSanjay Patel   int ScalarOpCost, VectorOpCost;
6434e34855SSanjay Patel 
6534e34855SSanjay Patel   // Get cost estimates for scalar and vector versions of the operation.
6634e34855SSanjay Patel   bool IsBinOp = Instruction::isBinaryOp(Opcode);
6734e34855SSanjay Patel   if (IsBinOp) {
6834e34855SSanjay Patel     ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, ScalarTy);
6934e34855SSanjay Patel     VectorOpCost = TTI.getArithmeticInstrCost(Opcode, VecTy);
7034e34855SSanjay Patel   } else {
7134e34855SSanjay Patel     assert((Opcode == Instruction::ICmp || Opcode == Instruction::FCmp) &&
7234e34855SSanjay Patel            "Expected a compare");
7334e34855SSanjay Patel     ScalarOpCost = TTI.getCmpSelInstrCost(Opcode, ScalarTy,
7434e34855SSanjay Patel                                           CmpInst::makeCmpResultType(ScalarTy));
7534e34855SSanjay Patel     VectorOpCost = TTI.getCmpSelInstrCost(Opcode, VecTy,
7634e34855SSanjay Patel                                           CmpInst::makeCmpResultType(VecTy));
7734e34855SSanjay Patel   }
7834e34855SSanjay Patel 
79a69158c1SSanjay Patel   // Get cost estimates for the extract elements. These costs will factor into
8034e34855SSanjay Patel   // both sequences.
81a69158c1SSanjay Patel   unsigned Ext0Index = cast<ConstantInt>(Ext0->getOperand(1))->getZExtValue();
82a69158c1SSanjay Patel   unsigned Ext1Index = cast<ConstantInt>(Ext1->getOperand(1))->getZExtValue();
83a69158c1SSanjay Patel 
84a69158c1SSanjay Patel   int Extract0Cost = TTI.getVectorInstrCost(Instruction::ExtractElement,
85a69158c1SSanjay Patel                                             VecTy, Ext0Index);
86a69158c1SSanjay Patel   int Extract1Cost = TTI.getVectorInstrCost(Instruction::ExtractElement,
87a69158c1SSanjay Patel                                             VecTy, Ext1Index);
88a69158c1SSanjay Patel 
89a69158c1SSanjay Patel   // A more expensive extract will always be replaced by a splat shuffle.
90a69158c1SSanjay Patel   // For example, if Ext0 is more expensive:
91a69158c1SSanjay Patel   // opcode (extelt V0, Ext0), (ext V1, Ext1) -->
92a69158c1SSanjay Patel   // extelt (opcode (splat V0, Ext0), V1), Ext1
93a69158c1SSanjay Patel   // TODO: Evaluate whether that always results in lowest cost. Alternatively,
94a69158c1SSanjay Patel   //       check the cost of creating a broadcast shuffle and shuffling both
95a69158c1SSanjay Patel   //       operands to element 0.
96a69158c1SSanjay Patel   int CheapExtractCost = std::min(Extract0Cost, Extract1Cost);
9734e34855SSanjay Patel 
9834e34855SSanjay Patel   // Extra uses of the extracts mean that we include those costs in the
9934e34855SSanjay Patel   // vector total because those instructions will not be eliminated.
100e9c79a7aSSanjay Patel   int OldCost, NewCost;
101a69158c1SSanjay Patel   if (Ext0->getOperand(0) == Ext1->getOperand(0) && Ext0Index == Ext1Index) {
102a69158c1SSanjay Patel     // Handle a special case. If the 2 extracts are identical, adjust the
10334e34855SSanjay Patel     // formulas to account for that. The extra use charge allows for either the
10434e34855SSanjay Patel     // CSE'd pattern or an unoptimized form with identical values:
10534e34855SSanjay Patel     // opcode (extelt V, C), (extelt V, C) --> extelt (opcode V, V), C
10634e34855SSanjay Patel     bool HasUseTax = Ext0 == Ext1 ? !Ext0->hasNUses(2)
10734e34855SSanjay Patel                                   : !Ext0->hasOneUse() || !Ext1->hasOneUse();
108a69158c1SSanjay Patel     OldCost = CheapExtractCost + ScalarOpCost;
109a69158c1SSanjay Patel     NewCost = VectorOpCost + CheapExtractCost + HasUseTax * CheapExtractCost;
11034e34855SSanjay Patel   } else {
11134e34855SSanjay Patel     // Handle the general case. Each extract is actually a different value:
112a69158c1SSanjay Patel     // opcode (extelt V0, C0), (extelt V1, C1) --> extelt (opcode V0, V1), C
113a69158c1SSanjay Patel     OldCost = Extract0Cost + Extract1Cost + ScalarOpCost;
114a69158c1SSanjay Patel     NewCost = VectorOpCost + CheapExtractCost +
115a69158c1SSanjay Patel               !Ext0->hasOneUse() * Extract0Cost +
116a69158c1SSanjay Patel               !Ext1->hasOneUse() * Extract1Cost;
11734e34855SSanjay Patel   }
118a69158c1SSanjay Patel 
119a69158c1SSanjay Patel   if (Ext0Index == Ext1Index) {
120a69158c1SSanjay Patel     // If the extract indexes are identical, no shuffle is needed.
121a69158c1SSanjay Patel     ConvertToShuffle = nullptr;
122a69158c1SSanjay Patel   } else {
123a69158c1SSanjay Patel     if (IsBinOp && DisableBinopExtractShuffle)
124a69158c1SSanjay Patel       return true;
125a69158c1SSanjay Patel 
126a69158c1SSanjay Patel     // If we are extracting from 2 different indexes, then one operand must be
127a69158c1SSanjay Patel     // shuffled before performing the vector operation. The shuffle mask is
128a69158c1SSanjay Patel     // undefined except for 1 lane that is being translated to the remaining
129a69158c1SSanjay Patel     // extraction lane. Therefore, it is a splat shuffle. Ex:
130a69158c1SSanjay Patel     // ShufMask = { undef, undef, 0, undef }
131a69158c1SSanjay Patel     // TODO: The cost model has an option for a "broadcast" shuffle
132a69158c1SSanjay Patel     //       (splat-from-element-0), but no option for a more general splat.
133a69158c1SSanjay Patel     NewCost +=
134a69158c1SSanjay Patel         TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, VecTy);
135a69158c1SSanjay Patel 
136ce97ce3aSSanjay Patel     // The more expensive extract will be replaced by a shuffle. If the costs
137ce97ce3aSSanjay Patel     // are equal and there is a preferred extract index, shuffle the opposite
138ce97ce3aSSanjay Patel     // operand. Otherwise, replace the extract with the higher index.
139a69158c1SSanjay Patel     if (Extract0Cost > Extract1Cost)
140a69158c1SSanjay Patel       ConvertToShuffle = Ext0;
141a69158c1SSanjay Patel     else if (Extract1Cost > Extract0Cost)
142a69158c1SSanjay Patel       ConvertToShuffle = Ext1;
143ce97ce3aSSanjay Patel     else if (PreferredExtractIndex == Ext0Index)
144ce97ce3aSSanjay Patel       ConvertToShuffle = Ext1;
145ce97ce3aSSanjay Patel     else if (PreferredExtractIndex == Ext1Index)
146ce97ce3aSSanjay Patel       ConvertToShuffle = Ext0;
147a69158c1SSanjay Patel     else
148a69158c1SSanjay Patel       ConvertToShuffle = Ext0Index > Ext1Index ? Ext0 : Ext1;
149a69158c1SSanjay Patel   }
150a69158c1SSanjay Patel 
15110ea01d8SSanjay Patel   // Aggressively form a vector op if the cost is equal because the transform
15210ea01d8SSanjay Patel   // may enable further optimization.
15310ea01d8SSanjay Patel   // Codegen can reverse this transform (scalarize) if it was not profitable.
15410ea01d8SSanjay Patel   return OldCost < NewCost;
15534e34855SSanjay Patel }
15634e34855SSanjay Patel 
157fc445589SSanjay Patel /// Try to reduce extract element costs by converting scalar compares to vector
158fc445589SSanjay Patel /// compares followed by extract.
159e9c79a7aSSanjay Patel /// cmp (ext0 V0, C), (ext1 V1, C)
160e9c79a7aSSanjay Patel static void foldExtExtCmp(Instruction *Ext0, Instruction *Ext1,
161fc445589SSanjay Patel                           Instruction &I, const TargetTransformInfo &TTI) {
162fc445589SSanjay Patel   assert(isa<CmpInst>(&I) && "Expected a compare");
163a17f03bdSSanjay Patel 
164a17f03bdSSanjay Patel   // cmp Pred (extelt V0, C), (extelt V1, C) --> extelt (cmp Pred V0, V1), C
165a17f03bdSSanjay Patel   ++NumVecCmp;
166a17f03bdSSanjay Patel   IRBuilder<> Builder(&I);
167fc445589SSanjay Patel   CmpInst::Predicate Pred = cast<CmpInst>(&I)->getPredicate();
168e9c79a7aSSanjay Patel   Value *V0 = Ext0->getOperand(0), *V1 = Ext1->getOperand(0);
16934e34855SSanjay Patel   Value *VecCmp =
17034e34855SSanjay Patel       Ext0->getType()->isFloatingPointTy() ? Builder.CreateFCmp(Pred, V0, V1)
171a17f03bdSSanjay Patel                                            : Builder.CreateICmp(Pred, V0, V1);
172fc445589SSanjay Patel   Value *Extract = Builder.CreateExtractElement(VecCmp, Ext0->getOperand(1));
173fc445589SSanjay Patel   I.replaceAllUsesWith(Extract);
174a17f03bdSSanjay Patel }
175a17f03bdSSanjay Patel 
17619b62b79SSanjay Patel /// Try to reduce extract element costs by converting scalar binops to vector
17719b62b79SSanjay Patel /// binops followed by extract.
178e9c79a7aSSanjay Patel /// bo (ext0 V0, C), (ext1 V1, C)
179e9c79a7aSSanjay Patel static void foldExtExtBinop(Instruction *Ext0, Instruction *Ext1,
180fc445589SSanjay Patel                             Instruction &I, const TargetTransformInfo &TTI) {
181fc445589SSanjay Patel   assert(isa<BinaryOperator>(&I) && "Expected a binary operator");
18219b62b79SSanjay Patel 
18334e34855SSanjay Patel   // bo (extelt V0, C), (extelt V1, C) --> extelt (bo V0, V1), C
18419b62b79SSanjay Patel   ++NumVecBO;
18519b62b79SSanjay Patel   IRBuilder<> Builder(&I);
186e9c79a7aSSanjay Patel   Value *V0 = Ext0->getOperand(0), *V1 = Ext1->getOperand(0);
187e9c79a7aSSanjay Patel   Value *VecBO =
18834e34855SSanjay Patel       Builder.CreateBinOp(cast<BinaryOperator>(&I)->getOpcode(), V0, V1);
189e9c79a7aSSanjay Patel 
19019b62b79SSanjay Patel   // All IR flags are safe to back-propagate because any potential poison
19119b62b79SSanjay Patel   // created in unused vector elements is discarded by the extract.
192e9c79a7aSSanjay Patel   if (auto *VecBOInst = dyn_cast<Instruction>(VecBO))
19319b62b79SSanjay Patel     VecBOInst->copyIRFlags(&I);
194e9c79a7aSSanjay Patel 
195e9c79a7aSSanjay Patel   Value *Extract = Builder.CreateExtractElement(VecBO, Ext0->getOperand(1));
19619b62b79SSanjay Patel   I.replaceAllUsesWith(Extract);
19719b62b79SSanjay Patel }
19819b62b79SSanjay Patel 
199fc445589SSanjay Patel /// Match an instruction with extracted vector operands.
200fc445589SSanjay Patel static bool foldExtractExtract(Instruction &I, const TargetTransformInfo &TTI) {
201e9c79a7aSSanjay Patel   // It is not safe to transform things like div, urem, etc. because we may
202e9c79a7aSSanjay Patel   // create undefined behavior when executing those on unknown vector elements.
203e9c79a7aSSanjay Patel   if (!isSafeToSpeculativelyExecute(&I))
204e9c79a7aSSanjay Patel     return false;
205e9c79a7aSSanjay Patel 
206fc445589SSanjay Patel   Instruction *Ext0, *Ext1;
207fc445589SSanjay Patel   CmpInst::Predicate Pred = CmpInst::BAD_ICMP_PREDICATE;
208fc445589SSanjay Patel   if (!match(&I, m_Cmp(Pred, m_Instruction(Ext0), m_Instruction(Ext1))) &&
209fc445589SSanjay Patel       !match(&I, m_BinOp(m_Instruction(Ext0), m_Instruction(Ext1))))
210fc445589SSanjay Patel     return false;
211fc445589SSanjay Patel 
212fc445589SSanjay Patel   Value *V0, *V1;
213fc445589SSanjay Patel   uint64_t C0, C1;
214fc445589SSanjay Patel   if (!match(Ext0, m_ExtractElement(m_Value(V0), m_ConstantInt(C0))) ||
215fc445589SSanjay Patel       !match(Ext1, m_ExtractElement(m_Value(V1), m_ConstantInt(C1))) ||
216fc445589SSanjay Patel       V0->getType() != V1->getType())
217fc445589SSanjay Patel     return false;
218fc445589SSanjay Patel 
219ce97ce3aSSanjay Patel   // If the scalar value 'I' is going to be re-inserted into a vector, then try
220ce97ce3aSSanjay Patel   // to create an extract to that same element. The extract/insert can be
221ce97ce3aSSanjay Patel   // reduced to a "select shuffle".
222ce97ce3aSSanjay Patel   // TODO: If we add a larger pattern match that starts from an insert, this
223ce97ce3aSSanjay Patel   //       probably becomes unnecessary.
224ce97ce3aSSanjay Patel   uint64_t InsertIndex = std::numeric_limits<uint64_t>::max();
225ce97ce3aSSanjay Patel   if (I.hasOneUse())
226ce97ce3aSSanjay Patel     match(I.user_back(), m_InsertElement(m_Value(), m_Value(),
227ce97ce3aSSanjay Patel                                          m_ConstantInt(InsertIndex)));
228ce97ce3aSSanjay Patel 
229a69158c1SSanjay Patel   Instruction *ConvertToShuffle;
230ce97ce3aSSanjay Patel   if (isExtractExtractCheap(Ext0, Ext1, I.getOpcode(), TTI, ConvertToShuffle,
231ce97ce3aSSanjay Patel                             InsertIndex))
232fc445589SSanjay Patel     return false;
233e9c79a7aSSanjay Patel 
234a69158c1SSanjay Patel   if (ConvertToShuffle) {
235a69158c1SSanjay Patel     // The shuffle mask is undefined except for 1 lane that is being translated
236a69158c1SSanjay Patel     // to the cheap extraction lane. Example:
237a69158c1SSanjay Patel     // ShufMask = { 2, undef, undef, undef }
238a69158c1SSanjay Patel     uint64_t SplatIndex = ConvertToShuffle == Ext0 ? C0 : C1;
239a69158c1SSanjay Patel     uint64_t CheapExtIndex = ConvertToShuffle == Ext0 ? C1 : C0;
2403297e9b7SChristopher Tetreault     auto *VecTy = cast<VectorType>(V0->getType());
2416f64dacaSBenjamin Kramer     SmallVector<int, 32> ShufMask(VecTy->getNumElements(), -1);
2426f64dacaSBenjamin Kramer     ShufMask[CheapExtIndex] = SplatIndex;
243a69158c1SSanjay Patel     IRBuilder<> Builder(ConvertToShuffle);
244a69158c1SSanjay Patel 
245a69158c1SSanjay Patel     // extelt X, C --> extelt (splat X), C'
246a69158c1SSanjay Patel     Value *Shuf = Builder.CreateShuffleVector(ConvertToShuffle->getOperand(0),
2476f64dacaSBenjamin Kramer                                               UndefValue::get(VecTy), ShufMask);
248a69158c1SSanjay Patel     Value *NewExt = Builder.CreateExtractElement(Shuf, CheapExtIndex);
249a69158c1SSanjay Patel     if (ConvertToShuffle == Ext0)
250a69158c1SSanjay Patel       Ext0 = cast<Instruction>(NewExt);
251a69158c1SSanjay Patel     else
252a69158c1SSanjay Patel       Ext1 = cast<Instruction>(NewExt);
253a69158c1SSanjay Patel   }
254e9c79a7aSSanjay Patel 
255e9c79a7aSSanjay Patel   if (Pred != CmpInst::BAD_ICMP_PREDICATE)
256e9c79a7aSSanjay Patel     foldExtExtCmp(Ext0, Ext1, I, TTI);
257e9c79a7aSSanjay Patel   else
258e9c79a7aSSanjay Patel     foldExtExtBinop(Ext0, Ext1, I, TTI);
259e9c79a7aSSanjay Patel 
260e9c79a7aSSanjay Patel   return true;
261fc445589SSanjay Patel }
262fc445589SSanjay Patel 
263bef6e67eSSanjay Patel /// If this is a bitcast of a shuffle, try to bitcast the source vector to the
264bef6e67eSSanjay Patel /// destination type followed by shuffle. This can enable further transforms by
265bef6e67eSSanjay Patel /// moving bitcasts or shuffles together.
266b6050ca1SSanjay Patel static bool foldBitcastShuf(Instruction &I, const TargetTransformInfo &TTI) {
267b6050ca1SSanjay Patel   Value *V;
268b6050ca1SSanjay Patel   ArrayRef<int> Mask;
269b6050ca1SSanjay Patel   if (!match(&I, m_BitCast(m_OneUse(m_ShuffleVector(m_Value(V), m_Undef(),
270b6050ca1SSanjay Patel                                                     m_Mask(Mask))))))
271b6050ca1SSanjay Patel     return false;
272b6050ca1SSanjay Patel 
273bef6e67eSSanjay Patel   // Disallow non-vector casts and length-changing shuffles.
274bef6e67eSSanjay Patel   // TODO: We could allow any shuffle.
2753297e9b7SChristopher Tetreault   auto *DestTy = dyn_cast<VectorType>(I.getType());
2763297e9b7SChristopher Tetreault   auto *SrcTy = cast<VectorType>(V->getType());
2773297e9b7SChristopher Tetreault   if (!DestTy || I.getOperand(0)->getType() != SrcTy)
278b6050ca1SSanjay Patel     return false;
279b6050ca1SSanjay Patel 
280b6050ca1SSanjay Patel   // The new shuffle must not cost more than the old shuffle. The bitcast is
281b6050ca1SSanjay Patel   // moved ahead of the shuffle, so assume that it has the same cost as before.
282b6050ca1SSanjay Patel   if (TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, DestTy) >
283b6050ca1SSanjay Patel       TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, SrcTy))
284b6050ca1SSanjay Patel     return false;
285b6050ca1SSanjay Patel 
286bef6e67eSSanjay Patel   unsigned DestNumElts = DestTy->getNumElements();
287bef6e67eSSanjay Patel   unsigned SrcNumElts = SrcTy->getNumElements();
288b6050ca1SSanjay Patel   SmallVector<int, 16> NewMask;
289bef6e67eSSanjay Patel   if (SrcNumElts <= DestNumElts) {
290bef6e67eSSanjay Patel     // The bitcast is from wide to narrow/equal elements. The shuffle mask can
291bef6e67eSSanjay Patel     // always be expanded to the equivalent form choosing narrower elements.
292b6050ca1SSanjay Patel     assert(DestNumElts % SrcNumElts == 0 && "Unexpected shuffle mask");
293b6050ca1SSanjay Patel     unsigned ScaleFactor = DestNumElts / SrcNumElts;
2941318ddbcSSanjay Patel     narrowShuffleMaskElts(ScaleFactor, Mask, NewMask);
295bef6e67eSSanjay Patel   } else {
296bef6e67eSSanjay Patel     // The bitcast is from narrow elements to wide elements. The shuffle mask
297bef6e67eSSanjay Patel     // must choose consecutive elements to allow casting first.
298bef6e67eSSanjay Patel     assert(SrcNumElts % DestNumElts == 0 && "Unexpected shuffle mask");
299bef6e67eSSanjay Patel     unsigned ScaleFactor = SrcNumElts / DestNumElts;
300bef6e67eSSanjay Patel     if (!widenShuffleMaskElts(ScaleFactor, Mask, NewMask))
301bef6e67eSSanjay Patel       return false;
302bef6e67eSSanjay Patel   }
303bef6e67eSSanjay Patel   // bitcast (shuf V, MaskC) --> shuf (bitcast V), MaskC'
304bef6e67eSSanjay Patel   IRBuilder<> Builder(&I);
305bef6e67eSSanjay Patel   Value *CastV = Builder.CreateBitCast(V, DestTy);
306b6050ca1SSanjay Patel   Value *Shuf = Builder.CreateShuffleVector(CastV, UndefValue::get(DestTy),
307b6050ca1SSanjay Patel                                             NewMask);
308b6050ca1SSanjay Patel   I.replaceAllUsesWith(Shuf);
309b6050ca1SSanjay Patel   return true;
310b6050ca1SSanjay Patel }
311b6050ca1SSanjay Patel 
3120d2a0b44SSanjay Patel /// Match a vector binop instruction with inserted scalar operands and convert
3130d2a0b44SSanjay Patel /// to scalar binop followed by insertelement.
3140d2a0b44SSanjay Patel static bool scalarizeBinop(Instruction &I, const TargetTransformInfo &TTI) {
3150d2a0b44SSanjay Patel   Instruction *Ins0, *Ins1;
3160d2a0b44SSanjay Patel   if (!match(&I, m_BinOp(m_Instruction(Ins0), m_Instruction(Ins1))))
3170d2a0b44SSanjay Patel     return false;
3180d2a0b44SSanjay Patel 
3190d2a0b44SSanjay Patel   // TODO: Deal with mismatched index constants and variable indexes?
3200d2a0b44SSanjay Patel   Constant *VecC0, *VecC1;
3210d2a0b44SSanjay Patel   Value *V0, *V1;
3220d2a0b44SSanjay Patel   uint64_t Index;
3235f730b64SSanjay Patel   if (!match(Ins0, m_InsertElement(m_Constant(VecC0), m_Value(V0),
3245f730b64SSanjay Patel                                    m_ConstantInt(Index))) ||
3255f730b64SSanjay Patel       !match(Ins1, m_InsertElement(m_Constant(VecC1), m_Value(V1),
3265f730b64SSanjay Patel                                    m_SpecificInt(Index))))
3270d2a0b44SSanjay Patel     return false;
3280d2a0b44SSanjay Patel 
3290d2a0b44SSanjay Patel   Type *ScalarTy = V0->getType();
3300d2a0b44SSanjay Patel   Type *VecTy = I.getType();
3310d2a0b44SSanjay Patel   assert(VecTy->isVectorTy() && ScalarTy == V1->getType() &&
3320d2a0b44SSanjay Patel          (ScalarTy->isIntegerTy() || ScalarTy->isFloatingPointTy()) &&
3330d2a0b44SSanjay Patel          "Unexpected types for insert into binop");
3340d2a0b44SSanjay Patel 
3350d2a0b44SSanjay Patel   Instruction::BinaryOps Opcode = cast<BinaryOperator>(&I)->getOpcode();
3360d2a0b44SSanjay Patel   int ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, ScalarTy);
3370d2a0b44SSanjay Patel   int VectorOpCost = TTI.getArithmeticInstrCost(Opcode, VecTy);
3380d2a0b44SSanjay Patel 
3390d2a0b44SSanjay Patel   // Get cost estimate for the insert element. This cost will factor into
3400d2a0b44SSanjay Patel   // both sequences.
3410d2a0b44SSanjay Patel   int InsertCost =
3420d2a0b44SSanjay Patel       TTI.getVectorInstrCost(Instruction::InsertElement, VecTy, Index);
3430d2a0b44SSanjay Patel   int OldCost = InsertCost + InsertCost + VectorOpCost;
3445f730b64SSanjay Patel   int NewCost = ScalarOpCost + InsertCost +
3455f730b64SSanjay Patel                 !Ins0->hasOneUse() * InsertCost +
3465f730b64SSanjay Patel                 !Ins1->hasOneUse() * InsertCost;
3470d2a0b44SSanjay Patel 
3480d2a0b44SSanjay Patel   // We want to scalarize unless the vector variant actually has lower cost.
3490d2a0b44SSanjay Patel   if (OldCost < NewCost)
3500d2a0b44SSanjay Patel     return false;
3510d2a0b44SSanjay Patel 
3520d2a0b44SSanjay Patel   // vec_bo (inselt VecC0, V0, Index), (inselt VecC1, V1, Index) -->
3530d2a0b44SSanjay Patel   // inselt NewVecC, (scalar_bo V0, V1), Index
3540d2a0b44SSanjay Patel   ++NumScalarBO;
3550d2a0b44SSanjay Patel   IRBuilder<> Builder(&I);
3560d2a0b44SSanjay Patel   Value *Scalar = Builder.CreateBinOp(Opcode, V0, V1, I.getName() + ".scalar");
3570d2a0b44SSanjay Patel 
3580d2a0b44SSanjay Patel   // All IR flags are safe to back-propagate. There is no potential for extra
3590d2a0b44SSanjay Patel   // poison to be created by the scalar instruction.
3600d2a0b44SSanjay Patel   if (auto *ScalarInst = dyn_cast<Instruction>(Scalar))
3610d2a0b44SSanjay Patel     ScalarInst->copyIRFlags(&I);
3620d2a0b44SSanjay Patel 
3630d2a0b44SSanjay Patel   // Fold the vector constants in the original vectors into a new base vector.
3640d2a0b44SSanjay Patel   Constant *NewVecC = ConstantExpr::get(Opcode, VecC0, VecC1);
3650d2a0b44SSanjay Patel   Value *Insert = Builder.CreateInsertElement(NewVecC, Scalar, Index);
3660d2a0b44SSanjay Patel   I.replaceAllUsesWith(Insert);
3670d2a0b44SSanjay Patel   Insert->takeName(&I);
3680d2a0b44SSanjay Patel   return true;
3690d2a0b44SSanjay Patel }
3700d2a0b44SSanjay Patel 
371a17f03bdSSanjay Patel /// This is the entry point for all transforms. Pass manager differences are
372a17f03bdSSanjay Patel /// handled in the callers of this function.
373a17f03bdSSanjay Patel static bool runImpl(Function &F, const TargetTransformInfo &TTI,
374a17f03bdSSanjay Patel                     const DominatorTree &DT) {
37525c6544fSSanjay Patel   if (DisableVectorCombine)
37625c6544fSSanjay Patel     return false;
37725c6544fSSanjay Patel 
378a17f03bdSSanjay Patel   bool MadeChange = false;
379a17f03bdSSanjay Patel   for (BasicBlock &BB : F) {
380a17f03bdSSanjay Patel     // Ignore unreachable basic blocks.
381a17f03bdSSanjay Patel     if (!DT.isReachableFromEntry(&BB))
382a17f03bdSSanjay Patel       continue;
383a17f03bdSSanjay Patel     // Do not delete instructions under here and invalidate the iterator.
38481e9ede3SSanjay Patel     // Walk the block forwards to enable simple iterative chains of transforms.
385a17f03bdSSanjay Patel     // TODO: It could be more efficient to remove dead instructions
386a17f03bdSSanjay Patel     //       iteratively in this loop rather than waiting until the end.
38781e9ede3SSanjay Patel     for (Instruction &I : BB) {
388fc3cc8a4SSanjay Patel       if (isa<DbgInfoIntrinsic>(I))
389fc3cc8a4SSanjay Patel         continue;
390fc445589SSanjay Patel       MadeChange |= foldExtractExtract(I, TTI);
391b6050ca1SSanjay Patel       MadeChange |= foldBitcastShuf(I, TTI);
3920d2a0b44SSanjay Patel       MadeChange |= scalarizeBinop(I, TTI);
393a17f03bdSSanjay Patel     }
394fc3cc8a4SSanjay Patel   }
395a17f03bdSSanjay Patel 
396a17f03bdSSanjay Patel   // We're done with transforms, so remove dead instructions.
397a17f03bdSSanjay Patel   if (MadeChange)
398a17f03bdSSanjay Patel     for (BasicBlock &BB : F)
399a17f03bdSSanjay Patel       SimplifyInstructionsInBlock(&BB);
400a17f03bdSSanjay Patel 
401a17f03bdSSanjay Patel   return MadeChange;
402a17f03bdSSanjay Patel }
403a17f03bdSSanjay Patel 
404a17f03bdSSanjay Patel // Pass manager boilerplate below here.
405a17f03bdSSanjay Patel 
406a17f03bdSSanjay Patel namespace {
407a17f03bdSSanjay Patel class VectorCombineLegacyPass : public FunctionPass {
408a17f03bdSSanjay Patel public:
409a17f03bdSSanjay Patel   static char ID;
410a17f03bdSSanjay Patel   VectorCombineLegacyPass() : FunctionPass(ID) {
411a17f03bdSSanjay Patel     initializeVectorCombineLegacyPassPass(*PassRegistry::getPassRegistry());
412a17f03bdSSanjay Patel   }
413a17f03bdSSanjay Patel 
414a17f03bdSSanjay Patel   void getAnalysisUsage(AnalysisUsage &AU) const override {
415a17f03bdSSanjay Patel     AU.addRequired<DominatorTreeWrapperPass>();
416a17f03bdSSanjay Patel     AU.addRequired<TargetTransformInfoWrapperPass>();
417a17f03bdSSanjay Patel     AU.setPreservesCFG();
418a17f03bdSSanjay Patel     AU.addPreserved<DominatorTreeWrapperPass>();
419a17f03bdSSanjay Patel     AU.addPreserved<GlobalsAAWrapperPass>();
420*024098aeSSanjay Patel     AU.addPreserved<AAResultsWrapperPass>();
421*024098aeSSanjay Patel     AU.addPreserved<BasicAAWrapperPass>();
422a17f03bdSSanjay Patel     FunctionPass::getAnalysisUsage(AU);
423a17f03bdSSanjay Patel   }
424a17f03bdSSanjay Patel 
425a17f03bdSSanjay Patel   bool runOnFunction(Function &F) override {
426a17f03bdSSanjay Patel     if (skipFunction(F))
427a17f03bdSSanjay Patel       return false;
428a17f03bdSSanjay Patel     auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
429a17f03bdSSanjay Patel     auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
430a17f03bdSSanjay Patel     return runImpl(F, TTI, DT);
431a17f03bdSSanjay Patel   }
432a17f03bdSSanjay Patel };
433a17f03bdSSanjay Patel } // namespace
434a17f03bdSSanjay Patel 
435a17f03bdSSanjay Patel char VectorCombineLegacyPass::ID = 0;
436a17f03bdSSanjay Patel INITIALIZE_PASS_BEGIN(VectorCombineLegacyPass, "vector-combine",
437a17f03bdSSanjay Patel                       "Optimize scalar/vector ops", false,
438a17f03bdSSanjay Patel                       false)
439a17f03bdSSanjay Patel INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
440a17f03bdSSanjay Patel INITIALIZE_PASS_END(VectorCombineLegacyPass, "vector-combine",
441a17f03bdSSanjay Patel                     "Optimize scalar/vector ops", false, false)
442a17f03bdSSanjay Patel Pass *llvm::createVectorCombinePass() {
443a17f03bdSSanjay Patel   return new VectorCombineLegacyPass();
444a17f03bdSSanjay Patel }
445a17f03bdSSanjay Patel 
446a17f03bdSSanjay Patel PreservedAnalyses VectorCombinePass::run(Function &F,
447a17f03bdSSanjay Patel                                          FunctionAnalysisManager &FAM) {
448a17f03bdSSanjay Patel   TargetTransformInfo &TTI = FAM.getResult<TargetIRAnalysis>(F);
449a17f03bdSSanjay Patel   DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F);
450a17f03bdSSanjay Patel   if (!runImpl(F, TTI, DT))
451a17f03bdSSanjay Patel     return PreservedAnalyses::all();
452a17f03bdSSanjay Patel   PreservedAnalyses PA;
453a17f03bdSSanjay Patel   PA.preserveSet<CFGAnalyses>();
454a17f03bdSSanjay Patel   PA.preserve<GlobalsAA>();
455*024098aeSSanjay Patel   PA.preserve<AAManager>();
456*024098aeSSanjay Patel   PA.preserve<BasicAA>();
457a17f03bdSSanjay Patel   return PA;
458a17f03bdSSanjay Patel }
459