1a17f03bdSSanjay Patel //===------- VectorCombine.cpp - Optimize partial vector operations -------===//
2a17f03bdSSanjay Patel //
3a17f03bdSSanjay Patel // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4a17f03bdSSanjay Patel // See https://llvm.org/LICENSE.txt for license information.
5a17f03bdSSanjay Patel // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6a17f03bdSSanjay Patel //
7a17f03bdSSanjay Patel //===----------------------------------------------------------------------===//
8a17f03bdSSanjay Patel //
9a17f03bdSSanjay Patel // This pass optimizes scalar/vector interactions using target cost models. The
10a17f03bdSSanjay Patel // transforms implemented here may not fit in traditional loop-based or SLP
11a17f03bdSSanjay Patel // vectorization passes.
12a17f03bdSSanjay Patel //
13a17f03bdSSanjay Patel //===----------------------------------------------------------------------===//
14a17f03bdSSanjay Patel 
15a17f03bdSSanjay Patel #include "llvm/Transforms/Vectorize/VectorCombine.h"
16a17f03bdSSanjay Patel #include "llvm/ADT/Statistic.h"
175006e551SSimon Pilgrim #include "llvm/Analysis/BasicAliasAnalysis.h"
18a17f03bdSSanjay Patel #include "llvm/Analysis/GlobalsModRef.h"
19a17f03bdSSanjay Patel #include "llvm/Analysis/TargetTransformInfo.h"
2019b62b79SSanjay Patel #include "llvm/Analysis/ValueTracking.h"
21b6050ca1SSanjay Patel #include "llvm/Analysis/VectorUtils.h"
22a17f03bdSSanjay Patel #include "llvm/IR/Dominators.h"
23a17f03bdSSanjay Patel #include "llvm/IR/Function.h"
24a17f03bdSSanjay Patel #include "llvm/IR/IRBuilder.h"
25a17f03bdSSanjay Patel #include "llvm/IR/PatternMatch.h"
26a17f03bdSSanjay Patel #include "llvm/InitializePasses.h"
27a17f03bdSSanjay Patel #include "llvm/Pass.h"
2825c6544fSSanjay Patel #include "llvm/Support/CommandLine.h"
29a17f03bdSSanjay Patel #include "llvm/Transforms/Utils/Local.h"
305006e551SSimon Pilgrim #include "llvm/Transforms/Vectorize.h"
31a17f03bdSSanjay Patel 
32a17f03bdSSanjay Patel using namespace llvm;
33a17f03bdSSanjay Patel using namespace llvm::PatternMatch;
34a17f03bdSSanjay Patel 
35a17f03bdSSanjay Patel #define DEBUG_TYPE "vector-combine"
36a17f03bdSSanjay Patel STATISTIC(NumVecCmp, "Number of vector compares formed");
3719b62b79SSanjay Patel STATISTIC(NumVecBO, "Number of vector binops formed");
387aeb41b3SRoman Lebedev STATISTIC(NumShufOfBitcast, "Number of shuffles moved after bitcast");
390d2a0b44SSanjay Patel STATISTIC(NumScalarBO, "Number of scalar binops formed");
40ed67f5e7SSanjay Patel STATISTIC(NumScalarCmp, "Number of scalar compares formed");
41a17f03bdSSanjay Patel 
4225c6544fSSanjay Patel static cl::opt<bool> DisableVectorCombine(
4325c6544fSSanjay Patel     "disable-vector-combine", cl::init(false), cl::Hidden,
4425c6544fSSanjay Patel     cl::desc("Disable all vector combine transforms"));
4525c6544fSSanjay Patel 
46a69158c1SSanjay Patel static cl::opt<bool> DisableBinopExtractShuffle(
47a69158c1SSanjay Patel     "disable-binop-extract-shuffle", cl::init(false), cl::Hidden,
48a69158c1SSanjay Patel     cl::desc("Disable binop extract to shuffle transforms"));
49a69158c1SSanjay Patel 
50a69158c1SSanjay Patel 
51a69158c1SSanjay Patel /// Compare the relative costs of 2 extracts followed by scalar operation vs.
52a69158c1SSanjay Patel /// vector operation(s) followed by extract. Return true if the existing
53a69158c1SSanjay Patel /// instructions are cheaper than a vector alternative. Otherwise, return false
54a69158c1SSanjay Patel /// and if one of the extracts should be transformed to a shufflevector, set
55a69158c1SSanjay Patel /// \p ConvertToShuffle to that extract instruction.
56*216a37bbSSanjay Patel static bool isExtractExtractCheap(ExtractElementInst *Ext0,
57*216a37bbSSanjay Patel                                   ExtractElementInst *Ext1, unsigned Opcode,
58a69158c1SSanjay Patel                                   const TargetTransformInfo &TTI,
59*216a37bbSSanjay Patel                                   ExtractElementInst *&ConvertToShuffle,
60ce97ce3aSSanjay Patel                                   unsigned PreferredExtractIndex) {
614fa63fd4SAustin Kerbow   assert(isa<ConstantInt>(Ext0->getOperand(1)) &&
62a69158c1SSanjay Patel          isa<ConstantInt>(Ext1->getOperand(1)) &&
63a69158c1SSanjay Patel          "Expected constant extract indexes");
6434e34855SSanjay Patel   Type *ScalarTy = Ext0->getType();
65e3056ae9SSam Parker   auto *VecTy = cast<VectorType>(Ext0->getOperand(0)->getType());
6634e34855SSanjay Patel   int ScalarOpCost, VectorOpCost;
6734e34855SSanjay Patel 
6834e34855SSanjay Patel   // Get cost estimates for scalar and vector versions of the operation.
6934e34855SSanjay Patel   bool IsBinOp = Instruction::isBinaryOp(Opcode);
7034e34855SSanjay Patel   if (IsBinOp) {
7134e34855SSanjay Patel     ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, ScalarTy);
7234e34855SSanjay Patel     VectorOpCost = TTI.getArithmeticInstrCost(Opcode, VecTy);
7334e34855SSanjay Patel   } else {
7434e34855SSanjay Patel     assert((Opcode == Instruction::ICmp || Opcode == Instruction::FCmp) &&
7534e34855SSanjay Patel            "Expected a compare");
7634e34855SSanjay Patel     ScalarOpCost = TTI.getCmpSelInstrCost(Opcode, ScalarTy,
7734e34855SSanjay Patel                                           CmpInst::makeCmpResultType(ScalarTy));
7834e34855SSanjay Patel     VectorOpCost = TTI.getCmpSelInstrCost(Opcode, VecTy,
7934e34855SSanjay Patel                                           CmpInst::makeCmpResultType(VecTy));
8034e34855SSanjay Patel   }
8134e34855SSanjay Patel 
82a69158c1SSanjay Patel   // Get cost estimates for the extract elements. These costs will factor into
8334e34855SSanjay Patel   // both sequences.
84a69158c1SSanjay Patel   unsigned Ext0Index = cast<ConstantInt>(Ext0->getOperand(1))->getZExtValue();
85a69158c1SSanjay Patel   unsigned Ext1Index = cast<ConstantInt>(Ext1->getOperand(1))->getZExtValue();
86a69158c1SSanjay Patel 
87a69158c1SSanjay Patel   int Extract0Cost = TTI.getVectorInstrCost(Instruction::ExtractElement,
88a69158c1SSanjay Patel                                             VecTy, Ext0Index);
89a69158c1SSanjay Patel   int Extract1Cost = TTI.getVectorInstrCost(Instruction::ExtractElement,
90a69158c1SSanjay Patel                                             VecTy, Ext1Index);
91a69158c1SSanjay Patel 
92a69158c1SSanjay Patel   // A more expensive extract will always be replaced by a splat shuffle.
93a69158c1SSanjay Patel   // For example, if Ext0 is more expensive:
94a69158c1SSanjay Patel   // opcode (extelt V0, Ext0), (ext V1, Ext1) -->
95a69158c1SSanjay Patel   // extelt (opcode (splat V0, Ext0), V1), Ext1
96a69158c1SSanjay Patel   // TODO: Evaluate whether that always results in lowest cost. Alternatively,
97a69158c1SSanjay Patel   //       check the cost of creating a broadcast shuffle and shuffling both
98a69158c1SSanjay Patel   //       operands to element 0.
99a69158c1SSanjay Patel   int CheapExtractCost = std::min(Extract0Cost, Extract1Cost);
10034e34855SSanjay Patel 
10134e34855SSanjay Patel   // Extra uses of the extracts mean that we include those costs in the
10234e34855SSanjay Patel   // vector total because those instructions will not be eliminated.
103e9c79a7aSSanjay Patel   int OldCost, NewCost;
104a69158c1SSanjay Patel   if (Ext0->getOperand(0) == Ext1->getOperand(0) && Ext0Index == Ext1Index) {
105a69158c1SSanjay Patel     // Handle a special case. If the 2 extracts are identical, adjust the
10634e34855SSanjay Patel     // formulas to account for that. The extra use charge allows for either the
10734e34855SSanjay Patel     // CSE'd pattern or an unoptimized form with identical values:
10834e34855SSanjay Patel     // opcode (extelt V, C), (extelt V, C) --> extelt (opcode V, V), C
10934e34855SSanjay Patel     bool HasUseTax = Ext0 == Ext1 ? !Ext0->hasNUses(2)
11034e34855SSanjay Patel                                   : !Ext0->hasOneUse() || !Ext1->hasOneUse();
111a69158c1SSanjay Patel     OldCost = CheapExtractCost + ScalarOpCost;
112a69158c1SSanjay Patel     NewCost = VectorOpCost + CheapExtractCost + HasUseTax * CheapExtractCost;
11334e34855SSanjay Patel   } else {
11434e34855SSanjay Patel     // Handle the general case. Each extract is actually a different value:
115a69158c1SSanjay Patel     // opcode (extelt V0, C0), (extelt V1, C1) --> extelt (opcode V0, V1), C
116a69158c1SSanjay Patel     OldCost = Extract0Cost + Extract1Cost + ScalarOpCost;
117a69158c1SSanjay Patel     NewCost = VectorOpCost + CheapExtractCost +
118a69158c1SSanjay Patel               !Ext0->hasOneUse() * Extract0Cost +
119a69158c1SSanjay Patel               !Ext1->hasOneUse() * Extract1Cost;
12034e34855SSanjay Patel   }
121a69158c1SSanjay Patel 
122a69158c1SSanjay Patel   if (Ext0Index == Ext1Index) {
123a69158c1SSanjay Patel     // If the extract indexes are identical, no shuffle is needed.
124a69158c1SSanjay Patel     ConvertToShuffle = nullptr;
125a69158c1SSanjay Patel   } else {
126a69158c1SSanjay Patel     if (IsBinOp && DisableBinopExtractShuffle)
127a69158c1SSanjay Patel       return true;
128a69158c1SSanjay Patel 
129a69158c1SSanjay Patel     // If we are extracting from 2 different indexes, then one operand must be
130a69158c1SSanjay Patel     // shuffled before performing the vector operation. The shuffle mask is
131a69158c1SSanjay Patel     // undefined except for 1 lane that is being translated to the remaining
132a69158c1SSanjay Patel     // extraction lane. Therefore, it is a splat shuffle. Ex:
133a69158c1SSanjay Patel     // ShufMask = { undef, undef, 0, undef }
134a69158c1SSanjay Patel     // TODO: The cost model has an option for a "broadcast" shuffle
135a69158c1SSanjay Patel     //       (splat-from-element-0), but no option for a more general splat.
136a69158c1SSanjay Patel     NewCost +=
137a69158c1SSanjay Patel         TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, VecTy);
138a69158c1SSanjay Patel 
139ce97ce3aSSanjay Patel     // The more expensive extract will be replaced by a shuffle. If the costs
140ce97ce3aSSanjay Patel     // are equal and there is a preferred extract index, shuffle the opposite
141ce97ce3aSSanjay Patel     // operand. Otherwise, replace the extract with the higher index.
142a69158c1SSanjay Patel     if (Extract0Cost > Extract1Cost)
143a69158c1SSanjay Patel       ConvertToShuffle = Ext0;
144a69158c1SSanjay Patel     else if (Extract1Cost > Extract0Cost)
145a69158c1SSanjay Patel       ConvertToShuffle = Ext1;
146ce97ce3aSSanjay Patel     else if (PreferredExtractIndex == Ext0Index)
147ce97ce3aSSanjay Patel       ConvertToShuffle = Ext1;
148ce97ce3aSSanjay Patel     else if (PreferredExtractIndex == Ext1Index)
149ce97ce3aSSanjay Patel       ConvertToShuffle = Ext0;
150a69158c1SSanjay Patel     else
151a69158c1SSanjay Patel       ConvertToShuffle = Ext0Index > Ext1Index ? Ext0 : Ext1;
152a69158c1SSanjay Patel   }
153a69158c1SSanjay Patel 
15410ea01d8SSanjay Patel   // Aggressively form a vector op if the cost is equal because the transform
15510ea01d8SSanjay Patel   // may enable further optimization.
15610ea01d8SSanjay Patel   // Codegen can reverse this transform (scalarize) if it was not profitable.
15710ea01d8SSanjay Patel   return OldCost < NewCost;
15834e34855SSanjay Patel }
15934e34855SSanjay Patel 
160*216a37bbSSanjay Patel /// Given an extract element instruction with constant index operand, shuffle
161*216a37bbSSanjay Patel /// the source vector (shift the scalar element) to a NewIndex for extraction.
162*216a37bbSSanjay Patel /// Return null if the input can be constant folded, so that we are not creating
163*216a37bbSSanjay Patel /// unnecessary instructions.
164*216a37bbSSanjay Patel static ExtractElementInst *translateExtract(ExtractElementInst *ExtElt,
165*216a37bbSSanjay Patel                                             unsigned NewIndex) {
166*216a37bbSSanjay Patel   // If the extract can be constant-folded, this code is unsimplified. Defer
167*216a37bbSSanjay Patel   // to other passes to handle that.
168*216a37bbSSanjay Patel   Value *X = ExtElt->getVectorOperand();
169*216a37bbSSanjay Patel   Value *C = ExtElt->getIndexOperand();
170*216a37bbSSanjay Patel   if (isa<Constant>(X))
171*216a37bbSSanjay Patel     return nullptr;
172*216a37bbSSanjay Patel 
173*216a37bbSSanjay Patel   // The shuffle mask is undefined except for 1 lane that is being translated
174*216a37bbSSanjay Patel   // to the cheap extraction lane. Example:
175*216a37bbSSanjay Patel   // ShufMask = { 2, undef, undef, undef }
176*216a37bbSSanjay Patel   auto *VecTy = cast<FixedVectorType>(X->getType());
177*216a37bbSSanjay Patel   SmallVector<int, 32> Mask(VecTy->getNumElements(), -1);
178*216a37bbSSanjay Patel   assert(isa<ConstantInt>(C) && "Expected a constant index operand");
179*216a37bbSSanjay Patel   Mask[NewIndex] = cast<ConstantInt>(C)->getZExtValue();
180*216a37bbSSanjay Patel 
181*216a37bbSSanjay Patel   // extelt X, C --> extelt (shuffle X), NewIndex
182*216a37bbSSanjay Patel   IRBuilder<> Builder(ExtElt);
183*216a37bbSSanjay Patel   Value *Shuf = Builder.CreateShuffleVector(X, UndefValue::get(VecTy), Mask);
184*216a37bbSSanjay Patel   return cast<ExtractElementInst>(Builder.CreateExtractElement(Shuf, NewIndex));
185*216a37bbSSanjay Patel }
186*216a37bbSSanjay Patel 
187fc445589SSanjay Patel /// Try to reduce extract element costs by converting scalar compares to vector
188fc445589SSanjay Patel /// compares followed by extract.
189e9c79a7aSSanjay Patel /// cmp (ext0 V0, C), (ext1 V1, C)
190*216a37bbSSanjay Patel static void foldExtExtCmp(ExtractElementInst *Ext0, ExtractElementInst *Ext1,
191039ff29eSSanjay Patel                           Instruction &I) {
192fc445589SSanjay Patel   assert(isa<CmpInst>(&I) && "Expected a compare");
193*216a37bbSSanjay Patel   assert(cast<ConstantInt>(Ext0->getIndexOperand())->getZExtValue() ==
194*216a37bbSSanjay Patel              cast<ConstantInt>(Ext1->getIndexOperand())->getZExtValue() &&
195*216a37bbSSanjay Patel          "Expected matching constant extract indexes");
196a17f03bdSSanjay Patel 
197a17f03bdSSanjay Patel   // cmp Pred (extelt V0, C), (extelt V1, C) --> extelt (cmp Pred V0, V1), C
198a17f03bdSSanjay Patel   ++NumVecCmp;
199a17f03bdSSanjay Patel   IRBuilder<> Builder(&I);
200fc445589SSanjay Patel   CmpInst::Predicate Pred = cast<CmpInst>(&I)->getPredicate();
201*216a37bbSSanjay Patel   Value *V0 = Ext0->getVectorOperand(), *V1 = Ext1->getVectorOperand();
20246a285adSSanjay Patel   Value *VecCmp = Builder.CreateCmp(Pred, V0, V1);
203*216a37bbSSanjay Patel   Value *NewExt = Builder.CreateExtractElement(VecCmp, Ext0->getIndexOperand());
204*216a37bbSSanjay Patel   I.replaceAllUsesWith(NewExt);
205a17f03bdSSanjay Patel }
206a17f03bdSSanjay Patel 
20719b62b79SSanjay Patel /// Try to reduce extract element costs by converting scalar binops to vector
20819b62b79SSanjay Patel /// binops followed by extract.
209e9c79a7aSSanjay Patel /// bo (ext0 V0, C), (ext1 V1, C)
210*216a37bbSSanjay Patel static void foldExtExtBinop(ExtractElementInst *Ext0, ExtractElementInst *Ext1,
211039ff29eSSanjay Patel                             Instruction &I) {
212fc445589SSanjay Patel   assert(isa<BinaryOperator>(&I) && "Expected a binary operator");
213*216a37bbSSanjay Patel   assert(cast<ConstantInt>(Ext0->getIndexOperand())->getZExtValue() ==
214*216a37bbSSanjay Patel              cast<ConstantInt>(Ext1->getIndexOperand())->getZExtValue() &&
215*216a37bbSSanjay Patel          "Expected matching constant extract indexes");
21619b62b79SSanjay Patel 
21734e34855SSanjay Patel   // bo (extelt V0, C), (extelt V1, C) --> extelt (bo V0, V1), C
21819b62b79SSanjay Patel   ++NumVecBO;
21919b62b79SSanjay Patel   IRBuilder<> Builder(&I);
220*216a37bbSSanjay Patel   Value *V0 = Ext0->getVectorOperand(), *V1 = Ext1->getVectorOperand();
221e9c79a7aSSanjay Patel   Value *VecBO =
22234e34855SSanjay Patel       Builder.CreateBinOp(cast<BinaryOperator>(&I)->getOpcode(), V0, V1);
223e9c79a7aSSanjay Patel 
22419b62b79SSanjay Patel   // All IR flags are safe to back-propagate because any potential poison
22519b62b79SSanjay Patel   // created in unused vector elements is discarded by the extract.
226e9c79a7aSSanjay Patel   if (auto *VecBOInst = dyn_cast<Instruction>(VecBO))
22719b62b79SSanjay Patel     VecBOInst->copyIRFlags(&I);
228e9c79a7aSSanjay Patel 
229*216a37bbSSanjay Patel   Value *NewExt = Builder.CreateExtractElement(VecBO, Ext0->getIndexOperand());
230*216a37bbSSanjay Patel   I.replaceAllUsesWith(NewExt);
23119b62b79SSanjay Patel }
23219b62b79SSanjay Patel 
233fc445589SSanjay Patel /// Match an instruction with extracted vector operands.
234fc445589SSanjay Patel static bool foldExtractExtract(Instruction &I, const TargetTransformInfo &TTI) {
235e9c79a7aSSanjay Patel   // It is not safe to transform things like div, urem, etc. because we may
236e9c79a7aSSanjay Patel   // create undefined behavior when executing those on unknown vector elements.
237e9c79a7aSSanjay Patel   if (!isSafeToSpeculativelyExecute(&I))
238e9c79a7aSSanjay Patel     return false;
239e9c79a7aSSanjay Patel 
240*216a37bbSSanjay Patel   Instruction *I0, *I1;
241fc445589SSanjay Patel   CmpInst::Predicate Pred = CmpInst::BAD_ICMP_PREDICATE;
242*216a37bbSSanjay Patel   if (!match(&I, m_Cmp(Pred, m_Instruction(I0), m_Instruction(I1))) &&
243*216a37bbSSanjay Patel       !match(&I, m_BinOp(m_Instruction(I0), m_Instruction(I1))))
244fc445589SSanjay Patel     return false;
245fc445589SSanjay Patel 
246fc445589SSanjay Patel   Value *V0, *V1;
247fc445589SSanjay Patel   uint64_t C0, C1;
248*216a37bbSSanjay Patel   if (!match(I0, m_ExtractElt(m_Value(V0), m_ConstantInt(C0))) ||
249*216a37bbSSanjay Patel       !match(I1, m_ExtractElt(m_Value(V1), m_ConstantInt(C1))) ||
250fc445589SSanjay Patel       V0->getType() != V1->getType())
251fc445589SSanjay Patel     return false;
252fc445589SSanjay Patel 
253ce97ce3aSSanjay Patel   // If the scalar value 'I' is going to be re-inserted into a vector, then try
254ce97ce3aSSanjay Patel   // to create an extract to that same element. The extract/insert can be
255ce97ce3aSSanjay Patel   // reduced to a "select shuffle".
256ce97ce3aSSanjay Patel   // TODO: If we add a larger pattern match that starts from an insert, this
257ce97ce3aSSanjay Patel   //       probably becomes unnecessary.
258*216a37bbSSanjay Patel   auto *Ext0 = cast<ExtractElementInst>(I0);
259*216a37bbSSanjay Patel   auto *Ext1 = cast<ExtractElementInst>(I1);
260ce97ce3aSSanjay Patel   uint64_t InsertIndex = std::numeric_limits<uint64_t>::max();
261ce97ce3aSSanjay Patel   if (I.hasOneUse())
2627eed772aSSanjay Patel     match(I.user_back(),
2637eed772aSSanjay Patel           m_InsertElt(m_Value(), m_Value(), m_ConstantInt(InsertIndex)));
264ce97ce3aSSanjay Patel 
265*216a37bbSSanjay Patel   ExtractElementInst *ExtractToChange;
266*216a37bbSSanjay Patel   if (isExtractExtractCheap(Ext0, Ext1, I.getOpcode(), TTI, ExtractToChange,
267ce97ce3aSSanjay Patel                             InsertIndex))
268fc445589SSanjay Patel     return false;
269e9c79a7aSSanjay Patel 
270*216a37bbSSanjay Patel   if (ExtractToChange) {
271*216a37bbSSanjay Patel     unsigned CheapExtractIdx = ExtractToChange == Ext0 ? C1 : C0;
272*216a37bbSSanjay Patel     ExtractElementInst *NewExtract =
273*216a37bbSSanjay Patel         translateExtract(ExtractToChange, CheapExtractIdx);
274*216a37bbSSanjay Patel     if (!NewExtract)
2756d864097SSanjay Patel       return false;
276*216a37bbSSanjay Patel     if (ExtractToChange == Ext0)
277*216a37bbSSanjay Patel       Ext0 = NewExtract;
278a69158c1SSanjay Patel     else
279*216a37bbSSanjay Patel       Ext1 = NewExtract;
280a69158c1SSanjay Patel   }
281e9c79a7aSSanjay Patel 
282e9c79a7aSSanjay Patel   if (Pred != CmpInst::BAD_ICMP_PREDICATE)
283039ff29eSSanjay Patel     foldExtExtCmp(Ext0, Ext1, I);
284e9c79a7aSSanjay Patel   else
285039ff29eSSanjay Patel     foldExtExtBinop(Ext0, Ext1, I);
286e9c79a7aSSanjay Patel 
287e9c79a7aSSanjay Patel   return true;
288fc445589SSanjay Patel }
289fc445589SSanjay Patel 
290bef6e67eSSanjay Patel /// If this is a bitcast of a shuffle, try to bitcast the source vector to the
291bef6e67eSSanjay Patel /// destination type followed by shuffle. This can enable further transforms by
292bef6e67eSSanjay Patel /// moving bitcasts or shuffles together.
293b6050ca1SSanjay Patel static bool foldBitcastShuf(Instruction &I, const TargetTransformInfo &TTI) {
294b6050ca1SSanjay Patel   Value *V;
295b6050ca1SSanjay Patel   ArrayRef<int> Mask;
2967eed772aSSanjay Patel   if (!match(&I, m_BitCast(
2977eed772aSSanjay Patel                      m_OneUse(m_Shuffle(m_Value(V), m_Undef(), m_Mask(Mask))))))
298b6050ca1SSanjay Patel     return false;
299b6050ca1SSanjay Patel 
300bef6e67eSSanjay Patel   // Disallow non-vector casts and length-changing shuffles.
301bef6e67eSSanjay Patel   // TODO: We could allow any shuffle.
3023297e9b7SChristopher Tetreault   auto *DestTy = dyn_cast<VectorType>(I.getType());
3033297e9b7SChristopher Tetreault   auto *SrcTy = cast<VectorType>(V->getType());
3043297e9b7SChristopher Tetreault   if (!DestTy || I.getOperand(0)->getType() != SrcTy)
305b6050ca1SSanjay Patel     return false;
306b6050ca1SSanjay Patel 
307b6050ca1SSanjay Patel   // The new shuffle must not cost more than the old shuffle. The bitcast is
308b6050ca1SSanjay Patel   // moved ahead of the shuffle, so assume that it has the same cost as before.
309b6050ca1SSanjay Patel   if (TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, DestTy) >
310b6050ca1SSanjay Patel       TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, SrcTy))
311b6050ca1SSanjay Patel     return false;
312b6050ca1SSanjay Patel 
313bef6e67eSSanjay Patel   unsigned DestNumElts = DestTy->getNumElements();
314bef6e67eSSanjay Patel   unsigned SrcNumElts = SrcTy->getNumElements();
315b6050ca1SSanjay Patel   SmallVector<int, 16> NewMask;
316bef6e67eSSanjay Patel   if (SrcNumElts <= DestNumElts) {
317bef6e67eSSanjay Patel     // The bitcast is from wide to narrow/equal elements. The shuffle mask can
318bef6e67eSSanjay Patel     // always be expanded to the equivalent form choosing narrower elements.
319b6050ca1SSanjay Patel     assert(DestNumElts % SrcNumElts == 0 && "Unexpected shuffle mask");
320b6050ca1SSanjay Patel     unsigned ScaleFactor = DestNumElts / SrcNumElts;
3211318ddbcSSanjay Patel     narrowShuffleMaskElts(ScaleFactor, Mask, NewMask);
322bef6e67eSSanjay Patel   } else {
323bef6e67eSSanjay Patel     // The bitcast is from narrow elements to wide elements. The shuffle mask
324bef6e67eSSanjay Patel     // must choose consecutive elements to allow casting first.
325bef6e67eSSanjay Patel     assert(SrcNumElts % DestNumElts == 0 && "Unexpected shuffle mask");
326bef6e67eSSanjay Patel     unsigned ScaleFactor = SrcNumElts / DestNumElts;
327bef6e67eSSanjay Patel     if (!widenShuffleMaskElts(ScaleFactor, Mask, NewMask))
328bef6e67eSSanjay Patel       return false;
329bef6e67eSSanjay Patel   }
330bef6e67eSSanjay Patel   // bitcast (shuf V, MaskC) --> shuf (bitcast V), MaskC'
3317aeb41b3SRoman Lebedev   ++NumShufOfBitcast;
332bef6e67eSSanjay Patel   IRBuilder<> Builder(&I);
333bef6e67eSSanjay Patel   Value *CastV = Builder.CreateBitCast(V, DestTy);
3347eed772aSSanjay Patel   Value *Shuf =
3357eed772aSSanjay Patel       Builder.CreateShuffleVector(CastV, UndefValue::get(DestTy), NewMask);
336b6050ca1SSanjay Patel   I.replaceAllUsesWith(Shuf);
337b6050ca1SSanjay Patel   return true;
338b6050ca1SSanjay Patel }
339b6050ca1SSanjay Patel 
340ed67f5e7SSanjay Patel /// Match a vector binop or compare instruction with at least one inserted
341ed67f5e7SSanjay Patel /// scalar operand and convert to scalar binop/cmp followed by insertelement.
342ed67f5e7SSanjay Patel static bool scalarizeBinopOrCmp(Instruction &I,
343ed67f5e7SSanjay Patel                                 const TargetTransformInfo &TTI) {
344ed67f5e7SSanjay Patel   CmpInst::Predicate Pred = CmpInst::BAD_ICMP_PREDICATE;
3455dc4e7c2SSimon Pilgrim   Value *Ins0, *Ins1;
346ed67f5e7SSanjay Patel   if (!match(&I, m_BinOp(m_Value(Ins0), m_Value(Ins1))) &&
347ed67f5e7SSanjay Patel       !match(&I, m_Cmp(Pred, m_Value(Ins0), m_Value(Ins1))))
348ed67f5e7SSanjay Patel     return false;
349ed67f5e7SSanjay Patel 
350ed67f5e7SSanjay Patel   // Do not convert the vector condition of a vector select into a scalar
351ed67f5e7SSanjay Patel   // condition. That may cause problems for codegen because of differences in
352ed67f5e7SSanjay Patel   // boolean formats and register-file transfers.
353ed67f5e7SSanjay Patel   // TODO: Can we account for that in the cost model?
354ed67f5e7SSanjay Patel   bool IsCmp = Pred != CmpInst::Predicate::BAD_ICMP_PREDICATE;
355ed67f5e7SSanjay Patel   if (IsCmp)
356ed67f5e7SSanjay Patel     for (User *U : I.users())
357ed67f5e7SSanjay Patel       if (match(U, m_Select(m_Specific(&I), m_Value(), m_Value())))
3580d2a0b44SSanjay Patel         return false;
3590d2a0b44SSanjay Patel 
3605dc4e7c2SSimon Pilgrim   // Match against one or both scalar values being inserted into constant
3615dc4e7c2SSimon Pilgrim   // vectors:
362ed67f5e7SSanjay Patel   // vec_op VecC0, (inselt VecC1, V1, Index)
363ed67f5e7SSanjay Patel   // vec_op (inselt VecC0, V0, Index), VecC1
364ed67f5e7SSanjay Patel   // vec_op (inselt VecC0, V0, Index), (inselt VecC1, V1, Index)
3650d2a0b44SSanjay Patel   // TODO: Deal with mismatched index constants and variable indexes?
3665dc4e7c2SSimon Pilgrim   Constant *VecC0 = nullptr, *VecC1 = nullptr;
3675dc4e7c2SSimon Pilgrim   Value *V0 = nullptr, *V1 = nullptr;
3685dc4e7c2SSimon Pilgrim   uint64_t Index0 = 0, Index1 = 0;
3697eed772aSSanjay Patel   if (!match(Ins0, m_InsertElt(m_Constant(VecC0), m_Value(V0),
3705dc4e7c2SSimon Pilgrim                                m_ConstantInt(Index0))) &&
3715dc4e7c2SSimon Pilgrim       !match(Ins0, m_Constant(VecC0)))
3725dc4e7c2SSimon Pilgrim     return false;
3735dc4e7c2SSimon Pilgrim   if (!match(Ins1, m_InsertElt(m_Constant(VecC1), m_Value(V1),
3745dc4e7c2SSimon Pilgrim                                m_ConstantInt(Index1))) &&
3755dc4e7c2SSimon Pilgrim       !match(Ins1, m_Constant(VecC1)))
3760d2a0b44SSanjay Patel     return false;
3770d2a0b44SSanjay Patel 
3785dc4e7c2SSimon Pilgrim   bool IsConst0 = !V0;
3795dc4e7c2SSimon Pilgrim   bool IsConst1 = !V1;
3805dc4e7c2SSimon Pilgrim   if (IsConst0 && IsConst1)
3815dc4e7c2SSimon Pilgrim     return false;
3825dc4e7c2SSimon Pilgrim   if (!IsConst0 && !IsConst1 && Index0 != Index1)
3835dc4e7c2SSimon Pilgrim     return false;
3845dc4e7c2SSimon Pilgrim 
3855dc4e7c2SSimon Pilgrim   // Bail for single insertion if it is a load.
3865dc4e7c2SSimon Pilgrim   // TODO: Handle this once getVectorInstrCost can cost for load/stores.
3875dc4e7c2SSimon Pilgrim   auto *I0 = dyn_cast_or_null<Instruction>(V0);
3885dc4e7c2SSimon Pilgrim   auto *I1 = dyn_cast_or_null<Instruction>(V1);
3895dc4e7c2SSimon Pilgrim   if ((IsConst0 && I1 && I1->mayReadFromMemory()) ||
3905dc4e7c2SSimon Pilgrim       (IsConst1 && I0 && I0->mayReadFromMemory()))
3915dc4e7c2SSimon Pilgrim     return false;
3925dc4e7c2SSimon Pilgrim 
3935dc4e7c2SSimon Pilgrim   uint64_t Index = IsConst0 ? Index1 : Index0;
3945dc4e7c2SSimon Pilgrim   Type *ScalarTy = IsConst0 ? V1->getType() : V0->getType();
3950d2a0b44SSanjay Patel   Type *VecTy = I.getType();
3965dc4e7c2SSimon Pilgrim   assert(VecTy->isVectorTy() &&
3975dc4e7c2SSimon Pilgrim          (IsConst0 || IsConst1 || V0->getType() == V1->getType()) &&
3980d2a0b44SSanjay Patel          (ScalarTy->isIntegerTy() || ScalarTy->isFloatingPointTy()) &&
3990d2a0b44SSanjay Patel          "Unexpected types for insert into binop");
4000d2a0b44SSanjay Patel 
401ed67f5e7SSanjay Patel   unsigned Opcode = I.getOpcode();
402ed67f5e7SSanjay Patel   int ScalarOpCost, VectorOpCost;
403ed67f5e7SSanjay Patel   if (IsCmp) {
404ed67f5e7SSanjay Patel     ScalarOpCost = TTI.getCmpSelInstrCost(Opcode, ScalarTy);
405ed67f5e7SSanjay Patel     VectorOpCost = TTI.getCmpSelInstrCost(Opcode, VecTy);
406ed67f5e7SSanjay Patel   } else {
407ed67f5e7SSanjay Patel     ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, ScalarTy);
408ed67f5e7SSanjay Patel     VectorOpCost = TTI.getArithmeticInstrCost(Opcode, VecTy);
409ed67f5e7SSanjay Patel   }
4100d2a0b44SSanjay Patel 
4110d2a0b44SSanjay Patel   // Get cost estimate for the insert element. This cost will factor into
4120d2a0b44SSanjay Patel   // both sequences.
4130d2a0b44SSanjay Patel   int InsertCost =
4140d2a0b44SSanjay Patel       TTI.getVectorInstrCost(Instruction::InsertElement, VecTy, Index);
4155dc4e7c2SSimon Pilgrim   int OldCost = (IsConst0 ? 0 : InsertCost) + (IsConst1 ? 0 : InsertCost) +
4165dc4e7c2SSimon Pilgrim                 VectorOpCost;
4175f730b64SSanjay Patel   int NewCost = ScalarOpCost + InsertCost +
4185dc4e7c2SSimon Pilgrim                 (IsConst0 ? 0 : !Ins0->hasOneUse() * InsertCost) +
4195dc4e7c2SSimon Pilgrim                 (IsConst1 ? 0 : !Ins1->hasOneUse() * InsertCost);
4200d2a0b44SSanjay Patel 
4210d2a0b44SSanjay Patel   // We want to scalarize unless the vector variant actually has lower cost.
4220d2a0b44SSanjay Patel   if (OldCost < NewCost)
4230d2a0b44SSanjay Patel     return false;
4240d2a0b44SSanjay Patel 
425ed67f5e7SSanjay Patel   // vec_op (inselt VecC0, V0, Index), (inselt VecC1, V1, Index) -->
426ed67f5e7SSanjay Patel   // inselt NewVecC, (scalar_op V0, V1), Index
427ed67f5e7SSanjay Patel   if (IsCmp)
428ed67f5e7SSanjay Patel     ++NumScalarCmp;
429ed67f5e7SSanjay Patel   else
4300d2a0b44SSanjay Patel     ++NumScalarBO;
4315dc4e7c2SSimon Pilgrim 
4325dc4e7c2SSimon Pilgrim   // For constant cases, extract the scalar element, this should constant fold.
433ed67f5e7SSanjay Patel   IRBuilder<> Builder(&I);
4345dc4e7c2SSimon Pilgrim   if (IsConst0)
4355dc4e7c2SSimon Pilgrim     V0 = ConstantExpr::getExtractElement(VecC0, Builder.getInt64(Index));
4365dc4e7c2SSimon Pilgrim   if (IsConst1)
4375dc4e7c2SSimon Pilgrim     V1 = ConstantExpr::getExtractElement(VecC1, Builder.getInt64(Index));
4385dc4e7c2SSimon Pilgrim 
439ed67f5e7SSanjay Patel   Value *Scalar =
44046a285adSSanjay Patel       IsCmp ? Builder.CreateCmp(Pred, V0, V1)
441ed67f5e7SSanjay Patel             : Builder.CreateBinOp((Instruction::BinaryOps)Opcode, V0, V1);
442ed67f5e7SSanjay Patel 
443ed67f5e7SSanjay Patel   Scalar->setName(I.getName() + ".scalar");
4440d2a0b44SSanjay Patel 
4450d2a0b44SSanjay Patel   // All IR flags are safe to back-propagate. There is no potential for extra
4460d2a0b44SSanjay Patel   // poison to be created by the scalar instruction.
4470d2a0b44SSanjay Patel   if (auto *ScalarInst = dyn_cast<Instruction>(Scalar))
4480d2a0b44SSanjay Patel     ScalarInst->copyIRFlags(&I);
4490d2a0b44SSanjay Patel 
4500d2a0b44SSanjay Patel   // Fold the vector constants in the original vectors into a new base vector.
451ed67f5e7SSanjay Patel   Constant *NewVecC = IsCmp ? ConstantExpr::getCompare(Pred, VecC0, VecC1)
452ed67f5e7SSanjay Patel                             : ConstantExpr::get(Opcode, VecC0, VecC1);
4530d2a0b44SSanjay Patel   Value *Insert = Builder.CreateInsertElement(NewVecC, Scalar, Index);
4540d2a0b44SSanjay Patel   I.replaceAllUsesWith(Insert);
4550d2a0b44SSanjay Patel   Insert->takeName(&I);
4560d2a0b44SSanjay Patel   return true;
4570d2a0b44SSanjay Patel }
4580d2a0b44SSanjay Patel 
459a17f03bdSSanjay Patel /// This is the entry point for all transforms. Pass manager differences are
460a17f03bdSSanjay Patel /// handled in the callers of this function.
461a17f03bdSSanjay Patel static bool runImpl(Function &F, const TargetTransformInfo &TTI,
462a17f03bdSSanjay Patel                     const DominatorTree &DT) {
46325c6544fSSanjay Patel   if (DisableVectorCombine)
46425c6544fSSanjay Patel     return false;
46525c6544fSSanjay Patel 
466a17f03bdSSanjay Patel   bool MadeChange = false;
467a17f03bdSSanjay Patel   for (BasicBlock &BB : F) {
468a17f03bdSSanjay Patel     // Ignore unreachable basic blocks.
469a17f03bdSSanjay Patel     if (!DT.isReachableFromEntry(&BB))
470a17f03bdSSanjay Patel       continue;
471a17f03bdSSanjay Patel     // Do not delete instructions under here and invalidate the iterator.
47281e9ede3SSanjay Patel     // Walk the block forwards to enable simple iterative chains of transforms.
473a17f03bdSSanjay Patel     // TODO: It could be more efficient to remove dead instructions
474a17f03bdSSanjay Patel     //       iteratively in this loop rather than waiting until the end.
47581e9ede3SSanjay Patel     for (Instruction &I : BB) {
476fc3cc8a4SSanjay Patel       if (isa<DbgInfoIntrinsic>(I))
477fc3cc8a4SSanjay Patel         continue;
478fc445589SSanjay Patel       MadeChange |= foldExtractExtract(I, TTI);
479b6050ca1SSanjay Patel       MadeChange |= foldBitcastShuf(I, TTI);
480ed67f5e7SSanjay Patel       MadeChange |= scalarizeBinopOrCmp(I, TTI);
481a17f03bdSSanjay Patel     }
482fc3cc8a4SSanjay Patel   }
483a17f03bdSSanjay Patel 
484a17f03bdSSanjay Patel   // We're done with transforms, so remove dead instructions.
485a17f03bdSSanjay Patel   if (MadeChange)
486a17f03bdSSanjay Patel     for (BasicBlock &BB : F)
487a17f03bdSSanjay Patel       SimplifyInstructionsInBlock(&BB);
488a17f03bdSSanjay Patel 
489a17f03bdSSanjay Patel   return MadeChange;
490a17f03bdSSanjay Patel }
491a17f03bdSSanjay Patel 
492a17f03bdSSanjay Patel // Pass manager boilerplate below here.
493a17f03bdSSanjay Patel 
494a17f03bdSSanjay Patel namespace {
495a17f03bdSSanjay Patel class VectorCombineLegacyPass : public FunctionPass {
496a17f03bdSSanjay Patel public:
497a17f03bdSSanjay Patel   static char ID;
498a17f03bdSSanjay Patel   VectorCombineLegacyPass() : FunctionPass(ID) {
499a17f03bdSSanjay Patel     initializeVectorCombineLegacyPassPass(*PassRegistry::getPassRegistry());
500a17f03bdSSanjay Patel   }
501a17f03bdSSanjay Patel 
502a17f03bdSSanjay Patel   void getAnalysisUsage(AnalysisUsage &AU) const override {
503a17f03bdSSanjay Patel     AU.addRequired<DominatorTreeWrapperPass>();
504a17f03bdSSanjay Patel     AU.addRequired<TargetTransformInfoWrapperPass>();
505a17f03bdSSanjay Patel     AU.setPreservesCFG();
506a17f03bdSSanjay Patel     AU.addPreserved<DominatorTreeWrapperPass>();
507a17f03bdSSanjay Patel     AU.addPreserved<GlobalsAAWrapperPass>();
508024098aeSSanjay Patel     AU.addPreserved<AAResultsWrapperPass>();
509024098aeSSanjay Patel     AU.addPreserved<BasicAAWrapperPass>();
510a17f03bdSSanjay Patel     FunctionPass::getAnalysisUsage(AU);
511a17f03bdSSanjay Patel   }
512a17f03bdSSanjay Patel 
513a17f03bdSSanjay Patel   bool runOnFunction(Function &F) override {
514a17f03bdSSanjay Patel     if (skipFunction(F))
515a17f03bdSSanjay Patel       return false;
516a17f03bdSSanjay Patel     auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
517a17f03bdSSanjay Patel     auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
518a17f03bdSSanjay Patel     return runImpl(F, TTI, DT);
519a17f03bdSSanjay Patel   }
520a17f03bdSSanjay Patel };
521a17f03bdSSanjay Patel } // namespace
522a17f03bdSSanjay Patel 
523a17f03bdSSanjay Patel char VectorCombineLegacyPass::ID = 0;
524a17f03bdSSanjay Patel INITIALIZE_PASS_BEGIN(VectorCombineLegacyPass, "vector-combine",
525a17f03bdSSanjay Patel                       "Optimize scalar/vector ops", false,
526a17f03bdSSanjay Patel                       false)
527a17f03bdSSanjay Patel INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
528a17f03bdSSanjay Patel INITIALIZE_PASS_END(VectorCombineLegacyPass, "vector-combine",
529a17f03bdSSanjay Patel                     "Optimize scalar/vector ops", false, false)
530a17f03bdSSanjay Patel Pass *llvm::createVectorCombinePass() {
531a17f03bdSSanjay Patel   return new VectorCombineLegacyPass();
532a17f03bdSSanjay Patel }
533a17f03bdSSanjay Patel 
534a17f03bdSSanjay Patel PreservedAnalyses VectorCombinePass::run(Function &F,
535a17f03bdSSanjay Patel                                          FunctionAnalysisManager &FAM) {
536a17f03bdSSanjay Patel   TargetTransformInfo &TTI = FAM.getResult<TargetIRAnalysis>(F);
537a17f03bdSSanjay Patel   DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F);
538a17f03bdSSanjay Patel   if (!runImpl(F, TTI, DT))
539a17f03bdSSanjay Patel     return PreservedAnalyses::all();
540a17f03bdSSanjay Patel   PreservedAnalyses PA;
541a17f03bdSSanjay Patel   PA.preserveSet<CFGAnalyses>();
542a17f03bdSSanjay Patel   PA.preserve<GlobalsAA>();
543024098aeSSanjay Patel   PA.preserve<AAManager>();
544024098aeSSanjay Patel   PA.preserve<BasicAA>();
545a17f03bdSSanjay Patel   return PA;
546a17f03bdSSanjay Patel }
547