1 //===-- AArch64TargetTransformInfo.cpp - AArch64 specific TTI -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "AArch64TargetTransformInfo.h"
10 #include "AArch64ExpandImm.h"
11 #include "MCTargetDesc/AArch64AddressingModes.h"
12 #include "llvm/Analysis/IVDescriptors.h"
13 #include "llvm/Analysis/LoopInfo.h"
14 #include "llvm/Analysis/TargetTransformInfo.h"
15 #include "llvm/CodeGen/BasicTTIImpl.h"
16 #include "llvm/CodeGen/CostTable.h"
17 #include "llvm/CodeGen/TargetLowering.h"
18 #include "llvm/IR/Intrinsics.h"
19 #include "llvm/IR/IntrinsicInst.h"
20 #include "llvm/IR/IntrinsicsAArch64.h"
21 #include "llvm/IR/PatternMatch.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Transforms/InstCombine/InstCombiner.h"
24 #include <algorithm>
25 using namespace llvm;
26 using namespace llvm::PatternMatch;
27 
28 #define DEBUG_TYPE "aarch64tti"
29 
30 static cl::opt<bool> EnableFalkorHWPFUnrollFix("enable-falkor-hwpf-unroll-fix",
31                                                cl::init(true), cl::Hidden);
32 
33 static cl::opt<unsigned> SVEGatherOverhead("sve-gather-overhead", cl::init(10),
34                                            cl::Hidden);
35 
36 static cl::opt<unsigned> SVEScatterOverhead("sve-scatter-overhead",
37                                             cl::init(10), cl::Hidden);
38 
39 bool AArch64TTIImpl::areInlineCompatible(const Function *Caller,
40                                          const Function *Callee) const {
41   const TargetMachine &TM = getTLI()->getTargetMachine();
42 
43   const FeatureBitset &CallerBits =
44       TM.getSubtargetImpl(*Caller)->getFeatureBits();
45   const FeatureBitset &CalleeBits =
46       TM.getSubtargetImpl(*Callee)->getFeatureBits();
47 
48   // Inline a callee if its target-features are a subset of the callers
49   // target-features.
50   return (CallerBits & CalleeBits) == CalleeBits;
51 }
52 
53 /// Calculate the cost of materializing a 64-bit value. This helper
54 /// method might only calculate a fraction of a larger immediate. Therefore it
55 /// is valid to return a cost of ZERO.
56 InstructionCost AArch64TTIImpl::getIntImmCost(int64_t Val) {
57   // Check if the immediate can be encoded within an instruction.
58   if (Val == 0 || AArch64_AM::isLogicalImmediate(Val, 64))
59     return 0;
60 
61   if (Val < 0)
62     Val = ~Val;
63 
64   // Calculate how many moves we will need to materialize this constant.
65   SmallVector<AArch64_IMM::ImmInsnModel, 4> Insn;
66   AArch64_IMM::expandMOVImm(Val, 64, Insn);
67   return Insn.size();
68 }
69 
70 /// Calculate the cost of materializing the given constant.
71 InstructionCost AArch64TTIImpl::getIntImmCost(const APInt &Imm, Type *Ty,
72                                               TTI::TargetCostKind CostKind) {
73   assert(Ty->isIntegerTy());
74 
75   unsigned BitSize = Ty->getPrimitiveSizeInBits();
76   if (BitSize == 0)
77     return ~0U;
78 
79   // Sign-extend all constants to a multiple of 64-bit.
80   APInt ImmVal = Imm;
81   if (BitSize & 0x3f)
82     ImmVal = Imm.sext((BitSize + 63) & ~0x3fU);
83 
84   // Split the constant into 64-bit chunks and calculate the cost for each
85   // chunk.
86   InstructionCost Cost = 0;
87   for (unsigned ShiftVal = 0; ShiftVal < BitSize; ShiftVal += 64) {
88     APInt Tmp = ImmVal.ashr(ShiftVal).sextOrTrunc(64);
89     int64_t Val = Tmp.getSExtValue();
90     Cost += getIntImmCost(Val);
91   }
92   // We need at least one instruction to materialze the constant.
93   return std::max<InstructionCost>(1, Cost);
94 }
95 
96 InstructionCost AArch64TTIImpl::getIntImmCostInst(unsigned Opcode, unsigned Idx,
97                                                   const APInt &Imm, Type *Ty,
98                                                   TTI::TargetCostKind CostKind,
99                                                   Instruction *Inst) {
100   assert(Ty->isIntegerTy());
101 
102   unsigned BitSize = Ty->getPrimitiveSizeInBits();
103   // There is no cost model for constants with a bit size of 0. Return TCC_Free
104   // here, so that constant hoisting will ignore this constant.
105   if (BitSize == 0)
106     return TTI::TCC_Free;
107 
108   unsigned ImmIdx = ~0U;
109   switch (Opcode) {
110   default:
111     return TTI::TCC_Free;
112   case Instruction::GetElementPtr:
113     // Always hoist the base address of a GetElementPtr.
114     if (Idx == 0)
115       return 2 * TTI::TCC_Basic;
116     return TTI::TCC_Free;
117   case Instruction::Store:
118     ImmIdx = 0;
119     break;
120   case Instruction::Add:
121   case Instruction::Sub:
122   case Instruction::Mul:
123   case Instruction::UDiv:
124   case Instruction::SDiv:
125   case Instruction::URem:
126   case Instruction::SRem:
127   case Instruction::And:
128   case Instruction::Or:
129   case Instruction::Xor:
130   case Instruction::ICmp:
131     ImmIdx = 1;
132     break;
133   // Always return TCC_Free for the shift value of a shift instruction.
134   case Instruction::Shl:
135   case Instruction::LShr:
136   case Instruction::AShr:
137     if (Idx == 1)
138       return TTI::TCC_Free;
139     break;
140   case Instruction::Trunc:
141   case Instruction::ZExt:
142   case Instruction::SExt:
143   case Instruction::IntToPtr:
144   case Instruction::PtrToInt:
145   case Instruction::BitCast:
146   case Instruction::PHI:
147   case Instruction::Call:
148   case Instruction::Select:
149   case Instruction::Ret:
150   case Instruction::Load:
151     break;
152   }
153 
154   if (Idx == ImmIdx) {
155     int NumConstants = (BitSize + 63) / 64;
156     InstructionCost Cost = AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind);
157     return (Cost <= NumConstants * TTI::TCC_Basic)
158                ? static_cast<int>(TTI::TCC_Free)
159                : Cost;
160   }
161   return AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind);
162 }
163 
164 InstructionCost
165 AArch64TTIImpl::getIntImmCostIntrin(Intrinsic::ID IID, unsigned Idx,
166                                     const APInt &Imm, Type *Ty,
167                                     TTI::TargetCostKind CostKind) {
168   assert(Ty->isIntegerTy());
169 
170   unsigned BitSize = Ty->getPrimitiveSizeInBits();
171   // There is no cost model for constants with a bit size of 0. Return TCC_Free
172   // here, so that constant hoisting will ignore this constant.
173   if (BitSize == 0)
174     return TTI::TCC_Free;
175 
176   // Most (all?) AArch64 intrinsics do not support folding immediates into the
177   // selected instruction, so we compute the materialization cost for the
178   // immediate directly.
179   if (IID >= Intrinsic::aarch64_addg && IID <= Intrinsic::aarch64_udiv)
180     return AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind);
181 
182   switch (IID) {
183   default:
184     return TTI::TCC_Free;
185   case Intrinsic::sadd_with_overflow:
186   case Intrinsic::uadd_with_overflow:
187   case Intrinsic::ssub_with_overflow:
188   case Intrinsic::usub_with_overflow:
189   case Intrinsic::smul_with_overflow:
190   case Intrinsic::umul_with_overflow:
191     if (Idx == 1) {
192       int NumConstants = (BitSize + 63) / 64;
193       InstructionCost Cost = AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind);
194       return (Cost <= NumConstants * TTI::TCC_Basic)
195                  ? static_cast<int>(TTI::TCC_Free)
196                  : Cost;
197     }
198     break;
199   case Intrinsic::experimental_stackmap:
200     if ((Idx < 2) || (Imm.getBitWidth() <= 64 && isInt<64>(Imm.getSExtValue())))
201       return TTI::TCC_Free;
202     break;
203   case Intrinsic::experimental_patchpoint_void:
204   case Intrinsic::experimental_patchpoint_i64:
205     if ((Idx < 4) || (Imm.getBitWidth() <= 64 && isInt<64>(Imm.getSExtValue())))
206       return TTI::TCC_Free;
207     break;
208   case Intrinsic::experimental_gc_statepoint:
209     if ((Idx < 5) || (Imm.getBitWidth() <= 64 && isInt<64>(Imm.getSExtValue())))
210       return TTI::TCC_Free;
211     break;
212   }
213   return AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind);
214 }
215 
216 TargetTransformInfo::PopcntSupportKind
217 AArch64TTIImpl::getPopcntSupport(unsigned TyWidth) {
218   assert(isPowerOf2_32(TyWidth) && "Ty width must be power of 2");
219   if (TyWidth == 32 || TyWidth == 64)
220     return TTI::PSK_FastHardware;
221   // TODO: AArch64TargetLowering::LowerCTPOP() supports 128bit popcount.
222   return TTI::PSK_Software;
223 }
224 
225 InstructionCost
226 AArch64TTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
227                                       TTI::TargetCostKind CostKind) {
228   auto *RetTy = ICA.getReturnType();
229   switch (ICA.getID()) {
230   case Intrinsic::umin:
231   case Intrinsic::umax:
232   case Intrinsic::smin:
233   case Intrinsic::smax: {
234     static const auto ValidMinMaxTys = {MVT::v8i8,  MVT::v16i8, MVT::v4i16,
235                                         MVT::v8i16, MVT::v2i32, MVT::v4i32};
236     auto LT = TLI->getTypeLegalizationCost(DL, RetTy);
237     // v2i64 types get converted to cmp+bif hence the cost of 2
238     if (LT.second == MVT::v2i64)
239       return LT.first * 2;
240     if (any_of(ValidMinMaxTys, [&LT](MVT M) { return M == LT.second; }))
241       return LT.first;
242     break;
243   }
244   case Intrinsic::sadd_sat:
245   case Intrinsic::ssub_sat:
246   case Intrinsic::uadd_sat:
247   case Intrinsic::usub_sat: {
248     static const auto ValidSatTys = {MVT::v8i8,  MVT::v16i8, MVT::v4i16,
249                                      MVT::v8i16, MVT::v2i32, MVT::v4i32,
250                                      MVT::v2i64};
251     auto LT = TLI->getTypeLegalizationCost(DL, RetTy);
252     // This is a base cost of 1 for the vadd, plus 3 extract shifts if we
253     // need to extend the type, as it uses shr(qadd(shl, shl)).
254     unsigned Instrs =
255         LT.second.getScalarSizeInBits() == RetTy->getScalarSizeInBits() ? 1 : 4;
256     if (any_of(ValidSatTys, [&LT](MVT M) { return M == LT.second; }))
257       return LT.first * Instrs;
258     break;
259   }
260   case Intrinsic::abs: {
261     static const auto ValidAbsTys = {MVT::v8i8,  MVT::v16i8, MVT::v4i16,
262                                      MVT::v8i16, MVT::v2i32, MVT::v4i32,
263                                      MVT::v2i64};
264     auto LT = TLI->getTypeLegalizationCost(DL, RetTy);
265     if (any_of(ValidAbsTys, [&LT](MVT M) { return M == LT.second; }))
266       return LT.first;
267     break;
268   }
269   case Intrinsic::experimental_stepvector: {
270     InstructionCost Cost = 1; // Cost of the `index' instruction
271     auto LT = TLI->getTypeLegalizationCost(DL, RetTy);
272     // Legalisation of illegal vectors involves an `index' instruction plus
273     // (LT.first - 1) vector adds.
274     if (LT.first > 1) {
275       Type *LegalVTy = EVT(LT.second).getTypeForEVT(RetTy->getContext());
276       InstructionCost AddCost =
277           getArithmeticInstrCost(Instruction::Add, LegalVTy, CostKind);
278       Cost += AddCost * (LT.first - 1);
279     }
280     return Cost;
281   }
282   case Intrinsic::bitreverse: {
283     static const CostTblEntry BitreverseTbl[] = {
284         {Intrinsic::bitreverse, MVT::i32, 1},
285         {Intrinsic::bitreverse, MVT::i64, 1},
286         {Intrinsic::bitreverse, MVT::v8i8, 1},
287         {Intrinsic::bitreverse, MVT::v16i8, 1},
288         {Intrinsic::bitreverse, MVT::v4i16, 2},
289         {Intrinsic::bitreverse, MVT::v8i16, 2},
290         {Intrinsic::bitreverse, MVT::v2i32, 2},
291         {Intrinsic::bitreverse, MVT::v4i32, 2},
292         {Intrinsic::bitreverse, MVT::v1i64, 2},
293         {Intrinsic::bitreverse, MVT::v2i64, 2},
294     };
295     const auto LegalisationCost = TLI->getTypeLegalizationCost(DL, RetTy);
296     const auto *Entry =
297         CostTableLookup(BitreverseTbl, ICA.getID(), LegalisationCost.second);
298     if (Entry) {
299       // Cost Model is using the legal type(i32) that i8 and i16 will be
300       // converted to +1 so that we match the actual lowering cost
301       if (TLI->getValueType(DL, RetTy, true) == MVT::i8 ||
302           TLI->getValueType(DL, RetTy, true) == MVT::i16)
303         return LegalisationCost.first * Entry->Cost + 1;
304 
305       return LegalisationCost.first * Entry->Cost;
306     }
307     break;
308   }
309   case Intrinsic::ctpop: {
310     static const CostTblEntry CtpopCostTbl[] = {
311         {ISD::CTPOP, MVT::v2i64, 4},
312         {ISD::CTPOP, MVT::v4i32, 3},
313         {ISD::CTPOP, MVT::v8i16, 2},
314         {ISD::CTPOP, MVT::v16i8, 1},
315         {ISD::CTPOP, MVT::i64,   4},
316         {ISD::CTPOP, MVT::v2i32, 3},
317         {ISD::CTPOP, MVT::v4i16, 2},
318         {ISD::CTPOP, MVT::v8i8,  1},
319         {ISD::CTPOP, MVT::i32,   5},
320     };
321     auto LT = TLI->getTypeLegalizationCost(DL, RetTy);
322     MVT MTy = LT.second;
323     if (const auto *Entry = CostTableLookup(CtpopCostTbl, ISD::CTPOP, MTy)) {
324       // Extra cost of +1 when illegal vector types are legalized by promoting
325       // the integer type.
326       int ExtraCost = MTy.isVector() && MTy.getScalarSizeInBits() !=
327                                             RetTy->getScalarSizeInBits()
328                           ? 1
329                           : 0;
330       return LT.first * Entry->Cost + ExtraCost;
331     }
332     break;
333   }
334   default:
335     break;
336   }
337   return BaseT::getIntrinsicInstrCost(ICA, CostKind);
338 }
339 
340 /// The function will remove redundant reinterprets casting in the presence
341 /// of the control flow
342 static Optional<Instruction *> processPhiNode(InstCombiner &IC,
343                                               IntrinsicInst &II) {
344   SmallVector<Instruction *, 32> Worklist;
345   auto RequiredType = II.getType();
346 
347   auto *PN = dyn_cast<PHINode>(II.getArgOperand(0));
348   assert(PN && "Expected Phi Node!");
349 
350   // Don't create a new Phi unless we can remove the old one.
351   if (!PN->hasOneUse())
352     return None;
353 
354   for (Value *IncValPhi : PN->incoming_values()) {
355     auto *Reinterpret = dyn_cast<IntrinsicInst>(IncValPhi);
356     if (!Reinterpret ||
357         Reinterpret->getIntrinsicID() !=
358             Intrinsic::aarch64_sve_convert_to_svbool ||
359         RequiredType != Reinterpret->getArgOperand(0)->getType())
360       return None;
361   }
362 
363   // Create the new Phi
364   LLVMContext &Ctx = PN->getContext();
365   IRBuilder<> Builder(Ctx);
366   Builder.SetInsertPoint(PN);
367   PHINode *NPN = Builder.CreatePHI(RequiredType, PN->getNumIncomingValues());
368   Worklist.push_back(PN);
369 
370   for (unsigned I = 0; I < PN->getNumIncomingValues(); I++) {
371     auto *Reinterpret = cast<Instruction>(PN->getIncomingValue(I));
372     NPN->addIncoming(Reinterpret->getOperand(0), PN->getIncomingBlock(I));
373     Worklist.push_back(Reinterpret);
374   }
375 
376   // Cleanup Phi Node and reinterprets
377   return IC.replaceInstUsesWith(II, NPN);
378 }
379 
380 static Optional<Instruction *> instCombineConvertFromSVBool(InstCombiner &IC,
381                                                             IntrinsicInst &II) {
382   // If the reinterpret instruction operand is a PHI Node
383   if (isa<PHINode>(II.getArgOperand(0)))
384     return processPhiNode(IC, II);
385 
386   SmallVector<Instruction *, 32> CandidatesForRemoval;
387   Value *Cursor = II.getOperand(0), *EarliestReplacement = nullptr;
388 
389   const auto *IVTy = cast<VectorType>(II.getType());
390 
391   // Walk the chain of conversions.
392   while (Cursor) {
393     // If the type of the cursor has fewer lanes than the final result, zeroing
394     // must take place, which breaks the equivalence chain.
395     const auto *CursorVTy = cast<VectorType>(Cursor->getType());
396     if (CursorVTy->getElementCount().getKnownMinValue() <
397         IVTy->getElementCount().getKnownMinValue())
398       break;
399 
400     // If the cursor has the same type as I, it is a viable replacement.
401     if (Cursor->getType() == IVTy)
402       EarliestReplacement = Cursor;
403 
404     auto *IntrinsicCursor = dyn_cast<IntrinsicInst>(Cursor);
405 
406     // If this is not an SVE conversion intrinsic, this is the end of the chain.
407     if (!IntrinsicCursor || !(IntrinsicCursor->getIntrinsicID() ==
408                                   Intrinsic::aarch64_sve_convert_to_svbool ||
409                               IntrinsicCursor->getIntrinsicID() ==
410                                   Intrinsic::aarch64_sve_convert_from_svbool))
411       break;
412 
413     CandidatesForRemoval.insert(CandidatesForRemoval.begin(), IntrinsicCursor);
414     Cursor = IntrinsicCursor->getOperand(0);
415   }
416 
417   // If no viable replacement in the conversion chain was found, there is
418   // nothing to do.
419   if (!EarliestReplacement)
420     return None;
421 
422   return IC.replaceInstUsesWith(II, EarliestReplacement);
423 }
424 
425 static Optional<Instruction *> instCombineSVEDup(InstCombiner &IC,
426                                                  IntrinsicInst &II) {
427   IntrinsicInst *Pg = dyn_cast<IntrinsicInst>(II.getArgOperand(1));
428   if (!Pg)
429     return None;
430 
431   if (Pg->getIntrinsicID() != Intrinsic::aarch64_sve_ptrue)
432     return None;
433 
434   const auto PTruePattern =
435       cast<ConstantInt>(Pg->getOperand(0))->getZExtValue();
436   if (PTruePattern != AArch64SVEPredPattern::vl1)
437     return None;
438 
439   // The intrinsic is inserting into lane zero so use an insert instead.
440   auto *IdxTy = Type::getInt64Ty(II.getContext());
441   auto *Insert = InsertElementInst::Create(
442       II.getArgOperand(0), II.getArgOperand(2), ConstantInt::get(IdxTy, 0));
443   Insert->insertBefore(&II);
444   Insert->takeName(&II);
445 
446   return IC.replaceInstUsesWith(II, Insert);
447 }
448 
449 static Optional<Instruction *> instCombineSVEDupX(InstCombiner &IC,
450                                                   IntrinsicInst &II) {
451   // Replace DupX with a regular IR splat.
452   IRBuilder<> Builder(II.getContext());
453   Builder.SetInsertPoint(&II);
454   auto *RetTy = cast<ScalableVectorType>(II.getType());
455   Value *Splat =
456       Builder.CreateVectorSplat(RetTy->getElementCount(), II.getArgOperand(0));
457   Splat->takeName(&II);
458   return IC.replaceInstUsesWith(II, Splat);
459 }
460 
461 static Optional<Instruction *> instCombineSVECmpNE(InstCombiner &IC,
462                                                    IntrinsicInst &II) {
463   LLVMContext &Ctx = II.getContext();
464   IRBuilder<> Builder(Ctx);
465   Builder.SetInsertPoint(&II);
466 
467   // Check that the predicate is all active
468   auto *Pg = dyn_cast<IntrinsicInst>(II.getArgOperand(0));
469   if (!Pg || Pg->getIntrinsicID() != Intrinsic::aarch64_sve_ptrue)
470     return None;
471 
472   const auto PTruePattern =
473       cast<ConstantInt>(Pg->getOperand(0))->getZExtValue();
474   if (PTruePattern != AArch64SVEPredPattern::all)
475     return None;
476 
477   // Check that we have a compare of zero..
478   auto *SplatValue =
479       dyn_cast_or_null<ConstantInt>(getSplatValue(II.getArgOperand(2)));
480   if (!SplatValue || !SplatValue->isZero())
481     return None;
482 
483   // ..against a dupq
484   auto *DupQLane = dyn_cast<IntrinsicInst>(II.getArgOperand(1));
485   if (!DupQLane ||
486       DupQLane->getIntrinsicID() != Intrinsic::aarch64_sve_dupq_lane)
487     return None;
488 
489   // Where the dupq is a lane 0 replicate of a vector insert
490   if (!cast<ConstantInt>(DupQLane->getArgOperand(1))->isZero())
491     return None;
492 
493   auto *VecIns = dyn_cast<IntrinsicInst>(DupQLane->getArgOperand(0));
494   if (!VecIns ||
495       VecIns->getIntrinsicID() != Intrinsic::experimental_vector_insert)
496     return None;
497 
498   // Where the vector insert is a fixed constant vector insert into undef at
499   // index zero
500   if (!isa<UndefValue>(VecIns->getArgOperand(0)))
501     return None;
502 
503   if (!cast<ConstantInt>(VecIns->getArgOperand(2))->isZero())
504     return None;
505 
506   auto *ConstVec = dyn_cast<Constant>(VecIns->getArgOperand(1));
507   if (!ConstVec)
508     return None;
509 
510   auto *VecTy = dyn_cast<FixedVectorType>(ConstVec->getType());
511   auto *OutTy = dyn_cast<ScalableVectorType>(II.getType());
512   if (!VecTy || !OutTy || VecTy->getNumElements() != OutTy->getMinNumElements())
513     return None;
514 
515   unsigned NumElts = VecTy->getNumElements();
516   unsigned PredicateBits = 0;
517 
518   // Expand intrinsic operands to a 16-bit byte level predicate
519   for (unsigned I = 0; I < NumElts; ++I) {
520     auto *Arg = dyn_cast<ConstantInt>(ConstVec->getAggregateElement(I));
521     if (!Arg)
522       return None;
523     if (!Arg->isZero())
524       PredicateBits |= 1 << (I * (16 / NumElts));
525   }
526 
527   // If all bits are zero bail early with an empty predicate
528   if (PredicateBits == 0) {
529     auto *PFalse = Constant::getNullValue(II.getType());
530     PFalse->takeName(&II);
531     return IC.replaceInstUsesWith(II, PFalse);
532   }
533 
534   // Calculate largest predicate type used (where byte predicate is largest)
535   unsigned Mask = 8;
536   for (unsigned I = 0; I < 16; ++I)
537     if ((PredicateBits & (1 << I)) != 0)
538       Mask |= (I % 8);
539 
540   unsigned PredSize = Mask & -Mask;
541   auto *PredType = ScalableVectorType::get(
542       Type::getInt1Ty(Ctx), AArch64::SVEBitsPerBlock / (PredSize * 8));
543 
544   // Ensure all relevant bits are set
545   for (unsigned I = 0; I < 16; I += PredSize)
546     if ((PredicateBits & (1 << I)) == 0)
547       return None;
548 
549   auto *PTruePat =
550       ConstantInt::get(Type::getInt32Ty(Ctx), AArch64SVEPredPattern::all);
551   auto *PTrue = Builder.CreateIntrinsic(Intrinsic::aarch64_sve_ptrue,
552                                         {PredType}, {PTruePat});
553   auto *ConvertToSVBool = Builder.CreateIntrinsic(
554       Intrinsic::aarch64_sve_convert_to_svbool, {PredType}, {PTrue});
555   auto *ConvertFromSVBool =
556       Builder.CreateIntrinsic(Intrinsic::aarch64_sve_convert_from_svbool,
557                               {II.getType()}, {ConvertToSVBool});
558 
559   ConvertFromSVBool->takeName(&II);
560   return IC.replaceInstUsesWith(II, ConvertFromSVBool);
561 }
562 
563 static Optional<Instruction *> instCombineSVELast(InstCombiner &IC,
564                                                   IntrinsicInst &II) {
565   IRBuilder<> Builder(II.getContext());
566   Builder.SetInsertPoint(&II);
567   Value *Pg = II.getArgOperand(0);
568   Value *Vec = II.getArgOperand(1);
569   auto IntrinsicID = II.getIntrinsicID();
570   bool IsAfter = IntrinsicID == Intrinsic::aarch64_sve_lasta;
571 
572   // lastX(splat(X)) --> X
573   if (auto *SplatVal = getSplatValue(Vec))
574     return IC.replaceInstUsesWith(II, SplatVal);
575 
576   // If x and/or y is a splat value then:
577   // lastX (binop (x, y)) --> binop(lastX(x), lastX(y))
578   Value *LHS, *RHS;
579   if (match(Vec, m_OneUse(m_BinOp(m_Value(LHS), m_Value(RHS))))) {
580     if (isSplatValue(LHS) || isSplatValue(RHS)) {
581       auto *OldBinOp = cast<BinaryOperator>(Vec);
582       auto OpC = OldBinOp->getOpcode();
583       auto *NewLHS =
584           Builder.CreateIntrinsic(IntrinsicID, {Vec->getType()}, {Pg, LHS});
585       auto *NewRHS =
586           Builder.CreateIntrinsic(IntrinsicID, {Vec->getType()}, {Pg, RHS});
587       auto *NewBinOp = BinaryOperator::CreateWithCopiedFlags(
588           OpC, NewLHS, NewRHS, OldBinOp, OldBinOp->getName(), &II);
589       return IC.replaceInstUsesWith(II, NewBinOp);
590     }
591   }
592 
593   auto *C = dyn_cast<Constant>(Pg);
594   if (IsAfter && C && C->isNullValue()) {
595     // The intrinsic is extracting lane 0 so use an extract instead.
596     auto *IdxTy = Type::getInt64Ty(II.getContext());
597     auto *Extract = ExtractElementInst::Create(Vec, ConstantInt::get(IdxTy, 0));
598     Extract->insertBefore(&II);
599     Extract->takeName(&II);
600     return IC.replaceInstUsesWith(II, Extract);
601   }
602 
603   auto *IntrPG = dyn_cast<IntrinsicInst>(Pg);
604   if (!IntrPG)
605     return None;
606 
607   if (IntrPG->getIntrinsicID() != Intrinsic::aarch64_sve_ptrue)
608     return None;
609 
610   const auto PTruePattern =
611       cast<ConstantInt>(IntrPG->getOperand(0))->getZExtValue();
612 
613   // Can the intrinsic's predicate be converted to a known constant index?
614   unsigned MinNumElts = getNumElementsFromSVEPredPattern(PTruePattern);
615   if (!MinNumElts)
616     return None;
617 
618   unsigned Idx = MinNumElts - 1;
619   // Increment the index if extracting the element after the last active
620   // predicate element.
621   if (IsAfter)
622     ++Idx;
623 
624   // Ignore extracts whose index is larger than the known minimum vector
625   // length. NOTE: This is an artificial constraint where we prefer to
626   // maintain what the user asked for until an alternative is proven faster.
627   auto *PgVTy = cast<ScalableVectorType>(Pg->getType());
628   if (Idx >= PgVTy->getMinNumElements())
629     return None;
630 
631   // The intrinsic is extracting a fixed lane so use an extract instead.
632   auto *IdxTy = Type::getInt64Ty(II.getContext());
633   auto *Extract = ExtractElementInst::Create(Vec, ConstantInt::get(IdxTy, Idx));
634   Extract->insertBefore(&II);
635   Extract->takeName(&II);
636   return IC.replaceInstUsesWith(II, Extract);
637 }
638 
639 static Optional<Instruction *> instCombineRDFFR(InstCombiner &IC,
640                                                 IntrinsicInst &II) {
641   LLVMContext &Ctx = II.getContext();
642   IRBuilder<> Builder(Ctx);
643   Builder.SetInsertPoint(&II);
644   // Replace rdffr with predicated rdffr.z intrinsic, so that optimizePTestInstr
645   // can work with RDFFR_PP for ptest elimination.
646   auto *AllPat =
647       ConstantInt::get(Type::getInt32Ty(Ctx), AArch64SVEPredPattern::all);
648   auto *PTrue = Builder.CreateIntrinsic(Intrinsic::aarch64_sve_ptrue,
649                                         {II.getType()}, {AllPat});
650   auto *RDFFR =
651       Builder.CreateIntrinsic(Intrinsic::aarch64_sve_rdffr_z, {}, {PTrue});
652   RDFFR->takeName(&II);
653   return IC.replaceInstUsesWith(II, RDFFR);
654 }
655 
656 static Optional<Instruction *>
657 instCombineSVECntElts(InstCombiner &IC, IntrinsicInst &II, unsigned NumElts) {
658   const auto Pattern = cast<ConstantInt>(II.getArgOperand(0))->getZExtValue();
659 
660   if (Pattern == AArch64SVEPredPattern::all) {
661     LLVMContext &Ctx = II.getContext();
662     IRBuilder<> Builder(Ctx);
663     Builder.SetInsertPoint(&II);
664 
665     Constant *StepVal = ConstantInt::get(II.getType(), NumElts);
666     auto *VScale = Builder.CreateVScale(StepVal);
667     VScale->takeName(&II);
668     return IC.replaceInstUsesWith(II, VScale);
669   }
670 
671   unsigned MinNumElts = getNumElementsFromSVEPredPattern(Pattern);
672 
673   return MinNumElts && NumElts >= MinNumElts
674              ? Optional<Instruction *>(IC.replaceInstUsesWith(
675                    II, ConstantInt::get(II.getType(), MinNumElts)))
676              : None;
677 }
678 
679 static Optional<Instruction *> instCombineSVEPTest(InstCombiner &IC,
680                                                    IntrinsicInst &II) {
681   IntrinsicInst *Op1 = dyn_cast<IntrinsicInst>(II.getArgOperand(0));
682   IntrinsicInst *Op2 = dyn_cast<IntrinsicInst>(II.getArgOperand(1));
683 
684   if (Op1 && Op2 &&
685       Op1->getIntrinsicID() == Intrinsic::aarch64_sve_convert_to_svbool &&
686       Op2->getIntrinsicID() == Intrinsic::aarch64_sve_convert_to_svbool &&
687       Op1->getArgOperand(0)->getType() == Op2->getArgOperand(0)->getType()) {
688 
689     IRBuilder<> Builder(II.getContext());
690     Builder.SetInsertPoint(&II);
691 
692     Value *Ops[] = {Op1->getArgOperand(0), Op2->getArgOperand(0)};
693     Type *Tys[] = {Op1->getArgOperand(0)->getType()};
694 
695     auto *PTest = Builder.CreateIntrinsic(II.getIntrinsicID(), Tys, Ops);
696 
697     PTest->takeName(&II);
698     return IC.replaceInstUsesWith(II, PTest);
699   }
700 
701   return None;
702 }
703 
704 static Optional<Instruction *> instCombineSVEVectorFMLA(InstCombiner &IC,
705                                                         IntrinsicInst &II) {
706   // fold (fadd p a (fmul p b c)) -> (fma p a b c)
707   Value *P = II.getOperand(0);
708   Value *A = II.getOperand(1);
709   auto FMul = II.getOperand(2);
710   Value *B, *C;
711   if (!match(FMul, m_Intrinsic<Intrinsic::aarch64_sve_fmul>(
712                        m_Specific(P), m_Value(B), m_Value(C))))
713     return None;
714 
715   if (!FMul->hasOneUse())
716     return None;
717 
718   llvm::FastMathFlags FAddFlags = II.getFastMathFlags();
719   // Stop the combine when the flags on the inputs differ in case dropping flags
720   // would lead to us missing out on more beneficial optimizations.
721   if (FAddFlags != cast<CallInst>(FMul)->getFastMathFlags())
722     return None;
723   if (!FAddFlags.allowContract())
724     return None;
725 
726   IRBuilder<> Builder(II.getContext());
727   Builder.SetInsertPoint(&II);
728   auto FMLA = Builder.CreateIntrinsic(Intrinsic::aarch64_sve_fmla,
729                                       {II.getType()}, {P, A, B, C}, &II);
730   FMLA->setFastMathFlags(FAddFlags);
731   return IC.replaceInstUsesWith(II, FMLA);
732 }
733 
734 static bool isAllActivePredicate(Value *Pred) {
735   // Look through convert.from.svbool(convert.to.svbool(...) chain.
736   Value *UncastedPred;
737   if (match(Pred, m_Intrinsic<Intrinsic::aarch64_sve_convert_from_svbool>(
738                       m_Intrinsic<Intrinsic::aarch64_sve_convert_to_svbool>(
739                           m_Value(UncastedPred)))))
740     // If the predicate has the same or less lanes than the uncasted
741     // predicate then we know the casting has no effect.
742     if (cast<ScalableVectorType>(Pred->getType())->getMinNumElements() <=
743         cast<ScalableVectorType>(UncastedPred->getType())->getMinNumElements())
744       Pred = UncastedPred;
745 
746   return match(Pred, m_Intrinsic<Intrinsic::aarch64_sve_ptrue>(
747                          m_ConstantInt<AArch64SVEPredPattern::all>()));
748 }
749 
750 static Optional<Instruction *>
751 instCombineSVELD1(InstCombiner &IC, IntrinsicInst &II, const DataLayout &DL) {
752   IRBuilder<> Builder(II.getContext());
753   Builder.SetInsertPoint(&II);
754 
755   Value *Pred = II.getOperand(0);
756   Value *PtrOp = II.getOperand(1);
757   Type *VecTy = II.getType();
758   Value *VecPtr = Builder.CreateBitCast(PtrOp, VecTy->getPointerTo());
759 
760   if (isAllActivePredicate(Pred)) {
761     LoadInst *Load = Builder.CreateLoad(VecTy, VecPtr);
762     return IC.replaceInstUsesWith(II, Load);
763   }
764 
765   CallInst *MaskedLoad =
766       Builder.CreateMaskedLoad(VecTy, VecPtr, PtrOp->getPointerAlignment(DL),
767                                Pred, ConstantAggregateZero::get(VecTy));
768   return IC.replaceInstUsesWith(II, MaskedLoad);
769 }
770 
771 static Optional<Instruction *>
772 instCombineSVEST1(InstCombiner &IC, IntrinsicInst &II, const DataLayout &DL) {
773   IRBuilder<> Builder(II.getContext());
774   Builder.SetInsertPoint(&II);
775 
776   Value *VecOp = II.getOperand(0);
777   Value *Pred = II.getOperand(1);
778   Value *PtrOp = II.getOperand(2);
779   Value *VecPtr =
780       Builder.CreateBitCast(PtrOp, VecOp->getType()->getPointerTo());
781 
782   if (isAllActivePredicate(Pred)) {
783     Builder.CreateStore(VecOp, VecPtr);
784     return IC.eraseInstFromFunction(II);
785   }
786 
787   Builder.CreateMaskedStore(VecOp, VecPtr, PtrOp->getPointerAlignment(DL),
788                             Pred);
789   return IC.eraseInstFromFunction(II);
790 }
791 
792 static Instruction::BinaryOps intrinsicIDToBinOpCode(unsigned Intrinsic) {
793   switch (Intrinsic) {
794   case Intrinsic::aarch64_sve_fmul:
795     return Instruction::BinaryOps::FMul;
796   case Intrinsic::aarch64_sve_fadd:
797     return Instruction::BinaryOps::FAdd;
798   case Intrinsic::aarch64_sve_fsub:
799     return Instruction::BinaryOps::FSub;
800   default:
801     return Instruction::BinaryOpsEnd;
802   }
803 }
804 
805 static Optional<Instruction *> instCombineSVEVectorBinOp(InstCombiner &IC,
806                                                          IntrinsicInst &II) {
807   auto *OpPredicate = II.getOperand(0);
808   auto BinOpCode = intrinsicIDToBinOpCode(II.getIntrinsicID());
809   if (BinOpCode == Instruction::BinaryOpsEnd ||
810       !match(OpPredicate, m_Intrinsic<Intrinsic::aarch64_sve_ptrue>(
811                               m_ConstantInt<AArch64SVEPredPattern::all>())))
812     return None;
813   IRBuilder<> Builder(II.getContext());
814   Builder.SetInsertPoint(&II);
815   Builder.setFastMathFlags(II.getFastMathFlags());
816   auto BinOp =
817       Builder.CreateBinOp(BinOpCode, II.getOperand(1), II.getOperand(2));
818   return IC.replaceInstUsesWith(II, BinOp);
819 }
820 
821 static Optional<Instruction *> instCombineSVEVectorFAdd(InstCombiner &IC,
822                                                         IntrinsicInst &II) {
823   if (auto FMLA = instCombineSVEVectorFMLA(IC, II))
824     return FMLA;
825   return instCombineSVEVectorBinOp(IC, II);
826 }
827 
828 static Optional<Instruction *> instCombineSVEVectorMul(InstCombiner &IC,
829                                                        IntrinsicInst &II) {
830   auto *OpPredicate = II.getOperand(0);
831   auto *OpMultiplicand = II.getOperand(1);
832   auto *OpMultiplier = II.getOperand(2);
833 
834   IRBuilder<> Builder(II.getContext());
835   Builder.SetInsertPoint(&II);
836 
837   // Return true if a given instruction is a unit splat value, false otherwise.
838   auto IsUnitSplat = [](auto *I) {
839     auto *SplatValue = getSplatValue(I);
840     if (!SplatValue)
841       return false;
842     return match(SplatValue, m_FPOne()) || match(SplatValue, m_One());
843   };
844 
845   // Return true if a given instruction is an aarch64_sve_dup intrinsic call
846   // with a unit splat value, false otherwise.
847   auto IsUnitDup = [](auto *I) {
848     auto *IntrI = dyn_cast<IntrinsicInst>(I);
849     if (!IntrI || IntrI->getIntrinsicID() != Intrinsic::aarch64_sve_dup)
850       return false;
851 
852     auto *SplatValue = IntrI->getOperand(2);
853     return match(SplatValue, m_FPOne()) || match(SplatValue, m_One());
854   };
855 
856   if (IsUnitSplat(OpMultiplier)) {
857     // [f]mul pg %n, (dupx 1) => %n
858     OpMultiplicand->takeName(&II);
859     return IC.replaceInstUsesWith(II, OpMultiplicand);
860   } else if (IsUnitDup(OpMultiplier)) {
861     // [f]mul pg %n, (dup pg 1) => %n
862     auto *DupInst = cast<IntrinsicInst>(OpMultiplier);
863     auto *DupPg = DupInst->getOperand(1);
864     // TODO: this is naive. The optimization is still valid if DupPg
865     // 'encompasses' OpPredicate, not only if they're the same predicate.
866     if (OpPredicate == DupPg) {
867       OpMultiplicand->takeName(&II);
868       return IC.replaceInstUsesWith(II, OpMultiplicand);
869     }
870   }
871 
872   return instCombineSVEVectorBinOp(IC, II);
873 }
874 
875 static Optional<Instruction *> instCombineSVEUnpack(InstCombiner &IC,
876                                                     IntrinsicInst &II) {
877   IRBuilder<> Builder(II.getContext());
878   Builder.SetInsertPoint(&II);
879   Value *UnpackArg = II.getArgOperand(0);
880   auto *RetTy = cast<ScalableVectorType>(II.getType());
881   bool IsSigned = II.getIntrinsicID() == Intrinsic::aarch64_sve_sunpkhi ||
882                   II.getIntrinsicID() == Intrinsic::aarch64_sve_sunpklo;
883 
884   // Hi = uunpkhi(splat(X)) --> Hi = splat(extend(X))
885   // Lo = uunpklo(splat(X)) --> Lo = splat(extend(X))
886   if (auto *ScalarArg = getSplatValue(UnpackArg)) {
887     ScalarArg =
888         Builder.CreateIntCast(ScalarArg, RetTy->getScalarType(), IsSigned);
889     Value *NewVal =
890         Builder.CreateVectorSplat(RetTy->getElementCount(), ScalarArg);
891     NewVal->takeName(&II);
892     return IC.replaceInstUsesWith(II, NewVal);
893   }
894 
895   return None;
896 }
897 static Optional<Instruction *> instCombineSVETBL(InstCombiner &IC,
898                                                  IntrinsicInst &II) {
899   auto *OpVal = II.getOperand(0);
900   auto *OpIndices = II.getOperand(1);
901   VectorType *VTy = cast<VectorType>(II.getType());
902 
903   // Check whether OpIndices is a constant splat value < minimal element count
904   // of result.
905   auto *SplatValue = dyn_cast_or_null<ConstantInt>(getSplatValue(OpIndices));
906   if (!SplatValue ||
907       SplatValue->getValue().uge(VTy->getElementCount().getKnownMinValue()))
908     return None;
909 
910   // Convert sve_tbl(OpVal sve_dup_x(SplatValue)) to
911   // splat_vector(extractelement(OpVal, SplatValue)) for further optimization.
912   IRBuilder<> Builder(II.getContext());
913   Builder.SetInsertPoint(&II);
914   auto *Extract = Builder.CreateExtractElement(OpVal, SplatValue);
915   auto *VectorSplat =
916       Builder.CreateVectorSplat(VTy->getElementCount(), Extract);
917 
918   VectorSplat->takeName(&II);
919   return IC.replaceInstUsesWith(II, VectorSplat);
920 }
921 
922 static Optional<Instruction *> instCombineSVETupleGet(InstCombiner &IC,
923                                                       IntrinsicInst &II) {
924   // Try to remove sequences of tuple get/set.
925   Value *SetTuple, *SetIndex, *SetValue;
926   auto *GetTuple = II.getArgOperand(0);
927   auto *GetIndex = II.getArgOperand(1);
928   // Check that we have tuple_get(GetTuple, GetIndex) where GetTuple is a
929   // call to tuple_set i.e. tuple_set(SetTuple, SetIndex, SetValue).
930   // Make sure that the types of the current intrinsic and SetValue match
931   // in order to safely remove the sequence.
932   if (!match(GetTuple,
933              m_Intrinsic<Intrinsic::aarch64_sve_tuple_set>(
934                  m_Value(SetTuple), m_Value(SetIndex), m_Value(SetValue))) ||
935       SetValue->getType() != II.getType())
936     return None;
937   // Case where we get the same index right after setting it.
938   // tuple_get(tuple_set(SetTuple, SetIndex, SetValue), GetIndex) --> SetValue
939   if (GetIndex == SetIndex)
940     return IC.replaceInstUsesWith(II, SetValue);
941   // If we are getting a different index than what was set in the tuple_set
942   // intrinsic. We can just set the input tuple to the one up in the chain.
943   // tuple_get(tuple_set(SetTuple, SetIndex, SetValue), GetIndex)
944   // --> tuple_get(SetTuple, GetIndex)
945   return IC.replaceOperand(II, 0, SetTuple);
946 }
947 
948 static Optional<Instruction *> instCombineSVEZip(InstCombiner &IC,
949                                                  IntrinsicInst &II) {
950   // zip1(uzp1(A, B), uzp2(A, B)) --> A
951   // zip2(uzp1(A, B), uzp2(A, B)) --> B
952   Value *A, *B;
953   if (match(II.getArgOperand(0),
954             m_Intrinsic<Intrinsic::aarch64_sve_uzp1>(m_Value(A), m_Value(B))) &&
955       match(II.getArgOperand(1), m_Intrinsic<Intrinsic::aarch64_sve_uzp2>(
956                                      m_Specific(A), m_Specific(B))))
957     return IC.replaceInstUsesWith(
958         II, (II.getIntrinsicID() == Intrinsic::aarch64_sve_zip1 ? A : B));
959 
960   return None;
961 }
962 
963 static Optional<Instruction *> instCombineLD1GatherIndex(InstCombiner &IC,
964                                                          IntrinsicInst &II) {
965   Value *Mask = II.getOperand(0);
966   Value *BasePtr = II.getOperand(1);
967   Value *Index = II.getOperand(2);
968   Type *Ty = II.getType();
969   Type *BasePtrTy = BasePtr->getType();
970   Value *PassThru = ConstantAggregateZero::get(Ty);
971 
972   // Contiguous gather => masked load.
973   // (sve.ld1.gather.index Mask BasePtr (sve.index IndexBase 1))
974   // => (masked.load (gep BasePtr IndexBase) Align Mask zeroinitializer)
975   Value *IndexBase;
976   if (match(Index, m_Intrinsic<Intrinsic::aarch64_sve_index>(
977                        m_Value(IndexBase), m_SpecificInt(1)))) {
978     IRBuilder<> Builder(II.getContext());
979     Builder.SetInsertPoint(&II);
980 
981     Align Alignment =
982         BasePtr->getPointerAlignment(II.getModule()->getDataLayout());
983 
984     Type *VecPtrTy = PointerType::getUnqual(Ty);
985     Value *Ptr = Builder.CreateGEP(BasePtrTy->getPointerElementType(), BasePtr,
986                                    IndexBase);
987     Ptr = Builder.CreateBitCast(Ptr, VecPtrTy);
988     CallInst *MaskedLoad =
989         Builder.CreateMaskedLoad(Ty, Ptr, Alignment, Mask, PassThru);
990     MaskedLoad->takeName(&II);
991     return IC.replaceInstUsesWith(II, MaskedLoad);
992   }
993 
994   return None;
995 }
996 
997 static Optional<Instruction *> instCombineST1ScatterIndex(InstCombiner &IC,
998                                                           IntrinsicInst &II) {
999   Value *Val = II.getOperand(0);
1000   Value *Mask = II.getOperand(1);
1001   Value *BasePtr = II.getOperand(2);
1002   Value *Index = II.getOperand(3);
1003   Type *Ty = Val->getType();
1004   Type *BasePtrTy = BasePtr->getType();
1005 
1006   // Contiguous scatter => masked store.
1007   // (sve.ld1.scatter.index Value Mask BasePtr (sve.index IndexBase 1))
1008   // => (masked.store Value (gep BasePtr IndexBase) Align Mask)
1009   Value *IndexBase;
1010   if (match(Index, m_Intrinsic<Intrinsic::aarch64_sve_index>(
1011                        m_Value(IndexBase), m_SpecificInt(1)))) {
1012     IRBuilder<> Builder(II.getContext());
1013     Builder.SetInsertPoint(&II);
1014 
1015     Align Alignment =
1016         BasePtr->getPointerAlignment(II.getModule()->getDataLayout());
1017 
1018     Value *Ptr = Builder.CreateGEP(BasePtrTy->getPointerElementType(), BasePtr,
1019                                    IndexBase);
1020     Type *VecPtrTy = PointerType::getUnqual(Ty);
1021     Ptr = Builder.CreateBitCast(Ptr, VecPtrTy);
1022 
1023     (void)Builder.CreateMaskedStore(Val, Ptr, Alignment, Mask);
1024 
1025     return IC.eraseInstFromFunction(II);
1026   }
1027 
1028   return None;
1029 }
1030 
1031 static Optional<Instruction *> instCombineSVESDIV(InstCombiner &IC,
1032                                                   IntrinsicInst &II) {
1033   IRBuilder<> Builder(II.getContext());
1034   Builder.SetInsertPoint(&II);
1035   Type *Int32Ty = Builder.getInt32Ty();
1036   Value *Pred = II.getOperand(0);
1037   Value *Vec = II.getOperand(1);
1038   Value *DivVec = II.getOperand(2);
1039 
1040   Value *SplatValue = getSplatValue(DivVec);
1041   ConstantInt *SplatConstantInt = dyn_cast_or_null<ConstantInt>(SplatValue);
1042   if (!SplatConstantInt)
1043     return None;
1044   APInt Divisor = SplatConstantInt->getValue();
1045 
1046   if (Divisor.isPowerOf2()) {
1047     Constant *DivisorLog2 = ConstantInt::get(Int32Ty, Divisor.logBase2());
1048     auto ASRD = Builder.CreateIntrinsic(
1049         Intrinsic::aarch64_sve_asrd, {II.getType()}, {Pred, Vec, DivisorLog2});
1050     return IC.replaceInstUsesWith(II, ASRD);
1051   }
1052   if (Divisor.isNegatedPowerOf2()) {
1053     Divisor.negate();
1054     Constant *DivisorLog2 = ConstantInt::get(Int32Ty, Divisor.logBase2());
1055     auto ASRD = Builder.CreateIntrinsic(
1056         Intrinsic::aarch64_sve_asrd, {II.getType()}, {Pred, Vec, DivisorLog2});
1057     auto NEG = Builder.CreateIntrinsic(Intrinsic::aarch64_sve_neg,
1058                                        {ASRD->getType()}, {ASRD, Pred, ASRD});
1059     return IC.replaceInstUsesWith(II, NEG);
1060   }
1061 
1062   return None;
1063 }
1064 
1065 Optional<Instruction *>
1066 AArch64TTIImpl::instCombineIntrinsic(InstCombiner &IC,
1067                                      IntrinsicInst &II) const {
1068   Intrinsic::ID IID = II.getIntrinsicID();
1069   switch (IID) {
1070   default:
1071     break;
1072   case Intrinsic::aarch64_sve_convert_from_svbool:
1073     return instCombineConvertFromSVBool(IC, II);
1074   case Intrinsic::aarch64_sve_dup:
1075     return instCombineSVEDup(IC, II);
1076   case Intrinsic::aarch64_sve_dup_x:
1077     return instCombineSVEDupX(IC, II);
1078   case Intrinsic::aarch64_sve_cmpne:
1079   case Intrinsic::aarch64_sve_cmpne_wide:
1080     return instCombineSVECmpNE(IC, II);
1081   case Intrinsic::aarch64_sve_rdffr:
1082     return instCombineRDFFR(IC, II);
1083   case Intrinsic::aarch64_sve_lasta:
1084   case Intrinsic::aarch64_sve_lastb:
1085     return instCombineSVELast(IC, II);
1086   case Intrinsic::aarch64_sve_cntd:
1087     return instCombineSVECntElts(IC, II, 2);
1088   case Intrinsic::aarch64_sve_cntw:
1089     return instCombineSVECntElts(IC, II, 4);
1090   case Intrinsic::aarch64_sve_cnth:
1091     return instCombineSVECntElts(IC, II, 8);
1092   case Intrinsic::aarch64_sve_cntb:
1093     return instCombineSVECntElts(IC, II, 16);
1094   case Intrinsic::aarch64_sve_ptest_any:
1095   case Intrinsic::aarch64_sve_ptest_first:
1096   case Intrinsic::aarch64_sve_ptest_last:
1097     return instCombineSVEPTest(IC, II);
1098   case Intrinsic::aarch64_sve_mul:
1099   case Intrinsic::aarch64_sve_fmul:
1100     return instCombineSVEVectorMul(IC, II);
1101   case Intrinsic::aarch64_sve_fadd:
1102     return instCombineSVEVectorFAdd(IC, II);
1103   case Intrinsic::aarch64_sve_fsub:
1104     return instCombineSVEVectorBinOp(IC, II);
1105   case Intrinsic::aarch64_sve_tbl:
1106     return instCombineSVETBL(IC, II);
1107   case Intrinsic::aarch64_sve_uunpkhi:
1108   case Intrinsic::aarch64_sve_uunpklo:
1109   case Intrinsic::aarch64_sve_sunpkhi:
1110   case Intrinsic::aarch64_sve_sunpklo:
1111     return instCombineSVEUnpack(IC, II);
1112   case Intrinsic::aarch64_sve_tuple_get:
1113     return instCombineSVETupleGet(IC, II);
1114   case Intrinsic::aarch64_sve_zip1:
1115   case Intrinsic::aarch64_sve_zip2:
1116     return instCombineSVEZip(IC, II);
1117   case Intrinsic::aarch64_sve_ld1_gather_index:
1118     return instCombineLD1GatherIndex(IC, II);
1119   case Intrinsic::aarch64_sve_st1_scatter_index:
1120     return instCombineST1ScatterIndex(IC, II);
1121   case Intrinsic::aarch64_sve_ld1:
1122     return instCombineSVELD1(IC, II, DL);
1123   case Intrinsic::aarch64_sve_st1:
1124     return instCombineSVEST1(IC, II, DL);
1125   case Intrinsic::aarch64_sve_sdiv:
1126     return instCombineSVESDIV(IC, II);
1127   }
1128 
1129   return None;
1130 }
1131 
1132 bool AArch64TTIImpl::isWideningInstruction(Type *DstTy, unsigned Opcode,
1133                                            ArrayRef<const Value *> Args) {
1134 
1135   // A helper that returns a vector type from the given type. The number of
1136   // elements in type Ty determine the vector width.
1137   auto toVectorTy = [&](Type *ArgTy) {
1138     return VectorType::get(ArgTy->getScalarType(),
1139                            cast<VectorType>(DstTy)->getElementCount());
1140   };
1141 
1142   // Exit early if DstTy is not a vector type whose elements are at least
1143   // 16-bits wide.
1144   if (!DstTy->isVectorTy() || DstTy->getScalarSizeInBits() < 16)
1145     return false;
1146 
1147   // Determine if the operation has a widening variant. We consider both the
1148   // "long" (e.g., usubl) and "wide" (e.g., usubw) versions of the
1149   // instructions.
1150   //
1151   // TODO: Add additional widening operations (e.g., mul, shl, etc.) once we
1152   //       verify that their extending operands are eliminated during code
1153   //       generation.
1154   switch (Opcode) {
1155   case Instruction::Add: // UADDL(2), SADDL(2), UADDW(2), SADDW(2).
1156   case Instruction::Sub: // USUBL(2), SSUBL(2), USUBW(2), SSUBW(2).
1157     break;
1158   default:
1159     return false;
1160   }
1161 
1162   // To be a widening instruction (either the "wide" or "long" versions), the
1163   // second operand must be a sign- or zero extend having a single user. We
1164   // only consider extends having a single user because they may otherwise not
1165   // be eliminated.
1166   if (Args.size() != 2 ||
1167       (!isa<SExtInst>(Args[1]) && !isa<ZExtInst>(Args[1])) ||
1168       !Args[1]->hasOneUse())
1169     return false;
1170   auto *Extend = cast<CastInst>(Args[1]);
1171 
1172   // Legalize the destination type and ensure it can be used in a widening
1173   // operation.
1174   auto DstTyL = TLI->getTypeLegalizationCost(DL, DstTy);
1175   unsigned DstElTySize = DstTyL.second.getScalarSizeInBits();
1176   if (!DstTyL.second.isVector() || DstElTySize != DstTy->getScalarSizeInBits())
1177     return false;
1178 
1179   // Legalize the source type and ensure it can be used in a widening
1180   // operation.
1181   auto *SrcTy = toVectorTy(Extend->getSrcTy());
1182   auto SrcTyL = TLI->getTypeLegalizationCost(DL, SrcTy);
1183   unsigned SrcElTySize = SrcTyL.second.getScalarSizeInBits();
1184   if (!SrcTyL.second.isVector() || SrcElTySize != SrcTy->getScalarSizeInBits())
1185     return false;
1186 
1187   // Get the total number of vector elements in the legalized types.
1188   InstructionCost NumDstEls =
1189       DstTyL.first * DstTyL.second.getVectorMinNumElements();
1190   InstructionCost NumSrcEls =
1191       SrcTyL.first * SrcTyL.second.getVectorMinNumElements();
1192 
1193   // Return true if the legalized types have the same number of vector elements
1194   // and the destination element type size is twice that of the source type.
1195   return NumDstEls == NumSrcEls && 2 * SrcElTySize == DstElTySize;
1196 }
1197 
1198 InstructionCost AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
1199                                                  Type *Src,
1200                                                  TTI::CastContextHint CCH,
1201                                                  TTI::TargetCostKind CostKind,
1202                                                  const Instruction *I) {
1203   int ISD = TLI->InstructionOpcodeToISD(Opcode);
1204   assert(ISD && "Invalid opcode");
1205 
1206   // If the cast is observable, and it is used by a widening instruction (e.g.,
1207   // uaddl, saddw, etc.), it may be free.
1208   if (I && I->hasOneUse()) {
1209     auto *SingleUser = cast<Instruction>(*I->user_begin());
1210     SmallVector<const Value *, 4> Operands(SingleUser->operand_values());
1211     if (isWideningInstruction(Dst, SingleUser->getOpcode(), Operands)) {
1212       // If the cast is the second operand, it is free. We will generate either
1213       // a "wide" or "long" version of the widening instruction.
1214       if (I == SingleUser->getOperand(1))
1215         return 0;
1216       // If the cast is not the second operand, it will be free if it looks the
1217       // same as the second operand. In this case, we will generate a "long"
1218       // version of the widening instruction.
1219       if (auto *Cast = dyn_cast<CastInst>(SingleUser->getOperand(1)))
1220         if (I->getOpcode() == unsigned(Cast->getOpcode()) &&
1221             cast<CastInst>(I)->getSrcTy() == Cast->getSrcTy())
1222           return 0;
1223     }
1224   }
1225 
1226   // TODO: Allow non-throughput costs that aren't binary.
1227   auto AdjustCost = [&CostKind](InstructionCost Cost) -> InstructionCost {
1228     if (CostKind != TTI::TCK_RecipThroughput)
1229       return Cost == 0 ? 0 : 1;
1230     return Cost;
1231   };
1232 
1233   EVT SrcTy = TLI->getValueType(DL, Src);
1234   EVT DstTy = TLI->getValueType(DL, Dst);
1235 
1236   if (!SrcTy.isSimple() || !DstTy.isSimple())
1237     return AdjustCost(
1238         BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I));
1239 
1240   static const TypeConversionCostTblEntry
1241   ConversionTbl[] = {
1242     { ISD::TRUNCATE, MVT::v4i16, MVT::v4i32,  1 },
1243     { ISD::TRUNCATE, MVT::v4i32, MVT::v4i64,  0 },
1244     { ISD::TRUNCATE, MVT::v8i8,  MVT::v8i32,  3 },
1245     { ISD::TRUNCATE, MVT::v16i8, MVT::v16i32, 6 },
1246 
1247     // Truncations on nxvmiN
1248     { ISD::TRUNCATE, MVT::nxv2i1, MVT::nxv2i16, 1 },
1249     { ISD::TRUNCATE, MVT::nxv2i1, MVT::nxv2i32, 1 },
1250     { ISD::TRUNCATE, MVT::nxv2i1, MVT::nxv2i64, 1 },
1251     { ISD::TRUNCATE, MVT::nxv4i1, MVT::nxv4i16, 1 },
1252     { ISD::TRUNCATE, MVT::nxv4i1, MVT::nxv4i32, 1 },
1253     { ISD::TRUNCATE, MVT::nxv4i1, MVT::nxv4i64, 2 },
1254     { ISD::TRUNCATE, MVT::nxv8i1, MVT::nxv8i16, 1 },
1255     { ISD::TRUNCATE, MVT::nxv8i1, MVT::nxv8i32, 3 },
1256     { ISD::TRUNCATE, MVT::nxv8i1, MVT::nxv8i64, 5 },
1257     { ISD::TRUNCATE, MVT::nxv16i1, MVT::nxv16i8, 1 },
1258     { ISD::TRUNCATE, MVT::nxv2i16, MVT::nxv2i32, 1 },
1259     { ISD::TRUNCATE, MVT::nxv2i32, MVT::nxv2i64, 1 },
1260     { ISD::TRUNCATE, MVT::nxv4i16, MVT::nxv4i32, 1 },
1261     { ISD::TRUNCATE, MVT::nxv4i32, MVT::nxv4i64, 2 },
1262     { ISD::TRUNCATE, MVT::nxv8i16, MVT::nxv8i32, 3 },
1263     { ISD::TRUNCATE, MVT::nxv8i32, MVT::nxv8i64, 6 },
1264 
1265     // The number of shll instructions for the extension.
1266     { ISD::SIGN_EXTEND, MVT::v4i64,  MVT::v4i16, 3 },
1267     { ISD::ZERO_EXTEND, MVT::v4i64,  MVT::v4i16, 3 },
1268     { ISD::SIGN_EXTEND, MVT::v4i64,  MVT::v4i32, 2 },
1269     { ISD::ZERO_EXTEND, MVT::v4i64,  MVT::v4i32, 2 },
1270     { ISD::SIGN_EXTEND, MVT::v8i32,  MVT::v8i8,  3 },
1271     { ISD::ZERO_EXTEND, MVT::v8i32,  MVT::v8i8,  3 },
1272     { ISD::SIGN_EXTEND, MVT::v8i32,  MVT::v8i16, 2 },
1273     { ISD::ZERO_EXTEND, MVT::v8i32,  MVT::v8i16, 2 },
1274     { ISD::SIGN_EXTEND, MVT::v8i64,  MVT::v8i8,  7 },
1275     { ISD::ZERO_EXTEND, MVT::v8i64,  MVT::v8i8,  7 },
1276     { ISD::SIGN_EXTEND, MVT::v8i64,  MVT::v8i16, 6 },
1277     { ISD::ZERO_EXTEND, MVT::v8i64,  MVT::v8i16, 6 },
1278     { ISD::SIGN_EXTEND, MVT::v16i16, MVT::v16i8, 2 },
1279     { ISD::ZERO_EXTEND, MVT::v16i16, MVT::v16i8, 2 },
1280     { ISD::SIGN_EXTEND, MVT::v16i32, MVT::v16i8, 6 },
1281     { ISD::ZERO_EXTEND, MVT::v16i32, MVT::v16i8, 6 },
1282 
1283     // LowerVectorINT_TO_FP:
1284     { ISD::SINT_TO_FP, MVT::v2f32, MVT::v2i32, 1 },
1285     { ISD::SINT_TO_FP, MVT::v4f32, MVT::v4i32, 1 },
1286     { ISD::SINT_TO_FP, MVT::v2f64, MVT::v2i64, 1 },
1287     { ISD::UINT_TO_FP, MVT::v2f32, MVT::v2i32, 1 },
1288     { ISD::UINT_TO_FP, MVT::v4f32, MVT::v4i32, 1 },
1289     { ISD::UINT_TO_FP, MVT::v2f64, MVT::v2i64, 1 },
1290 
1291     // Complex: to v2f32
1292     { ISD::SINT_TO_FP, MVT::v2f32, MVT::v2i8,  3 },
1293     { ISD::SINT_TO_FP, MVT::v2f32, MVT::v2i16, 3 },
1294     { ISD::SINT_TO_FP, MVT::v2f32, MVT::v2i64, 2 },
1295     { ISD::UINT_TO_FP, MVT::v2f32, MVT::v2i8,  3 },
1296     { ISD::UINT_TO_FP, MVT::v2f32, MVT::v2i16, 3 },
1297     { ISD::UINT_TO_FP, MVT::v2f32, MVT::v2i64, 2 },
1298 
1299     // Complex: to v4f32
1300     { ISD::SINT_TO_FP, MVT::v4f32, MVT::v4i8,  4 },
1301     { ISD::SINT_TO_FP, MVT::v4f32, MVT::v4i16, 2 },
1302     { ISD::UINT_TO_FP, MVT::v4f32, MVT::v4i8,  3 },
1303     { ISD::UINT_TO_FP, MVT::v4f32, MVT::v4i16, 2 },
1304 
1305     // Complex: to v8f32
1306     { ISD::SINT_TO_FP, MVT::v8f32, MVT::v8i8,  10 },
1307     { ISD::SINT_TO_FP, MVT::v8f32, MVT::v8i16, 4 },
1308     { ISD::UINT_TO_FP, MVT::v8f32, MVT::v8i8,  10 },
1309     { ISD::UINT_TO_FP, MVT::v8f32, MVT::v8i16, 4 },
1310 
1311     // Complex: to v16f32
1312     { ISD::SINT_TO_FP, MVT::v16f32, MVT::v16i8, 21 },
1313     { ISD::UINT_TO_FP, MVT::v16f32, MVT::v16i8, 21 },
1314 
1315     // Complex: to v2f64
1316     { ISD::SINT_TO_FP, MVT::v2f64, MVT::v2i8,  4 },
1317     { ISD::SINT_TO_FP, MVT::v2f64, MVT::v2i16, 4 },
1318     { ISD::SINT_TO_FP, MVT::v2f64, MVT::v2i32, 2 },
1319     { ISD::UINT_TO_FP, MVT::v2f64, MVT::v2i8,  4 },
1320     { ISD::UINT_TO_FP, MVT::v2f64, MVT::v2i16, 4 },
1321     { ISD::UINT_TO_FP, MVT::v2f64, MVT::v2i32, 2 },
1322 
1323 
1324     // LowerVectorFP_TO_INT
1325     { ISD::FP_TO_SINT, MVT::v2i32, MVT::v2f32, 1 },
1326     { ISD::FP_TO_SINT, MVT::v4i32, MVT::v4f32, 1 },
1327     { ISD::FP_TO_SINT, MVT::v2i64, MVT::v2f64, 1 },
1328     { ISD::FP_TO_UINT, MVT::v2i32, MVT::v2f32, 1 },
1329     { ISD::FP_TO_UINT, MVT::v4i32, MVT::v4f32, 1 },
1330     { ISD::FP_TO_UINT, MVT::v2i64, MVT::v2f64, 1 },
1331 
1332     // Complex, from v2f32: legal type is v2i32 (no cost) or v2i64 (1 ext).
1333     { ISD::FP_TO_SINT, MVT::v2i64, MVT::v2f32, 2 },
1334     { ISD::FP_TO_SINT, MVT::v2i16, MVT::v2f32, 1 },
1335     { ISD::FP_TO_SINT, MVT::v2i8,  MVT::v2f32, 1 },
1336     { ISD::FP_TO_UINT, MVT::v2i64, MVT::v2f32, 2 },
1337     { ISD::FP_TO_UINT, MVT::v2i16, MVT::v2f32, 1 },
1338     { ISD::FP_TO_UINT, MVT::v2i8,  MVT::v2f32, 1 },
1339 
1340     // Complex, from v4f32: legal type is v4i16, 1 narrowing => ~2
1341     { ISD::FP_TO_SINT, MVT::v4i16, MVT::v4f32, 2 },
1342     { ISD::FP_TO_SINT, MVT::v4i8,  MVT::v4f32, 2 },
1343     { ISD::FP_TO_UINT, MVT::v4i16, MVT::v4f32, 2 },
1344     { ISD::FP_TO_UINT, MVT::v4i8,  MVT::v4f32, 2 },
1345 
1346     // Complex, from nxv2f32.
1347     { ISD::FP_TO_SINT, MVT::nxv2i64, MVT::nxv2f32, 1 },
1348     { ISD::FP_TO_SINT, MVT::nxv2i32, MVT::nxv2f32, 1 },
1349     { ISD::FP_TO_SINT, MVT::nxv2i16, MVT::nxv2f32, 1 },
1350     { ISD::FP_TO_SINT, MVT::nxv2i8,  MVT::nxv2f32, 1 },
1351     { ISD::FP_TO_UINT, MVT::nxv2i64, MVT::nxv2f32, 1 },
1352     { ISD::FP_TO_UINT, MVT::nxv2i32, MVT::nxv2f32, 1 },
1353     { ISD::FP_TO_UINT, MVT::nxv2i16, MVT::nxv2f32, 1 },
1354     { ISD::FP_TO_UINT, MVT::nxv2i8,  MVT::nxv2f32, 1 },
1355 
1356     // Complex, from v2f64: legal type is v2i32, 1 narrowing => ~2.
1357     { ISD::FP_TO_SINT, MVT::v2i32, MVT::v2f64, 2 },
1358     { ISD::FP_TO_SINT, MVT::v2i16, MVT::v2f64, 2 },
1359     { ISD::FP_TO_SINT, MVT::v2i8,  MVT::v2f64, 2 },
1360     { ISD::FP_TO_UINT, MVT::v2i32, MVT::v2f64, 2 },
1361     { ISD::FP_TO_UINT, MVT::v2i16, MVT::v2f64, 2 },
1362     { ISD::FP_TO_UINT, MVT::v2i8,  MVT::v2f64, 2 },
1363 
1364     // Complex, from nxv2f64.
1365     { ISD::FP_TO_SINT, MVT::nxv2i64, MVT::nxv2f64, 1 },
1366     { ISD::FP_TO_SINT, MVT::nxv2i32, MVT::nxv2f64, 1 },
1367     { ISD::FP_TO_SINT, MVT::nxv2i16, MVT::nxv2f64, 1 },
1368     { ISD::FP_TO_SINT, MVT::nxv2i8,  MVT::nxv2f64, 1 },
1369     { ISD::FP_TO_UINT, MVT::nxv2i64, MVT::nxv2f64, 1 },
1370     { ISD::FP_TO_UINT, MVT::nxv2i32, MVT::nxv2f64, 1 },
1371     { ISD::FP_TO_UINT, MVT::nxv2i16, MVT::nxv2f64, 1 },
1372     { ISD::FP_TO_UINT, MVT::nxv2i8,  MVT::nxv2f64, 1 },
1373 
1374     // Complex, from nxv4f32.
1375     { ISD::FP_TO_SINT, MVT::nxv4i64, MVT::nxv4f32, 4 },
1376     { ISD::FP_TO_SINT, MVT::nxv4i32, MVT::nxv4f32, 1 },
1377     { ISD::FP_TO_SINT, MVT::nxv4i16, MVT::nxv4f32, 1 },
1378     { ISD::FP_TO_SINT, MVT::nxv4i8,  MVT::nxv4f32, 1 },
1379     { ISD::FP_TO_UINT, MVT::nxv4i64, MVT::nxv4f32, 4 },
1380     { ISD::FP_TO_UINT, MVT::nxv4i32, MVT::nxv4f32, 1 },
1381     { ISD::FP_TO_UINT, MVT::nxv4i16, MVT::nxv4f32, 1 },
1382     { ISD::FP_TO_UINT, MVT::nxv4i8,  MVT::nxv4f32, 1 },
1383 
1384     // Complex, from nxv8f64. Illegal -> illegal conversions not required.
1385     { ISD::FP_TO_SINT, MVT::nxv8i16, MVT::nxv8f64, 7 },
1386     { ISD::FP_TO_SINT, MVT::nxv8i8,  MVT::nxv8f64, 7 },
1387     { ISD::FP_TO_UINT, MVT::nxv8i16, MVT::nxv8f64, 7 },
1388     { ISD::FP_TO_UINT, MVT::nxv8i8,  MVT::nxv8f64, 7 },
1389 
1390     // Complex, from nxv4f64. Illegal -> illegal conversions not required.
1391     { ISD::FP_TO_SINT, MVT::nxv4i32, MVT::nxv4f64, 3 },
1392     { ISD::FP_TO_SINT, MVT::nxv4i16, MVT::nxv4f64, 3 },
1393     { ISD::FP_TO_SINT, MVT::nxv4i8,  MVT::nxv4f64, 3 },
1394     { ISD::FP_TO_UINT, MVT::nxv4i32, MVT::nxv4f64, 3 },
1395     { ISD::FP_TO_UINT, MVT::nxv4i16, MVT::nxv4f64, 3 },
1396     { ISD::FP_TO_UINT, MVT::nxv4i8,  MVT::nxv4f64, 3 },
1397 
1398     // Complex, from nxv8f32. Illegal -> illegal conversions not required.
1399     { ISD::FP_TO_SINT, MVT::nxv8i16, MVT::nxv8f32, 3 },
1400     { ISD::FP_TO_SINT, MVT::nxv8i8,  MVT::nxv8f32, 3 },
1401     { ISD::FP_TO_UINT, MVT::nxv8i16, MVT::nxv8f32, 3 },
1402     { ISD::FP_TO_UINT, MVT::nxv8i8,  MVT::nxv8f32, 3 },
1403 
1404     // Complex, from nxv8f16.
1405     { ISD::FP_TO_SINT, MVT::nxv8i64, MVT::nxv8f16, 10 },
1406     { ISD::FP_TO_SINT, MVT::nxv8i32, MVT::nxv8f16, 4 },
1407     { ISD::FP_TO_SINT, MVT::nxv8i16, MVT::nxv8f16, 1 },
1408     { ISD::FP_TO_SINT, MVT::nxv8i8,  MVT::nxv8f16, 1 },
1409     { ISD::FP_TO_UINT, MVT::nxv8i64, MVT::nxv8f16, 10 },
1410     { ISD::FP_TO_UINT, MVT::nxv8i32, MVT::nxv8f16, 4 },
1411     { ISD::FP_TO_UINT, MVT::nxv8i16, MVT::nxv8f16, 1 },
1412     { ISD::FP_TO_UINT, MVT::nxv8i8,  MVT::nxv8f16, 1 },
1413 
1414     // Complex, from nxv4f16.
1415     { ISD::FP_TO_SINT, MVT::nxv4i64, MVT::nxv4f16, 4 },
1416     { ISD::FP_TO_SINT, MVT::nxv4i32, MVT::nxv4f16, 1 },
1417     { ISD::FP_TO_SINT, MVT::nxv4i16, MVT::nxv4f16, 1 },
1418     { ISD::FP_TO_SINT, MVT::nxv4i8,  MVT::nxv4f16, 1 },
1419     { ISD::FP_TO_UINT, MVT::nxv4i64, MVT::nxv4f16, 4 },
1420     { ISD::FP_TO_UINT, MVT::nxv4i32, MVT::nxv4f16, 1 },
1421     { ISD::FP_TO_UINT, MVT::nxv4i16, MVT::nxv4f16, 1 },
1422     { ISD::FP_TO_UINT, MVT::nxv4i8,  MVT::nxv4f16, 1 },
1423 
1424     // Complex, from nxv2f16.
1425     { ISD::FP_TO_SINT, MVT::nxv2i64, MVT::nxv2f16, 1 },
1426     { ISD::FP_TO_SINT, MVT::nxv2i32, MVT::nxv2f16, 1 },
1427     { ISD::FP_TO_SINT, MVT::nxv2i16, MVT::nxv2f16, 1 },
1428     { ISD::FP_TO_SINT, MVT::nxv2i8,  MVT::nxv2f16, 1 },
1429     { ISD::FP_TO_UINT, MVT::nxv2i64, MVT::nxv2f16, 1 },
1430     { ISD::FP_TO_UINT, MVT::nxv2i32, MVT::nxv2f16, 1 },
1431     { ISD::FP_TO_UINT, MVT::nxv2i16, MVT::nxv2f16, 1 },
1432     { ISD::FP_TO_UINT, MVT::nxv2i8,  MVT::nxv2f16, 1 },
1433 
1434     // Truncate from nxvmf32 to nxvmf16.
1435     { ISD::FP_ROUND, MVT::nxv2f16, MVT::nxv2f32, 1 },
1436     { ISD::FP_ROUND, MVT::nxv4f16, MVT::nxv4f32, 1 },
1437     { ISD::FP_ROUND, MVT::nxv8f16, MVT::nxv8f32, 3 },
1438 
1439     // Truncate from nxvmf64 to nxvmf16.
1440     { ISD::FP_ROUND, MVT::nxv2f16, MVT::nxv2f64, 1 },
1441     { ISD::FP_ROUND, MVT::nxv4f16, MVT::nxv4f64, 3 },
1442     { ISD::FP_ROUND, MVT::nxv8f16, MVT::nxv8f64, 7 },
1443 
1444     // Truncate from nxvmf64 to nxvmf32.
1445     { ISD::FP_ROUND, MVT::nxv2f32, MVT::nxv2f64, 1 },
1446     { ISD::FP_ROUND, MVT::nxv4f32, MVT::nxv4f64, 3 },
1447     { ISD::FP_ROUND, MVT::nxv8f32, MVT::nxv8f64, 6 },
1448 
1449     // Extend from nxvmf16 to nxvmf32.
1450     { ISD::FP_EXTEND, MVT::nxv2f32, MVT::nxv2f16, 1},
1451     { ISD::FP_EXTEND, MVT::nxv4f32, MVT::nxv4f16, 1},
1452     { ISD::FP_EXTEND, MVT::nxv8f32, MVT::nxv8f16, 2},
1453 
1454     // Extend from nxvmf16 to nxvmf64.
1455     { ISD::FP_EXTEND, MVT::nxv2f64, MVT::nxv2f16, 1},
1456     { ISD::FP_EXTEND, MVT::nxv4f64, MVT::nxv4f16, 2},
1457     { ISD::FP_EXTEND, MVT::nxv8f64, MVT::nxv8f16, 4},
1458 
1459     // Extend from nxvmf32 to nxvmf64.
1460     { ISD::FP_EXTEND, MVT::nxv2f64, MVT::nxv2f32, 1},
1461     { ISD::FP_EXTEND, MVT::nxv4f64, MVT::nxv4f32, 2},
1462     { ISD::FP_EXTEND, MVT::nxv8f64, MVT::nxv8f32, 6},
1463 
1464   };
1465 
1466   if (const auto *Entry = ConvertCostTableLookup(ConversionTbl, ISD,
1467                                                  DstTy.getSimpleVT(),
1468                                                  SrcTy.getSimpleVT()))
1469     return AdjustCost(Entry->Cost);
1470 
1471   return AdjustCost(
1472       BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I));
1473 }
1474 
1475 InstructionCost AArch64TTIImpl::getExtractWithExtendCost(unsigned Opcode,
1476                                                          Type *Dst,
1477                                                          VectorType *VecTy,
1478                                                          unsigned Index) {
1479 
1480   // Make sure we were given a valid extend opcode.
1481   assert((Opcode == Instruction::SExt || Opcode == Instruction::ZExt) &&
1482          "Invalid opcode");
1483 
1484   // We are extending an element we extract from a vector, so the source type
1485   // of the extend is the element type of the vector.
1486   auto *Src = VecTy->getElementType();
1487 
1488   // Sign- and zero-extends are for integer types only.
1489   assert(isa<IntegerType>(Dst) && isa<IntegerType>(Src) && "Invalid type");
1490 
1491   // Get the cost for the extract. We compute the cost (if any) for the extend
1492   // below.
1493   InstructionCost Cost =
1494       getVectorInstrCost(Instruction::ExtractElement, VecTy, Index);
1495 
1496   // Legalize the types.
1497   auto VecLT = TLI->getTypeLegalizationCost(DL, VecTy);
1498   auto DstVT = TLI->getValueType(DL, Dst);
1499   auto SrcVT = TLI->getValueType(DL, Src);
1500   TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
1501 
1502   // If the resulting type is still a vector and the destination type is legal,
1503   // we may get the extension for free. If not, get the default cost for the
1504   // extend.
1505   if (!VecLT.second.isVector() || !TLI->isTypeLegal(DstVT))
1506     return Cost + getCastInstrCost(Opcode, Dst, Src, TTI::CastContextHint::None,
1507                                    CostKind);
1508 
1509   // The destination type should be larger than the element type. If not, get
1510   // the default cost for the extend.
1511   if (DstVT.getFixedSizeInBits() < SrcVT.getFixedSizeInBits())
1512     return Cost + getCastInstrCost(Opcode, Dst, Src, TTI::CastContextHint::None,
1513                                    CostKind);
1514 
1515   switch (Opcode) {
1516   default:
1517     llvm_unreachable("Opcode should be either SExt or ZExt");
1518 
1519   // For sign-extends, we only need a smov, which performs the extension
1520   // automatically.
1521   case Instruction::SExt:
1522     return Cost;
1523 
1524   // For zero-extends, the extend is performed automatically by a umov unless
1525   // the destination type is i64 and the element type is i8 or i16.
1526   case Instruction::ZExt:
1527     if (DstVT.getSizeInBits() != 64u || SrcVT.getSizeInBits() == 32u)
1528       return Cost;
1529   }
1530 
1531   // If we are unable to perform the extend for free, get the default cost.
1532   return Cost + getCastInstrCost(Opcode, Dst, Src, TTI::CastContextHint::None,
1533                                  CostKind);
1534 }
1535 
1536 InstructionCost AArch64TTIImpl::getCFInstrCost(unsigned Opcode,
1537                                                TTI::TargetCostKind CostKind,
1538                                                const Instruction *I) {
1539   if (CostKind != TTI::TCK_RecipThroughput)
1540     return Opcode == Instruction::PHI ? 0 : 1;
1541   assert(CostKind == TTI::TCK_RecipThroughput && "unexpected CostKind");
1542   // Branches are assumed to be predicted.
1543   return 0;
1544 }
1545 
1546 InstructionCost AArch64TTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
1547                                                    unsigned Index) {
1548   assert(Val->isVectorTy() && "This must be a vector type");
1549 
1550   if (Index != -1U) {
1551     // Legalize the type.
1552     std::pair<InstructionCost, MVT> LT = TLI->getTypeLegalizationCost(DL, Val);
1553 
1554     // This type is legalized to a scalar type.
1555     if (!LT.second.isVector())
1556       return 0;
1557 
1558     // The type may be split. Normalize the index to the new type.
1559     unsigned Width = LT.second.getVectorNumElements();
1560     Index = Index % Width;
1561 
1562     // The element at index zero is already inside the vector.
1563     if (Index == 0)
1564       return 0;
1565   }
1566 
1567   // All other insert/extracts cost this much.
1568   return ST->getVectorInsertExtractBaseCost();
1569 }
1570 
1571 InstructionCost AArch64TTIImpl::getArithmeticInstrCost(
1572     unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
1573     TTI::OperandValueKind Opd1Info, TTI::OperandValueKind Opd2Info,
1574     TTI::OperandValueProperties Opd1PropInfo,
1575     TTI::OperandValueProperties Opd2PropInfo, ArrayRef<const Value *> Args,
1576     const Instruction *CxtI) {
1577   // TODO: Handle more cost kinds.
1578   if (CostKind != TTI::TCK_RecipThroughput)
1579     return BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Opd1Info,
1580                                          Opd2Info, Opd1PropInfo,
1581                                          Opd2PropInfo, Args, CxtI);
1582 
1583   // Legalize the type.
1584   std::pair<InstructionCost, MVT> LT = TLI->getTypeLegalizationCost(DL, Ty);
1585 
1586   // If the instruction is a widening instruction (e.g., uaddl, saddw, etc.),
1587   // add in the widening overhead specified by the sub-target. Since the
1588   // extends feeding widening instructions are performed automatically, they
1589   // aren't present in the generated code and have a zero cost. By adding a
1590   // widening overhead here, we attach the total cost of the combined operation
1591   // to the widening instruction.
1592   InstructionCost Cost = 0;
1593   if (isWideningInstruction(Ty, Opcode, Args))
1594     Cost += ST->getWideningBaseCost();
1595 
1596   int ISD = TLI->InstructionOpcodeToISD(Opcode);
1597 
1598   switch (ISD) {
1599   default:
1600     return Cost + BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Opd1Info,
1601                                                 Opd2Info,
1602                                                 Opd1PropInfo, Opd2PropInfo);
1603   case ISD::SDIV:
1604     if (Opd2Info == TargetTransformInfo::OK_UniformConstantValue &&
1605         Opd2PropInfo == TargetTransformInfo::OP_PowerOf2) {
1606       // On AArch64, scalar signed division by constants power-of-two are
1607       // normally expanded to the sequence ADD + CMP + SELECT + SRA.
1608       // The OperandValue properties many not be same as that of previous
1609       // operation; conservatively assume OP_None.
1610       Cost += getArithmeticInstrCost(Instruction::Add, Ty, CostKind,
1611                                      Opd1Info, Opd2Info,
1612                                      TargetTransformInfo::OP_None,
1613                                      TargetTransformInfo::OP_None);
1614       Cost += getArithmeticInstrCost(Instruction::Sub, Ty, CostKind,
1615                                      Opd1Info, Opd2Info,
1616                                      TargetTransformInfo::OP_None,
1617                                      TargetTransformInfo::OP_None);
1618       Cost += getArithmeticInstrCost(Instruction::Select, Ty, CostKind,
1619                                      Opd1Info, Opd2Info,
1620                                      TargetTransformInfo::OP_None,
1621                                      TargetTransformInfo::OP_None);
1622       Cost += getArithmeticInstrCost(Instruction::AShr, Ty, CostKind,
1623                                      Opd1Info, Opd2Info,
1624                                      TargetTransformInfo::OP_None,
1625                                      TargetTransformInfo::OP_None);
1626       return Cost;
1627     }
1628     LLVM_FALLTHROUGH;
1629   case ISD::UDIV:
1630     if (Opd2Info == TargetTransformInfo::OK_UniformConstantValue) {
1631       auto VT = TLI->getValueType(DL, Ty);
1632       if (TLI->isOperationLegalOrCustom(ISD::MULHU, VT)) {
1633         // Vector signed division by constant are expanded to the
1634         // sequence MULHS + ADD/SUB + SRA + SRL + ADD, and unsigned division
1635         // to MULHS + SUB + SRL + ADD + SRL.
1636         InstructionCost MulCost = getArithmeticInstrCost(
1637             Instruction::Mul, Ty, CostKind, Opd1Info, Opd2Info,
1638             TargetTransformInfo::OP_None, TargetTransformInfo::OP_None);
1639         InstructionCost AddCost = getArithmeticInstrCost(
1640             Instruction::Add, Ty, CostKind, Opd1Info, Opd2Info,
1641             TargetTransformInfo::OP_None, TargetTransformInfo::OP_None);
1642         InstructionCost ShrCost = getArithmeticInstrCost(
1643             Instruction::AShr, Ty, CostKind, Opd1Info, Opd2Info,
1644             TargetTransformInfo::OP_None, TargetTransformInfo::OP_None);
1645         return MulCost * 2 + AddCost * 2 + ShrCost * 2 + 1;
1646       }
1647     }
1648 
1649     Cost += BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Opd1Info,
1650                                           Opd2Info,
1651                                           Opd1PropInfo, Opd2PropInfo);
1652     if (Ty->isVectorTy()) {
1653       // On AArch64, vector divisions are not supported natively and are
1654       // expanded into scalar divisions of each pair of elements.
1655       Cost += getArithmeticInstrCost(Instruction::ExtractElement, Ty, CostKind,
1656                                      Opd1Info, Opd2Info, Opd1PropInfo,
1657                                      Opd2PropInfo);
1658       Cost += getArithmeticInstrCost(Instruction::InsertElement, Ty, CostKind,
1659                                      Opd1Info, Opd2Info, Opd1PropInfo,
1660                                      Opd2PropInfo);
1661       // TODO: if one of the arguments is scalar, then it's not necessary to
1662       // double the cost of handling the vector elements.
1663       Cost += Cost;
1664     }
1665     return Cost;
1666 
1667   case ISD::MUL:
1668     if (LT.second != MVT::v2i64)
1669       return (Cost + 1) * LT.first;
1670     // Since we do not have a MUL.2d instruction, a mul <2 x i64> is expensive
1671     // as elements are extracted from the vectors and the muls scalarized.
1672     // As getScalarizationOverhead is a bit too pessimistic, we estimate the
1673     // cost for a i64 vector directly here, which is:
1674     // - four i64 extracts,
1675     // - two i64 inserts, and
1676     // - two muls.
1677     // So, for a v2i64 with LT.First = 1 the cost is 8, and for a v4i64 with
1678     // LT.first = 2 the cost is 16.
1679     return LT.first * 8;
1680   case ISD::ADD:
1681   case ISD::XOR:
1682   case ISD::OR:
1683   case ISD::AND:
1684     // These nodes are marked as 'custom' for combining purposes only.
1685     // We know that they are legal. See LowerAdd in ISelLowering.
1686     return (Cost + 1) * LT.first;
1687 
1688   case ISD::FADD:
1689   case ISD::FSUB:
1690   case ISD::FMUL:
1691   case ISD::FDIV:
1692   case ISD::FNEG:
1693     // These nodes are marked as 'custom' just to lower them to SVE.
1694     // We know said lowering will incur no additional cost.
1695     if (!Ty->getScalarType()->isFP128Ty())
1696       return (Cost + 2) * LT.first;
1697 
1698     return Cost + BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Opd1Info,
1699                                                 Opd2Info,
1700                                                 Opd1PropInfo, Opd2PropInfo);
1701   }
1702 }
1703 
1704 InstructionCost AArch64TTIImpl::getAddressComputationCost(Type *Ty,
1705                                                           ScalarEvolution *SE,
1706                                                           const SCEV *Ptr) {
1707   // Address computations in vectorized code with non-consecutive addresses will
1708   // likely result in more instructions compared to scalar code where the
1709   // computation can more often be merged into the index mode. The resulting
1710   // extra micro-ops can significantly decrease throughput.
1711   unsigned NumVectorInstToHideOverhead = 10;
1712   int MaxMergeDistance = 64;
1713 
1714   if (Ty->isVectorTy() && SE &&
1715       !BaseT::isConstantStridedAccessLessThan(SE, Ptr, MaxMergeDistance + 1))
1716     return NumVectorInstToHideOverhead;
1717 
1718   // In many cases the address computation is not merged into the instruction
1719   // addressing mode.
1720   return 1;
1721 }
1722 
1723 InstructionCost AArch64TTIImpl::getCmpSelInstrCost(unsigned Opcode, Type *ValTy,
1724                                                    Type *CondTy,
1725                                                    CmpInst::Predicate VecPred,
1726                                                    TTI::TargetCostKind CostKind,
1727                                                    const Instruction *I) {
1728   // TODO: Handle other cost kinds.
1729   if (CostKind != TTI::TCK_RecipThroughput)
1730     return BaseT::getCmpSelInstrCost(Opcode, ValTy, CondTy, VecPred, CostKind,
1731                                      I);
1732 
1733   int ISD = TLI->InstructionOpcodeToISD(Opcode);
1734   // We don't lower some vector selects well that are wider than the register
1735   // width.
1736   if (isa<FixedVectorType>(ValTy) && ISD == ISD::SELECT) {
1737     // We would need this many instructions to hide the scalarization happening.
1738     const int AmortizationCost = 20;
1739 
1740     // If VecPred is not set, check if we can get a predicate from the context
1741     // instruction, if its type matches the requested ValTy.
1742     if (VecPred == CmpInst::BAD_ICMP_PREDICATE && I && I->getType() == ValTy) {
1743       CmpInst::Predicate CurrentPred;
1744       if (match(I, m_Select(m_Cmp(CurrentPred, m_Value(), m_Value()), m_Value(),
1745                             m_Value())))
1746         VecPred = CurrentPred;
1747     }
1748     // Check if we have a compare/select chain that can be lowered using CMxx &
1749     // BFI pair.
1750     if (CmpInst::isIntPredicate(VecPred)) {
1751       static const auto ValidMinMaxTys = {MVT::v8i8,  MVT::v16i8, MVT::v4i16,
1752                                           MVT::v8i16, MVT::v2i32, MVT::v4i32,
1753                                           MVT::v2i64};
1754       auto LT = TLI->getTypeLegalizationCost(DL, ValTy);
1755       if (any_of(ValidMinMaxTys, [&LT](MVT M) { return M == LT.second; }))
1756         return LT.first;
1757     }
1758 
1759     static const TypeConversionCostTblEntry
1760     VectorSelectTbl[] = {
1761       { ISD::SELECT, MVT::v16i1, MVT::v16i16, 16 },
1762       { ISD::SELECT, MVT::v8i1, MVT::v8i32, 8 },
1763       { ISD::SELECT, MVT::v16i1, MVT::v16i32, 16 },
1764       { ISD::SELECT, MVT::v4i1, MVT::v4i64, 4 * AmortizationCost },
1765       { ISD::SELECT, MVT::v8i1, MVT::v8i64, 8 * AmortizationCost },
1766       { ISD::SELECT, MVT::v16i1, MVT::v16i64, 16 * AmortizationCost }
1767     };
1768 
1769     EVT SelCondTy = TLI->getValueType(DL, CondTy);
1770     EVT SelValTy = TLI->getValueType(DL, ValTy);
1771     if (SelCondTy.isSimple() && SelValTy.isSimple()) {
1772       if (const auto *Entry = ConvertCostTableLookup(VectorSelectTbl, ISD,
1773                                                      SelCondTy.getSimpleVT(),
1774                                                      SelValTy.getSimpleVT()))
1775         return Entry->Cost;
1776     }
1777   }
1778   // The base case handles scalable vectors fine for now, since it treats the
1779   // cost as 1 * legalization cost.
1780   return BaseT::getCmpSelInstrCost(Opcode, ValTy, CondTy, VecPred, CostKind, I);
1781 }
1782 
1783 AArch64TTIImpl::TTI::MemCmpExpansionOptions
1784 AArch64TTIImpl::enableMemCmpExpansion(bool OptSize, bool IsZeroCmp) const {
1785   TTI::MemCmpExpansionOptions Options;
1786   if (ST->requiresStrictAlign()) {
1787     // TODO: Add cost modeling for strict align. Misaligned loads expand to
1788     // a bunch of instructions when strict align is enabled.
1789     return Options;
1790   }
1791   Options.AllowOverlappingLoads = true;
1792   Options.MaxNumLoads = TLI->getMaxExpandSizeMemcmp(OptSize);
1793   Options.NumLoadsPerBlock = Options.MaxNumLoads;
1794   // TODO: Though vector loads usually perform well on AArch64, in some targets
1795   // they may wake up the FP unit, which raises the power consumption.  Perhaps
1796   // they could be used with no holds barred (-O3).
1797   Options.LoadSizes = {8, 4, 2, 1};
1798   return Options;
1799 }
1800 
1801 InstructionCost
1802 AArch64TTIImpl::getMaskedMemoryOpCost(unsigned Opcode, Type *Src,
1803                                       Align Alignment, unsigned AddressSpace,
1804                                       TTI::TargetCostKind CostKind) {
1805   if (useNeonVector(Src))
1806     return BaseT::getMaskedMemoryOpCost(Opcode, Src, Alignment, AddressSpace,
1807                                         CostKind);
1808   auto LT = TLI->getTypeLegalizationCost(DL, Src);
1809   if (!LT.first.isValid())
1810     return InstructionCost::getInvalid();
1811 
1812   // The code-generator is currently not able to handle scalable vectors
1813   // of <vscale x 1 x eltty> yet, so return an invalid cost to avoid selecting
1814   // it. This change will be removed when code-generation for these types is
1815   // sufficiently reliable.
1816   if (cast<VectorType>(Src)->getElementCount() == ElementCount::getScalable(1))
1817     return InstructionCost::getInvalid();
1818 
1819   return LT.first * 2;
1820 }
1821 
1822 static unsigned getSVEGatherScatterOverhead(unsigned Opcode) {
1823   return Opcode == Instruction::Load ? SVEGatherOverhead : SVEScatterOverhead;
1824 }
1825 
1826 InstructionCost AArch64TTIImpl::getGatherScatterOpCost(
1827     unsigned Opcode, Type *DataTy, const Value *Ptr, bool VariableMask,
1828     Align Alignment, TTI::TargetCostKind CostKind, const Instruction *I) {
1829   if (useNeonVector(DataTy))
1830     return BaseT::getGatherScatterOpCost(Opcode, DataTy, Ptr, VariableMask,
1831                                          Alignment, CostKind, I);
1832   auto *VT = cast<VectorType>(DataTy);
1833   auto LT = TLI->getTypeLegalizationCost(DL, DataTy);
1834   if (!LT.first.isValid())
1835     return InstructionCost::getInvalid();
1836 
1837   // The code-generator is currently not able to handle scalable vectors
1838   // of <vscale x 1 x eltty> yet, so return an invalid cost to avoid selecting
1839   // it. This change will be removed when code-generation for these types is
1840   // sufficiently reliable.
1841   if (cast<VectorType>(DataTy)->getElementCount() ==
1842       ElementCount::getScalable(1))
1843     return InstructionCost::getInvalid();
1844 
1845   ElementCount LegalVF = LT.second.getVectorElementCount();
1846   InstructionCost MemOpCost =
1847       getMemoryOpCost(Opcode, VT->getElementType(), Alignment, 0, CostKind, I);
1848   // Add on an overhead cost for using gathers/scatters.
1849   // TODO: At the moment this is applied unilaterally for all CPUs, but at some
1850   // point we may want a per-CPU overhead.
1851   MemOpCost *= getSVEGatherScatterOverhead(Opcode);
1852   return LT.first * MemOpCost * getMaxNumElements(LegalVF);
1853 }
1854 
1855 bool AArch64TTIImpl::useNeonVector(const Type *Ty) const {
1856   return isa<FixedVectorType>(Ty) && !ST->useSVEForFixedLengthVectors();
1857 }
1858 
1859 InstructionCost AArch64TTIImpl::getMemoryOpCost(unsigned Opcode, Type *Ty,
1860                                                 MaybeAlign Alignment,
1861                                                 unsigned AddressSpace,
1862                                                 TTI::TargetCostKind CostKind,
1863                                                 const Instruction *I) {
1864   EVT VT = TLI->getValueType(DL, Ty, true);
1865   // Type legalization can't handle structs
1866   if (VT == MVT::Other)
1867     return BaseT::getMemoryOpCost(Opcode, Ty, Alignment, AddressSpace,
1868                                   CostKind);
1869 
1870   auto LT = TLI->getTypeLegalizationCost(DL, Ty);
1871   if (!LT.first.isValid())
1872     return InstructionCost::getInvalid();
1873 
1874   // The code-generator is currently not able to handle scalable vectors
1875   // of <vscale x 1 x eltty> yet, so return an invalid cost to avoid selecting
1876   // it. This change will be removed when code-generation for these types is
1877   // sufficiently reliable.
1878   if (auto *VTy = dyn_cast<ScalableVectorType>(Ty))
1879     if (VTy->getElementCount() == ElementCount::getScalable(1))
1880       return InstructionCost::getInvalid();
1881 
1882   // TODO: consider latency as well for TCK_SizeAndLatency.
1883   if (CostKind == TTI::TCK_CodeSize || CostKind == TTI::TCK_SizeAndLatency)
1884     return LT.first;
1885 
1886   if (CostKind != TTI::TCK_RecipThroughput)
1887     return 1;
1888 
1889   if (ST->isMisaligned128StoreSlow() && Opcode == Instruction::Store &&
1890       LT.second.is128BitVector() && (!Alignment || *Alignment < Align(16))) {
1891     // Unaligned stores are extremely inefficient. We don't split all
1892     // unaligned 128-bit stores because the negative impact that has shown in
1893     // practice on inlined block copy code.
1894     // We make such stores expensive so that we will only vectorize if there
1895     // are 6 other instructions getting vectorized.
1896     const int AmortizationCost = 6;
1897 
1898     return LT.first * 2 * AmortizationCost;
1899   }
1900 
1901   // Check truncating stores and extending loads.
1902   if (useNeonVector(Ty) &&
1903       Ty->getScalarSizeInBits() != LT.second.getScalarSizeInBits()) {
1904     // v4i8 types are lowered to scalar a load/store and sshll/xtn.
1905     if (VT == MVT::v4i8)
1906       return 2;
1907     // Otherwise we need to scalarize.
1908     return cast<FixedVectorType>(Ty)->getNumElements() * 2;
1909   }
1910 
1911   return LT.first;
1912 }
1913 
1914 InstructionCost AArch64TTIImpl::getInterleavedMemoryOpCost(
1915     unsigned Opcode, Type *VecTy, unsigned Factor, ArrayRef<unsigned> Indices,
1916     Align Alignment, unsigned AddressSpace, TTI::TargetCostKind CostKind,
1917     bool UseMaskForCond, bool UseMaskForGaps) {
1918   assert(Factor >= 2 && "Invalid interleave factor");
1919   auto *VecVTy = cast<FixedVectorType>(VecTy);
1920 
1921   if (!UseMaskForCond && !UseMaskForGaps &&
1922       Factor <= TLI->getMaxSupportedInterleaveFactor()) {
1923     unsigned NumElts = VecVTy->getNumElements();
1924     auto *SubVecTy =
1925         FixedVectorType::get(VecTy->getScalarType(), NumElts / Factor);
1926 
1927     // ldN/stN only support legal vector types of size 64 or 128 in bits.
1928     // Accesses having vector types that are a multiple of 128 bits can be
1929     // matched to more than one ldN/stN instruction.
1930     bool UseScalable;
1931     if (NumElts % Factor == 0 &&
1932         TLI->isLegalInterleavedAccessType(SubVecTy, DL, UseScalable))
1933       return Factor * TLI->getNumInterleavedAccesses(SubVecTy, DL, UseScalable);
1934   }
1935 
1936   return BaseT::getInterleavedMemoryOpCost(Opcode, VecTy, Factor, Indices,
1937                                            Alignment, AddressSpace, CostKind,
1938                                            UseMaskForCond, UseMaskForGaps);
1939 }
1940 
1941 InstructionCost
1942 AArch64TTIImpl::getCostOfKeepingLiveOverCall(ArrayRef<Type *> Tys) {
1943   InstructionCost Cost = 0;
1944   TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
1945   for (auto *I : Tys) {
1946     if (!I->isVectorTy())
1947       continue;
1948     if (I->getScalarSizeInBits() * cast<FixedVectorType>(I)->getNumElements() ==
1949         128)
1950       Cost += getMemoryOpCost(Instruction::Store, I, Align(128), 0, CostKind) +
1951               getMemoryOpCost(Instruction::Load, I, Align(128), 0, CostKind);
1952   }
1953   return Cost;
1954 }
1955 
1956 unsigned AArch64TTIImpl::getMaxInterleaveFactor(unsigned VF) {
1957   return ST->getMaxInterleaveFactor();
1958 }
1959 
1960 // For Falkor, we want to avoid having too many strided loads in a loop since
1961 // that can exhaust the HW prefetcher resources.  We adjust the unroller
1962 // MaxCount preference below to attempt to ensure unrolling doesn't create too
1963 // many strided loads.
1964 static void
1965 getFalkorUnrollingPreferences(Loop *L, ScalarEvolution &SE,
1966                               TargetTransformInfo::UnrollingPreferences &UP) {
1967   enum { MaxStridedLoads = 7 };
1968   auto countStridedLoads = [](Loop *L, ScalarEvolution &SE) {
1969     int StridedLoads = 0;
1970     // FIXME? We could make this more precise by looking at the CFG and
1971     // e.g. not counting loads in each side of an if-then-else diamond.
1972     for (const auto BB : L->blocks()) {
1973       for (auto &I : *BB) {
1974         LoadInst *LMemI = dyn_cast<LoadInst>(&I);
1975         if (!LMemI)
1976           continue;
1977 
1978         Value *PtrValue = LMemI->getPointerOperand();
1979         if (L->isLoopInvariant(PtrValue))
1980           continue;
1981 
1982         const SCEV *LSCEV = SE.getSCEV(PtrValue);
1983         const SCEVAddRecExpr *LSCEVAddRec = dyn_cast<SCEVAddRecExpr>(LSCEV);
1984         if (!LSCEVAddRec || !LSCEVAddRec->isAffine())
1985           continue;
1986 
1987         // FIXME? We could take pairing of unrolled load copies into account
1988         // by looking at the AddRec, but we would probably have to limit this
1989         // to loops with no stores or other memory optimization barriers.
1990         ++StridedLoads;
1991         // We've seen enough strided loads that seeing more won't make a
1992         // difference.
1993         if (StridedLoads > MaxStridedLoads / 2)
1994           return StridedLoads;
1995       }
1996     }
1997     return StridedLoads;
1998   };
1999 
2000   int StridedLoads = countStridedLoads(L, SE);
2001   LLVM_DEBUG(dbgs() << "falkor-hwpf: detected " << StridedLoads
2002                     << " strided loads\n");
2003   // Pick the largest power of 2 unroll count that won't result in too many
2004   // strided loads.
2005   if (StridedLoads) {
2006     UP.MaxCount = 1 << Log2_32(MaxStridedLoads / StridedLoads);
2007     LLVM_DEBUG(dbgs() << "falkor-hwpf: setting unroll MaxCount to "
2008                       << UP.MaxCount << '\n');
2009   }
2010 }
2011 
2012 void AArch64TTIImpl::getUnrollingPreferences(Loop *L, ScalarEvolution &SE,
2013                                              TTI::UnrollingPreferences &UP,
2014                                              OptimizationRemarkEmitter *ORE) {
2015   // Enable partial unrolling and runtime unrolling.
2016   BaseT::getUnrollingPreferences(L, SE, UP, ORE);
2017 
2018   UP.UpperBound = true;
2019 
2020   // For inner loop, it is more likely to be a hot one, and the runtime check
2021   // can be promoted out from LICM pass, so the overhead is less, let's try
2022   // a larger threshold to unroll more loops.
2023   if (L->getLoopDepth() > 1)
2024     UP.PartialThreshold *= 2;
2025 
2026   // Disable partial & runtime unrolling on -Os.
2027   UP.PartialOptSizeThreshold = 0;
2028 
2029   if (ST->getProcFamily() == AArch64Subtarget::Falkor &&
2030       EnableFalkorHWPFUnrollFix)
2031     getFalkorUnrollingPreferences(L, SE, UP);
2032 
2033   // Scan the loop: don't unroll loops with calls as this could prevent
2034   // inlining. Don't unroll vector loops either, as they don't benefit much from
2035   // unrolling.
2036   for (auto *BB : L->getBlocks()) {
2037     for (auto &I : *BB) {
2038       // Don't unroll vectorised loop.
2039       if (I.getType()->isVectorTy())
2040         return;
2041 
2042       if (isa<CallInst>(I) || isa<InvokeInst>(I)) {
2043         if (const Function *F = cast<CallBase>(I).getCalledFunction()) {
2044           if (!isLoweredToCall(F))
2045             continue;
2046         }
2047         return;
2048       }
2049     }
2050   }
2051 
2052   // Enable runtime unrolling for in-order models
2053   // If mcpu is omitted, getProcFamily() returns AArch64Subtarget::Others, so by
2054   // checking for that case, we can ensure that the default behaviour is
2055   // unchanged
2056   if (ST->getProcFamily() != AArch64Subtarget::Others &&
2057       !ST->getSchedModel().isOutOfOrder()) {
2058     UP.Runtime = true;
2059     UP.Partial = true;
2060     UP.UnrollRemainder = true;
2061     UP.DefaultUnrollRuntimeCount = 4;
2062 
2063     UP.UnrollAndJam = true;
2064     UP.UnrollAndJamInnerLoopThreshold = 60;
2065   }
2066 }
2067 
2068 void AArch64TTIImpl::getPeelingPreferences(Loop *L, ScalarEvolution &SE,
2069                                            TTI::PeelingPreferences &PP) {
2070   BaseT::getPeelingPreferences(L, SE, PP);
2071 }
2072 
2073 Value *AArch64TTIImpl::getOrCreateResultFromMemIntrinsic(IntrinsicInst *Inst,
2074                                                          Type *ExpectedType) {
2075   switch (Inst->getIntrinsicID()) {
2076   default:
2077     return nullptr;
2078   case Intrinsic::aarch64_neon_st2:
2079   case Intrinsic::aarch64_neon_st3:
2080   case Intrinsic::aarch64_neon_st4: {
2081     // Create a struct type
2082     StructType *ST = dyn_cast<StructType>(ExpectedType);
2083     if (!ST)
2084       return nullptr;
2085     unsigned NumElts = Inst->arg_size() - 1;
2086     if (ST->getNumElements() != NumElts)
2087       return nullptr;
2088     for (unsigned i = 0, e = NumElts; i != e; ++i) {
2089       if (Inst->getArgOperand(i)->getType() != ST->getElementType(i))
2090         return nullptr;
2091     }
2092     Value *Res = UndefValue::get(ExpectedType);
2093     IRBuilder<> Builder(Inst);
2094     for (unsigned i = 0, e = NumElts; i != e; ++i) {
2095       Value *L = Inst->getArgOperand(i);
2096       Res = Builder.CreateInsertValue(Res, L, i);
2097     }
2098     return Res;
2099   }
2100   case Intrinsic::aarch64_neon_ld2:
2101   case Intrinsic::aarch64_neon_ld3:
2102   case Intrinsic::aarch64_neon_ld4:
2103     if (Inst->getType() == ExpectedType)
2104       return Inst;
2105     return nullptr;
2106   }
2107 }
2108 
2109 bool AArch64TTIImpl::getTgtMemIntrinsic(IntrinsicInst *Inst,
2110                                         MemIntrinsicInfo &Info) {
2111   switch (Inst->getIntrinsicID()) {
2112   default:
2113     break;
2114   case Intrinsic::aarch64_neon_ld2:
2115   case Intrinsic::aarch64_neon_ld3:
2116   case Intrinsic::aarch64_neon_ld4:
2117     Info.ReadMem = true;
2118     Info.WriteMem = false;
2119     Info.PtrVal = Inst->getArgOperand(0);
2120     break;
2121   case Intrinsic::aarch64_neon_st2:
2122   case Intrinsic::aarch64_neon_st3:
2123   case Intrinsic::aarch64_neon_st4:
2124     Info.ReadMem = false;
2125     Info.WriteMem = true;
2126     Info.PtrVal = Inst->getArgOperand(Inst->arg_size() - 1);
2127     break;
2128   }
2129 
2130   switch (Inst->getIntrinsicID()) {
2131   default:
2132     return false;
2133   case Intrinsic::aarch64_neon_ld2:
2134   case Intrinsic::aarch64_neon_st2:
2135     Info.MatchingId = VECTOR_LDST_TWO_ELEMENTS;
2136     break;
2137   case Intrinsic::aarch64_neon_ld3:
2138   case Intrinsic::aarch64_neon_st3:
2139     Info.MatchingId = VECTOR_LDST_THREE_ELEMENTS;
2140     break;
2141   case Intrinsic::aarch64_neon_ld4:
2142   case Intrinsic::aarch64_neon_st4:
2143     Info.MatchingId = VECTOR_LDST_FOUR_ELEMENTS;
2144     break;
2145   }
2146   return true;
2147 }
2148 
2149 /// See if \p I should be considered for address type promotion. We check if \p
2150 /// I is a sext with right type and used in memory accesses. If it used in a
2151 /// "complex" getelementptr, we allow it to be promoted without finding other
2152 /// sext instructions that sign extended the same initial value. A getelementptr
2153 /// is considered as "complex" if it has more than 2 operands.
2154 bool AArch64TTIImpl::shouldConsiderAddressTypePromotion(
2155     const Instruction &I, bool &AllowPromotionWithoutCommonHeader) {
2156   bool Considerable = false;
2157   AllowPromotionWithoutCommonHeader = false;
2158   if (!isa<SExtInst>(&I))
2159     return false;
2160   Type *ConsideredSExtType =
2161       Type::getInt64Ty(I.getParent()->getParent()->getContext());
2162   if (I.getType() != ConsideredSExtType)
2163     return false;
2164   // See if the sext is the one with the right type and used in at least one
2165   // GetElementPtrInst.
2166   for (const User *U : I.users()) {
2167     if (const GetElementPtrInst *GEPInst = dyn_cast<GetElementPtrInst>(U)) {
2168       Considerable = true;
2169       // A getelementptr is considered as "complex" if it has more than 2
2170       // operands. We will promote a SExt used in such complex GEP as we
2171       // expect some computation to be merged if they are done on 64 bits.
2172       if (GEPInst->getNumOperands() > 2) {
2173         AllowPromotionWithoutCommonHeader = true;
2174         break;
2175       }
2176     }
2177   }
2178   return Considerable;
2179 }
2180 
2181 bool AArch64TTIImpl::isLegalToVectorizeReduction(
2182     const RecurrenceDescriptor &RdxDesc, ElementCount VF) const {
2183   if (!VF.isScalable())
2184     return true;
2185 
2186   Type *Ty = RdxDesc.getRecurrenceType();
2187   if (Ty->isBFloatTy() || !isElementTypeLegalForScalableVector(Ty))
2188     return false;
2189 
2190   switch (RdxDesc.getRecurrenceKind()) {
2191   case RecurKind::Add:
2192   case RecurKind::FAdd:
2193   case RecurKind::And:
2194   case RecurKind::Or:
2195   case RecurKind::Xor:
2196   case RecurKind::SMin:
2197   case RecurKind::SMax:
2198   case RecurKind::UMin:
2199   case RecurKind::UMax:
2200   case RecurKind::FMin:
2201   case RecurKind::FMax:
2202   case RecurKind::SelectICmp:
2203   case RecurKind::SelectFCmp:
2204   case RecurKind::FMulAdd:
2205     return true;
2206   default:
2207     return false;
2208   }
2209 }
2210 
2211 InstructionCost
2212 AArch64TTIImpl::getMinMaxReductionCost(VectorType *Ty, VectorType *CondTy,
2213                                        bool IsUnsigned,
2214                                        TTI::TargetCostKind CostKind) {
2215   std::pair<InstructionCost, MVT> LT = TLI->getTypeLegalizationCost(DL, Ty);
2216 
2217   if (LT.second.getScalarType() == MVT::f16 && !ST->hasFullFP16())
2218     return BaseT::getMinMaxReductionCost(Ty, CondTy, IsUnsigned, CostKind);
2219 
2220   assert((isa<ScalableVectorType>(Ty) == isa<ScalableVectorType>(CondTy)) &&
2221          "Both vector needs to be equally scalable");
2222 
2223   InstructionCost LegalizationCost = 0;
2224   if (LT.first > 1) {
2225     Type *LegalVTy = EVT(LT.second).getTypeForEVT(Ty->getContext());
2226     unsigned MinMaxOpcode =
2227         Ty->isFPOrFPVectorTy()
2228             ? Intrinsic::maxnum
2229             : (IsUnsigned ? Intrinsic::umin : Intrinsic::smin);
2230     IntrinsicCostAttributes Attrs(MinMaxOpcode, LegalVTy, {LegalVTy, LegalVTy});
2231     LegalizationCost = getIntrinsicInstrCost(Attrs, CostKind) * (LT.first - 1);
2232   }
2233 
2234   return LegalizationCost + /*Cost of horizontal reduction*/ 2;
2235 }
2236 
2237 InstructionCost AArch64TTIImpl::getArithmeticReductionCostSVE(
2238     unsigned Opcode, VectorType *ValTy, TTI::TargetCostKind CostKind) {
2239   std::pair<InstructionCost, MVT> LT = TLI->getTypeLegalizationCost(DL, ValTy);
2240   InstructionCost LegalizationCost = 0;
2241   if (LT.first > 1) {
2242     Type *LegalVTy = EVT(LT.second).getTypeForEVT(ValTy->getContext());
2243     LegalizationCost = getArithmeticInstrCost(Opcode, LegalVTy, CostKind);
2244     LegalizationCost *= LT.first - 1;
2245   }
2246 
2247   int ISD = TLI->InstructionOpcodeToISD(Opcode);
2248   assert(ISD && "Invalid opcode");
2249   // Add the final reduction cost for the legal horizontal reduction
2250   switch (ISD) {
2251   case ISD::ADD:
2252   case ISD::AND:
2253   case ISD::OR:
2254   case ISD::XOR:
2255   case ISD::FADD:
2256     return LegalizationCost + 2;
2257   default:
2258     return InstructionCost::getInvalid();
2259   }
2260 }
2261 
2262 InstructionCost
2263 AArch64TTIImpl::getArithmeticReductionCost(unsigned Opcode, VectorType *ValTy,
2264                                            Optional<FastMathFlags> FMF,
2265                                            TTI::TargetCostKind CostKind) {
2266   if (TTI::requiresOrderedReduction(FMF)) {
2267     if (auto *FixedVTy = dyn_cast<FixedVectorType>(ValTy)) {
2268       InstructionCost BaseCost =
2269           BaseT::getArithmeticReductionCost(Opcode, ValTy, FMF, CostKind);
2270       // Add on extra cost to reflect the extra overhead on some CPUs. We still
2271       // end up vectorizing for more computationally intensive loops.
2272       return BaseCost + FixedVTy->getNumElements();
2273     }
2274 
2275     if (Opcode != Instruction::FAdd)
2276       return InstructionCost::getInvalid();
2277 
2278     auto *VTy = cast<ScalableVectorType>(ValTy);
2279     InstructionCost Cost =
2280         getArithmeticInstrCost(Opcode, VTy->getScalarType(), CostKind);
2281     Cost *= getMaxNumElements(VTy->getElementCount());
2282     return Cost;
2283   }
2284 
2285   if (isa<ScalableVectorType>(ValTy))
2286     return getArithmeticReductionCostSVE(Opcode, ValTy, CostKind);
2287 
2288   std::pair<InstructionCost, MVT> LT = TLI->getTypeLegalizationCost(DL, ValTy);
2289   MVT MTy = LT.second;
2290   int ISD = TLI->InstructionOpcodeToISD(Opcode);
2291   assert(ISD && "Invalid opcode");
2292 
2293   // Horizontal adds can use the 'addv' instruction. We model the cost of these
2294   // instructions as twice a normal vector add, plus 1 for each legalization
2295   // step (LT.first). This is the only arithmetic vector reduction operation for
2296   // which we have an instruction.
2297   // OR, XOR and AND costs should match the codegen from:
2298   // OR: llvm/test/CodeGen/AArch64/reduce-or.ll
2299   // XOR: llvm/test/CodeGen/AArch64/reduce-xor.ll
2300   // AND: llvm/test/CodeGen/AArch64/reduce-and.ll
2301   static const CostTblEntry CostTblNoPairwise[]{
2302       {ISD::ADD, MVT::v8i8,   2},
2303       {ISD::ADD, MVT::v16i8,  2},
2304       {ISD::ADD, MVT::v4i16,  2},
2305       {ISD::ADD, MVT::v8i16,  2},
2306       {ISD::ADD, MVT::v4i32,  2},
2307       {ISD::OR,  MVT::v8i8,  15},
2308       {ISD::OR,  MVT::v16i8, 17},
2309       {ISD::OR,  MVT::v4i16,  7},
2310       {ISD::OR,  MVT::v8i16,  9},
2311       {ISD::OR,  MVT::v2i32,  3},
2312       {ISD::OR,  MVT::v4i32,  5},
2313       {ISD::OR,  MVT::v2i64,  3},
2314       {ISD::XOR, MVT::v8i8,  15},
2315       {ISD::XOR, MVT::v16i8, 17},
2316       {ISD::XOR, MVT::v4i16,  7},
2317       {ISD::XOR, MVT::v8i16,  9},
2318       {ISD::XOR, MVT::v2i32,  3},
2319       {ISD::XOR, MVT::v4i32,  5},
2320       {ISD::XOR, MVT::v2i64,  3},
2321       {ISD::AND, MVT::v8i8,  15},
2322       {ISD::AND, MVT::v16i8, 17},
2323       {ISD::AND, MVT::v4i16,  7},
2324       {ISD::AND, MVT::v8i16,  9},
2325       {ISD::AND, MVT::v2i32,  3},
2326       {ISD::AND, MVT::v4i32,  5},
2327       {ISD::AND, MVT::v2i64,  3},
2328   };
2329   switch (ISD) {
2330   default:
2331     break;
2332   case ISD::ADD:
2333     if (const auto *Entry = CostTableLookup(CostTblNoPairwise, ISD, MTy))
2334       return (LT.first - 1) + Entry->Cost;
2335     break;
2336   case ISD::XOR:
2337   case ISD::AND:
2338   case ISD::OR:
2339     const auto *Entry = CostTableLookup(CostTblNoPairwise, ISD, MTy);
2340     if (!Entry)
2341       break;
2342     auto *ValVTy = cast<FixedVectorType>(ValTy);
2343     if (!ValVTy->getElementType()->isIntegerTy(1) &&
2344         MTy.getVectorNumElements() <= ValVTy->getNumElements() &&
2345         isPowerOf2_32(ValVTy->getNumElements())) {
2346       InstructionCost ExtraCost = 0;
2347       if (LT.first != 1) {
2348         // Type needs to be split, so there is an extra cost of LT.first - 1
2349         // arithmetic ops.
2350         auto *Ty = FixedVectorType::get(ValTy->getElementType(),
2351                                         MTy.getVectorNumElements());
2352         ExtraCost = getArithmeticInstrCost(Opcode, Ty, CostKind);
2353         ExtraCost *= LT.first - 1;
2354       }
2355       return Entry->Cost + ExtraCost;
2356     }
2357     break;
2358   }
2359   return BaseT::getArithmeticReductionCost(Opcode, ValTy, FMF, CostKind);
2360 }
2361 
2362 InstructionCost AArch64TTIImpl::getSpliceCost(VectorType *Tp, int Index) {
2363   static const CostTblEntry ShuffleTbl[] = {
2364       { TTI::SK_Splice, MVT::nxv16i8,  1 },
2365       { TTI::SK_Splice, MVT::nxv8i16,  1 },
2366       { TTI::SK_Splice, MVT::nxv4i32,  1 },
2367       { TTI::SK_Splice, MVT::nxv2i64,  1 },
2368       { TTI::SK_Splice, MVT::nxv2f16,  1 },
2369       { TTI::SK_Splice, MVT::nxv4f16,  1 },
2370       { TTI::SK_Splice, MVT::nxv8f16,  1 },
2371       { TTI::SK_Splice, MVT::nxv2bf16, 1 },
2372       { TTI::SK_Splice, MVT::nxv4bf16, 1 },
2373       { TTI::SK_Splice, MVT::nxv8bf16, 1 },
2374       { TTI::SK_Splice, MVT::nxv2f32,  1 },
2375       { TTI::SK_Splice, MVT::nxv4f32,  1 },
2376       { TTI::SK_Splice, MVT::nxv2f64,  1 },
2377   };
2378 
2379   std::pair<InstructionCost, MVT> LT = TLI->getTypeLegalizationCost(DL, Tp);
2380   Type *LegalVTy = EVT(LT.second).getTypeForEVT(Tp->getContext());
2381   TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
2382   EVT PromotedVT = LT.second.getScalarType() == MVT::i1
2383                        ? TLI->getPromotedVTForPredicate(EVT(LT.second))
2384                        : LT.second;
2385   Type *PromotedVTy = EVT(PromotedVT).getTypeForEVT(Tp->getContext());
2386   InstructionCost LegalizationCost = 0;
2387   if (Index < 0) {
2388     LegalizationCost =
2389         getCmpSelInstrCost(Instruction::ICmp, PromotedVTy, PromotedVTy,
2390                            CmpInst::BAD_ICMP_PREDICATE, CostKind) +
2391         getCmpSelInstrCost(Instruction::Select, PromotedVTy, LegalVTy,
2392                            CmpInst::BAD_ICMP_PREDICATE, CostKind);
2393   }
2394 
2395   // Predicated splice are promoted when lowering. See AArch64ISelLowering.cpp
2396   // Cost performed on a promoted type.
2397   if (LT.second.getScalarType() == MVT::i1) {
2398     LegalizationCost +=
2399         getCastInstrCost(Instruction::ZExt, PromotedVTy, LegalVTy,
2400                          TTI::CastContextHint::None, CostKind) +
2401         getCastInstrCost(Instruction::Trunc, LegalVTy, PromotedVTy,
2402                          TTI::CastContextHint::None, CostKind);
2403   }
2404   const auto *Entry =
2405       CostTableLookup(ShuffleTbl, TTI::SK_Splice, PromotedVT.getSimpleVT());
2406   assert(Entry && "Illegal Type for Splice");
2407   LegalizationCost += Entry->Cost;
2408   return LegalizationCost * LT.first;
2409 }
2410 
2411 InstructionCost AArch64TTIImpl::getShuffleCost(TTI::ShuffleKind Kind,
2412                                                VectorType *Tp,
2413                                                ArrayRef<int> Mask, int Index,
2414                                                VectorType *SubTp) {
2415   Kind = improveShuffleKindFromMask(Kind, Mask);
2416   if (Kind == TTI::SK_Broadcast || Kind == TTI::SK_Transpose ||
2417       Kind == TTI::SK_Select || Kind == TTI::SK_PermuteSingleSrc ||
2418       Kind == TTI::SK_Reverse) {
2419     static const CostTblEntry ShuffleTbl[] = {
2420       // Broadcast shuffle kinds can be performed with 'dup'.
2421       { TTI::SK_Broadcast, MVT::v8i8,  1 },
2422       { TTI::SK_Broadcast, MVT::v16i8, 1 },
2423       { TTI::SK_Broadcast, MVT::v4i16, 1 },
2424       { TTI::SK_Broadcast, MVT::v8i16, 1 },
2425       { TTI::SK_Broadcast, MVT::v2i32, 1 },
2426       { TTI::SK_Broadcast, MVT::v4i32, 1 },
2427       { TTI::SK_Broadcast, MVT::v2i64, 1 },
2428       { TTI::SK_Broadcast, MVT::v2f32, 1 },
2429       { TTI::SK_Broadcast, MVT::v4f32, 1 },
2430       { TTI::SK_Broadcast, MVT::v2f64, 1 },
2431       // Transpose shuffle kinds can be performed with 'trn1/trn2' and
2432       // 'zip1/zip2' instructions.
2433       { TTI::SK_Transpose, MVT::v8i8,  1 },
2434       { TTI::SK_Transpose, MVT::v16i8, 1 },
2435       { TTI::SK_Transpose, MVT::v4i16, 1 },
2436       { TTI::SK_Transpose, MVT::v8i16, 1 },
2437       { TTI::SK_Transpose, MVT::v2i32, 1 },
2438       { TTI::SK_Transpose, MVT::v4i32, 1 },
2439       { TTI::SK_Transpose, MVT::v2i64, 1 },
2440       { TTI::SK_Transpose, MVT::v2f32, 1 },
2441       { TTI::SK_Transpose, MVT::v4f32, 1 },
2442       { TTI::SK_Transpose, MVT::v2f64, 1 },
2443       // Select shuffle kinds.
2444       // TODO: handle vXi8/vXi16.
2445       { TTI::SK_Select, MVT::v2i32, 1 }, // mov.
2446       { TTI::SK_Select, MVT::v4i32, 2 }, // rev+trn (or similar).
2447       { TTI::SK_Select, MVT::v2i64, 1 }, // mov.
2448       { TTI::SK_Select, MVT::v2f32, 1 }, // mov.
2449       { TTI::SK_Select, MVT::v4f32, 2 }, // rev+trn (or similar).
2450       { TTI::SK_Select, MVT::v2f64, 1 }, // mov.
2451       // PermuteSingleSrc shuffle kinds.
2452       { TTI::SK_PermuteSingleSrc, MVT::v2i32, 1 }, // mov.
2453       { TTI::SK_PermuteSingleSrc, MVT::v4i32, 3 }, // perfectshuffle worst case.
2454       { TTI::SK_PermuteSingleSrc, MVT::v2i64, 1 }, // mov.
2455       { TTI::SK_PermuteSingleSrc, MVT::v2f32, 1 }, // mov.
2456       { TTI::SK_PermuteSingleSrc, MVT::v4f32, 3 }, // perfectshuffle worst case.
2457       { TTI::SK_PermuteSingleSrc, MVT::v2f64, 1 }, // mov.
2458       { TTI::SK_PermuteSingleSrc, MVT::v4i16, 3 }, // perfectshuffle worst case.
2459       { TTI::SK_PermuteSingleSrc, MVT::v4f16, 3 }, // perfectshuffle worst case.
2460       { TTI::SK_PermuteSingleSrc, MVT::v4bf16, 3 }, // perfectshuffle worst case.
2461       { TTI::SK_PermuteSingleSrc, MVT::v8i16, 8 }, // constpool + load + tbl
2462       { TTI::SK_PermuteSingleSrc, MVT::v8f16, 8 }, // constpool + load + tbl
2463       { TTI::SK_PermuteSingleSrc, MVT::v8bf16, 8 }, // constpool + load + tbl
2464       { TTI::SK_PermuteSingleSrc, MVT::v8i8, 8 }, // constpool + load + tbl
2465       { TTI::SK_PermuteSingleSrc, MVT::v16i8, 8 }, // constpool + load + tbl
2466       // Reverse can be lowered with `rev`.
2467       { TTI::SK_Reverse, MVT::v2i32, 1 }, // mov.
2468       { TTI::SK_Reverse, MVT::v4i32, 2 }, // REV64; EXT
2469       { TTI::SK_Reverse, MVT::v2i64, 1 }, // mov.
2470       { TTI::SK_Reverse, MVT::v2f32, 1 }, // mov.
2471       { TTI::SK_Reverse, MVT::v4f32, 2 }, // REV64; EXT
2472       { TTI::SK_Reverse, MVT::v2f64, 1 }, // mov.
2473       // Broadcast shuffle kinds for scalable vectors
2474       { TTI::SK_Broadcast, MVT::nxv16i8,  1 },
2475       { TTI::SK_Broadcast, MVT::nxv8i16,  1 },
2476       { TTI::SK_Broadcast, MVT::nxv4i32,  1 },
2477       { TTI::SK_Broadcast, MVT::nxv2i64,  1 },
2478       { TTI::SK_Broadcast, MVT::nxv2f16,  1 },
2479       { TTI::SK_Broadcast, MVT::nxv4f16,  1 },
2480       { TTI::SK_Broadcast, MVT::nxv8f16,  1 },
2481       { TTI::SK_Broadcast, MVT::nxv2bf16, 1 },
2482       { TTI::SK_Broadcast, MVT::nxv4bf16, 1 },
2483       { TTI::SK_Broadcast, MVT::nxv8bf16, 1 },
2484       { TTI::SK_Broadcast, MVT::nxv2f32,  1 },
2485       { TTI::SK_Broadcast, MVT::nxv4f32,  1 },
2486       { TTI::SK_Broadcast, MVT::nxv2f64,  1 },
2487       { TTI::SK_Broadcast, MVT::nxv16i1,  1 },
2488       { TTI::SK_Broadcast, MVT::nxv8i1,   1 },
2489       { TTI::SK_Broadcast, MVT::nxv4i1,   1 },
2490       { TTI::SK_Broadcast, MVT::nxv2i1,   1 },
2491       // Handle the cases for vector.reverse with scalable vectors
2492       { TTI::SK_Reverse, MVT::nxv16i8,  1 },
2493       { TTI::SK_Reverse, MVT::nxv8i16,  1 },
2494       { TTI::SK_Reverse, MVT::nxv4i32,  1 },
2495       { TTI::SK_Reverse, MVT::nxv2i64,  1 },
2496       { TTI::SK_Reverse, MVT::nxv2f16,  1 },
2497       { TTI::SK_Reverse, MVT::nxv4f16,  1 },
2498       { TTI::SK_Reverse, MVT::nxv8f16,  1 },
2499       { TTI::SK_Reverse, MVT::nxv2bf16, 1 },
2500       { TTI::SK_Reverse, MVT::nxv4bf16, 1 },
2501       { TTI::SK_Reverse, MVT::nxv8bf16, 1 },
2502       { TTI::SK_Reverse, MVT::nxv2f32,  1 },
2503       { TTI::SK_Reverse, MVT::nxv4f32,  1 },
2504       { TTI::SK_Reverse, MVT::nxv2f64,  1 },
2505       { TTI::SK_Reverse, MVT::nxv16i1,  1 },
2506       { TTI::SK_Reverse, MVT::nxv8i1,   1 },
2507       { TTI::SK_Reverse, MVT::nxv4i1,   1 },
2508       { TTI::SK_Reverse, MVT::nxv2i1,   1 },
2509     };
2510     std::pair<InstructionCost, MVT> LT = TLI->getTypeLegalizationCost(DL, Tp);
2511     if (const auto *Entry = CostTableLookup(ShuffleTbl, Kind, LT.second))
2512       return LT.first * Entry->Cost;
2513   }
2514   if (Kind == TTI::SK_Splice && isa<ScalableVectorType>(Tp))
2515     return getSpliceCost(Tp, Index);
2516   return BaseT::getShuffleCost(Kind, Tp, Mask, Index, SubTp);
2517 }
2518