15ffd83dbSDimitry Andric //===------- VectorCombine.cpp - Optimize partial vector operations -------===//
25ffd83dbSDimitry Andric //
35ffd83dbSDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
45ffd83dbSDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
55ffd83dbSDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
65ffd83dbSDimitry Andric //
75ffd83dbSDimitry Andric //===----------------------------------------------------------------------===//
85ffd83dbSDimitry Andric //
95ffd83dbSDimitry Andric // This pass optimizes scalar/vector interactions using target cost models. The
105ffd83dbSDimitry Andric // transforms implemented here may not fit in traditional loop-based or SLP
115ffd83dbSDimitry Andric // vectorization passes.
125ffd83dbSDimitry Andric //
135ffd83dbSDimitry Andric //===----------------------------------------------------------------------===//
145ffd83dbSDimitry Andric
155ffd83dbSDimitry Andric #include "llvm/Transforms/Vectorize/VectorCombine.h"
16*c9157d92SDimitry Andric #include "llvm/ADT/DenseMap.h"
17*c9157d92SDimitry Andric #include "llvm/ADT/ScopeExit.h"
185ffd83dbSDimitry Andric #include "llvm/ADT/Statistic.h"
19fe6060f1SDimitry Andric #include "llvm/Analysis/AssumptionCache.h"
205ffd83dbSDimitry Andric #include "llvm/Analysis/BasicAliasAnalysis.h"
215ffd83dbSDimitry Andric #include "llvm/Analysis/GlobalsModRef.h"
22e8d8bef9SDimitry Andric #include "llvm/Analysis/Loads.h"
235ffd83dbSDimitry Andric #include "llvm/Analysis/TargetTransformInfo.h"
245ffd83dbSDimitry Andric #include "llvm/Analysis/ValueTracking.h"
255ffd83dbSDimitry Andric #include "llvm/Analysis/VectorUtils.h"
265ffd83dbSDimitry Andric #include "llvm/IR/Dominators.h"
275ffd83dbSDimitry Andric #include "llvm/IR/Function.h"
285ffd83dbSDimitry Andric #include "llvm/IR/IRBuilder.h"
295ffd83dbSDimitry Andric #include "llvm/IR/PatternMatch.h"
305ffd83dbSDimitry Andric #include "llvm/Support/CommandLine.h"
315ffd83dbSDimitry Andric #include "llvm/Transforms/Utils/Local.h"
32bdd1243dSDimitry Andric #include <numeric>
33*c9157d92SDimitry Andric #include <queue>
345ffd83dbSDimitry Andric
35349cc55cSDimitry Andric #define DEBUG_TYPE "vector-combine"
36349cc55cSDimitry Andric #include "llvm/Transforms/Utils/InstructionWorklist.h"
37349cc55cSDimitry Andric
385ffd83dbSDimitry Andric using namespace llvm;
395ffd83dbSDimitry Andric using namespace llvm::PatternMatch;
405ffd83dbSDimitry Andric
41e8d8bef9SDimitry Andric STATISTIC(NumVecLoad, "Number of vector loads formed");
425ffd83dbSDimitry Andric STATISTIC(NumVecCmp, "Number of vector compares formed");
435ffd83dbSDimitry Andric STATISTIC(NumVecBO, "Number of vector binops formed");
445ffd83dbSDimitry Andric STATISTIC(NumVecCmpBO, "Number of vector compare + binop formed");
455ffd83dbSDimitry Andric STATISTIC(NumShufOfBitcast, "Number of shuffles moved after bitcast");
465ffd83dbSDimitry Andric STATISTIC(NumScalarBO, "Number of scalar binops formed");
475ffd83dbSDimitry Andric STATISTIC(NumScalarCmp, "Number of scalar compares formed");
485ffd83dbSDimitry Andric
495ffd83dbSDimitry Andric static cl::opt<bool> DisableVectorCombine(
505ffd83dbSDimitry Andric "disable-vector-combine", cl::init(false), cl::Hidden,
515ffd83dbSDimitry Andric cl::desc("Disable all vector combine transforms"));
525ffd83dbSDimitry Andric
535ffd83dbSDimitry Andric static cl::opt<bool> DisableBinopExtractShuffle(
545ffd83dbSDimitry Andric "disable-binop-extract-shuffle", cl::init(false), cl::Hidden,
555ffd83dbSDimitry Andric cl::desc("Disable binop extract to shuffle transforms"));
565ffd83dbSDimitry Andric
57fe6060f1SDimitry Andric static cl::opt<unsigned> MaxInstrsToScan(
58fe6060f1SDimitry Andric "vector-combine-max-scan-instrs", cl::init(30), cl::Hidden,
59fe6060f1SDimitry Andric cl::desc("Max number of instructions to scan for vector combining."));
60fe6060f1SDimitry Andric
615ffd83dbSDimitry Andric static const unsigned InvalidIndex = std::numeric_limits<unsigned>::max();
625ffd83dbSDimitry Andric
635ffd83dbSDimitry Andric namespace {
645ffd83dbSDimitry Andric class VectorCombine {
655ffd83dbSDimitry Andric public:
VectorCombine(Function & F,const TargetTransformInfo & TTI,const DominatorTree & DT,AAResults & AA,AssumptionCache & AC,bool TryEarlyFoldsOnly)665ffd83dbSDimitry Andric VectorCombine(Function &F, const TargetTransformInfo &TTI,
67349cc55cSDimitry Andric const DominatorTree &DT, AAResults &AA, AssumptionCache &AC,
68bdd1243dSDimitry Andric bool TryEarlyFoldsOnly)
69349cc55cSDimitry Andric : F(F), Builder(F.getContext()), TTI(TTI), DT(DT), AA(AA), AC(AC),
70bdd1243dSDimitry Andric TryEarlyFoldsOnly(TryEarlyFoldsOnly) {}
715ffd83dbSDimitry Andric
725ffd83dbSDimitry Andric bool run();
735ffd83dbSDimitry Andric
745ffd83dbSDimitry Andric private:
755ffd83dbSDimitry Andric Function &F;
765ffd83dbSDimitry Andric IRBuilder<> Builder;
775ffd83dbSDimitry Andric const TargetTransformInfo &TTI;
785ffd83dbSDimitry Andric const DominatorTree &DT;
79fe6060f1SDimitry Andric AAResults &AA;
80fe6060f1SDimitry Andric AssumptionCache &AC;
815ffd83dbSDimitry Andric
82bdd1243dSDimitry Andric /// If true, only perform beneficial early IR transforms. Do not introduce new
83349cc55cSDimitry Andric /// vector operations.
84bdd1243dSDimitry Andric bool TryEarlyFoldsOnly;
85349cc55cSDimitry Andric
86349cc55cSDimitry Andric InstructionWorklist Worklist;
87349cc55cSDimitry Andric
88bdd1243dSDimitry Andric // TODO: Direct calls from the top-level "run" loop use a plain "Instruction"
89bdd1243dSDimitry Andric // parameter. That should be updated to specific sub-classes because the
90bdd1243dSDimitry Andric // run loop was changed to dispatch on opcode.
91e8d8bef9SDimitry Andric bool vectorizeLoadInsert(Instruction &I);
92bdd1243dSDimitry Andric bool widenSubvectorLoad(Instruction &I);
935ffd83dbSDimitry Andric ExtractElementInst *getShuffleExtract(ExtractElementInst *Ext0,
945ffd83dbSDimitry Andric ExtractElementInst *Ext1,
955ffd83dbSDimitry Andric unsigned PreferredExtractIndex) const;
965ffd83dbSDimitry Andric bool isExtractExtractCheap(ExtractElementInst *Ext0, ExtractElementInst *Ext1,
97349cc55cSDimitry Andric const Instruction &I,
985ffd83dbSDimitry Andric ExtractElementInst *&ConvertToShuffle,
995ffd83dbSDimitry Andric unsigned PreferredExtractIndex);
1005ffd83dbSDimitry Andric void foldExtExtCmp(ExtractElementInst *Ext0, ExtractElementInst *Ext1,
1015ffd83dbSDimitry Andric Instruction &I);
1025ffd83dbSDimitry Andric void foldExtExtBinop(ExtractElementInst *Ext0, ExtractElementInst *Ext1,
1035ffd83dbSDimitry Andric Instruction &I);
1045ffd83dbSDimitry Andric bool foldExtractExtract(Instruction &I);
105bdd1243dSDimitry Andric bool foldInsExtFNeg(Instruction &I);
106*c9157d92SDimitry Andric bool foldBitcastShuffle(Instruction &I);
1075ffd83dbSDimitry Andric bool scalarizeBinopOrCmp(Instruction &I);
108*c9157d92SDimitry Andric bool scalarizeVPIntrinsic(Instruction &I);
1095ffd83dbSDimitry Andric bool foldExtractedCmps(Instruction &I);
110fe6060f1SDimitry Andric bool foldSingleElementStore(Instruction &I);
111fe6060f1SDimitry Andric bool scalarizeLoadExtract(Instruction &I);
112349cc55cSDimitry Andric bool foldShuffleOfBinops(Instruction &I);
11381ad6265SDimitry Andric bool foldShuffleFromReductions(Instruction &I);
11481ad6265SDimitry Andric bool foldSelectShuffle(Instruction &I, bool FromReduction = false);
1155ffd83dbSDimitry Andric
replaceValue(Value & Old,Value & New)116349cc55cSDimitry Andric void replaceValue(Value &Old, Value &New) {
1175ffd83dbSDimitry Andric Old.replaceAllUsesWith(&New);
118349cc55cSDimitry Andric if (auto *NewI = dyn_cast<Instruction>(&New)) {
11981ad6265SDimitry Andric New.takeName(&Old);
120349cc55cSDimitry Andric Worklist.pushUsersToWorkList(*NewI);
121349cc55cSDimitry Andric Worklist.pushValue(NewI);
1225ffd83dbSDimitry Andric }
123349cc55cSDimitry Andric Worklist.pushValue(&Old);
124349cc55cSDimitry Andric }
125349cc55cSDimitry Andric
eraseInstruction(Instruction & I)126349cc55cSDimitry Andric void eraseInstruction(Instruction &I) {
127349cc55cSDimitry Andric for (Value *Op : I.operands())
128349cc55cSDimitry Andric Worklist.pushValue(Op);
129349cc55cSDimitry Andric Worklist.remove(&I);
130349cc55cSDimitry Andric I.eraseFromParent();
131349cc55cSDimitry Andric }
132349cc55cSDimitry Andric };
133349cc55cSDimitry Andric } // namespace
1345ffd83dbSDimitry Andric
canWidenLoad(LoadInst * Load,const TargetTransformInfo & TTI)135bdd1243dSDimitry Andric static bool canWidenLoad(LoadInst *Load, const TargetTransformInfo &TTI) {
136bdd1243dSDimitry Andric // Do not widen load if atomic/volatile or under asan/hwasan/memtag/tsan.
137bdd1243dSDimitry Andric // The widened load may load data from dirty regions or create data races
138bdd1243dSDimitry Andric // non-existent in the source.
139bdd1243dSDimitry Andric if (!Load || !Load->isSimple() || !Load->hasOneUse() ||
140bdd1243dSDimitry Andric Load->getFunction()->hasFnAttribute(Attribute::SanitizeMemTag) ||
141bdd1243dSDimitry Andric mustSuppressSpeculation(*Load))
142bdd1243dSDimitry Andric return false;
143bdd1243dSDimitry Andric
144bdd1243dSDimitry Andric // We are potentially transforming byte-sized (8-bit) memory accesses, so make
145bdd1243dSDimitry Andric // sure we have all of our type-based constraints in place for this target.
146bdd1243dSDimitry Andric Type *ScalarTy = Load->getType()->getScalarType();
147bdd1243dSDimitry Andric uint64_t ScalarSize = ScalarTy->getPrimitiveSizeInBits();
148bdd1243dSDimitry Andric unsigned MinVectorSize = TTI.getMinVectorRegisterBitWidth();
149bdd1243dSDimitry Andric if (!ScalarSize || !MinVectorSize || MinVectorSize % ScalarSize != 0 ||
150bdd1243dSDimitry Andric ScalarSize % 8 != 0)
151bdd1243dSDimitry Andric return false;
152bdd1243dSDimitry Andric
153bdd1243dSDimitry Andric return true;
154bdd1243dSDimitry Andric }
155bdd1243dSDimitry Andric
vectorizeLoadInsert(Instruction & I)156e8d8bef9SDimitry Andric bool VectorCombine::vectorizeLoadInsert(Instruction &I) {
157e8d8bef9SDimitry Andric // Match insert into fixed vector of scalar value.
158e8d8bef9SDimitry Andric // TODO: Handle non-zero insert index.
159e8d8bef9SDimitry Andric Value *Scalar;
160bdd1243dSDimitry Andric if (!match(&I, m_InsertElt(m_Undef(), m_Value(Scalar), m_ZeroInt())) ||
161e8d8bef9SDimitry Andric !Scalar->hasOneUse())
162e8d8bef9SDimitry Andric return false;
163e8d8bef9SDimitry Andric
164e8d8bef9SDimitry Andric // Optionally match an extract from another vector.
165e8d8bef9SDimitry Andric Value *X;
166e8d8bef9SDimitry Andric bool HasExtract = match(Scalar, m_ExtractElt(m_Value(X), m_ZeroInt()));
167e8d8bef9SDimitry Andric if (!HasExtract)
168e8d8bef9SDimitry Andric X = Scalar;
169e8d8bef9SDimitry Andric
170e8d8bef9SDimitry Andric auto *Load = dyn_cast<LoadInst>(X);
171bdd1243dSDimitry Andric if (!canWidenLoad(Load, TTI))
172e8d8bef9SDimitry Andric return false;
173e8d8bef9SDimitry Andric
174e8d8bef9SDimitry Andric Type *ScalarTy = Scalar->getType();
175e8d8bef9SDimitry Andric uint64_t ScalarSize = ScalarTy->getPrimitiveSizeInBits();
176e8d8bef9SDimitry Andric unsigned MinVectorSize = TTI.getMinVectorRegisterBitWidth();
177e8d8bef9SDimitry Andric
178e8d8bef9SDimitry Andric // Check safety of replacing the scalar load with a larger vector load.
179e8d8bef9SDimitry Andric // We use minimal alignment (maximum flexibility) because we only care about
180e8d8bef9SDimitry Andric // the dereferenceable region. When calculating cost and creating a new op,
181e8d8bef9SDimitry Andric // we may use a larger value based on alignment attributes.
182bdd1243dSDimitry Andric const DataLayout &DL = I.getModule()->getDataLayout();
183bdd1243dSDimitry Andric Value *SrcPtr = Load->getPointerOperand()->stripPointerCasts();
184bdd1243dSDimitry Andric assert(isa<PointerType>(SrcPtr->getType()) && "Expected a pointer type");
185bdd1243dSDimitry Andric
186e8d8bef9SDimitry Andric unsigned MinVecNumElts = MinVectorSize / ScalarSize;
187e8d8bef9SDimitry Andric auto *MinVecTy = VectorType::get(ScalarTy, MinVecNumElts, false);
188e8d8bef9SDimitry Andric unsigned OffsetEltIndex = 0;
189e8d8bef9SDimitry Andric Align Alignment = Load->getAlign();
190bdd1243dSDimitry Andric if (!isSafeToLoadUnconditionally(SrcPtr, MinVecTy, Align(1), DL, Load, &AC,
191bdd1243dSDimitry Andric &DT)) {
192e8d8bef9SDimitry Andric // It is not safe to load directly from the pointer, but we can still peek
193e8d8bef9SDimitry Andric // through gep offsets and check if it safe to load from a base address with
194e8d8bef9SDimitry Andric // updated alignment. If it is, we can shuffle the element(s) into place
195e8d8bef9SDimitry Andric // after loading.
196e8d8bef9SDimitry Andric unsigned OffsetBitWidth = DL.getIndexTypeSizeInBits(SrcPtr->getType());
197e8d8bef9SDimitry Andric APInt Offset(OffsetBitWidth, 0);
198e8d8bef9SDimitry Andric SrcPtr = SrcPtr->stripAndAccumulateInBoundsConstantOffsets(DL, Offset);
199e8d8bef9SDimitry Andric
200e8d8bef9SDimitry Andric // We want to shuffle the result down from a high element of a vector, so
201e8d8bef9SDimitry Andric // the offset must be positive.
202e8d8bef9SDimitry Andric if (Offset.isNegative())
203e8d8bef9SDimitry Andric return false;
204e8d8bef9SDimitry Andric
205e8d8bef9SDimitry Andric // The offset must be a multiple of the scalar element to shuffle cleanly
206e8d8bef9SDimitry Andric // in the element's size.
207e8d8bef9SDimitry Andric uint64_t ScalarSizeInBytes = ScalarSize / 8;
208e8d8bef9SDimitry Andric if (Offset.urem(ScalarSizeInBytes) != 0)
209e8d8bef9SDimitry Andric return false;
210e8d8bef9SDimitry Andric
211e8d8bef9SDimitry Andric // If we load MinVecNumElts, will our target element still be loaded?
212e8d8bef9SDimitry Andric OffsetEltIndex = Offset.udiv(ScalarSizeInBytes).getZExtValue();
213e8d8bef9SDimitry Andric if (OffsetEltIndex >= MinVecNumElts)
214e8d8bef9SDimitry Andric return false;
215e8d8bef9SDimitry Andric
216bdd1243dSDimitry Andric if (!isSafeToLoadUnconditionally(SrcPtr, MinVecTy, Align(1), DL, Load, &AC,
217bdd1243dSDimitry Andric &DT))
218e8d8bef9SDimitry Andric return false;
219e8d8bef9SDimitry Andric
220e8d8bef9SDimitry Andric // Update alignment with offset value. Note that the offset could be negated
221e8d8bef9SDimitry Andric // to more accurately represent "(new) SrcPtr - Offset = (old) SrcPtr", but
222e8d8bef9SDimitry Andric // negation does not change the result of the alignment calculation.
223e8d8bef9SDimitry Andric Alignment = commonAlignment(Alignment, Offset.getZExtValue());
224e8d8bef9SDimitry Andric }
225e8d8bef9SDimitry Andric
226e8d8bef9SDimitry Andric // Original pattern: insertelt undef, load [free casts of] PtrOp, 0
227e8d8bef9SDimitry Andric // Use the greater of the alignment on the load or its source pointer.
228e8d8bef9SDimitry Andric Alignment = std::max(SrcPtr->getPointerAlignment(DL), Alignment);
229e8d8bef9SDimitry Andric Type *LoadTy = Load->getType();
230bdd1243dSDimitry Andric unsigned AS = Load->getPointerAddressSpace();
231e8d8bef9SDimitry Andric InstructionCost OldCost =
232e8d8bef9SDimitry Andric TTI.getMemoryOpCost(Instruction::Load, LoadTy, Alignment, AS);
233e8d8bef9SDimitry Andric APInt DemandedElts = APInt::getOneBitSet(MinVecNumElts, 0);
234bdd1243dSDimitry Andric TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
235bdd1243dSDimitry Andric OldCost +=
236bdd1243dSDimitry Andric TTI.getScalarizationOverhead(MinVecTy, DemandedElts,
237bdd1243dSDimitry Andric /* Insert */ true, HasExtract, CostKind);
238e8d8bef9SDimitry Andric
239e8d8bef9SDimitry Andric // New pattern: load VecPtr
240e8d8bef9SDimitry Andric InstructionCost NewCost =
241e8d8bef9SDimitry Andric TTI.getMemoryOpCost(Instruction::Load, MinVecTy, Alignment, AS);
242e8d8bef9SDimitry Andric // Optionally, we are shuffling the loaded vector element(s) into place.
243fe6060f1SDimitry Andric // For the mask set everything but element 0 to undef to prevent poison from
244fe6060f1SDimitry Andric // propagating from the extra loaded memory. This will also optionally
245fe6060f1SDimitry Andric // shrink/grow the vector from the loaded size to the output size.
246fe6060f1SDimitry Andric // We assume this operation has no cost in codegen if there was no offset.
247fe6060f1SDimitry Andric // Note that we could use freeze to avoid poison problems, but then we might
248fe6060f1SDimitry Andric // still need a shuffle to change the vector size.
249bdd1243dSDimitry Andric auto *Ty = cast<FixedVectorType>(I.getType());
250fe6060f1SDimitry Andric unsigned OutputNumElts = Ty->getNumElements();
251fe013be4SDimitry Andric SmallVector<int, 16> Mask(OutputNumElts, PoisonMaskElem);
252fe6060f1SDimitry Andric assert(OffsetEltIndex < MinVecNumElts && "Address offset too big");
253fe6060f1SDimitry Andric Mask[0] = OffsetEltIndex;
254e8d8bef9SDimitry Andric if (OffsetEltIndex)
255fe6060f1SDimitry Andric NewCost += TTI.getShuffleCost(TTI::SK_PermuteSingleSrc, MinVecTy, Mask);
256e8d8bef9SDimitry Andric
257e8d8bef9SDimitry Andric // We can aggressively convert to the vector form because the backend can
258e8d8bef9SDimitry Andric // invert this transform if it does not result in a performance win.
259e8d8bef9SDimitry Andric if (OldCost < NewCost || !NewCost.isValid())
260e8d8bef9SDimitry Andric return false;
261e8d8bef9SDimitry Andric
262e8d8bef9SDimitry Andric // It is safe and potentially profitable to load a vector directly:
263e8d8bef9SDimitry Andric // inselt undef, load Scalar, 0 --> load VecPtr
264e8d8bef9SDimitry Andric IRBuilder<> Builder(Load);
265*c9157d92SDimitry Andric Value *CastedPtr =
266*c9157d92SDimitry Andric Builder.CreatePointerBitCastOrAddrSpaceCast(SrcPtr, Builder.getPtrTy(AS));
267e8d8bef9SDimitry Andric Value *VecLd = Builder.CreateAlignedLoad(MinVecTy, CastedPtr, Alignment);
268e8d8bef9SDimitry Andric VecLd = Builder.CreateShuffleVector(VecLd, Mask);
269e8d8bef9SDimitry Andric
270e8d8bef9SDimitry Andric replaceValue(I, *VecLd);
271e8d8bef9SDimitry Andric ++NumVecLoad;
272e8d8bef9SDimitry Andric return true;
273e8d8bef9SDimitry Andric }
274e8d8bef9SDimitry Andric
275bdd1243dSDimitry Andric /// If we are loading a vector and then inserting it into a larger vector with
276bdd1243dSDimitry Andric /// undefined elements, try to load the larger vector and eliminate the insert.
277bdd1243dSDimitry Andric /// This removes a shuffle in IR and may allow combining of other loaded values.
widenSubvectorLoad(Instruction & I)278bdd1243dSDimitry Andric bool VectorCombine::widenSubvectorLoad(Instruction &I) {
279bdd1243dSDimitry Andric // Match subvector insert of fixed vector.
280bdd1243dSDimitry Andric auto *Shuf = cast<ShuffleVectorInst>(&I);
281bdd1243dSDimitry Andric if (!Shuf->isIdentityWithPadding())
282bdd1243dSDimitry Andric return false;
283bdd1243dSDimitry Andric
284bdd1243dSDimitry Andric // Allow a non-canonical shuffle mask that is choosing elements from op1.
285bdd1243dSDimitry Andric unsigned NumOpElts =
286bdd1243dSDimitry Andric cast<FixedVectorType>(Shuf->getOperand(0)->getType())->getNumElements();
287bdd1243dSDimitry Andric unsigned OpIndex = any_of(Shuf->getShuffleMask(), [&NumOpElts](int M) {
288bdd1243dSDimitry Andric return M >= (int)(NumOpElts);
289bdd1243dSDimitry Andric });
290bdd1243dSDimitry Andric
291bdd1243dSDimitry Andric auto *Load = dyn_cast<LoadInst>(Shuf->getOperand(OpIndex));
292bdd1243dSDimitry Andric if (!canWidenLoad(Load, TTI))
293bdd1243dSDimitry Andric return false;
294bdd1243dSDimitry Andric
295bdd1243dSDimitry Andric // We use minimal alignment (maximum flexibility) because we only care about
296bdd1243dSDimitry Andric // the dereferenceable region. When calculating cost and creating a new op,
297bdd1243dSDimitry Andric // we may use a larger value based on alignment attributes.
298bdd1243dSDimitry Andric auto *Ty = cast<FixedVectorType>(I.getType());
299bdd1243dSDimitry Andric const DataLayout &DL = I.getModule()->getDataLayout();
300bdd1243dSDimitry Andric Value *SrcPtr = Load->getPointerOperand()->stripPointerCasts();
301bdd1243dSDimitry Andric assert(isa<PointerType>(SrcPtr->getType()) && "Expected a pointer type");
302bdd1243dSDimitry Andric Align Alignment = Load->getAlign();
303bdd1243dSDimitry Andric if (!isSafeToLoadUnconditionally(SrcPtr, Ty, Align(1), DL, Load, &AC, &DT))
304bdd1243dSDimitry Andric return false;
305bdd1243dSDimitry Andric
306bdd1243dSDimitry Andric Alignment = std::max(SrcPtr->getPointerAlignment(DL), Alignment);
307bdd1243dSDimitry Andric Type *LoadTy = Load->getType();
308bdd1243dSDimitry Andric unsigned AS = Load->getPointerAddressSpace();
309bdd1243dSDimitry Andric
310bdd1243dSDimitry Andric // Original pattern: insert_subvector (load PtrOp)
311bdd1243dSDimitry Andric // This conservatively assumes that the cost of a subvector insert into an
312bdd1243dSDimitry Andric // undef value is 0. We could add that cost if the cost model accurately
313bdd1243dSDimitry Andric // reflects the real cost of that operation.
314bdd1243dSDimitry Andric InstructionCost OldCost =
315bdd1243dSDimitry Andric TTI.getMemoryOpCost(Instruction::Load, LoadTy, Alignment, AS);
316bdd1243dSDimitry Andric
317bdd1243dSDimitry Andric // New pattern: load PtrOp
318bdd1243dSDimitry Andric InstructionCost NewCost =
319bdd1243dSDimitry Andric TTI.getMemoryOpCost(Instruction::Load, Ty, Alignment, AS);
320bdd1243dSDimitry Andric
321bdd1243dSDimitry Andric // We can aggressively convert to the vector form because the backend can
322bdd1243dSDimitry Andric // invert this transform if it does not result in a performance win.
323bdd1243dSDimitry Andric if (OldCost < NewCost || !NewCost.isValid())
324bdd1243dSDimitry Andric return false;
325bdd1243dSDimitry Andric
326bdd1243dSDimitry Andric IRBuilder<> Builder(Load);
327bdd1243dSDimitry Andric Value *CastedPtr =
328*c9157d92SDimitry Andric Builder.CreatePointerBitCastOrAddrSpaceCast(SrcPtr, Builder.getPtrTy(AS));
329bdd1243dSDimitry Andric Value *VecLd = Builder.CreateAlignedLoad(Ty, CastedPtr, Alignment);
330bdd1243dSDimitry Andric replaceValue(I, *VecLd);
331bdd1243dSDimitry Andric ++NumVecLoad;
332bdd1243dSDimitry Andric return true;
333bdd1243dSDimitry Andric }
334bdd1243dSDimitry Andric
3355ffd83dbSDimitry Andric /// Determine which, if any, of the inputs should be replaced by a shuffle
3365ffd83dbSDimitry Andric /// followed by extract from a different index.
getShuffleExtract(ExtractElementInst * Ext0,ExtractElementInst * Ext1,unsigned PreferredExtractIndex=InvalidIndex) const3375ffd83dbSDimitry Andric ExtractElementInst *VectorCombine::getShuffleExtract(
3385ffd83dbSDimitry Andric ExtractElementInst *Ext0, ExtractElementInst *Ext1,
3395ffd83dbSDimitry Andric unsigned PreferredExtractIndex = InvalidIndex) const {
34081ad6265SDimitry Andric auto *Index0C = dyn_cast<ConstantInt>(Ext0->getIndexOperand());
34181ad6265SDimitry Andric auto *Index1C = dyn_cast<ConstantInt>(Ext1->getIndexOperand());
34281ad6265SDimitry Andric assert(Index0C && Index1C && "Expected constant extract indexes");
3435ffd83dbSDimitry Andric
34481ad6265SDimitry Andric unsigned Index0 = Index0C->getZExtValue();
34581ad6265SDimitry Andric unsigned Index1 = Index1C->getZExtValue();
3465ffd83dbSDimitry Andric
3475ffd83dbSDimitry Andric // If the extract indexes are identical, no shuffle is needed.
3485ffd83dbSDimitry Andric if (Index0 == Index1)
3495ffd83dbSDimitry Andric return nullptr;
3505ffd83dbSDimitry Andric
3515ffd83dbSDimitry Andric Type *VecTy = Ext0->getVectorOperand()->getType();
352bdd1243dSDimitry Andric TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
3535ffd83dbSDimitry Andric assert(VecTy == Ext1->getVectorOperand()->getType() && "Need matching types");
354e8d8bef9SDimitry Andric InstructionCost Cost0 =
355bdd1243dSDimitry Andric TTI.getVectorInstrCost(*Ext0, VecTy, CostKind, Index0);
356e8d8bef9SDimitry Andric InstructionCost Cost1 =
357bdd1243dSDimitry Andric TTI.getVectorInstrCost(*Ext1, VecTy, CostKind, Index1);
358e8d8bef9SDimitry Andric
359e8d8bef9SDimitry Andric // If both costs are invalid no shuffle is needed
360e8d8bef9SDimitry Andric if (!Cost0.isValid() && !Cost1.isValid())
361e8d8bef9SDimitry Andric return nullptr;
3625ffd83dbSDimitry Andric
3635ffd83dbSDimitry Andric // We are extracting from 2 different indexes, so one operand must be shuffled
3645ffd83dbSDimitry Andric // before performing a vector operation and/or extract. The more expensive
3655ffd83dbSDimitry Andric // extract will be replaced by a shuffle.
3665ffd83dbSDimitry Andric if (Cost0 > Cost1)
3675ffd83dbSDimitry Andric return Ext0;
3685ffd83dbSDimitry Andric if (Cost1 > Cost0)
3695ffd83dbSDimitry Andric return Ext1;
3705ffd83dbSDimitry Andric
3715ffd83dbSDimitry Andric // If the costs are equal and there is a preferred extract index, shuffle the
3725ffd83dbSDimitry Andric // opposite operand.
3735ffd83dbSDimitry Andric if (PreferredExtractIndex == Index0)
3745ffd83dbSDimitry Andric return Ext1;
3755ffd83dbSDimitry Andric if (PreferredExtractIndex == Index1)
3765ffd83dbSDimitry Andric return Ext0;
3775ffd83dbSDimitry Andric
3785ffd83dbSDimitry Andric // Otherwise, replace the extract with the higher index.
3795ffd83dbSDimitry Andric return Index0 > Index1 ? Ext0 : Ext1;
3805ffd83dbSDimitry Andric }
3815ffd83dbSDimitry Andric
3825ffd83dbSDimitry Andric /// Compare the relative costs of 2 extracts followed by scalar operation vs.
3835ffd83dbSDimitry Andric /// vector operation(s) followed by extract. Return true if the existing
3845ffd83dbSDimitry Andric /// instructions are cheaper than a vector alternative. Otherwise, return false
3855ffd83dbSDimitry Andric /// and if one of the extracts should be transformed to a shufflevector, set
3865ffd83dbSDimitry Andric /// \p ConvertToShuffle to that extract instruction.
isExtractExtractCheap(ExtractElementInst * Ext0,ExtractElementInst * Ext1,const Instruction & I,ExtractElementInst * & ConvertToShuffle,unsigned PreferredExtractIndex)3875ffd83dbSDimitry Andric bool VectorCombine::isExtractExtractCheap(ExtractElementInst *Ext0,
3885ffd83dbSDimitry Andric ExtractElementInst *Ext1,
389349cc55cSDimitry Andric const Instruction &I,
3905ffd83dbSDimitry Andric ExtractElementInst *&ConvertToShuffle,
3915ffd83dbSDimitry Andric unsigned PreferredExtractIndex) {
39281ad6265SDimitry Andric auto *Ext0IndexC = dyn_cast<ConstantInt>(Ext0->getOperand(1));
39381ad6265SDimitry Andric auto *Ext1IndexC = dyn_cast<ConstantInt>(Ext1->getOperand(1));
39481ad6265SDimitry Andric assert(Ext0IndexC && Ext1IndexC && "Expected constant extract indexes");
39581ad6265SDimitry Andric
396349cc55cSDimitry Andric unsigned Opcode = I.getOpcode();
3975ffd83dbSDimitry Andric Type *ScalarTy = Ext0->getType();
3985ffd83dbSDimitry Andric auto *VecTy = cast<VectorType>(Ext0->getOperand(0)->getType());
399e8d8bef9SDimitry Andric InstructionCost ScalarOpCost, VectorOpCost;
4005ffd83dbSDimitry Andric
4015ffd83dbSDimitry Andric // Get cost estimates for scalar and vector versions of the operation.
4025ffd83dbSDimitry Andric bool IsBinOp = Instruction::isBinaryOp(Opcode);
4035ffd83dbSDimitry Andric if (IsBinOp) {
4045ffd83dbSDimitry Andric ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, ScalarTy);
4055ffd83dbSDimitry Andric VectorOpCost = TTI.getArithmeticInstrCost(Opcode, VecTy);
4065ffd83dbSDimitry Andric } else {
4075ffd83dbSDimitry Andric assert((Opcode == Instruction::ICmp || Opcode == Instruction::FCmp) &&
4085ffd83dbSDimitry Andric "Expected a compare");
409349cc55cSDimitry Andric CmpInst::Predicate Pred = cast<CmpInst>(I).getPredicate();
410349cc55cSDimitry Andric ScalarOpCost = TTI.getCmpSelInstrCost(
411349cc55cSDimitry Andric Opcode, ScalarTy, CmpInst::makeCmpResultType(ScalarTy), Pred);
412349cc55cSDimitry Andric VectorOpCost = TTI.getCmpSelInstrCost(
413349cc55cSDimitry Andric Opcode, VecTy, CmpInst::makeCmpResultType(VecTy), Pred);
4145ffd83dbSDimitry Andric }
4155ffd83dbSDimitry Andric
4165ffd83dbSDimitry Andric // Get cost estimates for the extract elements. These costs will factor into
4175ffd83dbSDimitry Andric // both sequences.
41881ad6265SDimitry Andric unsigned Ext0Index = Ext0IndexC->getZExtValue();
41981ad6265SDimitry Andric unsigned Ext1Index = Ext1IndexC->getZExtValue();
420bdd1243dSDimitry Andric TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
4215ffd83dbSDimitry Andric
422e8d8bef9SDimitry Andric InstructionCost Extract0Cost =
423bdd1243dSDimitry Andric TTI.getVectorInstrCost(*Ext0, VecTy, CostKind, Ext0Index);
424e8d8bef9SDimitry Andric InstructionCost Extract1Cost =
425bdd1243dSDimitry Andric TTI.getVectorInstrCost(*Ext1, VecTy, CostKind, Ext1Index);
4265ffd83dbSDimitry Andric
4275ffd83dbSDimitry Andric // A more expensive extract will always be replaced by a splat shuffle.
4285ffd83dbSDimitry Andric // For example, if Ext0 is more expensive:
4295ffd83dbSDimitry Andric // opcode (extelt V0, Ext0), (ext V1, Ext1) -->
4305ffd83dbSDimitry Andric // extelt (opcode (splat V0, Ext0), V1), Ext1
4315ffd83dbSDimitry Andric // TODO: Evaluate whether that always results in lowest cost. Alternatively,
4325ffd83dbSDimitry Andric // check the cost of creating a broadcast shuffle and shuffling both
4335ffd83dbSDimitry Andric // operands to element 0.
434e8d8bef9SDimitry Andric InstructionCost CheapExtractCost = std::min(Extract0Cost, Extract1Cost);
4355ffd83dbSDimitry Andric
4365ffd83dbSDimitry Andric // Extra uses of the extracts mean that we include those costs in the
4375ffd83dbSDimitry Andric // vector total because those instructions will not be eliminated.
438e8d8bef9SDimitry Andric InstructionCost OldCost, NewCost;
4395ffd83dbSDimitry Andric if (Ext0->getOperand(0) == Ext1->getOperand(0) && Ext0Index == Ext1Index) {
4405ffd83dbSDimitry Andric // Handle a special case. If the 2 extracts are identical, adjust the
4415ffd83dbSDimitry Andric // formulas to account for that. The extra use charge allows for either the
4425ffd83dbSDimitry Andric // CSE'd pattern or an unoptimized form with identical values:
4435ffd83dbSDimitry Andric // opcode (extelt V, C), (extelt V, C) --> extelt (opcode V, V), C
4445ffd83dbSDimitry Andric bool HasUseTax = Ext0 == Ext1 ? !Ext0->hasNUses(2)
4455ffd83dbSDimitry Andric : !Ext0->hasOneUse() || !Ext1->hasOneUse();
4465ffd83dbSDimitry Andric OldCost = CheapExtractCost + ScalarOpCost;
4475ffd83dbSDimitry Andric NewCost = VectorOpCost + CheapExtractCost + HasUseTax * CheapExtractCost;
4485ffd83dbSDimitry Andric } else {
4495ffd83dbSDimitry Andric // Handle the general case. Each extract is actually a different value:
4505ffd83dbSDimitry Andric // opcode (extelt V0, C0), (extelt V1, C1) --> extelt (opcode V0, V1), C
4515ffd83dbSDimitry Andric OldCost = Extract0Cost + Extract1Cost + ScalarOpCost;
4525ffd83dbSDimitry Andric NewCost = VectorOpCost + CheapExtractCost +
4535ffd83dbSDimitry Andric !Ext0->hasOneUse() * Extract0Cost +
4545ffd83dbSDimitry Andric !Ext1->hasOneUse() * Extract1Cost;
4555ffd83dbSDimitry Andric }
4565ffd83dbSDimitry Andric
4575ffd83dbSDimitry Andric ConvertToShuffle = getShuffleExtract(Ext0, Ext1, PreferredExtractIndex);
4585ffd83dbSDimitry Andric if (ConvertToShuffle) {
4595ffd83dbSDimitry Andric if (IsBinOp && DisableBinopExtractShuffle)
4605ffd83dbSDimitry Andric return true;
4615ffd83dbSDimitry Andric
4625ffd83dbSDimitry Andric // If we are extracting from 2 different indexes, then one operand must be
4635ffd83dbSDimitry Andric // shuffled before performing the vector operation. The shuffle mask is
464fe013be4SDimitry Andric // poison except for 1 lane that is being translated to the remaining
4655ffd83dbSDimitry Andric // extraction lane. Therefore, it is a splat shuffle. Ex:
466fe013be4SDimitry Andric // ShufMask = { poison, poison, 0, poison }
4675ffd83dbSDimitry Andric // TODO: The cost model has an option for a "broadcast" shuffle
4685ffd83dbSDimitry Andric // (splat-from-element-0), but no option for a more general splat.
4695ffd83dbSDimitry Andric NewCost +=
4705ffd83dbSDimitry Andric TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, VecTy);
4715ffd83dbSDimitry Andric }
4725ffd83dbSDimitry Andric
4735ffd83dbSDimitry Andric // Aggressively form a vector op if the cost is equal because the transform
4745ffd83dbSDimitry Andric // may enable further optimization.
4755ffd83dbSDimitry Andric // Codegen can reverse this transform (scalarize) if it was not profitable.
4765ffd83dbSDimitry Andric return OldCost < NewCost;
4775ffd83dbSDimitry Andric }
4785ffd83dbSDimitry Andric
4795ffd83dbSDimitry Andric /// Create a shuffle that translates (shifts) 1 element from the input vector
4805ffd83dbSDimitry Andric /// to a new element location.
createShiftShuffle(Value * Vec,unsigned OldIndex,unsigned NewIndex,IRBuilder<> & Builder)4815ffd83dbSDimitry Andric static Value *createShiftShuffle(Value *Vec, unsigned OldIndex,
4825ffd83dbSDimitry Andric unsigned NewIndex, IRBuilder<> &Builder) {
483fe013be4SDimitry Andric // The shuffle mask is poison except for 1 lane that is being translated
4845ffd83dbSDimitry Andric // to the new element index. Example for OldIndex == 2 and NewIndex == 0:
485fe013be4SDimitry Andric // ShufMask = { 2, poison, poison, poison }
4865ffd83dbSDimitry Andric auto *VecTy = cast<FixedVectorType>(Vec->getType());
487fe013be4SDimitry Andric SmallVector<int, 32> ShufMask(VecTy->getNumElements(), PoisonMaskElem);
4885ffd83dbSDimitry Andric ShufMask[NewIndex] = OldIndex;
489e8d8bef9SDimitry Andric return Builder.CreateShuffleVector(Vec, ShufMask, "shift");
4905ffd83dbSDimitry Andric }
4915ffd83dbSDimitry Andric
4925ffd83dbSDimitry Andric /// Given an extract element instruction with constant index operand, shuffle
4935ffd83dbSDimitry Andric /// the source vector (shift the scalar element) to a NewIndex for extraction.
4945ffd83dbSDimitry Andric /// Return null if the input can be constant folded, so that we are not creating
4955ffd83dbSDimitry Andric /// unnecessary instructions.
translateExtract(ExtractElementInst * ExtElt,unsigned NewIndex,IRBuilder<> & Builder)4965ffd83dbSDimitry Andric static ExtractElementInst *translateExtract(ExtractElementInst *ExtElt,
4975ffd83dbSDimitry Andric unsigned NewIndex,
4985ffd83dbSDimitry Andric IRBuilder<> &Builder) {
499753f127fSDimitry Andric // Shufflevectors can only be created for fixed-width vectors.
500753f127fSDimitry Andric if (!isa<FixedVectorType>(ExtElt->getOperand(0)->getType()))
501753f127fSDimitry Andric return nullptr;
502753f127fSDimitry Andric
5035ffd83dbSDimitry Andric // If the extract can be constant-folded, this code is unsimplified. Defer
5045ffd83dbSDimitry Andric // to other passes to handle that.
5055ffd83dbSDimitry Andric Value *X = ExtElt->getVectorOperand();
5065ffd83dbSDimitry Andric Value *C = ExtElt->getIndexOperand();
5075ffd83dbSDimitry Andric assert(isa<ConstantInt>(C) && "Expected a constant index operand");
5085ffd83dbSDimitry Andric if (isa<Constant>(X))
5095ffd83dbSDimitry Andric return nullptr;
5105ffd83dbSDimitry Andric
5115ffd83dbSDimitry Andric Value *Shuf = createShiftShuffle(X, cast<ConstantInt>(C)->getZExtValue(),
5125ffd83dbSDimitry Andric NewIndex, Builder);
5135ffd83dbSDimitry Andric return cast<ExtractElementInst>(Builder.CreateExtractElement(Shuf, NewIndex));
5145ffd83dbSDimitry Andric }
5155ffd83dbSDimitry Andric
5165ffd83dbSDimitry Andric /// Try to reduce extract element costs by converting scalar compares to vector
5175ffd83dbSDimitry Andric /// compares followed by extract.
5185ffd83dbSDimitry Andric /// cmp (ext0 V0, C), (ext1 V1, C)
foldExtExtCmp(ExtractElementInst * Ext0,ExtractElementInst * Ext1,Instruction & I)5195ffd83dbSDimitry Andric void VectorCombine::foldExtExtCmp(ExtractElementInst *Ext0,
5205ffd83dbSDimitry Andric ExtractElementInst *Ext1, Instruction &I) {
5215ffd83dbSDimitry Andric assert(isa<CmpInst>(&I) && "Expected a compare");
5225ffd83dbSDimitry Andric assert(cast<ConstantInt>(Ext0->getIndexOperand())->getZExtValue() ==
5235ffd83dbSDimitry Andric cast<ConstantInt>(Ext1->getIndexOperand())->getZExtValue() &&
5245ffd83dbSDimitry Andric "Expected matching constant extract indexes");
5255ffd83dbSDimitry Andric
5265ffd83dbSDimitry Andric // cmp Pred (extelt V0, C), (extelt V1, C) --> extelt (cmp Pred V0, V1), C
5275ffd83dbSDimitry Andric ++NumVecCmp;
5285ffd83dbSDimitry Andric CmpInst::Predicate Pred = cast<CmpInst>(&I)->getPredicate();
5295ffd83dbSDimitry Andric Value *V0 = Ext0->getVectorOperand(), *V1 = Ext1->getVectorOperand();
5305ffd83dbSDimitry Andric Value *VecCmp = Builder.CreateCmp(Pred, V0, V1);
5315ffd83dbSDimitry Andric Value *NewExt = Builder.CreateExtractElement(VecCmp, Ext0->getIndexOperand());
5325ffd83dbSDimitry Andric replaceValue(I, *NewExt);
5335ffd83dbSDimitry Andric }
5345ffd83dbSDimitry Andric
5355ffd83dbSDimitry Andric /// Try to reduce extract element costs by converting scalar binops to vector
5365ffd83dbSDimitry Andric /// binops followed by extract.
5375ffd83dbSDimitry Andric /// bo (ext0 V0, C), (ext1 V1, C)
foldExtExtBinop(ExtractElementInst * Ext0,ExtractElementInst * Ext1,Instruction & I)5385ffd83dbSDimitry Andric void VectorCombine::foldExtExtBinop(ExtractElementInst *Ext0,
5395ffd83dbSDimitry Andric ExtractElementInst *Ext1, Instruction &I) {
5405ffd83dbSDimitry Andric assert(isa<BinaryOperator>(&I) && "Expected a binary operator");
5415ffd83dbSDimitry Andric assert(cast<ConstantInt>(Ext0->getIndexOperand())->getZExtValue() ==
5425ffd83dbSDimitry Andric cast<ConstantInt>(Ext1->getIndexOperand())->getZExtValue() &&
5435ffd83dbSDimitry Andric "Expected matching constant extract indexes");
5445ffd83dbSDimitry Andric
5455ffd83dbSDimitry Andric // bo (extelt V0, C), (extelt V1, C) --> extelt (bo V0, V1), C
5465ffd83dbSDimitry Andric ++NumVecBO;
5475ffd83dbSDimitry Andric Value *V0 = Ext0->getVectorOperand(), *V1 = Ext1->getVectorOperand();
5485ffd83dbSDimitry Andric Value *VecBO =
5495ffd83dbSDimitry Andric Builder.CreateBinOp(cast<BinaryOperator>(&I)->getOpcode(), V0, V1);
5505ffd83dbSDimitry Andric
5515ffd83dbSDimitry Andric // All IR flags are safe to back-propagate because any potential poison
5525ffd83dbSDimitry Andric // created in unused vector elements is discarded by the extract.
5535ffd83dbSDimitry Andric if (auto *VecBOInst = dyn_cast<Instruction>(VecBO))
5545ffd83dbSDimitry Andric VecBOInst->copyIRFlags(&I);
5555ffd83dbSDimitry Andric
5565ffd83dbSDimitry Andric Value *NewExt = Builder.CreateExtractElement(VecBO, Ext0->getIndexOperand());
5575ffd83dbSDimitry Andric replaceValue(I, *NewExt);
5585ffd83dbSDimitry Andric }
5595ffd83dbSDimitry Andric
5605ffd83dbSDimitry Andric /// Match an instruction with extracted vector operands.
foldExtractExtract(Instruction & I)5615ffd83dbSDimitry Andric bool VectorCombine::foldExtractExtract(Instruction &I) {
5625ffd83dbSDimitry Andric // It is not safe to transform things like div, urem, etc. because we may
5635ffd83dbSDimitry Andric // create undefined behavior when executing those on unknown vector elements.
5645ffd83dbSDimitry Andric if (!isSafeToSpeculativelyExecute(&I))
5655ffd83dbSDimitry Andric return false;
5665ffd83dbSDimitry Andric
5675ffd83dbSDimitry Andric Instruction *I0, *I1;
5685ffd83dbSDimitry Andric CmpInst::Predicate Pred = CmpInst::BAD_ICMP_PREDICATE;
5695ffd83dbSDimitry Andric if (!match(&I, m_Cmp(Pred, m_Instruction(I0), m_Instruction(I1))) &&
5705ffd83dbSDimitry Andric !match(&I, m_BinOp(m_Instruction(I0), m_Instruction(I1))))
5715ffd83dbSDimitry Andric return false;
5725ffd83dbSDimitry Andric
5735ffd83dbSDimitry Andric Value *V0, *V1;
5745ffd83dbSDimitry Andric uint64_t C0, C1;
5755ffd83dbSDimitry Andric if (!match(I0, m_ExtractElt(m_Value(V0), m_ConstantInt(C0))) ||
5765ffd83dbSDimitry Andric !match(I1, m_ExtractElt(m_Value(V1), m_ConstantInt(C1))) ||
5775ffd83dbSDimitry Andric V0->getType() != V1->getType())
5785ffd83dbSDimitry Andric return false;
5795ffd83dbSDimitry Andric
5805ffd83dbSDimitry Andric // If the scalar value 'I' is going to be re-inserted into a vector, then try
5815ffd83dbSDimitry Andric // to create an extract to that same element. The extract/insert can be
5825ffd83dbSDimitry Andric // reduced to a "select shuffle".
5835ffd83dbSDimitry Andric // TODO: If we add a larger pattern match that starts from an insert, this
5845ffd83dbSDimitry Andric // probably becomes unnecessary.
5855ffd83dbSDimitry Andric auto *Ext0 = cast<ExtractElementInst>(I0);
5865ffd83dbSDimitry Andric auto *Ext1 = cast<ExtractElementInst>(I1);
5875ffd83dbSDimitry Andric uint64_t InsertIndex = InvalidIndex;
5885ffd83dbSDimitry Andric if (I.hasOneUse())
5895ffd83dbSDimitry Andric match(I.user_back(),
5905ffd83dbSDimitry Andric m_InsertElt(m_Value(), m_Value(), m_ConstantInt(InsertIndex)));
5915ffd83dbSDimitry Andric
5925ffd83dbSDimitry Andric ExtractElementInst *ExtractToChange;
593349cc55cSDimitry Andric if (isExtractExtractCheap(Ext0, Ext1, I, ExtractToChange, InsertIndex))
5945ffd83dbSDimitry Andric return false;
5955ffd83dbSDimitry Andric
5965ffd83dbSDimitry Andric if (ExtractToChange) {
5975ffd83dbSDimitry Andric unsigned CheapExtractIdx = ExtractToChange == Ext0 ? C1 : C0;
5985ffd83dbSDimitry Andric ExtractElementInst *NewExtract =
5995ffd83dbSDimitry Andric translateExtract(ExtractToChange, CheapExtractIdx, Builder);
6005ffd83dbSDimitry Andric if (!NewExtract)
6015ffd83dbSDimitry Andric return false;
6025ffd83dbSDimitry Andric if (ExtractToChange == Ext0)
6035ffd83dbSDimitry Andric Ext0 = NewExtract;
6045ffd83dbSDimitry Andric else
6055ffd83dbSDimitry Andric Ext1 = NewExtract;
6065ffd83dbSDimitry Andric }
6075ffd83dbSDimitry Andric
6085ffd83dbSDimitry Andric if (Pred != CmpInst::BAD_ICMP_PREDICATE)
6095ffd83dbSDimitry Andric foldExtExtCmp(Ext0, Ext1, I);
6105ffd83dbSDimitry Andric else
6115ffd83dbSDimitry Andric foldExtExtBinop(Ext0, Ext1, I);
6125ffd83dbSDimitry Andric
613349cc55cSDimitry Andric Worklist.push(Ext0);
614349cc55cSDimitry Andric Worklist.push(Ext1);
6155ffd83dbSDimitry Andric return true;
6165ffd83dbSDimitry Andric }
6175ffd83dbSDimitry Andric
618bdd1243dSDimitry Andric /// Try to replace an extract + scalar fneg + insert with a vector fneg +
619bdd1243dSDimitry Andric /// shuffle.
foldInsExtFNeg(Instruction & I)620bdd1243dSDimitry Andric bool VectorCombine::foldInsExtFNeg(Instruction &I) {
621bdd1243dSDimitry Andric // Match an insert (op (extract)) pattern.
622bdd1243dSDimitry Andric Value *DestVec;
623bdd1243dSDimitry Andric uint64_t Index;
624bdd1243dSDimitry Andric Instruction *FNeg;
625bdd1243dSDimitry Andric if (!match(&I, m_InsertElt(m_Value(DestVec), m_OneUse(m_Instruction(FNeg)),
626bdd1243dSDimitry Andric m_ConstantInt(Index))))
627bdd1243dSDimitry Andric return false;
628bdd1243dSDimitry Andric
629bdd1243dSDimitry Andric // Note: This handles the canonical fneg instruction and "fsub -0.0, X".
630bdd1243dSDimitry Andric Value *SrcVec;
631bdd1243dSDimitry Andric Instruction *Extract;
632bdd1243dSDimitry Andric if (!match(FNeg, m_FNeg(m_CombineAnd(
633bdd1243dSDimitry Andric m_Instruction(Extract),
634bdd1243dSDimitry Andric m_ExtractElt(m_Value(SrcVec), m_SpecificInt(Index))))))
635bdd1243dSDimitry Andric return false;
636bdd1243dSDimitry Andric
637bdd1243dSDimitry Andric // TODO: We could handle this with a length-changing shuffle.
638bdd1243dSDimitry Andric auto *VecTy = cast<FixedVectorType>(I.getType());
639bdd1243dSDimitry Andric if (SrcVec->getType() != VecTy)
640bdd1243dSDimitry Andric return false;
641bdd1243dSDimitry Andric
642bdd1243dSDimitry Andric // Ignore bogus insert/extract index.
643bdd1243dSDimitry Andric unsigned NumElts = VecTy->getNumElements();
644bdd1243dSDimitry Andric if (Index >= NumElts)
645bdd1243dSDimitry Andric return false;
646bdd1243dSDimitry Andric
647bdd1243dSDimitry Andric // We are inserting the negated element into the same lane that we extracted
648bdd1243dSDimitry Andric // from. This is equivalent to a select-shuffle that chooses all but the
649bdd1243dSDimitry Andric // negated element from the destination vector.
650bdd1243dSDimitry Andric SmallVector<int> Mask(NumElts);
651bdd1243dSDimitry Andric std::iota(Mask.begin(), Mask.end(), 0);
652bdd1243dSDimitry Andric Mask[Index] = Index + NumElts;
653bdd1243dSDimitry Andric
654bdd1243dSDimitry Andric Type *ScalarTy = VecTy->getScalarType();
655bdd1243dSDimitry Andric TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
656bdd1243dSDimitry Andric InstructionCost OldCost =
657bdd1243dSDimitry Andric TTI.getArithmeticInstrCost(Instruction::FNeg, ScalarTy) +
658bdd1243dSDimitry Andric TTI.getVectorInstrCost(I, VecTy, CostKind, Index);
659bdd1243dSDimitry Andric
660bdd1243dSDimitry Andric // If the extract has one use, it will be eliminated, so count it in the
661bdd1243dSDimitry Andric // original cost. If it has more than one use, ignore the cost because it will
662bdd1243dSDimitry Andric // be the same before/after.
663bdd1243dSDimitry Andric if (Extract->hasOneUse())
664bdd1243dSDimitry Andric OldCost += TTI.getVectorInstrCost(*Extract, VecTy, CostKind, Index);
665bdd1243dSDimitry Andric
666bdd1243dSDimitry Andric InstructionCost NewCost =
667bdd1243dSDimitry Andric TTI.getArithmeticInstrCost(Instruction::FNeg, VecTy) +
668bdd1243dSDimitry Andric TTI.getShuffleCost(TargetTransformInfo::SK_Select, VecTy, Mask);
669bdd1243dSDimitry Andric
670bdd1243dSDimitry Andric if (NewCost > OldCost)
671bdd1243dSDimitry Andric return false;
672bdd1243dSDimitry Andric
673bdd1243dSDimitry Andric // insertelt DestVec, (fneg (extractelt SrcVec, Index)), Index -->
674bdd1243dSDimitry Andric // shuffle DestVec, (fneg SrcVec), Mask
675bdd1243dSDimitry Andric Value *VecFNeg = Builder.CreateFNegFMF(SrcVec, FNeg);
676bdd1243dSDimitry Andric Value *Shuf = Builder.CreateShuffleVector(DestVec, VecFNeg, Mask);
677bdd1243dSDimitry Andric replaceValue(I, *Shuf);
678bdd1243dSDimitry Andric return true;
679bdd1243dSDimitry Andric }
680bdd1243dSDimitry Andric
6815ffd83dbSDimitry Andric /// If this is a bitcast of a shuffle, try to bitcast the source vector to the
6825ffd83dbSDimitry Andric /// destination type followed by shuffle. This can enable further transforms by
6835ffd83dbSDimitry Andric /// moving bitcasts or shuffles together.
foldBitcastShuffle(Instruction & I)684*c9157d92SDimitry Andric bool VectorCombine::foldBitcastShuffle(Instruction &I) {
6855ffd83dbSDimitry Andric Value *V;
6865ffd83dbSDimitry Andric ArrayRef<int> Mask;
6875ffd83dbSDimitry Andric if (!match(&I, m_BitCast(
6885ffd83dbSDimitry Andric m_OneUse(m_Shuffle(m_Value(V), m_Undef(), m_Mask(Mask))))))
6895ffd83dbSDimitry Andric return false;
6905ffd83dbSDimitry Andric
691e8d8bef9SDimitry Andric // 1) Do not fold bitcast shuffle for scalable type. First, shuffle cost for
692e8d8bef9SDimitry Andric // scalable type is unknown; Second, we cannot reason if the narrowed shuffle
693e8d8bef9SDimitry Andric // mask for scalable type is a splat or not.
694*c9157d92SDimitry Andric // 2) Disallow non-vector casts.
6955ffd83dbSDimitry Andric // TODO: We could allow any shuffle.
696*c9157d92SDimitry Andric auto *DestTy = dyn_cast<FixedVectorType>(I.getType());
697e8d8bef9SDimitry Andric auto *SrcTy = dyn_cast<FixedVectorType>(V->getType());
698*c9157d92SDimitry Andric if (!DestTy || !SrcTy)
6995ffd83dbSDimitry Andric return false;
7005ffd83dbSDimitry Andric
701*c9157d92SDimitry Andric unsigned DestEltSize = DestTy->getScalarSizeInBits();
702*c9157d92SDimitry Andric unsigned SrcEltSize = SrcTy->getScalarSizeInBits();
703*c9157d92SDimitry Andric if (SrcTy->getPrimitiveSizeInBits() % DestEltSize != 0)
704*c9157d92SDimitry Andric return false;
705*c9157d92SDimitry Andric
7065ffd83dbSDimitry Andric SmallVector<int, 16> NewMask;
707*c9157d92SDimitry Andric if (DestEltSize <= SrcEltSize) {
7085ffd83dbSDimitry Andric // The bitcast is from wide to narrow/equal elements. The shuffle mask can
7095ffd83dbSDimitry Andric // always be expanded to the equivalent form choosing narrower elements.
710*c9157d92SDimitry Andric assert(SrcEltSize % DestEltSize == 0 && "Unexpected shuffle mask");
711*c9157d92SDimitry Andric unsigned ScaleFactor = SrcEltSize / DestEltSize;
7125ffd83dbSDimitry Andric narrowShuffleMaskElts(ScaleFactor, Mask, NewMask);
7135ffd83dbSDimitry Andric } else {
7145ffd83dbSDimitry Andric // The bitcast is from narrow elements to wide elements. The shuffle mask
7155ffd83dbSDimitry Andric // must choose consecutive elements to allow casting first.
716*c9157d92SDimitry Andric assert(DestEltSize % SrcEltSize == 0 && "Unexpected shuffle mask");
717*c9157d92SDimitry Andric unsigned ScaleFactor = DestEltSize / SrcEltSize;
7185ffd83dbSDimitry Andric if (!widenShuffleMaskElts(ScaleFactor, Mask, NewMask))
7195ffd83dbSDimitry Andric return false;
7205ffd83dbSDimitry Andric }
721fe6060f1SDimitry Andric
722*c9157d92SDimitry Andric // Bitcast the shuffle src - keep its original width but using the destination
723*c9157d92SDimitry Andric // scalar type.
724*c9157d92SDimitry Andric unsigned NumSrcElts = SrcTy->getPrimitiveSizeInBits() / DestEltSize;
725*c9157d92SDimitry Andric auto *ShuffleTy = FixedVectorType::get(DestTy->getScalarType(), NumSrcElts);
726*c9157d92SDimitry Andric
727fe6060f1SDimitry Andric // The new shuffle must not cost more than the old shuffle. The bitcast is
728fe6060f1SDimitry Andric // moved ahead of the shuffle, so assume that it has the same cost as before.
729fe6060f1SDimitry Andric InstructionCost DestCost = TTI.getShuffleCost(
730*c9157d92SDimitry Andric TargetTransformInfo::SK_PermuteSingleSrc, ShuffleTy, NewMask);
731fe6060f1SDimitry Andric InstructionCost SrcCost =
732fe6060f1SDimitry Andric TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, SrcTy, Mask);
733fe6060f1SDimitry Andric if (DestCost > SrcCost || !DestCost.isValid())
734fe6060f1SDimitry Andric return false;
735fe6060f1SDimitry Andric
7365ffd83dbSDimitry Andric // bitcast (shuf V, MaskC) --> shuf (bitcast V), MaskC'
7375ffd83dbSDimitry Andric ++NumShufOfBitcast;
738*c9157d92SDimitry Andric Value *CastV = Builder.CreateBitCast(V, ShuffleTy);
739e8d8bef9SDimitry Andric Value *Shuf = Builder.CreateShuffleVector(CastV, NewMask);
7405ffd83dbSDimitry Andric replaceValue(I, *Shuf);
7415ffd83dbSDimitry Andric return true;
7425ffd83dbSDimitry Andric }
7435ffd83dbSDimitry Andric
744*c9157d92SDimitry Andric /// VP Intrinsics whose vector operands are both splat values may be simplified
745*c9157d92SDimitry Andric /// into the scalar version of the operation and the result splatted. This
746*c9157d92SDimitry Andric /// can lead to scalarization down the line.
scalarizeVPIntrinsic(Instruction & I)747*c9157d92SDimitry Andric bool VectorCombine::scalarizeVPIntrinsic(Instruction &I) {
748*c9157d92SDimitry Andric if (!isa<VPIntrinsic>(I))
749*c9157d92SDimitry Andric return false;
750*c9157d92SDimitry Andric VPIntrinsic &VPI = cast<VPIntrinsic>(I);
751*c9157d92SDimitry Andric Value *Op0 = VPI.getArgOperand(0);
752*c9157d92SDimitry Andric Value *Op1 = VPI.getArgOperand(1);
753*c9157d92SDimitry Andric
754*c9157d92SDimitry Andric if (!isSplatValue(Op0) || !isSplatValue(Op1))
755*c9157d92SDimitry Andric return false;
756*c9157d92SDimitry Andric
757*c9157d92SDimitry Andric // Check getSplatValue early in this function, to avoid doing unnecessary
758*c9157d92SDimitry Andric // work.
759*c9157d92SDimitry Andric Value *ScalarOp0 = getSplatValue(Op0);
760*c9157d92SDimitry Andric Value *ScalarOp1 = getSplatValue(Op1);
761*c9157d92SDimitry Andric if (!ScalarOp0 || !ScalarOp1)
762*c9157d92SDimitry Andric return false;
763*c9157d92SDimitry Andric
764*c9157d92SDimitry Andric // For the binary VP intrinsics supported here, the result on disabled lanes
765*c9157d92SDimitry Andric // is a poison value. For now, only do this simplification if all lanes
766*c9157d92SDimitry Andric // are active.
767*c9157d92SDimitry Andric // TODO: Relax the condition that all lanes are active by using insertelement
768*c9157d92SDimitry Andric // on inactive lanes.
769*c9157d92SDimitry Andric auto IsAllTrueMask = [](Value *MaskVal) {
770*c9157d92SDimitry Andric if (Value *SplattedVal = getSplatValue(MaskVal))
771*c9157d92SDimitry Andric if (auto *ConstValue = dyn_cast<Constant>(SplattedVal))
772*c9157d92SDimitry Andric return ConstValue->isAllOnesValue();
773*c9157d92SDimitry Andric return false;
774*c9157d92SDimitry Andric };
775*c9157d92SDimitry Andric if (!IsAllTrueMask(VPI.getArgOperand(2)))
776*c9157d92SDimitry Andric return false;
777*c9157d92SDimitry Andric
778*c9157d92SDimitry Andric // Check to make sure we support scalarization of the intrinsic
779*c9157d92SDimitry Andric Intrinsic::ID IntrID = VPI.getIntrinsicID();
780*c9157d92SDimitry Andric if (!VPBinOpIntrinsic::isVPBinOp(IntrID))
781*c9157d92SDimitry Andric return false;
782*c9157d92SDimitry Andric
783*c9157d92SDimitry Andric // Calculate cost of splatting both operands into vectors and the vector
784*c9157d92SDimitry Andric // intrinsic
785*c9157d92SDimitry Andric VectorType *VecTy = cast<VectorType>(VPI.getType());
786*c9157d92SDimitry Andric TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
787*c9157d92SDimitry Andric InstructionCost SplatCost =
788*c9157d92SDimitry Andric TTI.getVectorInstrCost(Instruction::InsertElement, VecTy, CostKind, 0) +
789*c9157d92SDimitry Andric TTI.getShuffleCost(TargetTransformInfo::SK_Broadcast, VecTy);
790*c9157d92SDimitry Andric
791*c9157d92SDimitry Andric // Calculate the cost of the VP Intrinsic
792*c9157d92SDimitry Andric SmallVector<Type *, 4> Args;
793*c9157d92SDimitry Andric for (Value *V : VPI.args())
794*c9157d92SDimitry Andric Args.push_back(V->getType());
795*c9157d92SDimitry Andric IntrinsicCostAttributes Attrs(IntrID, VecTy, Args);
796*c9157d92SDimitry Andric InstructionCost VectorOpCost = TTI.getIntrinsicInstrCost(Attrs, CostKind);
797*c9157d92SDimitry Andric InstructionCost OldCost = 2 * SplatCost + VectorOpCost;
798*c9157d92SDimitry Andric
799*c9157d92SDimitry Andric // Determine scalar opcode
800*c9157d92SDimitry Andric std::optional<unsigned> FunctionalOpcode =
801*c9157d92SDimitry Andric VPI.getFunctionalOpcode();
802*c9157d92SDimitry Andric std::optional<Intrinsic::ID> ScalarIntrID = std::nullopt;
803*c9157d92SDimitry Andric if (!FunctionalOpcode) {
804*c9157d92SDimitry Andric ScalarIntrID = VPI.getFunctionalIntrinsicID();
805*c9157d92SDimitry Andric if (!ScalarIntrID)
806*c9157d92SDimitry Andric return false;
807*c9157d92SDimitry Andric }
808*c9157d92SDimitry Andric
809*c9157d92SDimitry Andric // Calculate cost of scalarizing
810*c9157d92SDimitry Andric InstructionCost ScalarOpCost = 0;
811*c9157d92SDimitry Andric if (ScalarIntrID) {
812*c9157d92SDimitry Andric IntrinsicCostAttributes Attrs(*ScalarIntrID, VecTy->getScalarType(), Args);
813*c9157d92SDimitry Andric ScalarOpCost = TTI.getIntrinsicInstrCost(Attrs, CostKind);
814*c9157d92SDimitry Andric } else {
815*c9157d92SDimitry Andric ScalarOpCost =
816*c9157d92SDimitry Andric TTI.getArithmeticInstrCost(*FunctionalOpcode, VecTy->getScalarType());
817*c9157d92SDimitry Andric }
818*c9157d92SDimitry Andric
819*c9157d92SDimitry Andric // The existing splats may be kept around if other instructions use them.
820*c9157d92SDimitry Andric InstructionCost CostToKeepSplats =
821*c9157d92SDimitry Andric (SplatCost * !Op0->hasOneUse()) + (SplatCost * !Op1->hasOneUse());
822*c9157d92SDimitry Andric InstructionCost NewCost = ScalarOpCost + SplatCost + CostToKeepSplats;
823*c9157d92SDimitry Andric
824*c9157d92SDimitry Andric LLVM_DEBUG(dbgs() << "Found a VP Intrinsic to scalarize: " << VPI
825*c9157d92SDimitry Andric << "\n");
826*c9157d92SDimitry Andric LLVM_DEBUG(dbgs() << "Cost of Intrinsic: " << OldCost
827*c9157d92SDimitry Andric << ", Cost of scalarizing:" << NewCost << "\n");
828*c9157d92SDimitry Andric
829*c9157d92SDimitry Andric // We want to scalarize unless the vector variant actually has lower cost.
830*c9157d92SDimitry Andric if (OldCost < NewCost || !NewCost.isValid())
831*c9157d92SDimitry Andric return false;
832*c9157d92SDimitry Andric
833*c9157d92SDimitry Andric // Scalarize the intrinsic
834*c9157d92SDimitry Andric ElementCount EC = cast<VectorType>(Op0->getType())->getElementCount();
835*c9157d92SDimitry Andric Value *EVL = VPI.getArgOperand(3);
836*c9157d92SDimitry Andric const DataLayout &DL = VPI.getModule()->getDataLayout();
837*c9157d92SDimitry Andric
838*c9157d92SDimitry Andric // If the VP op might introduce UB or poison, we can scalarize it provided
839*c9157d92SDimitry Andric // that we know the EVL > 0: If the EVL is zero, then the original VP op
840*c9157d92SDimitry Andric // becomes a no-op and thus won't be UB, so make sure we don't introduce UB by
841*c9157d92SDimitry Andric // scalarizing it.
842*c9157d92SDimitry Andric bool SafeToSpeculate;
843*c9157d92SDimitry Andric if (ScalarIntrID)
844*c9157d92SDimitry Andric SafeToSpeculate = Intrinsic::getAttributes(I.getContext(), *ScalarIntrID)
845*c9157d92SDimitry Andric .hasFnAttr(Attribute::AttrKind::Speculatable);
846*c9157d92SDimitry Andric else
847*c9157d92SDimitry Andric SafeToSpeculate = isSafeToSpeculativelyExecuteWithOpcode(
848*c9157d92SDimitry Andric *FunctionalOpcode, &VPI, nullptr, &AC, &DT);
849*c9157d92SDimitry Andric if (!SafeToSpeculate && !isKnownNonZero(EVL, DL, 0, &AC, &VPI, &DT))
850*c9157d92SDimitry Andric return false;
851*c9157d92SDimitry Andric
852*c9157d92SDimitry Andric Value *ScalarVal =
853*c9157d92SDimitry Andric ScalarIntrID
854*c9157d92SDimitry Andric ? Builder.CreateIntrinsic(VecTy->getScalarType(), *ScalarIntrID,
855*c9157d92SDimitry Andric {ScalarOp0, ScalarOp1})
856*c9157d92SDimitry Andric : Builder.CreateBinOp((Instruction::BinaryOps)(*FunctionalOpcode),
857*c9157d92SDimitry Andric ScalarOp0, ScalarOp1);
858*c9157d92SDimitry Andric
859*c9157d92SDimitry Andric replaceValue(VPI, *Builder.CreateVectorSplat(EC, ScalarVal));
860*c9157d92SDimitry Andric return true;
861*c9157d92SDimitry Andric }
862*c9157d92SDimitry Andric
8635ffd83dbSDimitry Andric /// Match a vector binop or compare instruction with at least one inserted
8645ffd83dbSDimitry Andric /// scalar operand and convert to scalar binop/cmp followed by insertelement.
scalarizeBinopOrCmp(Instruction & I)8655ffd83dbSDimitry Andric bool VectorCombine::scalarizeBinopOrCmp(Instruction &I) {
8665ffd83dbSDimitry Andric CmpInst::Predicate Pred = CmpInst::BAD_ICMP_PREDICATE;
8675ffd83dbSDimitry Andric Value *Ins0, *Ins1;
8685ffd83dbSDimitry Andric if (!match(&I, m_BinOp(m_Value(Ins0), m_Value(Ins1))) &&
8695ffd83dbSDimitry Andric !match(&I, m_Cmp(Pred, m_Value(Ins0), m_Value(Ins1))))
8705ffd83dbSDimitry Andric return false;
8715ffd83dbSDimitry Andric
8725ffd83dbSDimitry Andric // Do not convert the vector condition of a vector select into a scalar
8735ffd83dbSDimitry Andric // condition. That may cause problems for codegen because of differences in
8745ffd83dbSDimitry Andric // boolean formats and register-file transfers.
8755ffd83dbSDimitry Andric // TODO: Can we account for that in the cost model?
8765ffd83dbSDimitry Andric bool IsCmp = Pred != CmpInst::Predicate::BAD_ICMP_PREDICATE;
8775ffd83dbSDimitry Andric if (IsCmp)
8785ffd83dbSDimitry Andric for (User *U : I.users())
8795ffd83dbSDimitry Andric if (match(U, m_Select(m_Specific(&I), m_Value(), m_Value())))
8805ffd83dbSDimitry Andric return false;
8815ffd83dbSDimitry Andric
8825ffd83dbSDimitry Andric // Match against one or both scalar values being inserted into constant
8835ffd83dbSDimitry Andric // vectors:
8845ffd83dbSDimitry Andric // vec_op VecC0, (inselt VecC1, V1, Index)
8855ffd83dbSDimitry Andric // vec_op (inselt VecC0, V0, Index), VecC1
8865ffd83dbSDimitry Andric // vec_op (inselt VecC0, V0, Index), (inselt VecC1, V1, Index)
8875ffd83dbSDimitry Andric // TODO: Deal with mismatched index constants and variable indexes?
8885ffd83dbSDimitry Andric Constant *VecC0 = nullptr, *VecC1 = nullptr;
8895ffd83dbSDimitry Andric Value *V0 = nullptr, *V1 = nullptr;
8905ffd83dbSDimitry Andric uint64_t Index0 = 0, Index1 = 0;
8915ffd83dbSDimitry Andric if (!match(Ins0, m_InsertElt(m_Constant(VecC0), m_Value(V0),
8925ffd83dbSDimitry Andric m_ConstantInt(Index0))) &&
8935ffd83dbSDimitry Andric !match(Ins0, m_Constant(VecC0)))
8945ffd83dbSDimitry Andric return false;
8955ffd83dbSDimitry Andric if (!match(Ins1, m_InsertElt(m_Constant(VecC1), m_Value(V1),
8965ffd83dbSDimitry Andric m_ConstantInt(Index1))) &&
8975ffd83dbSDimitry Andric !match(Ins1, m_Constant(VecC1)))
8985ffd83dbSDimitry Andric return false;
8995ffd83dbSDimitry Andric
9005ffd83dbSDimitry Andric bool IsConst0 = !V0;
9015ffd83dbSDimitry Andric bool IsConst1 = !V1;
9025ffd83dbSDimitry Andric if (IsConst0 && IsConst1)
9035ffd83dbSDimitry Andric return false;
9045ffd83dbSDimitry Andric if (!IsConst0 && !IsConst1 && Index0 != Index1)
9055ffd83dbSDimitry Andric return false;
9065ffd83dbSDimitry Andric
9075ffd83dbSDimitry Andric // Bail for single insertion if it is a load.
9085ffd83dbSDimitry Andric // TODO: Handle this once getVectorInstrCost can cost for load/stores.
9095ffd83dbSDimitry Andric auto *I0 = dyn_cast_or_null<Instruction>(V0);
9105ffd83dbSDimitry Andric auto *I1 = dyn_cast_or_null<Instruction>(V1);
9115ffd83dbSDimitry Andric if ((IsConst0 && I1 && I1->mayReadFromMemory()) ||
9125ffd83dbSDimitry Andric (IsConst1 && I0 && I0->mayReadFromMemory()))
9135ffd83dbSDimitry Andric return false;
9145ffd83dbSDimitry Andric
9155ffd83dbSDimitry Andric uint64_t Index = IsConst0 ? Index1 : Index0;
9165ffd83dbSDimitry Andric Type *ScalarTy = IsConst0 ? V1->getType() : V0->getType();
9175ffd83dbSDimitry Andric Type *VecTy = I.getType();
9185ffd83dbSDimitry Andric assert(VecTy->isVectorTy() &&
9195ffd83dbSDimitry Andric (IsConst0 || IsConst1 || V0->getType() == V1->getType()) &&
9205ffd83dbSDimitry Andric (ScalarTy->isIntegerTy() || ScalarTy->isFloatingPointTy() ||
9215ffd83dbSDimitry Andric ScalarTy->isPointerTy()) &&
9225ffd83dbSDimitry Andric "Unexpected types for insert element into binop or cmp");
9235ffd83dbSDimitry Andric
9245ffd83dbSDimitry Andric unsigned Opcode = I.getOpcode();
925e8d8bef9SDimitry Andric InstructionCost ScalarOpCost, VectorOpCost;
9265ffd83dbSDimitry Andric if (IsCmp) {
927349cc55cSDimitry Andric CmpInst::Predicate Pred = cast<CmpInst>(I).getPredicate();
928349cc55cSDimitry Andric ScalarOpCost = TTI.getCmpSelInstrCost(
929349cc55cSDimitry Andric Opcode, ScalarTy, CmpInst::makeCmpResultType(ScalarTy), Pred);
930349cc55cSDimitry Andric VectorOpCost = TTI.getCmpSelInstrCost(
931349cc55cSDimitry Andric Opcode, VecTy, CmpInst::makeCmpResultType(VecTy), Pred);
9325ffd83dbSDimitry Andric } else {
9335ffd83dbSDimitry Andric ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, ScalarTy);
9345ffd83dbSDimitry Andric VectorOpCost = TTI.getArithmeticInstrCost(Opcode, VecTy);
9355ffd83dbSDimitry Andric }
9365ffd83dbSDimitry Andric
9375ffd83dbSDimitry Andric // Get cost estimate for the insert element. This cost will factor into
9385ffd83dbSDimitry Andric // both sequences.
939bdd1243dSDimitry Andric TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
940bdd1243dSDimitry Andric InstructionCost InsertCost = TTI.getVectorInstrCost(
941bdd1243dSDimitry Andric Instruction::InsertElement, VecTy, CostKind, Index);
942e8d8bef9SDimitry Andric InstructionCost OldCost =
943e8d8bef9SDimitry Andric (IsConst0 ? 0 : InsertCost) + (IsConst1 ? 0 : InsertCost) + VectorOpCost;
944e8d8bef9SDimitry Andric InstructionCost NewCost = ScalarOpCost + InsertCost +
9455ffd83dbSDimitry Andric (IsConst0 ? 0 : !Ins0->hasOneUse() * InsertCost) +
9465ffd83dbSDimitry Andric (IsConst1 ? 0 : !Ins1->hasOneUse() * InsertCost);
9475ffd83dbSDimitry Andric
9485ffd83dbSDimitry Andric // We want to scalarize unless the vector variant actually has lower cost.
949e8d8bef9SDimitry Andric if (OldCost < NewCost || !NewCost.isValid())
9505ffd83dbSDimitry Andric return false;
9515ffd83dbSDimitry Andric
9525ffd83dbSDimitry Andric // vec_op (inselt VecC0, V0, Index), (inselt VecC1, V1, Index) -->
9535ffd83dbSDimitry Andric // inselt NewVecC, (scalar_op V0, V1), Index
9545ffd83dbSDimitry Andric if (IsCmp)
9555ffd83dbSDimitry Andric ++NumScalarCmp;
9565ffd83dbSDimitry Andric else
9575ffd83dbSDimitry Andric ++NumScalarBO;
9585ffd83dbSDimitry Andric
9595ffd83dbSDimitry Andric // For constant cases, extract the scalar element, this should constant fold.
9605ffd83dbSDimitry Andric if (IsConst0)
9615ffd83dbSDimitry Andric V0 = ConstantExpr::getExtractElement(VecC0, Builder.getInt64(Index));
9625ffd83dbSDimitry Andric if (IsConst1)
9635ffd83dbSDimitry Andric V1 = ConstantExpr::getExtractElement(VecC1, Builder.getInt64(Index));
9645ffd83dbSDimitry Andric
9655ffd83dbSDimitry Andric Value *Scalar =
9665ffd83dbSDimitry Andric IsCmp ? Builder.CreateCmp(Pred, V0, V1)
9675ffd83dbSDimitry Andric : Builder.CreateBinOp((Instruction::BinaryOps)Opcode, V0, V1);
9685ffd83dbSDimitry Andric
9695ffd83dbSDimitry Andric Scalar->setName(I.getName() + ".scalar");
9705ffd83dbSDimitry Andric
9715ffd83dbSDimitry Andric // All IR flags are safe to back-propagate. There is no potential for extra
9725ffd83dbSDimitry Andric // poison to be created by the scalar instruction.
9735ffd83dbSDimitry Andric if (auto *ScalarInst = dyn_cast<Instruction>(Scalar))
9745ffd83dbSDimitry Andric ScalarInst->copyIRFlags(&I);
9755ffd83dbSDimitry Andric
9765ffd83dbSDimitry Andric // Fold the vector constants in the original vectors into a new base vector.
97781ad6265SDimitry Andric Value *NewVecC =
97881ad6265SDimitry Andric IsCmp ? Builder.CreateCmp(Pred, VecC0, VecC1)
97981ad6265SDimitry Andric : Builder.CreateBinOp((Instruction::BinaryOps)Opcode, VecC0, VecC1);
9805ffd83dbSDimitry Andric Value *Insert = Builder.CreateInsertElement(NewVecC, Scalar, Index);
9815ffd83dbSDimitry Andric replaceValue(I, *Insert);
9825ffd83dbSDimitry Andric return true;
9835ffd83dbSDimitry Andric }
9845ffd83dbSDimitry Andric
9855ffd83dbSDimitry Andric /// Try to combine a scalar binop + 2 scalar compares of extracted elements of
9865ffd83dbSDimitry Andric /// a vector into vector operations followed by extract. Note: The SLP pass
9875ffd83dbSDimitry Andric /// may miss this pattern because of implementation problems.
foldExtractedCmps(Instruction & I)9885ffd83dbSDimitry Andric bool VectorCombine::foldExtractedCmps(Instruction &I) {
9895ffd83dbSDimitry Andric // We are looking for a scalar binop of booleans.
9905ffd83dbSDimitry Andric // binop i1 (cmp Pred I0, C0), (cmp Pred I1, C1)
9915ffd83dbSDimitry Andric if (!I.isBinaryOp() || !I.getType()->isIntegerTy(1))
9925ffd83dbSDimitry Andric return false;
9935ffd83dbSDimitry Andric
9945ffd83dbSDimitry Andric // The compare predicates should match, and each compare should have a
9955ffd83dbSDimitry Andric // constant operand.
9965ffd83dbSDimitry Andric // TODO: Relax the one-use constraints.
9975ffd83dbSDimitry Andric Value *B0 = I.getOperand(0), *B1 = I.getOperand(1);
9985ffd83dbSDimitry Andric Instruction *I0, *I1;
9995ffd83dbSDimitry Andric Constant *C0, *C1;
10005ffd83dbSDimitry Andric CmpInst::Predicate P0, P1;
10015ffd83dbSDimitry Andric if (!match(B0, m_OneUse(m_Cmp(P0, m_Instruction(I0), m_Constant(C0)))) ||
10025ffd83dbSDimitry Andric !match(B1, m_OneUse(m_Cmp(P1, m_Instruction(I1), m_Constant(C1)))) ||
10035ffd83dbSDimitry Andric P0 != P1)
10045ffd83dbSDimitry Andric return false;
10055ffd83dbSDimitry Andric
10065ffd83dbSDimitry Andric // The compare operands must be extracts of the same vector with constant
10075ffd83dbSDimitry Andric // extract indexes.
10085ffd83dbSDimitry Andric // TODO: Relax the one-use constraints.
10095ffd83dbSDimitry Andric Value *X;
10105ffd83dbSDimitry Andric uint64_t Index0, Index1;
10115ffd83dbSDimitry Andric if (!match(I0, m_OneUse(m_ExtractElt(m_Value(X), m_ConstantInt(Index0)))) ||
10125ffd83dbSDimitry Andric !match(I1, m_OneUse(m_ExtractElt(m_Specific(X), m_ConstantInt(Index1)))))
10135ffd83dbSDimitry Andric return false;
10145ffd83dbSDimitry Andric
10155ffd83dbSDimitry Andric auto *Ext0 = cast<ExtractElementInst>(I0);
10165ffd83dbSDimitry Andric auto *Ext1 = cast<ExtractElementInst>(I1);
10175ffd83dbSDimitry Andric ExtractElementInst *ConvertToShuf = getShuffleExtract(Ext0, Ext1);
10185ffd83dbSDimitry Andric if (!ConvertToShuf)
10195ffd83dbSDimitry Andric return false;
10205ffd83dbSDimitry Andric
10215ffd83dbSDimitry Andric // The original scalar pattern is:
10225ffd83dbSDimitry Andric // binop i1 (cmp Pred (ext X, Index0), C0), (cmp Pred (ext X, Index1), C1)
10235ffd83dbSDimitry Andric CmpInst::Predicate Pred = P0;
10245ffd83dbSDimitry Andric unsigned CmpOpcode = CmpInst::isFPPredicate(Pred) ? Instruction::FCmp
10255ffd83dbSDimitry Andric : Instruction::ICmp;
10265ffd83dbSDimitry Andric auto *VecTy = dyn_cast<FixedVectorType>(X->getType());
10275ffd83dbSDimitry Andric if (!VecTy)
10285ffd83dbSDimitry Andric return false;
10295ffd83dbSDimitry Andric
1030bdd1243dSDimitry Andric TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
1031e8d8bef9SDimitry Andric InstructionCost OldCost =
1032bdd1243dSDimitry Andric TTI.getVectorInstrCost(*Ext0, VecTy, CostKind, Index0);
1033bdd1243dSDimitry Andric OldCost += TTI.getVectorInstrCost(*Ext1, VecTy, CostKind, Index1);
1034349cc55cSDimitry Andric OldCost +=
1035349cc55cSDimitry Andric TTI.getCmpSelInstrCost(CmpOpcode, I0->getType(),
1036349cc55cSDimitry Andric CmpInst::makeCmpResultType(I0->getType()), Pred) *
1037349cc55cSDimitry Andric 2;
10385ffd83dbSDimitry Andric OldCost += TTI.getArithmeticInstrCost(I.getOpcode(), I.getType());
10395ffd83dbSDimitry Andric
10405ffd83dbSDimitry Andric // The proposed vector pattern is:
10415ffd83dbSDimitry Andric // vcmp = cmp Pred X, VecC
10425ffd83dbSDimitry Andric // ext (binop vNi1 vcmp, (shuffle vcmp, Index1)), Index0
10435ffd83dbSDimitry Andric int CheapIndex = ConvertToShuf == Ext0 ? Index1 : Index0;
10445ffd83dbSDimitry Andric int ExpensiveIndex = ConvertToShuf == Ext0 ? Index0 : Index1;
10455ffd83dbSDimitry Andric auto *CmpTy = cast<FixedVectorType>(CmpInst::makeCmpResultType(X->getType()));
1046349cc55cSDimitry Andric InstructionCost NewCost = TTI.getCmpSelInstrCost(
1047349cc55cSDimitry Andric CmpOpcode, X->getType(), CmpInst::makeCmpResultType(X->getType()), Pred);
1048fe013be4SDimitry Andric SmallVector<int, 32> ShufMask(VecTy->getNumElements(), PoisonMaskElem);
1049fe6060f1SDimitry Andric ShufMask[CheapIndex] = ExpensiveIndex;
1050fe6060f1SDimitry Andric NewCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, CmpTy,
1051fe6060f1SDimitry Andric ShufMask);
10525ffd83dbSDimitry Andric NewCost += TTI.getArithmeticInstrCost(I.getOpcode(), CmpTy);
1053bdd1243dSDimitry Andric NewCost += TTI.getVectorInstrCost(*Ext0, CmpTy, CostKind, CheapIndex);
10545ffd83dbSDimitry Andric
10555ffd83dbSDimitry Andric // Aggressively form vector ops if the cost is equal because the transform
10565ffd83dbSDimitry Andric // may enable further optimization.
10575ffd83dbSDimitry Andric // Codegen can reverse this transform (scalarize) if it was not profitable.
1058e8d8bef9SDimitry Andric if (OldCost < NewCost || !NewCost.isValid())
10595ffd83dbSDimitry Andric return false;
10605ffd83dbSDimitry Andric
10615ffd83dbSDimitry Andric // Create a vector constant from the 2 scalar constants.
10625ffd83dbSDimitry Andric SmallVector<Constant *, 32> CmpC(VecTy->getNumElements(),
1063fe013be4SDimitry Andric PoisonValue::get(VecTy->getElementType()));
10645ffd83dbSDimitry Andric CmpC[Index0] = C0;
10655ffd83dbSDimitry Andric CmpC[Index1] = C1;
10665ffd83dbSDimitry Andric Value *VCmp = Builder.CreateCmp(Pred, X, ConstantVector::get(CmpC));
10675ffd83dbSDimitry Andric
10685ffd83dbSDimitry Andric Value *Shuf = createShiftShuffle(VCmp, ExpensiveIndex, CheapIndex, Builder);
10695ffd83dbSDimitry Andric Value *VecLogic = Builder.CreateBinOp(cast<BinaryOperator>(I).getOpcode(),
10705ffd83dbSDimitry Andric VCmp, Shuf);
10715ffd83dbSDimitry Andric Value *NewExt = Builder.CreateExtractElement(VecLogic, CheapIndex);
10725ffd83dbSDimitry Andric replaceValue(I, *NewExt);
10735ffd83dbSDimitry Andric ++NumVecCmpBO;
10745ffd83dbSDimitry Andric return true;
10755ffd83dbSDimitry Andric }
10765ffd83dbSDimitry Andric
1077fe6060f1SDimitry Andric // Check if memory loc modified between two instrs in the same BB
isMemModifiedBetween(BasicBlock::iterator Begin,BasicBlock::iterator End,const MemoryLocation & Loc,AAResults & AA)1078fe6060f1SDimitry Andric static bool isMemModifiedBetween(BasicBlock::iterator Begin,
1079fe6060f1SDimitry Andric BasicBlock::iterator End,
1080fe6060f1SDimitry Andric const MemoryLocation &Loc, AAResults &AA) {
1081fe6060f1SDimitry Andric unsigned NumScanned = 0;
1082fe6060f1SDimitry Andric return std::any_of(Begin, End, [&](const Instruction &Instr) {
1083fe6060f1SDimitry Andric return isModSet(AA.getModRefInfo(&Instr, Loc)) ||
1084fe6060f1SDimitry Andric ++NumScanned > MaxInstrsToScan;
1085fe6060f1SDimitry Andric });
1086fe6060f1SDimitry Andric }
1087fe6060f1SDimitry Andric
1088bdd1243dSDimitry Andric namespace {
1089349cc55cSDimitry Andric /// Helper class to indicate whether a vector index can be safely scalarized and
1090349cc55cSDimitry Andric /// if a freeze needs to be inserted.
1091349cc55cSDimitry Andric class ScalarizationResult {
1092349cc55cSDimitry Andric enum class StatusTy { Unsafe, Safe, SafeWithFreeze };
1093349cc55cSDimitry Andric
1094349cc55cSDimitry Andric StatusTy Status;
1095349cc55cSDimitry Andric Value *ToFreeze;
1096349cc55cSDimitry Andric
ScalarizationResult(StatusTy Status,Value * ToFreeze=nullptr)1097349cc55cSDimitry Andric ScalarizationResult(StatusTy Status, Value *ToFreeze = nullptr)
1098349cc55cSDimitry Andric : Status(Status), ToFreeze(ToFreeze) {}
1099349cc55cSDimitry Andric
1100349cc55cSDimitry Andric public:
1101349cc55cSDimitry Andric ScalarizationResult(const ScalarizationResult &Other) = default;
~ScalarizationResult()1102349cc55cSDimitry Andric ~ScalarizationResult() {
1103349cc55cSDimitry Andric assert(!ToFreeze && "freeze() not called with ToFreeze being set");
1104349cc55cSDimitry Andric }
1105349cc55cSDimitry Andric
unsafe()1106349cc55cSDimitry Andric static ScalarizationResult unsafe() { return {StatusTy::Unsafe}; }
safe()1107349cc55cSDimitry Andric static ScalarizationResult safe() { return {StatusTy::Safe}; }
safeWithFreeze(Value * ToFreeze)1108349cc55cSDimitry Andric static ScalarizationResult safeWithFreeze(Value *ToFreeze) {
1109349cc55cSDimitry Andric return {StatusTy::SafeWithFreeze, ToFreeze};
1110349cc55cSDimitry Andric }
1111349cc55cSDimitry Andric
1112349cc55cSDimitry Andric /// Returns true if the index can be scalarize without requiring a freeze.
isSafe() const1113349cc55cSDimitry Andric bool isSafe() const { return Status == StatusTy::Safe; }
1114349cc55cSDimitry Andric /// Returns true if the index cannot be scalarized.
isUnsafe() const1115349cc55cSDimitry Andric bool isUnsafe() const { return Status == StatusTy::Unsafe; }
1116349cc55cSDimitry Andric /// Returns true if the index can be scalarize, but requires inserting a
1117349cc55cSDimitry Andric /// freeze.
isSafeWithFreeze() const1118349cc55cSDimitry Andric bool isSafeWithFreeze() const { return Status == StatusTy::SafeWithFreeze; }
1119349cc55cSDimitry Andric
1120349cc55cSDimitry Andric /// Reset the state of Unsafe and clear ToFreze if set.
discard()1121349cc55cSDimitry Andric void discard() {
1122349cc55cSDimitry Andric ToFreeze = nullptr;
1123349cc55cSDimitry Andric Status = StatusTy::Unsafe;
1124349cc55cSDimitry Andric }
1125349cc55cSDimitry Andric
1126349cc55cSDimitry Andric /// Freeze the ToFreeze and update the use in \p User to use it.
freeze(IRBuilder<> & Builder,Instruction & UserI)1127349cc55cSDimitry Andric void freeze(IRBuilder<> &Builder, Instruction &UserI) {
1128349cc55cSDimitry Andric assert(isSafeWithFreeze() &&
1129349cc55cSDimitry Andric "should only be used when freezing is required");
1130349cc55cSDimitry Andric assert(is_contained(ToFreeze->users(), &UserI) &&
1131349cc55cSDimitry Andric "UserI must be a user of ToFreeze");
1132349cc55cSDimitry Andric IRBuilder<>::InsertPointGuard Guard(Builder);
1133349cc55cSDimitry Andric Builder.SetInsertPoint(cast<Instruction>(&UserI));
1134349cc55cSDimitry Andric Value *Frozen =
1135349cc55cSDimitry Andric Builder.CreateFreeze(ToFreeze, ToFreeze->getName() + ".frozen");
1136349cc55cSDimitry Andric for (Use &U : make_early_inc_range((UserI.operands())))
1137349cc55cSDimitry Andric if (U.get() == ToFreeze)
1138349cc55cSDimitry Andric U.set(Frozen);
1139349cc55cSDimitry Andric
1140349cc55cSDimitry Andric ToFreeze = nullptr;
1141349cc55cSDimitry Andric }
1142349cc55cSDimitry Andric };
1143bdd1243dSDimitry Andric } // namespace
1144349cc55cSDimitry Andric
1145fe6060f1SDimitry Andric /// Check if it is legal to scalarize a memory access to \p VecTy at index \p
1146fe6060f1SDimitry Andric /// Idx. \p Idx must access a valid vector element.
canScalarizeAccess(VectorType * VecTy,Value * Idx,Instruction * CtxI,AssumptionCache & AC,const DominatorTree & DT)1147*c9157d92SDimitry Andric static ScalarizationResult canScalarizeAccess(VectorType *VecTy, Value *Idx,
1148*c9157d92SDimitry Andric Instruction *CtxI,
1149349cc55cSDimitry Andric AssumptionCache &AC,
1150349cc55cSDimitry Andric const DominatorTree &DT) {
1151*c9157d92SDimitry Andric // We do checks for both fixed vector types and scalable vector types.
1152*c9157d92SDimitry Andric // This is the number of elements of fixed vector types,
1153*c9157d92SDimitry Andric // or the minimum number of elements of scalable vector types.
1154*c9157d92SDimitry Andric uint64_t NumElements = VecTy->getElementCount().getKnownMinValue();
1155*c9157d92SDimitry Andric
1156349cc55cSDimitry Andric if (auto *C = dyn_cast<ConstantInt>(Idx)) {
1157*c9157d92SDimitry Andric if (C->getValue().ult(NumElements))
1158349cc55cSDimitry Andric return ScalarizationResult::safe();
1159349cc55cSDimitry Andric return ScalarizationResult::unsafe();
1160349cc55cSDimitry Andric }
1161fe6060f1SDimitry Andric
1162349cc55cSDimitry Andric unsigned IntWidth = Idx->getType()->getScalarSizeInBits();
1163349cc55cSDimitry Andric APInt Zero(IntWidth, 0);
1164*c9157d92SDimitry Andric APInt MaxElts(IntWidth, NumElements);
1165fe6060f1SDimitry Andric ConstantRange ValidIndices(Zero, MaxElts);
1166349cc55cSDimitry Andric ConstantRange IdxRange(IntWidth, true);
1167349cc55cSDimitry Andric
1168349cc55cSDimitry Andric if (isGuaranteedNotToBePoison(Idx, &AC)) {
116904eeddc0SDimitry Andric if (ValidIndices.contains(computeConstantRange(Idx, /* ForSigned */ false,
117004eeddc0SDimitry Andric true, &AC, CtxI, &DT)))
1171349cc55cSDimitry Andric return ScalarizationResult::safe();
1172349cc55cSDimitry Andric return ScalarizationResult::unsafe();
1173349cc55cSDimitry Andric }
1174349cc55cSDimitry Andric
1175349cc55cSDimitry Andric // If the index may be poison, check if we can insert a freeze before the
1176349cc55cSDimitry Andric // range of the index is restricted.
1177349cc55cSDimitry Andric Value *IdxBase;
1178349cc55cSDimitry Andric ConstantInt *CI;
1179349cc55cSDimitry Andric if (match(Idx, m_And(m_Value(IdxBase), m_ConstantInt(CI)))) {
1180349cc55cSDimitry Andric IdxRange = IdxRange.binaryAnd(CI->getValue());
1181349cc55cSDimitry Andric } else if (match(Idx, m_URem(m_Value(IdxBase), m_ConstantInt(CI)))) {
1182349cc55cSDimitry Andric IdxRange = IdxRange.urem(CI->getValue());
1183349cc55cSDimitry Andric }
1184349cc55cSDimitry Andric
1185349cc55cSDimitry Andric if (ValidIndices.contains(IdxRange))
1186349cc55cSDimitry Andric return ScalarizationResult::safeWithFreeze(IdxBase);
1187349cc55cSDimitry Andric return ScalarizationResult::unsafe();
1188fe6060f1SDimitry Andric }
1189fe6060f1SDimitry Andric
1190fe6060f1SDimitry Andric /// The memory operation on a vector of \p ScalarType had alignment of
1191fe6060f1SDimitry Andric /// \p VectorAlignment. Compute the maximal, but conservatively correct,
1192fe6060f1SDimitry Andric /// alignment that will be valid for the memory operation on a single scalar
1193fe6060f1SDimitry Andric /// element of the same type with index \p Idx.
computeAlignmentAfterScalarization(Align VectorAlignment,Type * ScalarType,Value * Idx,const DataLayout & DL)1194fe6060f1SDimitry Andric static Align computeAlignmentAfterScalarization(Align VectorAlignment,
1195fe6060f1SDimitry Andric Type *ScalarType, Value *Idx,
1196fe6060f1SDimitry Andric const DataLayout &DL) {
1197fe6060f1SDimitry Andric if (auto *C = dyn_cast<ConstantInt>(Idx))
1198fe6060f1SDimitry Andric return commonAlignment(VectorAlignment,
1199fe6060f1SDimitry Andric C->getZExtValue() * DL.getTypeStoreSize(ScalarType));
1200fe6060f1SDimitry Andric return commonAlignment(VectorAlignment, DL.getTypeStoreSize(ScalarType));
1201fe6060f1SDimitry Andric }
1202fe6060f1SDimitry Andric
1203fe6060f1SDimitry Andric // Combine patterns like:
1204fe6060f1SDimitry Andric // %0 = load <4 x i32>, <4 x i32>* %a
1205fe6060f1SDimitry Andric // %1 = insertelement <4 x i32> %0, i32 %b, i32 1
1206fe6060f1SDimitry Andric // store <4 x i32> %1, <4 x i32>* %a
1207fe6060f1SDimitry Andric // to:
1208fe6060f1SDimitry Andric // %0 = bitcast <4 x i32>* %a to i32*
1209fe6060f1SDimitry Andric // %1 = getelementptr inbounds i32, i32* %0, i64 0, i64 1
1210fe6060f1SDimitry Andric // store i32 %b, i32* %1
foldSingleElementStore(Instruction & I)1211fe6060f1SDimitry Andric bool VectorCombine::foldSingleElementStore(Instruction &I) {
1212bdd1243dSDimitry Andric auto *SI = cast<StoreInst>(&I);
1213*c9157d92SDimitry Andric if (!SI->isSimple() || !isa<VectorType>(SI->getValueOperand()->getType()))
1214fe6060f1SDimitry Andric return false;
1215fe6060f1SDimitry Andric
1216fe6060f1SDimitry Andric // TODO: Combine more complicated patterns (multiple insert) by referencing
1217fe6060f1SDimitry Andric // TargetTransformInfo.
1218fe6060f1SDimitry Andric Instruction *Source;
1219fe6060f1SDimitry Andric Value *NewElement;
1220fe6060f1SDimitry Andric Value *Idx;
1221fe6060f1SDimitry Andric if (!match(SI->getValueOperand(),
1222fe6060f1SDimitry Andric m_InsertElt(m_Instruction(Source), m_Value(NewElement),
1223fe6060f1SDimitry Andric m_Value(Idx))))
1224fe6060f1SDimitry Andric return false;
1225fe6060f1SDimitry Andric
1226fe6060f1SDimitry Andric if (auto *Load = dyn_cast<LoadInst>(Source)) {
1227*c9157d92SDimitry Andric auto VecTy = cast<VectorType>(SI->getValueOperand()->getType());
1228fe6060f1SDimitry Andric const DataLayout &DL = I.getModule()->getDataLayout();
1229fe6060f1SDimitry Andric Value *SrcAddr = Load->getPointerOperand()->stripPointerCasts();
1230fe6060f1SDimitry Andric // Don't optimize for atomic/volatile load or store. Ensure memory is not
1231fe6060f1SDimitry Andric // modified between, vector type matches store size, and index is inbounds.
1232fe6060f1SDimitry Andric if (!Load->isSimple() || Load->getParent() != SI->getParent() ||
1233*c9157d92SDimitry Andric !DL.typeSizeEqualsStoreSize(Load->getType()->getScalarType()) ||
1234349cc55cSDimitry Andric SrcAddr != SI->getPointerOperand()->stripPointerCasts())
1235349cc55cSDimitry Andric return false;
1236349cc55cSDimitry Andric
1237349cc55cSDimitry Andric auto ScalarizableIdx = canScalarizeAccess(VecTy, Idx, Load, AC, DT);
1238349cc55cSDimitry Andric if (ScalarizableIdx.isUnsafe() ||
1239fe6060f1SDimitry Andric isMemModifiedBetween(Load->getIterator(), SI->getIterator(),
1240fe6060f1SDimitry Andric MemoryLocation::get(SI), AA))
1241fe6060f1SDimitry Andric return false;
1242fe6060f1SDimitry Andric
1243349cc55cSDimitry Andric if (ScalarizableIdx.isSafeWithFreeze())
1244349cc55cSDimitry Andric ScalarizableIdx.freeze(Builder, *cast<Instruction>(Idx));
1245fe6060f1SDimitry Andric Value *GEP = Builder.CreateInBoundsGEP(
1246fe6060f1SDimitry Andric SI->getValueOperand()->getType(), SI->getPointerOperand(),
1247fe6060f1SDimitry Andric {ConstantInt::get(Idx->getType(), 0), Idx});
1248fe6060f1SDimitry Andric StoreInst *NSI = Builder.CreateStore(NewElement, GEP);
1249fe6060f1SDimitry Andric NSI->copyMetadata(*SI);
1250fe6060f1SDimitry Andric Align ScalarOpAlignment = computeAlignmentAfterScalarization(
1251fe6060f1SDimitry Andric std::max(SI->getAlign(), Load->getAlign()), NewElement->getType(), Idx,
1252fe6060f1SDimitry Andric DL);
1253fe6060f1SDimitry Andric NSI->setAlignment(ScalarOpAlignment);
1254fe6060f1SDimitry Andric replaceValue(I, *NSI);
1255349cc55cSDimitry Andric eraseInstruction(I);
1256fe6060f1SDimitry Andric return true;
1257fe6060f1SDimitry Andric }
1258fe6060f1SDimitry Andric
1259fe6060f1SDimitry Andric return false;
1260fe6060f1SDimitry Andric }
1261fe6060f1SDimitry Andric
1262fe6060f1SDimitry Andric /// Try to scalarize vector loads feeding extractelement instructions.
scalarizeLoadExtract(Instruction & I)1263fe6060f1SDimitry Andric bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
1264fe6060f1SDimitry Andric Value *Ptr;
1265349cc55cSDimitry Andric if (!match(&I, m_Load(m_Value(Ptr))))
1266fe6060f1SDimitry Andric return false;
1267fe6060f1SDimitry Andric
1268*c9157d92SDimitry Andric auto *VecTy = cast<VectorType>(I.getType());
1269349cc55cSDimitry Andric auto *LI = cast<LoadInst>(&I);
1270fe6060f1SDimitry Andric const DataLayout &DL = I.getModule()->getDataLayout();
1271*c9157d92SDimitry Andric if (LI->isVolatile() || !DL.typeSizeEqualsStoreSize(VecTy->getScalarType()))
1272fe6060f1SDimitry Andric return false;
1273fe6060f1SDimitry Andric
12740eae32dcSDimitry Andric InstructionCost OriginalCost =
1275*c9157d92SDimitry Andric TTI.getMemoryOpCost(Instruction::Load, VecTy, LI->getAlign(),
1276fe6060f1SDimitry Andric LI->getPointerAddressSpace());
1277fe6060f1SDimitry Andric InstructionCost ScalarizedCost = 0;
1278fe6060f1SDimitry Andric
1279fe6060f1SDimitry Andric Instruction *LastCheckedInst = LI;
1280fe6060f1SDimitry Andric unsigned NumInstChecked = 0;
1281*c9157d92SDimitry Andric DenseMap<ExtractElementInst *, ScalarizationResult> NeedFreeze;
1282*c9157d92SDimitry Andric auto FailureGuard = make_scope_exit([&]() {
1283*c9157d92SDimitry Andric // If the transform is aborted, discard the ScalarizationResults.
1284*c9157d92SDimitry Andric for (auto &Pair : NeedFreeze)
1285*c9157d92SDimitry Andric Pair.second.discard();
1286*c9157d92SDimitry Andric });
1287*c9157d92SDimitry Andric
1288fe6060f1SDimitry Andric // Check if all users of the load are extracts with no memory modifications
1289fe6060f1SDimitry Andric // between the load and the extract. Compute the cost of both the original
1290fe6060f1SDimitry Andric // code and the scalarized version.
1291fe6060f1SDimitry Andric for (User *U : LI->users()) {
1292fe6060f1SDimitry Andric auto *UI = dyn_cast<ExtractElementInst>(U);
1293fe6060f1SDimitry Andric if (!UI || UI->getParent() != LI->getParent())
1294fe6060f1SDimitry Andric return false;
1295fe6060f1SDimitry Andric
1296fe6060f1SDimitry Andric // Check if any instruction between the load and the extract may modify
1297fe6060f1SDimitry Andric // memory.
1298fe6060f1SDimitry Andric if (LastCheckedInst->comesBefore(UI)) {
1299fe6060f1SDimitry Andric for (Instruction &I :
1300fe6060f1SDimitry Andric make_range(std::next(LI->getIterator()), UI->getIterator())) {
1301fe6060f1SDimitry Andric // Bail out if we reached the check limit or the instruction may write
1302fe6060f1SDimitry Andric // to memory.
1303fe6060f1SDimitry Andric if (NumInstChecked == MaxInstrsToScan || I.mayWriteToMemory())
1304fe6060f1SDimitry Andric return false;
1305fe6060f1SDimitry Andric NumInstChecked++;
1306fe6060f1SDimitry Andric }
130781ad6265SDimitry Andric LastCheckedInst = UI;
1308fe6060f1SDimitry Andric }
1309fe6060f1SDimitry Andric
1310*c9157d92SDimitry Andric auto ScalarIdx = canScalarizeAccess(VecTy, UI->getOperand(1), &I, AC, DT);
1311*c9157d92SDimitry Andric if (ScalarIdx.isUnsafe())
1312fe6060f1SDimitry Andric return false;
1313*c9157d92SDimitry Andric if (ScalarIdx.isSafeWithFreeze()) {
1314*c9157d92SDimitry Andric NeedFreeze.try_emplace(UI, ScalarIdx);
1315*c9157d92SDimitry Andric ScalarIdx.discard();
1316349cc55cSDimitry Andric }
1317fe6060f1SDimitry Andric
1318fe6060f1SDimitry Andric auto *Index = dyn_cast<ConstantInt>(UI->getOperand(1));
1319bdd1243dSDimitry Andric TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
1320fe6060f1SDimitry Andric OriginalCost +=
1321*c9157d92SDimitry Andric TTI.getVectorInstrCost(Instruction::ExtractElement, VecTy, CostKind,
1322fe6060f1SDimitry Andric Index ? Index->getZExtValue() : -1);
1323fe6060f1SDimitry Andric ScalarizedCost +=
1324*c9157d92SDimitry Andric TTI.getMemoryOpCost(Instruction::Load, VecTy->getElementType(),
1325fe6060f1SDimitry Andric Align(1), LI->getPointerAddressSpace());
1326*c9157d92SDimitry Andric ScalarizedCost += TTI.getAddressComputationCost(VecTy->getElementType());
1327fe6060f1SDimitry Andric }
1328fe6060f1SDimitry Andric
1329fe6060f1SDimitry Andric if (ScalarizedCost >= OriginalCost)
1330fe6060f1SDimitry Andric return false;
1331fe6060f1SDimitry Andric
1332fe6060f1SDimitry Andric // Replace extracts with narrow scalar loads.
1333fe6060f1SDimitry Andric for (User *U : LI->users()) {
1334fe6060f1SDimitry Andric auto *EI = cast<ExtractElementInst>(U);
1335fe6060f1SDimitry Andric Value *Idx = EI->getOperand(1);
1336*c9157d92SDimitry Andric
1337*c9157d92SDimitry Andric // Insert 'freeze' for poison indexes.
1338*c9157d92SDimitry Andric auto It = NeedFreeze.find(EI);
1339*c9157d92SDimitry Andric if (It != NeedFreeze.end())
1340*c9157d92SDimitry Andric It->second.freeze(Builder, *cast<Instruction>(Idx));
1341*c9157d92SDimitry Andric
1342*c9157d92SDimitry Andric Builder.SetInsertPoint(EI);
1343fe6060f1SDimitry Andric Value *GEP =
1344*c9157d92SDimitry Andric Builder.CreateInBoundsGEP(VecTy, Ptr, {Builder.getInt32(0), Idx});
1345fe6060f1SDimitry Andric auto *NewLoad = cast<LoadInst>(Builder.CreateLoad(
1346*c9157d92SDimitry Andric VecTy->getElementType(), GEP, EI->getName() + ".scalar"));
1347fe6060f1SDimitry Andric
1348fe6060f1SDimitry Andric Align ScalarOpAlignment = computeAlignmentAfterScalarization(
1349*c9157d92SDimitry Andric LI->getAlign(), VecTy->getElementType(), Idx, DL);
1350fe6060f1SDimitry Andric NewLoad->setAlignment(ScalarOpAlignment);
1351fe6060f1SDimitry Andric
1352fe6060f1SDimitry Andric replaceValue(*EI, *NewLoad);
1353fe6060f1SDimitry Andric }
1354fe6060f1SDimitry Andric
1355*c9157d92SDimitry Andric FailureGuard.release();
1356fe6060f1SDimitry Andric return true;
1357fe6060f1SDimitry Andric }
1358fe6060f1SDimitry Andric
1359349cc55cSDimitry Andric /// Try to convert "shuffle (binop), (binop)" with a shared binop operand into
1360349cc55cSDimitry Andric /// "binop (shuffle), (shuffle)".
foldShuffleOfBinops(Instruction & I)1361349cc55cSDimitry Andric bool VectorCombine::foldShuffleOfBinops(Instruction &I) {
1362bdd1243dSDimitry Andric auto *VecTy = cast<FixedVectorType>(I.getType());
1363349cc55cSDimitry Andric BinaryOperator *B0, *B1;
1364349cc55cSDimitry Andric ArrayRef<int> Mask;
1365349cc55cSDimitry Andric if (!match(&I, m_Shuffle(m_OneUse(m_BinOp(B0)), m_OneUse(m_BinOp(B1)),
1366349cc55cSDimitry Andric m_Mask(Mask))) ||
1367349cc55cSDimitry Andric B0->getOpcode() != B1->getOpcode() || B0->getType() != VecTy)
1368349cc55cSDimitry Andric return false;
1369349cc55cSDimitry Andric
1370349cc55cSDimitry Andric // Try to replace a binop with a shuffle if the shuffle is not costly.
1371349cc55cSDimitry Andric // The new shuffle will choose from a single, common operand, so it may be
1372349cc55cSDimitry Andric // cheaper than the existing two-operand shuffle.
1373349cc55cSDimitry Andric SmallVector<int> UnaryMask = createUnaryMask(Mask, Mask.size());
1374349cc55cSDimitry Andric Instruction::BinaryOps Opcode = B0->getOpcode();
1375349cc55cSDimitry Andric InstructionCost BinopCost = TTI.getArithmeticInstrCost(Opcode, VecTy);
1376349cc55cSDimitry Andric InstructionCost ShufCost = TTI.getShuffleCost(
1377349cc55cSDimitry Andric TargetTransformInfo::SK_PermuteSingleSrc, VecTy, UnaryMask);
1378349cc55cSDimitry Andric if (ShufCost > BinopCost)
1379349cc55cSDimitry Andric return false;
1380349cc55cSDimitry Andric
1381349cc55cSDimitry Andric // If we have something like "add X, Y" and "add Z, X", swap ops to match.
1382349cc55cSDimitry Andric Value *X = B0->getOperand(0), *Y = B0->getOperand(1);
1383349cc55cSDimitry Andric Value *Z = B1->getOperand(0), *W = B1->getOperand(1);
1384349cc55cSDimitry Andric if (BinaryOperator::isCommutative(Opcode) && X != Z && Y != W)
1385349cc55cSDimitry Andric std::swap(X, Y);
1386349cc55cSDimitry Andric
1387349cc55cSDimitry Andric Value *Shuf0, *Shuf1;
1388349cc55cSDimitry Andric if (X == Z) {
1389349cc55cSDimitry Andric // shuf (bo X, Y), (bo X, W) --> bo (shuf X), (shuf Y, W)
1390349cc55cSDimitry Andric Shuf0 = Builder.CreateShuffleVector(X, UnaryMask);
1391349cc55cSDimitry Andric Shuf1 = Builder.CreateShuffleVector(Y, W, Mask);
1392349cc55cSDimitry Andric } else if (Y == W) {
1393349cc55cSDimitry Andric // shuf (bo X, Y), (bo Z, Y) --> bo (shuf X, Z), (shuf Y)
1394349cc55cSDimitry Andric Shuf0 = Builder.CreateShuffleVector(X, Z, Mask);
1395349cc55cSDimitry Andric Shuf1 = Builder.CreateShuffleVector(Y, UnaryMask);
1396349cc55cSDimitry Andric } else {
1397349cc55cSDimitry Andric return false;
1398349cc55cSDimitry Andric }
1399349cc55cSDimitry Andric
1400349cc55cSDimitry Andric Value *NewBO = Builder.CreateBinOp(Opcode, Shuf0, Shuf1);
1401349cc55cSDimitry Andric // Intersect flags from the old binops.
1402349cc55cSDimitry Andric if (auto *NewInst = dyn_cast<Instruction>(NewBO)) {
1403349cc55cSDimitry Andric NewInst->copyIRFlags(B0);
1404349cc55cSDimitry Andric NewInst->andIRFlags(B1);
1405349cc55cSDimitry Andric }
1406349cc55cSDimitry Andric replaceValue(I, *NewBO);
1407349cc55cSDimitry Andric return true;
1408349cc55cSDimitry Andric }
1409349cc55cSDimitry Andric
141081ad6265SDimitry Andric /// Given a commutative reduction, the order of the input lanes does not alter
141181ad6265SDimitry Andric /// the results. We can use this to remove certain shuffles feeding the
141281ad6265SDimitry Andric /// reduction, removing the need to shuffle at all.
foldShuffleFromReductions(Instruction & I)141381ad6265SDimitry Andric bool VectorCombine::foldShuffleFromReductions(Instruction &I) {
141481ad6265SDimitry Andric auto *II = dyn_cast<IntrinsicInst>(&I);
141581ad6265SDimitry Andric if (!II)
141681ad6265SDimitry Andric return false;
141781ad6265SDimitry Andric switch (II->getIntrinsicID()) {
141881ad6265SDimitry Andric case Intrinsic::vector_reduce_add:
141981ad6265SDimitry Andric case Intrinsic::vector_reduce_mul:
142081ad6265SDimitry Andric case Intrinsic::vector_reduce_and:
142181ad6265SDimitry Andric case Intrinsic::vector_reduce_or:
142281ad6265SDimitry Andric case Intrinsic::vector_reduce_xor:
142381ad6265SDimitry Andric case Intrinsic::vector_reduce_smin:
142481ad6265SDimitry Andric case Intrinsic::vector_reduce_smax:
142581ad6265SDimitry Andric case Intrinsic::vector_reduce_umin:
142681ad6265SDimitry Andric case Intrinsic::vector_reduce_umax:
142781ad6265SDimitry Andric break;
142881ad6265SDimitry Andric default:
142981ad6265SDimitry Andric return false;
143081ad6265SDimitry Andric }
143181ad6265SDimitry Andric
143281ad6265SDimitry Andric // Find all the inputs when looking through operations that do not alter the
143381ad6265SDimitry Andric // lane order (binops, for example). Currently we look for a single shuffle,
143481ad6265SDimitry Andric // and can ignore splat values.
143581ad6265SDimitry Andric std::queue<Value *> Worklist;
143681ad6265SDimitry Andric SmallPtrSet<Value *, 4> Visited;
143781ad6265SDimitry Andric ShuffleVectorInst *Shuffle = nullptr;
143881ad6265SDimitry Andric if (auto *Op = dyn_cast<Instruction>(I.getOperand(0)))
143981ad6265SDimitry Andric Worklist.push(Op);
144081ad6265SDimitry Andric
144181ad6265SDimitry Andric while (!Worklist.empty()) {
144281ad6265SDimitry Andric Value *CV = Worklist.front();
144381ad6265SDimitry Andric Worklist.pop();
144481ad6265SDimitry Andric if (Visited.contains(CV))
144581ad6265SDimitry Andric continue;
144681ad6265SDimitry Andric
144781ad6265SDimitry Andric // Splats don't change the order, so can be safely ignored.
144881ad6265SDimitry Andric if (isSplatValue(CV))
144981ad6265SDimitry Andric continue;
145081ad6265SDimitry Andric
145181ad6265SDimitry Andric Visited.insert(CV);
145281ad6265SDimitry Andric
145381ad6265SDimitry Andric if (auto *CI = dyn_cast<Instruction>(CV)) {
145481ad6265SDimitry Andric if (CI->isBinaryOp()) {
145581ad6265SDimitry Andric for (auto *Op : CI->operand_values())
145681ad6265SDimitry Andric Worklist.push(Op);
145781ad6265SDimitry Andric continue;
145881ad6265SDimitry Andric } else if (auto *SV = dyn_cast<ShuffleVectorInst>(CI)) {
145981ad6265SDimitry Andric if (Shuffle && Shuffle != SV)
146081ad6265SDimitry Andric return false;
146181ad6265SDimitry Andric Shuffle = SV;
146281ad6265SDimitry Andric continue;
146381ad6265SDimitry Andric }
146481ad6265SDimitry Andric }
146581ad6265SDimitry Andric
146681ad6265SDimitry Andric // Anything else is currently an unknown node.
146781ad6265SDimitry Andric return false;
146881ad6265SDimitry Andric }
146981ad6265SDimitry Andric
147081ad6265SDimitry Andric if (!Shuffle)
147181ad6265SDimitry Andric return false;
147281ad6265SDimitry Andric
147381ad6265SDimitry Andric // Check all uses of the binary ops and shuffles are also included in the
147481ad6265SDimitry Andric // lane-invariant operations (Visited should be the list of lanewise
147581ad6265SDimitry Andric // instructions, including the shuffle that we found).
147681ad6265SDimitry Andric for (auto *V : Visited)
147781ad6265SDimitry Andric for (auto *U : V->users())
147881ad6265SDimitry Andric if (!Visited.contains(U) && U != &I)
147981ad6265SDimitry Andric return false;
148081ad6265SDimitry Andric
148181ad6265SDimitry Andric FixedVectorType *VecType =
148281ad6265SDimitry Andric dyn_cast<FixedVectorType>(II->getOperand(0)->getType());
148381ad6265SDimitry Andric if (!VecType)
148481ad6265SDimitry Andric return false;
148581ad6265SDimitry Andric FixedVectorType *ShuffleInputType =
148681ad6265SDimitry Andric dyn_cast<FixedVectorType>(Shuffle->getOperand(0)->getType());
148781ad6265SDimitry Andric if (!ShuffleInputType)
148881ad6265SDimitry Andric return false;
1489*c9157d92SDimitry Andric unsigned NumInputElts = ShuffleInputType->getNumElements();
149081ad6265SDimitry Andric
149181ad6265SDimitry Andric // Find the mask from sorting the lanes into order. This is most likely to
149281ad6265SDimitry Andric // become a identity or concat mask. Undef elements are pushed to the end.
149381ad6265SDimitry Andric SmallVector<int> ConcatMask;
149481ad6265SDimitry Andric Shuffle->getShuffleMask(ConcatMask);
149581ad6265SDimitry Andric sort(ConcatMask, [](int X, int Y) { return (unsigned)X < (unsigned)Y; });
1496*c9157d92SDimitry Andric // In the case of a truncating shuffle it's possible for the mask
1497*c9157d92SDimitry Andric // to have an index greater than the size of the resulting vector.
1498*c9157d92SDimitry Andric // This requires special handling.
1499*c9157d92SDimitry Andric bool IsTruncatingShuffle = VecType->getNumElements() < NumInputElts;
150081ad6265SDimitry Andric bool UsesSecondVec =
1501*c9157d92SDimitry Andric any_of(ConcatMask, [&](int M) { return M >= (int)NumInputElts; });
1502*c9157d92SDimitry Andric
1503*c9157d92SDimitry Andric FixedVectorType *VecTyForCost =
1504*c9157d92SDimitry Andric (UsesSecondVec && !IsTruncatingShuffle) ? VecType : ShuffleInputType;
150581ad6265SDimitry Andric InstructionCost OldCost = TTI.getShuffleCost(
1506*c9157d92SDimitry Andric UsesSecondVec ? TTI::SK_PermuteTwoSrc : TTI::SK_PermuteSingleSrc,
1507*c9157d92SDimitry Andric VecTyForCost, Shuffle->getShuffleMask());
150881ad6265SDimitry Andric InstructionCost NewCost = TTI.getShuffleCost(
1509*c9157d92SDimitry Andric UsesSecondVec ? TTI::SK_PermuteTwoSrc : TTI::SK_PermuteSingleSrc,
1510*c9157d92SDimitry Andric VecTyForCost, ConcatMask);
151181ad6265SDimitry Andric
151281ad6265SDimitry Andric LLVM_DEBUG(dbgs() << "Found a reduction feeding from a shuffle: " << *Shuffle
151381ad6265SDimitry Andric << "\n");
151481ad6265SDimitry Andric LLVM_DEBUG(dbgs() << " OldCost: " << OldCost << " vs NewCost: " << NewCost
151581ad6265SDimitry Andric << "\n");
151681ad6265SDimitry Andric if (NewCost < OldCost) {
151781ad6265SDimitry Andric Builder.SetInsertPoint(Shuffle);
151881ad6265SDimitry Andric Value *NewShuffle = Builder.CreateShuffleVector(
151981ad6265SDimitry Andric Shuffle->getOperand(0), Shuffle->getOperand(1), ConcatMask);
152081ad6265SDimitry Andric LLVM_DEBUG(dbgs() << "Created new shuffle: " << *NewShuffle << "\n");
152181ad6265SDimitry Andric replaceValue(*Shuffle, *NewShuffle);
152281ad6265SDimitry Andric }
152381ad6265SDimitry Andric
152481ad6265SDimitry Andric // See if we can re-use foldSelectShuffle, getting it to reduce the size of
152581ad6265SDimitry Andric // the shuffle into a nicer order, as it can ignore the order of the shuffles.
152681ad6265SDimitry Andric return foldSelectShuffle(*Shuffle, true);
152781ad6265SDimitry Andric }
152881ad6265SDimitry Andric
152981ad6265SDimitry Andric /// This method looks for groups of shuffles acting on binops, of the form:
153081ad6265SDimitry Andric /// %x = shuffle ...
153181ad6265SDimitry Andric /// %y = shuffle ...
153281ad6265SDimitry Andric /// %a = binop %x, %y
153381ad6265SDimitry Andric /// %b = binop %x, %y
153481ad6265SDimitry Andric /// shuffle %a, %b, selectmask
153581ad6265SDimitry Andric /// We may, especially if the shuffle is wider than legal, be able to convert
153681ad6265SDimitry Andric /// the shuffle to a form where only parts of a and b need to be computed. On
153781ad6265SDimitry Andric /// architectures with no obvious "select" shuffle, this can reduce the total
153881ad6265SDimitry Andric /// number of operations if the target reports them as cheaper.
foldSelectShuffle(Instruction & I,bool FromReduction)153981ad6265SDimitry Andric bool VectorCombine::foldSelectShuffle(Instruction &I, bool FromReduction) {
1540bdd1243dSDimitry Andric auto *SVI = cast<ShuffleVectorInst>(&I);
1541bdd1243dSDimitry Andric auto *VT = cast<FixedVectorType>(I.getType());
154281ad6265SDimitry Andric auto *Op0 = dyn_cast<Instruction>(SVI->getOperand(0));
154381ad6265SDimitry Andric auto *Op1 = dyn_cast<Instruction>(SVI->getOperand(1));
154481ad6265SDimitry Andric if (!Op0 || !Op1 || Op0 == Op1 || !Op0->isBinaryOp() || !Op1->isBinaryOp() ||
154581ad6265SDimitry Andric VT != Op0->getType())
154681ad6265SDimitry Andric return false;
1547bdd1243dSDimitry Andric
1548753f127fSDimitry Andric auto *SVI0A = dyn_cast<Instruction>(Op0->getOperand(0));
1549753f127fSDimitry Andric auto *SVI0B = dyn_cast<Instruction>(Op0->getOperand(1));
1550753f127fSDimitry Andric auto *SVI1A = dyn_cast<Instruction>(Op1->getOperand(0));
1551753f127fSDimitry Andric auto *SVI1B = dyn_cast<Instruction>(Op1->getOperand(1));
1552753f127fSDimitry Andric SmallPtrSet<Instruction *, 4> InputShuffles({SVI0A, SVI0B, SVI1A, SVI1B});
155381ad6265SDimitry Andric auto checkSVNonOpUses = [&](Instruction *I) {
155481ad6265SDimitry Andric if (!I || I->getOperand(0)->getType() != VT)
155581ad6265SDimitry Andric return true;
1556753f127fSDimitry Andric return any_of(I->users(), [&](User *U) {
1557753f127fSDimitry Andric return U != Op0 && U != Op1 &&
1558753f127fSDimitry Andric !(isa<ShuffleVectorInst>(U) &&
1559753f127fSDimitry Andric (InputShuffles.contains(cast<Instruction>(U)) ||
1560753f127fSDimitry Andric isInstructionTriviallyDead(cast<Instruction>(U))));
1561753f127fSDimitry Andric });
156281ad6265SDimitry Andric };
156381ad6265SDimitry Andric if (checkSVNonOpUses(SVI0A) || checkSVNonOpUses(SVI0B) ||
156481ad6265SDimitry Andric checkSVNonOpUses(SVI1A) || checkSVNonOpUses(SVI1B))
156581ad6265SDimitry Andric return false;
156681ad6265SDimitry Andric
156781ad6265SDimitry Andric // Collect all the uses that are shuffles that we can transform together. We
156881ad6265SDimitry Andric // may not have a single shuffle, but a group that can all be transformed
156981ad6265SDimitry Andric // together profitably.
157081ad6265SDimitry Andric SmallVector<ShuffleVectorInst *> Shuffles;
157181ad6265SDimitry Andric auto collectShuffles = [&](Instruction *I) {
157281ad6265SDimitry Andric for (auto *U : I->users()) {
157381ad6265SDimitry Andric auto *SV = dyn_cast<ShuffleVectorInst>(U);
157481ad6265SDimitry Andric if (!SV || SV->getType() != VT)
157581ad6265SDimitry Andric return false;
1576753f127fSDimitry Andric if ((SV->getOperand(0) != Op0 && SV->getOperand(0) != Op1) ||
1577753f127fSDimitry Andric (SV->getOperand(1) != Op0 && SV->getOperand(1) != Op1))
1578753f127fSDimitry Andric return false;
157981ad6265SDimitry Andric if (!llvm::is_contained(Shuffles, SV))
158081ad6265SDimitry Andric Shuffles.push_back(SV);
158181ad6265SDimitry Andric }
158281ad6265SDimitry Andric return true;
158381ad6265SDimitry Andric };
158481ad6265SDimitry Andric if (!collectShuffles(Op0) || !collectShuffles(Op1))
158581ad6265SDimitry Andric return false;
158681ad6265SDimitry Andric // From a reduction, we need to be processing a single shuffle, otherwise the
158781ad6265SDimitry Andric // other uses will not be lane-invariant.
158881ad6265SDimitry Andric if (FromReduction && Shuffles.size() > 1)
158981ad6265SDimitry Andric return false;
159081ad6265SDimitry Andric
1591753f127fSDimitry Andric // Add any shuffle uses for the shuffles we have found, to include them in our
1592753f127fSDimitry Andric // cost calculations.
1593753f127fSDimitry Andric if (!FromReduction) {
1594753f127fSDimitry Andric for (ShuffleVectorInst *SV : Shuffles) {
1595bdd1243dSDimitry Andric for (auto *U : SV->users()) {
1596753f127fSDimitry Andric ShuffleVectorInst *SSV = dyn_cast<ShuffleVectorInst>(U);
1597fcaf7f86SDimitry Andric if (SSV && isa<UndefValue>(SSV->getOperand(1)) && SSV->getType() == VT)
1598753f127fSDimitry Andric Shuffles.push_back(SSV);
1599753f127fSDimitry Andric }
1600753f127fSDimitry Andric }
1601753f127fSDimitry Andric }
1602753f127fSDimitry Andric
160381ad6265SDimitry Andric // For each of the output shuffles, we try to sort all the first vector
160481ad6265SDimitry Andric // elements to the beginning, followed by the second array elements at the
160581ad6265SDimitry Andric // end. If the binops are legalized to smaller vectors, this may reduce total
160681ad6265SDimitry Andric // number of binops. We compute the ReconstructMask mask needed to convert
160781ad6265SDimitry Andric // back to the original lane order.
1608753f127fSDimitry Andric SmallVector<std::pair<int, int>> V1, V2;
1609753f127fSDimitry Andric SmallVector<SmallVector<int>> OrigReconstructMasks;
161081ad6265SDimitry Andric int MaxV1Elt = 0, MaxV2Elt = 0;
161181ad6265SDimitry Andric unsigned NumElts = VT->getNumElements();
161281ad6265SDimitry Andric for (ShuffleVectorInst *SVN : Shuffles) {
161381ad6265SDimitry Andric SmallVector<int> Mask;
161481ad6265SDimitry Andric SVN->getShuffleMask(Mask);
161581ad6265SDimitry Andric
161681ad6265SDimitry Andric // Check the operands are the same as the original, or reversed (in which
161781ad6265SDimitry Andric // case we need to commute the mask).
161881ad6265SDimitry Andric Value *SVOp0 = SVN->getOperand(0);
161981ad6265SDimitry Andric Value *SVOp1 = SVN->getOperand(1);
1620753f127fSDimitry Andric if (isa<UndefValue>(SVOp1)) {
1621753f127fSDimitry Andric auto *SSV = cast<ShuffleVectorInst>(SVOp0);
1622753f127fSDimitry Andric SVOp0 = SSV->getOperand(0);
1623753f127fSDimitry Andric SVOp1 = SSV->getOperand(1);
1624753f127fSDimitry Andric for (unsigned I = 0, E = Mask.size(); I != E; I++) {
1625753f127fSDimitry Andric if (Mask[I] >= static_cast<int>(SSV->getShuffleMask().size()))
1626753f127fSDimitry Andric return false;
1627753f127fSDimitry Andric Mask[I] = Mask[I] < 0 ? Mask[I] : SSV->getMaskValue(Mask[I]);
1628753f127fSDimitry Andric }
1629753f127fSDimitry Andric }
163081ad6265SDimitry Andric if (SVOp0 == Op1 && SVOp1 == Op0) {
163181ad6265SDimitry Andric std::swap(SVOp0, SVOp1);
163281ad6265SDimitry Andric ShuffleVectorInst::commuteShuffleMask(Mask, NumElts);
163381ad6265SDimitry Andric }
163481ad6265SDimitry Andric if (SVOp0 != Op0 || SVOp1 != Op1)
163581ad6265SDimitry Andric return false;
163681ad6265SDimitry Andric
163781ad6265SDimitry Andric // Calculate the reconstruction mask for this shuffle, as the mask needed to
163881ad6265SDimitry Andric // take the packed values from Op0/Op1 and reconstructing to the original
163981ad6265SDimitry Andric // order.
164081ad6265SDimitry Andric SmallVector<int> ReconstructMask;
164181ad6265SDimitry Andric for (unsigned I = 0; I < Mask.size(); I++) {
164281ad6265SDimitry Andric if (Mask[I] < 0) {
164381ad6265SDimitry Andric ReconstructMask.push_back(-1);
164481ad6265SDimitry Andric } else if (Mask[I] < static_cast<int>(NumElts)) {
164581ad6265SDimitry Andric MaxV1Elt = std::max(MaxV1Elt, Mask[I]);
1646753f127fSDimitry Andric auto It = find_if(V1, [&](const std::pair<int, int> &A) {
1647753f127fSDimitry Andric return Mask[I] == A.first;
1648753f127fSDimitry Andric });
164981ad6265SDimitry Andric if (It != V1.end())
165081ad6265SDimitry Andric ReconstructMask.push_back(It - V1.begin());
165181ad6265SDimitry Andric else {
165281ad6265SDimitry Andric ReconstructMask.push_back(V1.size());
1653753f127fSDimitry Andric V1.emplace_back(Mask[I], V1.size());
165481ad6265SDimitry Andric }
165581ad6265SDimitry Andric } else {
165681ad6265SDimitry Andric MaxV2Elt = std::max<int>(MaxV2Elt, Mask[I] - NumElts);
1657753f127fSDimitry Andric auto It = find_if(V2, [&](const std::pair<int, int> &A) {
1658753f127fSDimitry Andric return Mask[I] - static_cast<int>(NumElts) == A.first;
1659753f127fSDimitry Andric });
166081ad6265SDimitry Andric if (It != V2.end())
166181ad6265SDimitry Andric ReconstructMask.push_back(NumElts + It - V2.begin());
166281ad6265SDimitry Andric else {
166381ad6265SDimitry Andric ReconstructMask.push_back(NumElts + V2.size());
1664753f127fSDimitry Andric V2.emplace_back(Mask[I] - NumElts, NumElts + V2.size());
166581ad6265SDimitry Andric }
166681ad6265SDimitry Andric }
166781ad6265SDimitry Andric }
166881ad6265SDimitry Andric
166981ad6265SDimitry Andric // For reductions, we know that the lane ordering out doesn't alter the
167081ad6265SDimitry Andric // result. In-order can help simplify the shuffle away.
167181ad6265SDimitry Andric if (FromReduction)
167281ad6265SDimitry Andric sort(ReconstructMask);
1673753f127fSDimitry Andric OrigReconstructMasks.push_back(std::move(ReconstructMask));
167481ad6265SDimitry Andric }
167581ad6265SDimitry Andric
167681ad6265SDimitry Andric // If the Maximum element used from V1 and V2 are not larger than the new
167781ad6265SDimitry Andric // vectors, the vectors are already packes and performing the optimization
167881ad6265SDimitry Andric // again will likely not help any further. This also prevents us from getting
167981ad6265SDimitry Andric // stuck in a cycle in case the costs do not also rule it out.
168081ad6265SDimitry Andric if (V1.empty() || V2.empty() ||
168181ad6265SDimitry Andric (MaxV1Elt == static_cast<int>(V1.size()) - 1 &&
168281ad6265SDimitry Andric MaxV2Elt == static_cast<int>(V2.size()) - 1))
168381ad6265SDimitry Andric return false;
168481ad6265SDimitry Andric
1685753f127fSDimitry Andric // GetBaseMaskValue takes one of the inputs, which may either be a shuffle, a
1686753f127fSDimitry Andric // shuffle of another shuffle, or not a shuffle (that is treated like a
1687753f127fSDimitry Andric // identity shuffle).
1688753f127fSDimitry Andric auto GetBaseMaskValue = [&](Instruction *I, int M) {
1689753f127fSDimitry Andric auto *SV = dyn_cast<ShuffleVectorInst>(I);
1690753f127fSDimitry Andric if (!SV)
1691753f127fSDimitry Andric return M;
1692753f127fSDimitry Andric if (isa<UndefValue>(SV->getOperand(1)))
1693753f127fSDimitry Andric if (auto *SSV = dyn_cast<ShuffleVectorInst>(SV->getOperand(0)))
1694753f127fSDimitry Andric if (InputShuffles.contains(SSV))
1695753f127fSDimitry Andric return SSV->getMaskValue(SV->getMaskValue(M));
1696753f127fSDimitry Andric return SV->getMaskValue(M);
1697753f127fSDimitry Andric };
1698753f127fSDimitry Andric
1699753f127fSDimitry Andric // Attempt to sort the inputs my ascending mask values to make simpler input
1700753f127fSDimitry Andric // shuffles and push complex shuffles down to the uses. We sort on the first
1701753f127fSDimitry Andric // of the two input shuffle orders, to try and get at least one input into a
1702753f127fSDimitry Andric // nice order.
1703753f127fSDimitry Andric auto SortBase = [&](Instruction *A, std::pair<int, int> X,
1704753f127fSDimitry Andric std::pair<int, int> Y) {
1705753f127fSDimitry Andric int MXA = GetBaseMaskValue(A, X.first);
1706753f127fSDimitry Andric int MYA = GetBaseMaskValue(A, Y.first);
1707753f127fSDimitry Andric return MXA < MYA;
1708753f127fSDimitry Andric };
1709753f127fSDimitry Andric stable_sort(V1, [&](std::pair<int, int> A, std::pair<int, int> B) {
1710753f127fSDimitry Andric return SortBase(SVI0A, A, B);
1711753f127fSDimitry Andric });
1712753f127fSDimitry Andric stable_sort(V2, [&](std::pair<int, int> A, std::pair<int, int> B) {
1713753f127fSDimitry Andric return SortBase(SVI1A, A, B);
1714753f127fSDimitry Andric });
1715753f127fSDimitry Andric // Calculate our ReconstructMasks from the OrigReconstructMasks and the
1716753f127fSDimitry Andric // modified order of the input shuffles.
1717753f127fSDimitry Andric SmallVector<SmallVector<int>> ReconstructMasks;
1718fe013be4SDimitry Andric for (const auto &Mask : OrigReconstructMasks) {
1719753f127fSDimitry Andric SmallVector<int> ReconstructMask;
1720753f127fSDimitry Andric for (int M : Mask) {
1721753f127fSDimitry Andric auto FindIndex = [](const SmallVector<std::pair<int, int>> &V, int M) {
1722753f127fSDimitry Andric auto It = find_if(V, [M](auto A) { return A.second == M; });
1723753f127fSDimitry Andric assert(It != V.end() && "Expected all entries in Mask");
1724753f127fSDimitry Andric return std::distance(V.begin(), It);
1725753f127fSDimitry Andric };
1726753f127fSDimitry Andric if (M < 0)
1727753f127fSDimitry Andric ReconstructMask.push_back(-1);
1728753f127fSDimitry Andric else if (M < static_cast<int>(NumElts)) {
1729753f127fSDimitry Andric ReconstructMask.push_back(FindIndex(V1, M));
1730753f127fSDimitry Andric } else {
1731753f127fSDimitry Andric ReconstructMask.push_back(NumElts + FindIndex(V2, M));
1732753f127fSDimitry Andric }
1733753f127fSDimitry Andric }
1734753f127fSDimitry Andric ReconstructMasks.push_back(std::move(ReconstructMask));
1735753f127fSDimitry Andric }
1736753f127fSDimitry Andric
173781ad6265SDimitry Andric // Calculate the masks needed for the new input shuffles, which get padded
173881ad6265SDimitry Andric // with undef
173981ad6265SDimitry Andric SmallVector<int> V1A, V1B, V2A, V2B;
174081ad6265SDimitry Andric for (unsigned I = 0; I < V1.size(); I++) {
1741753f127fSDimitry Andric V1A.push_back(GetBaseMaskValue(SVI0A, V1[I].first));
1742753f127fSDimitry Andric V1B.push_back(GetBaseMaskValue(SVI0B, V1[I].first));
174381ad6265SDimitry Andric }
174481ad6265SDimitry Andric for (unsigned I = 0; I < V2.size(); I++) {
1745753f127fSDimitry Andric V2A.push_back(GetBaseMaskValue(SVI1A, V2[I].first));
1746753f127fSDimitry Andric V2B.push_back(GetBaseMaskValue(SVI1B, V2[I].first));
174781ad6265SDimitry Andric }
174881ad6265SDimitry Andric while (V1A.size() < NumElts) {
1749fe013be4SDimitry Andric V1A.push_back(PoisonMaskElem);
1750fe013be4SDimitry Andric V1B.push_back(PoisonMaskElem);
175181ad6265SDimitry Andric }
175281ad6265SDimitry Andric while (V2A.size() < NumElts) {
1753fe013be4SDimitry Andric V2A.push_back(PoisonMaskElem);
1754fe013be4SDimitry Andric V2B.push_back(PoisonMaskElem);
175581ad6265SDimitry Andric }
175681ad6265SDimitry Andric
1757753f127fSDimitry Andric auto AddShuffleCost = [&](InstructionCost C, Instruction *I) {
1758753f127fSDimitry Andric auto *SV = dyn_cast<ShuffleVectorInst>(I);
1759753f127fSDimitry Andric if (!SV)
1760753f127fSDimitry Andric return C;
1761753f127fSDimitry Andric return C + TTI.getShuffleCost(isa<UndefValue>(SV->getOperand(1))
1762753f127fSDimitry Andric ? TTI::SK_PermuteSingleSrc
1763753f127fSDimitry Andric : TTI::SK_PermuteTwoSrc,
1764753f127fSDimitry Andric VT, SV->getShuffleMask());
176581ad6265SDimitry Andric };
176681ad6265SDimitry Andric auto AddShuffleMaskCost = [&](InstructionCost C, ArrayRef<int> Mask) {
176781ad6265SDimitry Andric return C + TTI.getShuffleCost(TTI::SK_PermuteTwoSrc, VT, Mask);
176881ad6265SDimitry Andric };
176981ad6265SDimitry Andric
177081ad6265SDimitry Andric // Get the costs of the shuffles + binops before and after with the new
177181ad6265SDimitry Andric // shuffle masks.
177281ad6265SDimitry Andric InstructionCost CostBefore =
177381ad6265SDimitry Andric TTI.getArithmeticInstrCost(Op0->getOpcode(), VT) +
177481ad6265SDimitry Andric TTI.getArithmeticInstrCost(Op1->getOpcode(), VT);
177581ad6265SDimitry Andric CostBefore += std::accumulate(Shuffles.begin(), Shuffles.end(),
177681ad6265SDimitry Andric InstructionCost(0), AddShuffleCost);
177781ad6265SDimitry Andric CostBefore += std::accumulate(InputShuffles.begin(), InputShuffles.end(),
177881ad6265SDimitry Andric InstructionCost(0), AddShuffleCost);
177981ad6265SDimitry Andric
178081ad6265SDimitry Andric // The new binops will be unused for lanes past the used shuffle lengths.
178181ad6265SDimitry Andric // These types attempt to get the correct cost for that from the target.
178281ad6265SDimitry Andric FixedVectorType *Op0SmallVT =
178381ad6265SDimitry Andric FixedVectorType::get(VT->getScalarType(), V1.size());
178481ad6265SDimitry Andric FixedVectorType *Op1SmallVT =
178581ad6265SDimitry Andric FixedVectorType::get(VT->getScalarType(), V2.size());
178681ad6265SDimitry Andric InstructionCost CostAfter =
178781ad6265SDimitry Andric TTI.getArithmeticInstrCost(Op0->getOpcode(), Op0SmallVT) +
178881ad6265SDimitry Andric TTI.getArithmeticInstrCost(Op1->getOpcode(), Op1SmallVT);
178981ad6265SDimitry Andric CostAfter += std::accumulate(ReconstructMasks.begin(), ReconstructMasks.end(),
179081ad6265SDimitry Andric InstructionCost(0), AddShuffleMaskCost);
179181ad6265SDimitry Andric std::set<SmallVector<int>> OutputShuffleMasks({V1A, V1B, V2A, V2B});
179281ad6265SDimitry Andric CostAfter +=
179381ad6265SDimitry Andric std::accumulate(OutputShuffleMasks.begin(), OutputShuffleMasks.end(),
179481ad6265SDimitry Andric InstructionCost(0), AddShuffleMaskCost);
179581ad6265SDimitry Andric
1796753f127fSDimitry Andric LLVM_DEBUG(dbgs() << "Found a binop select shuffle pattern: " << I << "\n");
1797753f127fSDimitry Andric LLVM_DEBUG(dbgs() << " CostBefore: " << CostBefore
1798753f127fSDimitry Andric << " vs CostAfter: " << CostAfter << "\n");
179981ad6265SDimitry Andric if (CostBefore <= CostAfter)
180081ad6265SDimitry Andric return false;
180181ad6265SDimitry Andric
180281ad6265SDimitry Andric // The cost model has passed, create the new instructions.
1803753f127fSDimitry Andric auto GetShuffleOperand = [&](Instruction *I, unsigned Op) -> Value * {
1804753f127fSDimitry Andric auto *SV = dyn_cast<ShuffleVectorInst>(I);
1805753f127fSDimitry Andric if (!SV)
1806753f127fSDimitry Andric return I;
1807753f127fSDimitry Andric if (isa<UndefValue>(SV->getOperand(1)))
1808753f127fSDimitry Andric if (auto *SSV = dyn_cast<ShuffleVectorInst>(SV->getOperand(0)))
1809753f127fSDimitry Andric if (InputShuffles.contains(SSV))
1810753f127fSDimitry Andric return SSV->getOperand(Op);
1811753f127fSDimitry Andric return SV->getOperand(Op);
1812753f127fSDimitry Andric };
1813*c9157d92SDimitry Andric Builder.SetInsertPoint(*SVI0A->getInsertionPointAfterDef());
1814753f127fSDimitry Andric Value *NSV0A = Builder.CreateShuffleVector(GetShuffleOperand(SVI0A, 0),
1815753f127fSDimitry Andric GetShuffleOperand(SVI0A, 1), V1A);
1816*c9157d92SDimitry Andric Builder.SetInsertPoint(*SVI0B->getInsertionPointAfterDef());
1817753f127fSDimitry Andric Value *NSV0B = Builder.CreateShuffleVector(GetShuffleOperand(SVI0B, 0),
1818753f127fSDimitry Andric GetShuffleOperand(SVI0B, 1), V1B);
1819*c9157d92SDimitry Andric Builder.SetInsertPoint(*SVI1A->getInsertionPointAfterDef());
1820753f127fSDimitry Andric Value *NSV1A = Builder.CreateShuffleVector(GetShuffleOperand(SVI1A, 0),
1821753f127fSDimitry Andric GetShuffleOperand(SVI1A, 1), V2A);
1822*c9157d92SDimitry Andric Builder.SetInsertPoint(*SVI1B->getInsertionPointAfterDef());
1823753f127fSDimitry Andric Value *NSV1B = Builder.CreateShuffleVector(GetShuffleOperand(SVI1B, 0),
1824753f127fSDimitry Andric GetShuffleOperand(SVI1B, 1), V2B);
182581ad6265SDimitry Andric Builder.SetInsertPoint(Op0);
182681ad6265SDimitry Andric Value *NOp0 = Builder.CreateBinOp((Instruction::BinaryOps)Op0->getOpcode(),
182781ad6265SDimitry Andric NSV0A, NSV0B);
182881ad6265SDimitry Andric if (auto *I = dyn_cast<Instruction>(NOp0))
182981ad6265SDimitry Andric I->copyIRFlags(Op0, true);
183081ad6265SDimitry Andric Builder.SetInsertPoint(Op1);
183181ad6265SDimitry Andric Value *NOp1 = Builder.CreateBinOp((Instruction::BinaryOps)Op1->getOpcode(),
183281ad6265SDimitry Andric NSV1A, NSV1B);
183381ad6265SDimitry Andric if (auto *I = dyn_cast<Instruction>(NOp1))
183481ad6265SDimitry Andric I->copyIRFlags(Op1, true);
183581ad6265SDimitry Andric
183681ad6265SDimitry Andric for (int S = 0, E = ReconstructMasks.size(); S != E; S++) {
183781ad6265SDimitry Andric Builder.SetInsertPoint(Shuffles[S]);
183881ad6265SDimitry Andric Value *NSV = Builder.CreateShuffleVector(NOp0, NOp1, ReconstructMasks[S]);
183981ad6265SDimitry Andric replaceValue(*Shuffles[S], *NSV);
184081ad6265SDimitry Andric }
184181ad6265SDimitry Andric
184281ad6265SDimitry Andric Worklist.pushValue(NSV0A);
184381ad6265SDimitry Andric Worklist.pushValue(NSV0B);
184481ad6265SDimitry Andric Worklist.pushValue(NSV1A);
184581ad6265SDimitry Andric Worklist.pushValue(NSV1B);
184681ad6265SDimitry Andric for (auto *S : Shuffles)
184781ad6265SDimitry Andric Worklist.add(S);
184881ad6265SDimitry Andric return true;
184981ad6265SDimitry Andric }
185081ad6265SDimitry Andric
18515ffd83dbSDimitry Andric /// This is the entry point for all transforms. Pass manager differences are
18525ffd83dbSDimitry Andric /// handled in the callers of this function.
run()18535ffd83dbSDimitry Andric bool VectorCombine::run() {
18545ffd83dbSDimitry Andric if (DisableVectorCombine)
18555ffd83dbSDimitry Andric return false;
18565ffd83dbSDimitry Andric
1857e8d8bef9SDimitry Andric // Don't attempt vectorization if the target does not support vectors.
1858e8d8bef9SDimitry Andric if (!TTI.getNumberOfRegisters(TTI.getRegisterClassForType(/*Vector*/ true)))
1859e8d8bef9SDimitry Andric return false;
1860e8d8bef9SDimitry Andric
18615ffd83dbSDimitry Andric bool MadeChange = false;
1862349cc55cSDimitry Andric auto FoldInst = [this, &MadeChange](Instruction &I) {
1863349cc55cSDimitry Andric Builder.SetInsertPoint(&I);
1864bdd1243dSDimitry Andric bool IsFixedVectorType = isa<FixedVectorType>(I.getType());
1865bdd1243dSDimitry Andric auto Opcode = I.getOpcode();
1866bdd1243dSDimitry Andric
1867bdd1243dSDimitry Andric // These folds should be beneficial regardless of when this pass is run
1868bdd1243dSDimitry Andric // in the optimization pipeline.
1869bdd1243dSDimitry Andric // The type checking is for run-time efficiency. We can avoid wasting time
1870bdd1243dSDimitry Andric // dispatching to folding functions if there's no chance of matching.
1871bdd1243dSDimitry Andric if (IsFixedVectorType) {
1872bdd1243dSDimitry Andric switch (Opcode) {
1873bdd1243dSDimitry Andric case Instruction::InsertElement:
1874349cc55cSDimitry Andric MadeChange |= vectorizeLoadInsert(I);
1875bdd1243dSDimitry Andric break;
1876bdd1243dSDimitry Andric case Instruction::ShuffleVector:
1877bdd1243dSDimitry Andric MadeChange |= widenSubvectorLoad(I);
1878bdd1243dSDimitry Andric break;
1879bdd1243dSDimitry Andric default:
1880bdd1243dSDimitry Andric break;
1881bdd1243dSDimitry Andric }
1882bdd1243dSDimitry Andric }
1883bdd1243dSDimitry Andric
1884bdd1243dSDimitry Andric // This transform works with scalable and fixed vectors
1885bdd1243dSDimitry Andric // TODO: Identify and allow other scalable transforms
1886*c9157d92SDimitry Andric if (isa<VectorType>(I.getType())) {
1887bdd1243dSDimitry Andric MadeChange |= scalarizeBinopOrCmp(I);
1888*c9157d92SDimitry Andric MadeChange |= scalarizeLoadExtract(I);
1889*c9157d92SDimitry Andric MadeChange |= scalarizeVPIntrinsic(I);
1890*c9157d92SDimitry Andric }
1891bdd1243dSDimitry Andric
1892bdd1243dSDimitry Andric if (Opcode == Instruction::Store)
1893349cc55cSDimitry Andric MadeChange |= foldSingleElementStore(I);
1894bdd1243dSDimitry Andric
1895bdd1243dSDimitry Andric // If this is an early pipeline invocation of this pass, we are done.
1896bdd1243dSDimitry Andric if (TryEarlyFoldsOnly)
1897bdd1243dSDimitry Andric return;
1898bdd1243dSDimitry Andric
1899bdd1243dSDimitry Andric // Otherwise, try folds that improve codegen but may interfere with
1900bdd1243dSDimitry Andric // early IR canonicalizations.
1901bdd1243dSDimitry Andric // The type checking is for run-time efficiency. We can avoid wasting time
1902bdd1243dSDimitry Andric // dispatching to folding functions if there's no chance of matching.
1903bdd1243dSDimitry Andric if (IsFixedVectorType) {
1904bdd1243dSDimitry Andric switch (Opcode) {
1905bdd1243dSDimitry Andric case Instruction::InsertElement:
1906bdd1243dSDimitry Andric MadeChange |= foldInsExtFNeg(I);
1907bdd1243dSDimitry Andric break;
1908bdd1243dSDimitry Andric case Instruction::ShuffleVector:
1909bdd1243dSDimitry Andric MadeChange |= foldShuffleOfBinops(I);
1910bdd1243dSDimitry Andric MadeChange |= foldSelectShuffle(I);
1911bdd1243dSDimitry Andric break;
1912bdd1243dSDimitry Andric case Instruction::BitCast:
1913*c9157d92SDimitry Andric MadeChange |= foldBitcastShuffle(I);
1914bdd1243dSDimitry Andric break;
1915bdd1243dSDimitry Andric }
1916bdd1243dSDimitry Andric } else {
1917bdd1243dSDimitry Andric switch (Opcode) {
1918bdd1243dSDimitry Andric case Instruction::Call:
1919bdd1243dSDimitry Andric MadeChange |= foldShuffleFromReductions(I);
1920bdd1243dSDimitry Andric break;
1921bdd1243dSDimitry Andric case Instruction::ICmp:
1922bdd1243dSDimitry Andric case Instruction::FCmp:
1923bdd1243dSDimitry Andric MadeChange |= foldExtractExtract(I);
1924bdd1243dSDimitry Andric break;
1925bdd1243dSDimitry Andric default:
1926bdd1243dSDimitry Andric if (Instruction::isBinaryOp(Opcode)) {
1927bdd1243dSDimitry Andric MadeChange |= foldExtractExtract(I);
1928bdd1243dSDimitry Andric MadeChange |= foldExtractedCmps(I);
1929bdd1243dSDimitry Andric }
1930bdd1243dSDimitry Andric break;
1931bdd1243dSDimitry Andric }
1932bdd1243dSDimitry Andric }
1933349cc55cSDimitry Andric };
1934bdd1243dSDimitry Andric
19355ffd83dbSDimitry Andric for (BasicBlock &BB : F) {
19365ffd83dbSDimitry Andric // Ignore unreachable basic blocks.
19375ffd83dbSDimitry Andric if (!DT.isReachableFromEntry(&BB))
19385ffd83dbSDimitry Andric continue;
1939fe6060f1SDimitry Andric // Use early increment range so that we can erase instructions in loop.
1940fe6060f1SDimitry Andric for (Instruction &I : make_early_inc_range(BB)) {
1941349cc55cSDimitry Andric if (I.isDebugOrPseudoInst())
19425ffd83dbSDimitry Andric continue;
1943349cc55cSDimitry Andric FoldInst(I);
19445ffd83dbSDimitry Andric }
19455ffd83dbSDimitry Andric }
19465ffd83dbSDimitry Andric
1947349cc55cSDimitry Andric while (!Worklist.isEmpty()) {
1948349cc55cSDimitry Andric Instruction *I = Worklist.removeOne();
1949349cc55cSDimitry Andric if (!I)
1950349cc55cSDimitry Andric continue;
1951349cc55cSDimitry Andric
1952349cc55cSDimitry Andric if (isInstructionTriviallyDead(I)) {
1953349cc55cSDimitry Andric eraseInstruction(*I);
1954349cc55cSDimitry Andric continue;
1955349cc55cSDimitry Andric }
1956349cc55cSDimitry Andric
1957349cc55cSDimitry Andric FoldInst(*I);
1958349cc55cSDimitry Andric }
19595ffd83dbSDimitry Andric
19605ffd83dbSDimitry Andric return MadeChange;
19615ffd83dbSDimitry Andric }
19625ffd83dbSDimitry Andric
run(Function & F,FunctionAnalysisManager & FAM)19635ffd83dbSDimitry Andric PreservedAnalyses VectorCombinePass::run(Function &F,
19645ffd83dbSDimitry Andric FunctionAnalysisManager &FAM) {
1965fe6060f1SDimitry Andric auto &AC = FAM.getResult<AssumptionAnalysis>(F);
19665ffd83dbSDimitry Andric TargetTransformInfo &TTI = FAM.getResult<TargetIRAnalysis>(F);
19675ffd83dbSDimitry Andric DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F);
1968fe6060f1SDimitry Andric AAResults &AA = FAM.getResult<AAManager>(F);
1969bdd1243dSDimitry Andric VectorCombine Combiner(F, TTI, DT, AA, AC, TryEarlyFoldsOnly);
19705ffd83dbSDimitry Andric if (!Combiner.run())
19715ffd83dbSDimitry Andric return PreservedAnalyses::all();
19725ffd83dbSDimitry Andric PreservedAnalyses PA;
19735ffd83dbSDimitry Andric PA.preserveSet<CFGAnalyses>();
19745ffd83dbSDimitry Andric return PA;
19755ffd83dbSDimitry Andric }
1976