1 //===-- AArch64TargetTransformInfo.cpp - AArch64 specific TTI -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "AArch64TargetTransformInfo.h"
10 #include "AArch64ExpandImm.h"
11 #include "MCTargetDesc/AArch64AddressingModes.h"
12 #include "llvm/Analysis/IVDescriptors.h"
13 #include "llvm/Analysis/LoopInfo.h"
14 #include "llvm/Analysis/TargetTransformInfo.h"
15 #include "llvm/CodeGen/BasicTTIImpl.h"
16 #include "llvm/CodeGen/CostTable.h"
17 #include "llvm/CodeGen/TargetLowering.h"
18 #include "llvm/IR/Intrinsics.h"
19 #include "llvm/IR/IntrinsicInst.h"
20 #include "llvm/IR/IntrinsicsAArch64.h"
21 #include "llvm/IR/PatternMatch.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Transforms/InstCombine/InstCombiner.h"
24 #include <algorithm>
25 using namespace llvm;
26 using namespace llvm::PatternMatch;
27 
28 #define DEBUG_TYPE "aarch64tti"
29 
30 static cl::opt<bool> EnableFalkorHWPFUnrollFix("enable-falkor-hwpf-unroll-fix",
31                                                cl::init(true), cl::Hidden);
32 
33 static cl::opt<unsigned> SVEGatherOverhead("sve-gather-overhead", cl::init(10),
34                                            cl::Hidden);
35 
36 static cl::opt<unsigned> SVEScatterOverhead("sve-scatter-overhead",
37                                             cl::init(10), cl::Hidden);
38 
39 bool AArch64TTIImpl::areInlineCompatible(const Function *Caller,
40                                          const Function *Callee) const {
41   const TargetMachine &TM = getTLI()->getTargetMachine();
42 
43   const FeatureBitset &CallerBits =
44       TM.getSubtargetImpl(*Caller)->getFeatureBits();
45   const FeatureBitset &CalleeBits =
46       TM.getSubtargetImpl(*Callee)->getFeatureBits();
47 
48   // Inline a callee if its target-features are a subset of the callers
49   // target-features.
50   return (CallerBits & CalleeBits) == CalleeBits;
51 }
52 
53 /// Calculate the cost of materializing a 64-bit value. This helper
54 /// method might only calculate a fraction of a larger immediate. Therefore it
55 /// is valid to return a cost of ZERO.
56 InstructionCost AArch64TTIImpl::getIntImmCost(int64_t Val) {
57   // Check if the immediate can be encoded within an instruction.
58   if (Val == 0 || AArch64_AM::isLogicalImmediate(Val, 64))
59     return 0;
60 
61   if (Val < 0)
62     Val = ~Val;
63 
64   // Calculate how many moves we will need to materialize this constant.
65   SmallVector<AArch64_IMM::ImmInsnModel, 4> Insn;
66   AArch64_IMM::expandMOVImm(Val, 64, Insn);
67   return Insn.size();
68 }
69 
70 /// Calculate the cost of materializing the given constant.
71 InstructionCost AArch64TTIImpl::getIntImmCost(const APInt &Imm, Type *Ty,
72                                               TTI::TargetCostKind CostKind) {
73   assert(Ty->isIntegerTy());
74 
75   unsigned BitSize = Ty->getPrimitiveSizeInBits();
76   if (BitSize == 0)
77     return ~0U;
78 
79   // Sign-extend all constants to a multiple of 64-bit.
80   APInt ImmVal = Imm;
81   if (BitSize & 0x3f)
82     ImmVal = Imm.sext((BitSize + 63) & ~0x3fU);
83 
84   // Split the constant into 64-bit chunks and calculate the cost for each
85   // chunk.
86   InstructionCost Cost = 0;
87   for (unsigned ShiftVal = 0; ShiftVal < BitSize; ShiftVal += 64) {
88     APInt Tmp = ImmVal.ashr(ShiftVal).sextOrTrunc(64);
89     int64_t Val = Tmp.getSExtValue();
90     Cost += getIntImmCost(Val);
91   }
92   // We need at least one instruction to materialze the constant.
93   return std::max<InstructionCost>(1, Cost);
94 }
95 
96 InstructionCost AArch64TTIImpl::getIntImmCostInst(unsigned Opcode, unsigned Idx,
97                                                   const APInt &Imm, Type *Ty,
98                                                   TTI::TargetCostKind CostKind,
99                                                   Instruction *Inst) {
100   assert(Ty->isIntegerTy());
101 
102   unsigned BitSize = Ty->getPrimitiveSizeInBits();
103   // There is no cost model for constants with a bit size of 0. Return TCC_Free
104   // here, so that constant hoisting will ignore this constant.
105   if (BitSize == 0)
106     return TTI::TCC_Free;
107 
108   unsigned ImmIdx = ~0U;
109   switch (Opcode) {
110   default:
111     return TTI::TCC_Free;
112   case Instruction::GetElementPtr:
113     // Always hoist the base address of a GetElementPtr.
114     if (Idx == 0)
115       return 2 * TTI::TCC_Basic;
116     return TTI::TCC_Free;
117   case Instruction::Store:
118     ImmIdx = 0;
119     break;
120   case Instruction::Add:
121   case Instruction::Sub:
122   case Instruction::Mul:
123   case Instruction::UDiv:
124   case Instruction::SDiv:
125   case Instruction::URem:
126   case Instruction::SRem:
127   case Instruction::And:
128   case Instruction::Or:
129   case Instruction::Xor:
130   case Instruction::ICmp:
131     ImmIdx = 1;
132     break;
133   // Always return TCC_Free for the shift value of a shift instruction.
134   case Instruction::Shl:
135   case Instruction::LShr:
136   case Instruction::AShr:
137     if (Idx == 1)
138       return TTI::TCC_Free;
139     break;
140   case Instruction::Trunc:
141   case Instruction::ZExt:
142   case Instruction::SExt:
143   case Instruction::IntToPtr:
144   case Instruction::PtrToInt:
145   case Instruction::BitCast:
146   case Instruction::PHI:
147   case Instruction::Call:
148   case Instruction::Select:
149   case Instruction::Ret:
150   case Instruction::Load:
151     break;
152   }
153 
154   if (Idx == ImmIdx) {
155     int NumConstants = (BitSize + 63) / 64;
156     InstructionCost Cost = AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind);
157     return (Cost <= NumConstants * TTI::TCC_Basic)
158                ? static_cast<int>(TTI::TCC_Free)
159                : Cost;
160   }
161   return AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind);
162 }
163 
164 InstructionCost
165 AArch64TTIImpl::getIntImmCostIntrin(Intrinsic::ID IID, unsigned Idx,
166                                     const APInt &Imm, Type *Ty,
167                                     TTI::TargetCostKind CostKind) {
168   assert(Ty->isIntegerTy());
169 
170   unsigned BitSize = Ty->getPrimitiveSizeInBits();
171   // There is no cost model for constants with a bit size of 0. Return TCC_Free
172   // here, so that constant hoisting will ignore this constant.
173   if (BitSize == 0)
174     return TTI::TCC_Free;
175 
176   // Most (all?) AArch64 intrinsics do not support folding immediates into the
177   // selected instruction, so we compute the materialization cost for the
178   // immediate directly.
179   if (IID >= Intrinsic::aarch64_addg && IID <= Intrinsic::aarch64_udiv)
180     return AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind);
181 
182   switch (IID) {
183   default:
184     return TTI::TCC_Free;
185   case Intrinsic::sadd_with_overflow:
186   case Intrinsic::uadd_with_overflow:
187   case Intrinsic::ssub_with_overflow:
188   case Intrinsic::usub_with_overflow:
189   case Intrinsic::smul_with_overflow:
190   case Intrinsic::umul_with_overflow:
191     if (Idx == 1) {
192       int NumConstants = (BitSize + 63) / 64;
193       InstructionCost Cost = AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind);
194       return (Cost <= NumConstants * TTI::TCC_Basic)
195                  ? static_cast<int>(TTI::TCC_Free)
196                  : Cost;
197     }
198     break;
199   case Intrinsic::experimental_stackmap:
200     if ((Idx < 2) || (Imm.getBitWidth() <= 64 && isInt<64>(Imm.getSExtValue())))
201       return TTI::TCC_Free;
202     break;
203   case Intrinsic::experimental_patchpoint_void:
204   case Intrinsic::experimental_patchpoint_i64:
205     if ((Idx < 4) || (Imm.getBitWidth() <= 64 && isInt<64>(Imm.getSExtValue())))
206       return TTI::TCC_Free;
207     break;
208   case Intrinsic::experimental_gc_statepoint:
209     if ((Idx < 5) || (Imm.getBitWidth() <= 64 && isInt<64>(Imm.getSExtValue())))
210       return TTI::TCC_Free;
211     break;
212   }
213   return AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind);
214 }
215 
216 TargetTransformInfo::PopcntSupportKind
217 AArch64TTIImpl::getPopcntSupport(unsigned TyWidth) {
218   assert(isPowerOf2_32(TyWidth) && "Ty width must be power of 2");
219   if (TyWidth == 32 || TyWidth == 64)
220     return TTI::PSK_FastHardware;
221   // TODO: AArch64TargetLowering::LowerCTPOP() supports 128bit popcount.
222   return TTI::PSK_Software;
223 }
224 
225 InstructionCost
226 AArch64TTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
227                                       TTI::TargetCostKind CostKind) {
228   auto *RetTy = ICA.getReturnType();
229   switch (ICA.getID()) {
230   case Intrinsic::umin:
231   case Intrinsic::umax:
232   case Intrinsic::smin:
233   case Intrinsic::smax: {
234     static const auto ValidMinMaxTys = {MVT::v8i8,  MVT::v16i8, MVT::v4i16,
235                                         MVT::v8i16, MVT::v2i32, MVT::v4i32};
236     auto LT = TLI->getTypeLegalizationCost(DL, RetTy);
237     // v2i64 types get converted to cmp+bif hence the cost of 2
238     if (LT.second == MVT::v2i64)
239       return LT.first * 2;
240     if (any_of(ValidMinMaxTys, [&LT](MVT M) { return M == LT.second; }))
241       return LT.first;
242     break;
243   }
244   case Intrinsic::sadd_sat:
245   case Intrinsic::ssub_sat:
246   case Intrinsic::uadd_sat:
247   case Intrinsic::usub_sat: {
248     static const auto ValidSatTys = {MVT::v8i8,  MVT::v16i8, MVT::v4i16,
249                                      MVT::v8i16, MVT::v2i32, MVT::v4i32,
250                                      MVT::v2i64};
251     auto LT = TLI->getTypeLegalizationCost(DL, RetTy);
252     // This is a base cost of 1 for the vadd, plus 3 extract shifts if we
253     // need to extend the type, as it uses shr(qadd(shl, shl)).
254     unsigned Instrs =
255         LT.second.getScalarSizeInBits() == RetTy->getScalarSizeInBits() ? 1 : 4;
256     if (any_of(ValidSatTys, [&LT](MVT M) { return M == LT.second; }))
257       return LT.first * Instrs;
258     break;
259   }
260   case Intrinsic::abs: {
261     static const auto ValidAbsTys = {MVT::v8i8,  MVT::v16i8, MVT::v4i16,
262                                      MVT::v8i16, MVT::v2i32, MVT::v4i32,
263                                      MVT::v2i64};
264     auto LT = TLI->getTypeLegalizationCost(DL, RetTy);
265     if (any_of(ValidAbsTys, [&LT](MVT M) { return M == LT.second; }))
266       return LT.first;
267     break;
268   }
269   case Intrinsic::experimental_stepvector: {
270     InstructionCost Cost = 1; // Cost of the `index' instruction
271     auto LT = TLI->getTypeLegalizationCost(DL, RetTy);
272     // Legalisation of illegal vectors involves an `index' instruction plus
273     // (LT.first - 1) vector adds.
274     if (LT.first > 1) {
275       Type *LegalVTy = EVT(LT.second).getTypeForEVT(RetTy->getContext());
276       InstructionCost AddCost =
277           getArithmeticInstrCost(Instruction::Add, LegalVTy, CostKind);
278       Cost += AddCost * (LT.first - 1);
279     }
280     return Cost;
281   }
282   case Intrinsic::bitreverse: {
283     static const CostTblEntry BitreverseTbl[] = {
284         {Intrinsic::bitreverse, MVT::i32, 1},
285         {Intrinsic::bitreverse, MVT::i64, 1},
286         {Intrinsic::bitreverse, MVT::v8i8, 1},
287         {Intrinsic::bitreverse, MVT::v16i8, 1},
288         {Intrinsic::bitreverse, MVT::v4i16, 2},
289         {Intrinsic::bitreverse, MVT::v8i16, 2},
290         {Intrinsic::bitreverse, MVT::v2i32, 2},
291         {Intrinsic::bitreverse, MVT::v4i32, 2},
292         {Intrinsic::bitreverse, MVT::v1i64, 2},
293         {Intrinsic::bitreverse, MVT::v2i64, 2},
294     };
295     const auto LegalisationCost = TLI->getTypeLegalizationCost(DL, RetTy);
296     const auto *Entry =
297         CostTableLookup(BitreverseTbl, ICA.getID(), LegalisationCost.second);
298     if (Entry) {
299       // Cost Model is using the legal type(i32) that i8 and i16 will be
300       // converted to +1 so that we match the actual lowering cost
301       if (TLI->getValueType(DL, RetTy, true) == MVT::i8 ||
302           TLI->getValueType(DL, RetTy, true) == MVT::i16)
303         return LegalisationCost.first * Entry->Cost + 1;
304 
305       return LegalisationCost.first * Entry->Cost;
306     }
307     break;
308   }
309   case Intrinsic::ctpop: {
310     static const CostTblEntry CtpopCostTbl[] = {
311         {ISD::CTPOP, MVT::v2i64, 4},
312         {ISD::CTPOP, MVT::v4i32, 3},
313         {ISD::CTPOP, MVT::v8i16, 2},
314         {ISD::CTPOP, MVT::v16i8, 1},
315         {ISD::CTPOP, MVT::i64,   4},
316         {ISD::CTPOP, MVT::v2i32, 3},
317         {ISD::CTPOP, MVT::v4i16, 2},
318         {ISD::CTPOP, MVT::v8i8,  1},
319         {ISD::CTPOP, MVT::i32,   5},
320     };
321     auto LT = TLI->getTypeLegalizationCost(DL, RetTy);
322     MVT MTy = LT.second;
323     if (const auto *Entry = CostTableLookup(CtpopCostTbl, ISD::CTPOP, MTy)) {
324       // Extra cost of +1 when illegal vector types are legalized by promoting
325       // the integer type.
326       int ExtraCost = MTy.isVector() && MTy.getScalarSizeInBits() !=
327                                             RetTy->getScalarSizeInBits()
328                           ? 1
329                           : 0;
330       return LT.first * Entry->Cost + ExtraCost;
331     }
332     break;
333   }
334   case Intrinsic::smul_with_overflow:
335   case Intrinsic::umul_with_overflow: {
336     static const CostTblEntry WithOverflowCostTbl[] = {
337         {Intrinsic::smul_with_overflow, MVT::i8, 5},
338         {Intrinsic::umul_with_overflow, MVT::i8, 4},
339         {Intrinsic::smul_with_overflow, MVT::i16, 5},
340         {Intrinsic::umul_with_overflow, MVT::i16, 4},
341         {Intrinsic::smul_with_overflow, MVT::i32, 2}, // eg umull;tst
342         {Intrinsic::umul_with_overflow, MVT::i32, 2}, // eg umull;cmp sxtw
343         {Intrinsic::smul_with_overflow, MVT::i64, 3}, // eg mul;smulh;cmp
344         {Intrinsic::umul_with_overflow, MVT::i64, 3}, // eg mul;umulh;cmp asr
345     };
346     EVT MTy = TLI->getValueType(DL, RetTy->getContainedType(0), true);
347     if (MTy.isSimple())
348       if (const auto *Entry = CostTableLookup(WithOverflowCostTbl, ICA.getID(),
349                                               MTy.getSimpleVT()))
350         return Entry->Cost;
351     break;
352   }
353   default:
354     break;
355   }
356   return BaseT::getIntrinsicInstrCost(ICA, CostKind);
357 }
358 
359 /// The function will remove redundant reinterprets casting in the presence
360 /// of the control flow
361 static Optional<Instruction *> processPhiNode(InstCombiner &IC,
362                                               IntrinsicInst &II) {
363   SmallVector<Instruction *, 32> Worklist;
364   auto RequiredType = II.getType();
365 
366   auto *PN = dyn_cast<PHINode>(II.getArgOperand(0));
367   assert(PN && "Expected Phi Node!");
368 
369   // Don't create a new Phi unless we can remove the old one.
370   if (!PN->hasOneUse())
371     return None;
372 
373   for (Value *IncValPhi : PN->incoming_values()) {
374     auto *Reinterpret = dyn_cast<IntrinsicInst>(IncValPhi);
375     if (!Reinterpret ||
376         Reinterpret->getIntrinsicID() !=
377             Intrinsic::aarch64_sve_convert_to_svbool ||
378         RequiredType != Reinterpret->getArgOperand(0)->getType())
379       return None;
380   }
381 
382   // Create the new Phi
383   LLVMContext &Ctx = PN->getContext();
384   IRBuilder<> Builder(Ctx);
385   Builder.SetInsertPoint(PN);
386   PHINode *NPN = Builder.CreatePHI(RequiredType, PN->getNumIncomingValues());
387   Worklist.push_back(PN);
388 
389   for (unsigned I = 0; I < PN->getNumIncomingValues(); I++) {
390     auto *Reinterpret = cast<Instruction>(PN->getIncomingValue(I));
391     NPN->addIncoming(Reinterpret->getOperand(0), PN->getIncomingBlock(I));
392     Worklist.push_back(Reinterpret);
393   }
394 
395   // Cleanup Phi Node and reinterprets
396   return IC.replaceInstUsesWith(II, NPN);
397 }
398 
399 static Optional<Instruction *> instCombineConvertFromSVBool(InstCombiner &IC,
400                                                             IntrinsicInst &II) {
401   // If the reinterpret instruction operand is a PHI Node
402   if (isa<PHINode>(II.getArgOperand(0)))
403     return processPhiNode(IC, II);
404 
405   SmallVector<Instruction *, 32> CandidatesForRemoval;
406   Value *Cursor = II.getOperand(0), *EarliestReplacement = nullptr;
407 
408   const auto *IVTy = cast<VectorType>(II.getType());
409 
410   // Walk the chain of conversions.
411   while (Cursor) {
412     // If the type of the cursor has fewer lanes than the final result, zeroing
413     // must take place, which breaks the equivalence chain.
414     const auto *CursorVTy = cast<VectorType>(Cursor->getType());
415     if (CursorVTy->getElementCount().getKnownMinValue() <
416         IVTy->getElementCount().getKnownMinValue())
417       break;
418 
419     // If the cursor has the same type as I, it is a viable replacement.
420     if (Cursor->getType() == IVTy)
421       EarliestReplacement = Cursor;
422 
423     auto *IntrinsicCursor = dyn_cast<IntrinsicInst>(Cursor);
424 
425     // If this is not an SVE conversion intrinsic, this is the end of the chain.
426     if (!IntrinsicCursor || !(IntrinsicCursor->getIntrinsicID() ==
427                                   Intrinsic::aarch64_sve_convert_to_svbool ||
428                               IntrinsicCursor->getIntrinsicID() ==
429                                   Intrinsic::aarch64_sve_convert_from_svbool))
430       break;
431 
432     CandidatesForRemoval.insert(CandidatesForRemoval.begin(), IntrinsicCursor);
433     Cursor = IntrinsicCursor->getOperand(0);
434   }
435 
436   // If no viable replacement in the conversion chain was found, there is
437   // nothing to do.
438   if (!EarliestReplacement)
439     return None;
440 
441   return IC.replaceInstUsesWith(II, EarliestReplacement);
442 }
443 
444 static Optional<Instruction *> instCombineSVEDup(InstCombiner &IC,
445                                                  IntrinsicInst &II) {
446   IntrinsicInst *Pg = dyn_cast<IntrinsicInst>(II.getArgOperand(1));
447   if (!Pg)
448     return None;
449 
450   if (Pg->getIntrinsicID() != Intrinsic::aarch64_sve_ptrue)
451     return None;
452 
453   const auto PTruePattern =
454       cast<ConstantInt>(Pg->getOperand(0))->getZExtValue();
455   if (PTruePattern != AArch64SVEPredPattern::vl1)
456     return None;
457 
458   // The intrinsic is inserting into lane zero so use an insert instead.
459   auto *IdxTy = Type::getInt64Ty(II.getContext());
460   auto *Insert = InsertElementInst::Create(
461       II.getArgOperand(0), II.getArgOperand(2), ConstantInt::get(IdxTy, 0));
462   Insert->insertBefore(&II);
463   Insert->takeName(&II);
464 
465   return IC.replaceInstUsesWith(II, Insert);
466 }
467 
468 static Optional<Instruction *> instCombineSVEDupX(InstCombiner &IC,
469                                                   IntrinsicInst &II) {
470   // Replace DupX with a regular IR splat.
471   IRBuilder<> Builder(II.getContext());
472   Builder.SetInsertPoint(&II);
473   auto *RetTy = cast<ScalableVectorType>(II.getType());
474   Value *Splat =
475       Builder.CreateVectorSplat(RetTy->getElementCount(), II.getArgOperand(0));
476   Splat->takeName(&II);
477   return IC.replaceInstUsesWith(II, Splat);
478 }
479 
480 static Optional<Instruction *> instCombineSVECmpNE(InstCombiner &IC,
481                                                    IntrinsicInst &II) {
482   LLVMContext &Ctx = II.getContext();
483   IRBuilder<> Builder(Ctx);
484   Builder.SetInsertPoint(&II);
485 
486   // Check that the predicate is all active
487   auto *Pg = dyn_cast<IntrinsicInst>(II.getArgOperand(0));
488   if (!Pg || Pg->getIntrinsicID() != Intrinsic::aarch64_sve_ptrue)
489     return None;
490 
491   const auto PTruePattern =
492       cast<ConstantInt>(Pg->getOperand(0))->getZExtValue();
493   if (PTruePattern != AArch64SVEPredPattern::all)
494     return None;
495 
496   // Check that we have a compare of zero..
497   auto *SplatValue =
498       dyn_cast_or_null<ConstantInt>(getSplatValue(II.getArgOperand(2)));
499   if (!SplatValue || !SplatValue->isZero())
500     return None;
501 
502   // ..against a dupq
503   auto *DupQLane = dyn_cast<IntrinsicInst>(II.getArgOperand(1));
504   if (!DupQLane ||
505       DupQLane->getIntrinsicID() != Intrinsic::aarch64_sve_dupq_lane)
506     return None;
507 
508   // Where the dupq is a lane 0 replicate of a vector insert
509   if (!cast<ConstantInt>(DupQLane->getArgOperand(1))->isZero())
510     return None;
511 
512   auto *VecIns = dyn_cast<IntrinsicInst>(DupQLane->getArgOperand(0));
513   if (!VecIns ||
514       VecIns->getIntrinsicID() != Intrinsic::experimental_vector_insert)
515     return None;
516 
517   // Where the vector insert is a fixed constant vector insert into undef at
518   // index zero
519   if (!isa<UndefValue>(VecIns->getArgOperand(0)))
520     return None;
521 
522   if (!cast<ConstantInt>(VecIns->getArgOperand(2))->isZero())
523     return None;
524 
525   auto *ConstVec = dyn_cast<Constant>(VecIns->getArgOperand(1));
526   if (!ConstVec)
527     return None;
528 
529   auto *VecTy = dyn_cast<FixedVectorType>(ConstVec->getType());
530   auto *OutTy = dyn_cast<ScalableVectorType>(II.getType());
531   if (!VecTy || !OutTy || VecTy->getNumElements() != OutTy->getMinNumElements())
532     return None;
533 
534   unsigned NumElts = VecTy->getNumElements();
535   unsigned PredicateBits = 0;
536 
537   // Expand intrinsic operands to a 16-bit byte level predicate
538   for (unsigned I = 0; I < NumElts; ++I) {
539     auto *Arg = dyn_cast<ConstantInt>(ConstVec->getAggregateElement(I));
540     if (!Arg)
541       return None;
542     if (!Arg->isZero())
543       PredicateBits |= 1 << (I * (16 / NumElts));
544   }
545 
546   // If all bits are zero bail early with an empty predicate
547   if (PredicateBits == 0) {
548     auto *PFalse = Constant::getNullValue(II.getType());
549     PFalse->takeName(&II);
550     return IC.replaceInstUsesWith(II, PFalse);
551   }
552 
553   // Calculate largest predicate type used (where byte predicate is largest)
554   unsigned Mask = 8;
555   for (unsigned I = 0; I < 16; ++I)
556     if ((PredicateBits & (1 << I)) != 0)
557       Mask |= (I % 8);
558 
559   unsigned PredSize = Mask & -Mask;
560   auto *PredType = ScalableVectorType::get(
561       Type::getInt1Ty(Ctx), AArch64::SVEBitsPerBlock / (PredSize * 8));
562 
563   // Ensure all relevant bits are set
564   for (unsigned I = 0; I < 16; I += PredSize)
565     if ((PredicateBits & (1 << I)) == 0)
566       return None;
567 
568   auto *PTruePat =
569       ConstantInt::get(Type::getInt32Ty(Ctx), AArch64SVEPredPattern::all);
570   auto *PTrue = Builder.CreateIntrinsic(Intrinsic::aarch64_sve_ptrue,
571                                         {PredType}, {PTruePat});
572   auto *ConvertToSVBool = Builder.CreateIntrinsic(
573       Intrinsic::aarch64_sve_convert_to_svbool, {PredType}, {PTrue});
574   auto *ConvertFromSVBool =
575       Builder.CreateIntrinsic(Intrinsic::aarch64_sve_convert_from_svbool,
576                               {II.getType()}, {ConvertToSVBool});
577 
578   ConvertFromSVBool->takeName(&II);
579   return IC.replaceInstUsesWith(II, ConvertFromSVBool);
580 }
581 
582 static Optional<Instruction *> instCombineSVELast(InstCombiner &IC,
583                                                   IntrinsicInst &II) {
584   IRBuilder<> Builder(II.getContext());
585   Builder.SetInsertPoint(&II);
586   Value *Pg = II.getArgOperand(0);
587   Value *Vec = II.getArgOperand(1);
588   auto IntrinsicID = II.getIntrinsicID();
589   bool IsAfter = IntrinsicID == Intrinsic::aarch64_sve_lasta;
590 
591   // lastX(splat(X)) --> X
592   if (auto *SplatVal = getSplatValue(Vec))
593     return IC.replaceInstUsesWith(II, SplatVal);
594 
595   // If x and/or y is a splat value then:
596   // lastX (binop (x, y)) --> binop(lastX(x), lastX(y))
597   Value *LHS, *RHS;
598   if (match(Vec, m_OneUse(m_BinOp(m_Value(LHS), m_Value(RHS))))) {
599     if (isSplatValue(LHS) || isSplatValue(RHS)) {
600       auto *OldBinOp = cast<BinaryOperator>(Vec);
601       auto OpC = OldBinOp->getOpcode();
602       auto *NewLHS =
603           Builder.CreateIntrinsic(IntrinsicID, {Vec->getType()}, {Pg, LHS});
604       auto *NewRHS =
605           Builder.CreateIntrinsic(IntrinsicID, {Vec->getType()}, {Pg, RHS});
606       auto *NewBinOp = BinaryOperator::CreateWithCopiedFlags(
607           OpC, NewLHS, NewRHS, OldBinOp, OldBinOp->getName(), &II);
608       return IC.replaceInstUsesWith(II, NewBinOp);
609     }
610   }
611 
612   auto *C = dyn_cast<Constant>(Pg);
613   if (IsAfter && C && C->isNullValue()) {
614     // The intrinsic is extracting lane 0 so use an extract instead.
615     auto *IdxTy = Type::getInt64Ty(II.getContext());
616     auto *Extract = ExtractElementInst::Create(Vec, ConstantInt::get(IdxTy, 0));
617     Extract->insertBefore(&II);
618     Extract->takeName(&II);
619     return IC.replaceInstUsesWith(II, Extract);
620   }
621 
622   auto *IntrPG = dyn_cast<IntrinsicInst>(Pg);
623   if (!IntrPG)
624     return None;
625 
626   if (IntrPG->getIntrinsicID() != Intrinsic::aarch64_sve_ptrue)
627     return None;
628 
629   const auto PTruePattern =
630       cast<ConstantInt>(IntrPG->getOperand(0))->getZExtValue();
631 
632   // Can the intrinsic's predicate be converted to a known constant index?
633   unsigned MinNumElts = getNumElementsFromSVEPredPattern(PTruePattern);
634   if (!MinNumElts)
635     return None;
636 
637   unsigned Idx = MinNumElts - 1;
638   // Increment the index if extracting the element after the last active
639   // predicate element.
640   if (IsAfter)
641     ++Idx;
642 
643   // Ignore extracts whose index is larger than the known minimum vector
644   // length. NOTE: This is an artificial constraint where we prefer to
645   // maintain what the user asked for until an alternative is proven faster.
646   auto *PgVTy = cast<ScalableVectorType>(Pg->getType());
647   if (Idx >= PgVTy->getMinNumElements())
648     return None;
649 
650   // The intrinsic is extracting a fixed lane so use an extract instead.
651   auto *IdxTy = Type::getInt64Ty(II.getContext());
652   auto *Extract = ExtractElementInst::Create(Vec, ConstantInt::get(IdxTy, Idx));
653   Extract->insertBefore(&II);
654   Extract->takeName(&II);
655   return IC.replaceInstUsesWith(II, Extract);
656 }
657 
658 static Optional<Instruction *> instCombineRDFFR(InstCombiner &IC,
659                                                 IntrinsicInst &II) {
660   LLVMContext &Ctx = II.getContext();
661   IRBuilder<> Builder(Ctx);
662   Builder.SetInsertPoint(&II);
663   // Replace rdffr with predicated rdffr.z intrinsic, so that optimizePTestInstr
664   // can work with RDFFR_PP for ptest elimination.
665   auto *AllPat =
666       ConstantInt::get(Type::getInt32Ty(Ctx), AArch64SVEPredPattern::all);
667   auto *PTrue = Builder.CreateIntrinsic(Intrinsic::aarch64_sve_ptrue,
668                                         {II.getType()}, {AllPat});
669   auto *RDFFR =
670       Builder.CreateIntrinsic(Intrinsic::aarch64_sve_rdffr_z, {}, {PTrue});
671   RDFFR->takeName(&II);
672   return IC.replaceInstUsesWith(II, RDFFR);
673 }
674 
675 static Optional<Instruction *>
676 instCombineSVECntElts(InstCombiner &IC, IntrinsicInst &II, unsigned NumElts) {
677   const auto Pattern = cast<ConstantInt>(II.getArgOperand(0))->getZExtValue();
678 
679   if (Pattern == AArch64SVEPredPattern::all) {
680     LLVMContext &Ctx = II.getContext();
681     IRBuilder<> Builder(Ctx);
682     Builder.SetInsertPoint(&II);
683 
684     Constant *StepVal = ConstantInt::get(II.getType(), NumElts);
685     auto *VScale = Builder.CreateVScale(StepVal);
686     VScale->takeName(&II);
687     return IC.replaceInstUsesWith(II, VScale);
688   }
689 
690   unsigned MinNumElts = getNumElementsFromSVEPredPattern(Pattern);
691 
692   return MinNumElts && NumElts >= MinNumElts
693              ? Optional<Instruction *>(IC.replaceInstUsesWith(
694                    II, ConstantInt::get(II.getType(), MinNumElts)))
695              : None;
696 }
697 
698 static Optional<Instruction *> instCombineSVEPTest(InstCombiner &IC,
699                                                    IntrinsicInst &II) {
700   IntrinsicInst *Op1 = dyn_cast<IntrinsicInst>(II.getArgOperand(0));
701   IntrinsicInst *Op2 = dyn_cast<IntrinsicInst>(II.getArgOperand(1));
702 
703   if (Op1 && Op2 &&
704       Op1->getIntrinsicID() == Intrinsic::aarch64_sve_convert_to_svbool &&
705       Op2->getIntrinsicID() == Intrinsic::aarch64_sve_convert_to_svbool &&
706       Op1->getArgOperand(0)->getType() == Op2->getArgOperand(0)->getType()) {
707 
708     IRBuilder<> Builder(II.getContext());
709     Builder.SetInsertPoint(&II);
710 
711     Value *Ops[] = {Op1->getArgOperand(0), Op2->getArgOperand(0)};
712     Type *Tys[] = {Op1->getArgOperand(0)->getType()};
713 
714     auto *PTest = Builder.CreateIntrinsic(II.getIntrinsicID(), Tys, Ops);
715 
716     PTest->takeName(&II);
717     return IC.replaceInstUsesWith(II, PTest);
718   }
719 
720   return None;
721 }
722 
723 static Optional<Instruction *> instCombineSVEVectorFMLA(InstCombiner &IC,
724                                                         IntrinsicInst &II) {
725   // fold (fadd p a (fmul p b c)) -> (fma p a b c)
726   Value *P = II.getOperand(0);
727   Value *A = II.getOperand(1);
728   auto FMul = II.getOperand(2);
729   Value *B, *C;
730   if (!match(FMul, m_Intrinsic<Intrinsic::aarch64_sve_fmul>(
731                        m_Specific(P), m_Value(B), m_Value(C))))
732     return None;
733 
734   if (!FMul->hasOneUse())
735     return None;
736 
737   llvm::FastMathFlags FAddFlags = II.getFastMathFlags();
738   // Stop the combine when the flags on the inputs differ in case dropping flags
739   // would lead to us missing out on more beneficial optimizations.
740   if (FAddFlags != cast<CallInst>(FMul)->getFastMathFlags())
741     return None;
742   if (!FAddFlags.allowContract())
743     return None;
744 
745   IRBuilder<> Builder(II.getContext());
746   Builder.SetInsertPoint(&II);
747   auto FMLA = Builder.CreateIntrinsic(Intrinsic::aarch64_sve_fmla,
748                                       {II.getType()}, {P, A, B, C}, &II);
749   FMLA->setFastMathFlags(FAddFlags);
750   return IC.replaceInstUsesWith(II, FMLA);
751 }
752 
753 static bool isAllActivePredicate(Value *Pred) {
754   // Look through convert.from.svbool(convert.to.svbool(...) chain.
755   Value *UncastedPred;
756   if (match(Pred, m_Intrinsic<Intrinsic::aarch64_sve_convert_from_svbool>(
757                       m_Intrinsic<Intrinsic::aarch64_sve_convert_to_svbool>(
758                           m_Value(UncastedPred)))))
759     // If the predicate has the same or less lanes than the uncasted
760     // predicate then we know the casting has no effect.
761     if (cast<ScalableVectorType>(Pred->getType())->getMinNumElements() <=
762         cast<ScalableVectorType>(UncastedPred->getType())->getMinNumElements())
763       Pred = UncastedPred;
764 
765   return match(Pred, m_Intrinsic<Intrinsic::aarch64_sve_ptrue>(
766                          m_ConstantInt<AArch64SVEPredPattern::all>()));
767 }
768 
769 static Optional<Instruction *>
770 instCombineSVELD1(InstCombiner &IC, IntrinsicInst &II, const DataLayout &DL) {
771   IRBuilder<> Builder(II.getContext());
772   Builder.SetInsertPoint(&II);
773 
774   Value *Pred = II.getOperand(0);
775   Value *PtrOp = II.getOperand(1);
776   Type *VecTy = II.getType();
777   Value *VecPtr = Builder.CreateBitCast(PtrOp, VecTy->getPointerTo());
778 
779   if (isAllActivePredicate(Pred)) {
780     LoadInst *Load = Builder.CreateLoad(VecTy, VecPtr);
781     return IC.replaceInstUsesWith(II, Load);
782   }
783 
784   CallInst *MaskedLoad =
785       Builder.CreateMaskedLoad(VecTy, VecPtr, PtrOp->getPointerAlignment(DL),
786                                Pred, ConstantAggregateZero::get(VecTy));
787   return IC.replaceInstUsesWith(II, MaskedLoad);
788 }
789 
790 static Optional<Instruction *>
791 instCombineSVEST1(InstCombiner &IC, IntrinsicInst &II, const DataLayout &DL) {
792   IRBuilder<> Builder(II.getContext());
793   Builder.SetInsertPoint(&II);
794 
795   Value *VecOp = II.getOperand(0);
796   Value *Pred = II.getOperand(1);
797   Value *PtrOp = II.getOperand(2);
798   Value *VecPtr =
799       Builder.CreateBitCast(PtrOp, VecOp->getType()->getPointerTo());
800 
801   if (isAllActivePredicate(Pred)) {
802     Builder.CreateStore(VecOp, VecPtr);
803     return IC.eraseInstFromFunction(II);
804   }
805 
806   Builder.CreateMaskedStore(VecOp, VecPtr, PtrOp->getPointerAlignment(DL),
807                             Pred);
808   return IC.eraseInstFromFunction(II);
809 }
810 
811 static Instruction::BinaryOps intrinsicIDToBinOpCode(unsigned Intrinsic) {
812   switch (Intrinsic) {
813   case Intrinsic::aarch64_sve_fmul:
814     return Instruction::BinaryOps::FMul;
815   case Intrinsic::aarch64_sve_fadd:
816     return Instruction::BinaryOps::FAdd;
817   case Intrinsic::aarch64_sve_fsub:
818     return Instruction::BinaryOps::FSub;
819   default:
820     return Instruction::BinaryOpsEnd;
821   }
822 }
823 
824 static Optional<Instruction *> instCombineSVEVectorBinOp(InstCombiner &IC,
825                                                          IntrinsicInst &II) {
826   auto *OpPredicate = II.getOperand(0);
827   auto BinOpCode = intrinsicIDToBinOpCode(II.getIntrinsicID());
828   if (BinOpCode == Instruction::BinaryOpsEnd ||
829       !match(OpPredicate, m_Intrinsic<Intrinsic::aarch64_sve_ptrue>(
830                               m_ConstantInt<AArch64SVEPredPattern::all>())))
831     return None;
832   IRBuilder<> Builder(II.getContext());
833   Builder.SetInsertPoint(&II);
834   Builder.setFastMathFlags(II.getFastMathFlags());
835   auto BinOp =
836       Builder.CreateBinOp(BinOpCode, II.getOperand(1), II.getOperand(2));
837   return IC.replaceInstUsesWith(II, BinOp);
838 }
839 
840 static Optional<Instruction *> instCombineSVEVectorFAdd(InstCombiner &IC,
841                                                         IntrinsicInst &II) {
842   if (auto FMLA = instCombineSVEVectorFMLA(IC, II))
843     return FMLA;
844   return instCombineSVEVectorBinOp(IC, II);
845 }
846 
847 static Optional<Instruction *> instCombineSVEVectorMul(InstCombiner &IC,
848                                                        IntrinsicInst &II) {
849   auto *OpPredicate = II.getOperand(0);
850   auto *OpMultiplicand = II.getOperand(1);
851   auto *OpMultiplier = II.getOperand(2);
852 
853   IRBuilder<> Builder(II.getContext());
854   Builder.SetInsertPoint(&II);
855 
856   // Return true if a given instruction is a unit splat value, false otherwise.
857   auto IsUnitSplat = [](auto *I) {
858     auto *SplatValue = getSplatValue(I);
859     if (!SplatValue)
860       return false;
861     return match(SplatValue, m_FPOne()) || match(SplatValue, m_One());
862   };
863 
864   // Return true if a given instruction is an aarch64_sve_dup intrinsic call
865   // with a unit splat value, false otherwise.
866   auto IsUnitDup = [](auto *I) {
867     auto *IntrI = dyn_cast<IntrinsicInst>(I);
868     if (!IntrI || IntrI->getIntrinsicID() != Intrinsic::aarch64_sve_dup)
869       return false;
870 
871     auto *SplatValue = IntrI->getOperand(2);
872     return match(SplatValue, m_FPOne()) || match(SplatValue, m_One());
873   };
874 
875   if (IsUnitSplat(OpMultiplier)) {
876     // [f]mul pg %n, (dupx 1) => %n
877     OpMultiplicand->takeName(&II);
878     return IC.replaceInstUsesWith(II, OpMultiplicand);
879   } else if (IsUnitDup(OpMultiplier)) {
880     // [f]mul pg %n, (dup pg 1) => %n
881     auto *DupInst = cast<IntrinsicInst>(OpMultiplier);
882     auto *DupPg = DupInst->getOperand(1);
883     // TODO: this is naive. The optimization is still valid if DupPg
884     // 'encompasses' OpPredicate, not only if they're the same predicate.
885     if (OpPredicate == DupPg) {
886       OpMultiplicand->takeName(&II);
887       return IC.replaceInstUsesWith(II, OpMultiplicand);
888     }
889   }
890 
891   return instCombineSVEVectorBinOp(IC, II);
892 }
893 
894 static Optional<Instruction *> instCombineSVEUnpack(InstCombiner &IC,
895                                                     IntrinsicInst &II) {
896   IRBuilder<> Builder(II.getContext());
897   Builder.SetInsertPoint(&II);
898   Value *UnpackArg = II.getArgOperand(0);
899   auto *RetTy = cast<ScalableVectorType>(II.getType());
900   bool IsSigned = II.getIntrinsicID() == Intrinsic::aarch64_sve_sunpkhi ||
901                   II.getIntrinsicID() == Intrinsic::aarch64_sve_sunpklo;
902 
903   // Hi = uunpkhi(splat(X)) --> Hi = splat(extend(X))
904   // Lo = uunpklo(splat(X)) --> Lo = splat(extend(X))
905   if (auto *ScalarArg = getSplatValue(UnpackArg)) {
906     ScalarArg =
907         Builder.CreateIntCast(ScalarArg, RetTy->getScalarType(), IsSigned);
908     Value *NewVal =
909         Builder.CreateVectorSplat(RetTy->getElementCount(), ScalarArg);
910     NewVal->takeName(&II);
911     return IC.replaceInstUsesWith(II, NewVal);
912   }
913 
914   return None;
915 }
916 static Optional<Instruction *> instCombineSVETBL(InstCombiner &IC,
917                                                  IntrinsicInst &II) {
918   auto *OpVal = II.getOperand(0);
919   auto *OpIndices = II.getOperand(1);
920   VectorType *VTy = cast<VectorType>(II.getType());
921 
922   // Check whether OpIndices is a constant splat value < minimal element count
923   // of result.
924   auto *SplatValue = dyn_cast_or_null<ConstantInt>(getSplatValue(OpIndices));
925   if (!SplatValue ||
926       SplatValue->getValue().uge(VTy->getElementCount().getKnownMinValue()))
927     return None;
928 
929   // Convert sve_tbl(OpVal sve_dup_x(SplatValue)) to
930   // splat_vector(extractelement(OpVal, SplatValue)) for further optimization.
931   IRBuilder<> Builder(II.getContext());
932   Builder.SetInsertPoint(&II);
933   auto *Extract = Builder.CreateExtractElement(OpVal, SplatValue);
934   auto *VectorSplat =
935       Builder.CreateVectorSplat(VTy->getElementCount(), Extract);
936 
937   VectorSplat->takeName(&II);
938   return IC.replaceInstUsesWith(II, VectorSplat);
939 }
940 
941 static Optional<Instruction *> instCombineSVETupleGet(InstCombiner &IC,
942                                                       IntrinsicInst &II) {
943   // Try to remove sequences of tuple get/set.
944   Value *SetTuple, *SetIndex, *SetValue;
945   auto *GetTuple = II.getArgOperand(0);
946   auto *GetIndex = II.getArgOperand(1);
947   // Check that we have tuple_get(GetTuple, GetIndex) where GetTuple is a
948   // call to tuple_set i.e. tuple_set(SetTuple, SetIndex, SetValue).
949   // Make sure that the types of the current intrinsic and SetValue match
950   // in order to safely remove the sequence.
951   if (!match(GetTuple,
952              m_Intrinsic<Intrinsic::aarch64_sve_tuple_set>(
953                  m_Value(SetTuple), m_Value(SetIndex), m_Value(SetValue))) ||
954       SetValue->getType() != II.getType())
955     return None;
956   // Case where we get the same index right after setting it.
957   // tuple_get(tuple_set(SetTuple, SetIndex, SetValue), GetIndex) --> SetValue
958   if (GetIndex == SetIndex)
959     return IC.replaceInstUsesWith(II, SetValue);
960   // If we are getting a different index than what was set in the tuple_set
961   // intrinsic. We can just set the input tuple to the one up in the chain.
962   // tuple_get(tuple_set(SetTuple, SetIndex, SetValue), GetIndex)
963   // --> tuple_get(SetTuple, GetIndex)
964   return IC.replaceOperand(II, 0, SetTuple);
965 }
966 
967 static Optional<Instruction *> instCombineSVEZip(InstCombiner &IC,
968                                                  IntrinsicInst &II) {
969   // zip1(uzp1(A, B), uzp2(A, B)) --> A
970   // zip2(uzp1(A, B), uzp2(A, B)) --> B
971   Value *A, *B;
972   if (match(II.getArgOperand(0),
973             m_Intrinsic<Intrinsic::aarch64_sve_uzp1>(m_Value(A), m_Value(B))) &&
974       match(II.getArgOperand(1), m_Intrinsic<Intrinsic::aarch64_sve_uzp2>(
975                                      m_Specific(A), m_Specific(B))))
976     return IC.replaceInstUsesWith(
977         II, (II.getIntrinsicID() == Intrinsic::aarch64_sve_zip1 ? A : B));
978 
979   return None;
980 }
981 
982 static Optional<Instruction *> instCombineLD1GatherIndex(InstCombiner &IC,
983                                                          IntrinsicInst &II) {
984   Value *Mask = II.getOperand(0);
985   Value *BasePtr = II.getOperand(1);
986   Value *Index = II.getOperand(2);
987   Type *Ty = II.getType();
988   Type *BasePtrTy = BasePtr->getType();
989   Value *PassThru = ConstantAggregateZero::get(Ty);
990 
991   // Contiguous gather => masked load.
992   // (sve.ld1.gather.index Mask BasePtr (sve.index IndexBase 1))
993   // => (masked.load (gep BasePtr IndexBase) Align Mask zeroinitializer)
994   Value *IndexBase;
995   if (match(Index, m_Intrinsic<Intrinsic::aarch64_sve_index>(
996                        m_Value(IndexBase), m_SpecificInt(1)))) {
997     IRBuilder<> Builder(II.getContext());
998     Builder.SetInsertPoint(&II);
999 
1000     Align Alignment =
1001         BasePtr->getPointerAlignment(II.getModule()->getDataLayout());
1002 
1003     Type *VecPtrTy = PointerType::getUnqual(Ty);
1004     Value *Ptr = Builder.CreateGEP(BasePtrTy->getPointerElementType(), BasePtr,
1005                                    IndexBase);
1006     Ptr = Builder.CreateBitCast(Ptr, VecPtrTy);
1007     CallInst *MaskedLoad =
1008         Builder.CreateMaskedLoad(Ty, Ptr, Alignment, Mask, PassThru);
1009     MaskedLoad->takeName(&II);
1010     return IC.replaceInstUsesWith(II, MaskedLoad);
1011   }
1012 
1013   return None;
1014 }
1015 
1016 static Optional<Instruction *> instCombineST1ScatterIndex(InstCombiner &IC,
1017                                                           IntrinsicInst &II) {
1018   Value *Val = II.getOperand(0);
1019   Value *Mask = II.getOperand(1);
1020   Value *BasePtr = II.getOperand(2);
1021   Value *Index = II.getOperand(3);
1022   Type *Ty = Val->getType();
1023   Type *BasePtrTy = BasePtr->getType();
1024 
1025   // Contiguous scatter => masked store.
1026   // (sve.ld1.scatter.index Value Mask BasePtr (sve.index IndexBase 1))
1027   // => (masked.store Value (gep BasePtr IndexBase) Align Mask)
1028   Value *IndexBase;
1029   if (match(Index, m_Intrinsic<Intrinsic::aarch64_sve_index>(
1030                        m_Value(IndexBase), m_SpecificInt(1)))) {
1031     IRBuilder<> Builder(II.getContext());
1032     Builder.SetInsertPoint(&II);
1033 
1034     Align Alignment =
1035         BasePtr->getPointerAlignment(II.getModule()->getDataLayout());
1036 
1037     Value *Ptr = Builder.CreateGEP(BasePtrTy->getPointerElementType(), BasePtr,
1038                                    IndexBase);
1039     Type *VecPtrTy = PointerType::getUnqual(Ty);
1040     Ptr = Builder.CreateBitCast(Ptr, VecPtrTy);
1041 
1042     (void)Builder.CreateMaskedStore(Val, Ptr, Alignment, Mask);
1043 
1044     return IC.eraseInstFromFunction(II);
1045   }
1046 
1047   return None;
1048 }
1049 
1050 static Optional<Instruction *> instCombineSVESDIV(InstCombiner &IC,
1051                                                   IntrinsicInst &II) {
1052   IRBuilder<> Builder(II.getContext());
1053   Builder.SetInsertPoint(&II);
1054   Type *Int32Ty = Builder.getInt32Ty();
1055   Value *Pred = II.getOperand(0);
1056   Value *Vec = II.getOperand(1);
1057   Value *DivVec = II.getOperand(2);
1058 
1059   Value *SplatValue = getSplatValue(DivVec);
1060   ConstantInt *SplatConstantInt = dyn_cast_or_null<ConstantInt>(SplatValue);
1061   if (!SplatConstantInt)
1062     return None;
1063   APInt Divisor = SplatConstantInt->getValue();
1064 
1065   if (Divisor.isPowerOf2()) {
1066     Constant *DivisorLog2 = ConstantInt::get(Int32Ty, Divisor.logBase2());
1067     auto ASRD = Builder.CreateIntrinsic(
1068         Intrinsic::aarch64_sve_asrd, {II.getType()}, {Pred, Vec, DivisorLog2});
1069     return IC.replaceInstUsesWith(II, ASRD);
1070   }
1071   if (Divisor.isNegatedPowerOf2()) {
1072     Divisor.negate();
1073     Constant *DivisorLog2 = ConstantInt::get(Int32Ty, Divisor.logBase2());
1074     auto ASRD = Builder.CreateIntrinsic(
1075         Intrinsic::aarch64_sve_asrd, {II.getType()}, {Pred, Vec, DivisorLog2});
1076     auto NEG = Builder.CreateIntrinsic(Intrinsic::aarch64_sve_neg,
1077                                        {ASRD->getType()}, {ASRD, Pred, ASRD});
1078     return IC.replaceInstUsesWith(II, NEG);
1079   }
1080 
1081   return None;
1082 }
1083 
1084 Optional<Instruction *>
1085 AArch64TTIImpl::instCombineIntrinsic(InstCombiner &IC,
1086                                      IntrinsicInst &II) const {
1087   Intrinsic::ID IID = II.getIntrinsicID();
1088   switch (IID) {
1089   default:
1090     break;
1091   case Intrinsic::aarch64_sve_convert_from_svbool:
1092     return instCombineConvertFromSVBool(IC, II);
1093   case Intrinsic::aarch64_sve_dup:
1094     return instCombineSVEDup(IC, II);
1095   case Intrinsic::aarch64_sve_dup_x:
1096     return instCombineSVEDupX(IC, II);
1097   case Intrinsic::aarch64_sve_cmpne:
1098   case Intrinsic::aarch64_sve_cmpne_wide:
1099     return instCombineSVECmpNE(IC, II);
1100   case Intrinsic::aarch64_sve_rdffr:
1101     return instCombineRDFFR(IC, II);
1102   case Intrinsic::aarch64_sve_lasta:
1103   case Intrinsic::aarch64_sve_lastb:
1104     return instCombineSVELast(IC, II);
1105   case Intrinsic::aarch64_sve_cntd:
1106     return instCombineSVECntElts(IC, II, 2);
1107   case Intrinsic::aarch64_sve_cntw:
1108     return instCombineSVECntElts(IC, II, 4);
1109   case Intrinsic::aarch64_sve_cnth:
1110     return instCombineSVECntElts(IC, II, 8);
1111   case Intrinsic::aarch64_sve_cntb:
1112     return instCombineSVECntElts(IC, II, 16);
1113   case Intrinsic::aarch64_sve_ptest_any:
1114   case Intrinsic::aarch64_sve_ptest_first:
1115   case Intrinsic::aarch64_sve_ptest_last:
1116     return instCombineSVEPTest(IC, II);
1117   case Intrinsic::aarch64_sve_mul:
1118   case Intrinsic::aarch64_sve_fmul:
1119     return instCombineSVEVectorMul(IC, II);
1120   case Intrinsic::aarch64_sve_fadd:
1121     return instCombineSVEVectorFAdd(IC, II);
1122   case Intrinsic::aarch64_sve_fsub:
1123     return instCombineSVEVectorBinOp(IC, II);
1124   case Intrinsic::aarch64_sve_tbl:
1125     return instCombineSVETBL(IC, II);
1126   case Intrinsic::aarch64_sve_uunpkhi:
1127   case Intrinsic::aarch64_sve_uunpklo:
1128   case Intrinsic::aarch64_sve_sunpkhi:
1129   case Intrinsic::aarch64_sve_sunpklo:
1130     return instCombineSVEUnpack(IC, II);
1131   case Intrinsic::aarch64_sve_tuple_get:
1132     return instCombineSVETupleGet(IC, II);
1133   case Intrinsic::aarch64_sve_zip1:
1134   case Intrinsic::aarch64_sve_zip2:
1135     return instCombineSVEZip(IC, II);
1136   case Intrinsic::aarch64_sve_ld1_gather_index:
1137     return instCombineLD1GatherIndex(IC, II);
1138   case Intrinsic::aarch64_sve_st1_scatter_index:
1139     return instCombineST1ScatterIndex(IC, II);
1140   case Intrinsic::aarch64_sve_ld1:
1141     return instCombineSVELD1(IC, II, DL);
1142   case Intrinsic::aarch64_sve_st1:
1143     return instCombineSVEST1(IC, II, DL);
1144   case Intrinsic::aarch64_sve_sdiv:
1145     return instCombineSVESDIV(IC, II);
1146   }
1147 
1148   return None;
1149 }
1150 
1151 bool AArch64TTIImpl::isWideningInstruction(Type *DstTy, unsigned Opcode,
1152                                            ArrayRef<const Value *> Args) {
1153 
1154   // A helper that returns a vector type from the given type. The number of
1155   // elements in type Ty determine the vector width.
1156   auto toVectorTy = [&](Type *ArgTy) {
1157     return VectorType::get(ArgTy->getScalarType(),
1158                            cast<VectorType>(DstTy)->getElementCount());
1159   };
1160 
1161   // Exit early if DstTy is not a vector type whose elements are at least
1162   // 16-bits wide.
1163   if (!DstTy->isVectorTy() || DstTy->getScalarSizeInBits() < 16)
1164     return false;
1165 
1166   // Determine if the operation has a widening variant. We consider both the
1167   // "long" (e.g., usubl) and "wide" (e.g., usubw) versions of the
1168   // instructions.
1169   //
1170   // TODO: Add additional widening operations (e.g., mul, shl, etc.) once we
1171   //       verify that their extending operands are eliminated during code
1172   //       generation.
1173   switch (Opcode) {
1174   case Instruction::Add: // UADDL(2), SADDL(2), UADDW(2), SADDW(2).
1175   case Instruction::Sub: // USUBL(2), SSUBL(2), USUBW(2), SSUBW(2).
1176     break;
1177   default:
1178     return false;
1179   }
1180 
1181   // To be a widening instruction (either the "wide" or "long" versions), the
1182   // second operand must be a sign- or zero extend having a single user. We
1183   // only consider extends having a single user because they may otherwise not
1184   // be eliminated.
1185   if (Args.size() != 2 ||
1186       (!isa<SExtInst>(Args[1]) && !isa<ZExtInst>(Args[1])) ||
1187       !Args[1]->hasOneUse())
1188     return false;
1189   auto *Extend = cast<CastInst>(Args[1]);
1190 
1191   // Legalize the destination type and ensure it can be used in a widening
1192   // operation.
1193   auto DstTyL = TLI->getTypeLegalizationCost(DL, DstTy);
1194   unsigned DstElTySize = DstTyL.second.getScalarSizeInBits();
1195   if (!DstTyL.second.isVector() || DstElTySize != DstTy->getScalarSizeInBits())
1196     return false;
1197 
1198   // Legalize the source type and ensure it can be used in a widening
1199   // operation.
1200   auto *SrcTy = toVectorTy(Extend->getSrcTy());
1201   auto SrcTyL = TLI->getTypeLegalizationCost(DL, SrcTy);
1202   unsigned SrcElTySize = SrcTyL.second.getScalarSizeInBits();
1203   if (!SrcTyL.second.isVector() || SrcElTySize != SrcTy->getScalarSizeInBits())
1204     return false;
1205 
1206   // Get the total number of vector elements in the legalized types.
1207   InstructionCost NumDstEls =
1208       DstTyL.first * DstTyL.second.getVectorMinNumElements();
1209   InstructionCost NumSrcEls =
1210       SrcTyL.first * SrcTyL.second.getVectorMinNumElements();
1211 
1212   // Return true if the legalized types have the same number of vector elements
1213   // and the destination element type size is twice that of the source type.
1214   return NumDstEls == NumSrcEls && 2 * SrcElTySize == DstElTySize;
1215 }
1216 
1217 InstructionCost AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
1218                                                  Type *Src,
1219                                                  TTI::CastContextHint CCH,
1220                                                  TTI::TargetCostKind CostKind,
1221                                                  const Instruction *I) {
1222   int ISD = TLI->InstructionOpcodeToISD(Opcode);
1223   assert(ISD && "Invalid opcode");
1224 
1225   // If the cast is observable, and it is used by a widening instruction (e.g.,
1226   // uaddl, saddw, etc.), it may be free.
1227   if (I && I->hasOneUse()) {
1228     auto *SingleUser = cast<Instruction>(*I->user_begin());
1229     SmallVector<const Value *, 4> Operands(SingleUser->operand_values());
1230     if (isWideningInstruction(Dst, SingleUser->getOpcode(), Operands)) {
1231       // If the cast is the second operand, it is free. We will generate either
1232       // a "wide" or "long" version of the widening instruction.
1233       if (I == SingleUser->getOperand(1))
1234         return 0;
1235       // If the cast is not the second operand, it will be free if it looks the
1236       // same as the second operand. In this case, we will generate a "long"
1237       // version of the widening instruction.
1238       if (auto *Cast = dyn_cast<CastInst>(SingleUser->getOperand(1)))
1239         if (I->getOpcode() == unsigned(Cast->getOpcode()) &&
1240             cast<CastInst>(I)->getSrcTy() == Cast->getSrcTy())
1241           return 0;
1242     }
1243   }
1244 
1245   // TODO: Allow non-throughput costs that aren't binary.
1246   auto AdjustCost = [&CostKind](InstructionCost Cost) -> InstructionCost {
1247     if (CostKind != TTI::TCK_RecipThroughput)
1248       return Cost == 0 ? 0 : 1;
1249     return Cost;
1250   };
1251 
1252   EVT SrcTy = TLI->getValueType(DL, Src);
1253   EVT DstTy = TLI->getValueType(DL, Dst);
1254 
1255   if (!SrcTy.isSimple() || !DstTy.isSimple())
1256     return AdjustCost(
1257         BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I));
1258 
1259   static const TypeConversionCostTblEntry
1260   ConversionTbl[] = {
1261     { ISD::TRUNCATE, MVT::v4i16, MVT::v4i32,  1 },
1262     { ISD::TRUNCATE, MVT::v4i32, MVT::v4i64,  0 },
1263     { ISD::TRUNCATE, MVT::v8i8,  MVT::v8i32,  3 },
1264     { ISD::TRUNCATE, MVT::v16i8, MVT::v16i32, 6 },
1265 
1266     // Truncations on nxvmiN
1267     { ISD::TRUNCATE, MVT::nxv2i1, MVT::nxv2i16, 1 },
1268     { ISD::TRUNCATE, MVT::nxv2i1, MVT::nxv2i32, 1 },
1269     { ISD::TRUNCATE, MVT::nxv2i1, MVT::nxv2i64, 1 },
1270     { ISD::TRUNCATE, MVT::nxv4i1, MVT::nxv4i16, 1 },
1271     { ISD::TRUNCATE, MVT::nxv4i1, MVT::nxv4i32, 1 },
1272     { ISD::TRUNCATE, MVT::nxv4i1, MVT::nxv4i64, 2 },
1273     { ISD::TRUNCATE, MVT::nxv8i1, MVT::nxv8i16, 1 },
1274     { ISD::TRUNCATE, MVT::nxv8i1, MVT::nxv8i32, 3 },
1275     { ISD::TRUNCATE, MVT::nxv8i1, MVT::nxv8i64, 5 },
1276     { ISD::TRUNCATE, MVT::nxv16i1, MVT::nxv16i8, 1 },
1277     { ISD::TRUNCATE, MVT::nxv2i16, MVT::nxv2i32, 1 },
1278     { ISD::TRUNCATE, MVT::nxv2i32, MVT::nxv2i64, 1 },
1279     { ISD::TRUNCATE, MVT::nxv4i16, MVT::nxv4i32, 1 },
1280     { ISD::TRUNCATE, MVT::nxv4i32, MVT::nxv4i64, 2 },
1281     { ISD::TRUNCATE, MVT::nxv8i16, MVT::nxv8i32, 3 },
1282     { ISD::TRUNCATE, MVT::nxv8i32, MVT::nxv8i64, 6 },
1283 
1284     // The number of shll instructions for the extension.
1285     { ISD::SIGN_EXTEND, MVT::v4i64,  MVT::v4i16, 3 },
1286     { ISD::ZERO_EXTEND, MVT::v4i64,  MVT::v4i16, 3 },
1287     { ISD::SIGN_EXTEND, MVT::v4i64,  MVT::v4i32, 2 },
1288     { ISD::ZERO_EXTEND, MVT::v4i64,  MVT::v4i32, 2 },
1289     { ISD::SIGN_EXTEND, MVT::v8i32,  MVT::v8i8,  3 },
1290     { ISD::ZERO_EXTEND, MVT::v8i32,  MVT::v8i8,  3 },
1291     { ISD::SIGN_EXTEND, MVT::v8i32,  MVT::v8i16, 2 },
1292     { ISD::ZERO_EXTEND, MVT::v8i32,  MVT::v8i16, 2 },
1293     { ISD::SIGN_EXTEND, MVT::v8i64,  MVT::v8i8,  7 },
1294     { ISD::ZERO_EXTEND, MVT::v8i64,  MVT::v8i8,  7 },
1295     { ISD::SIGN_EXTEND, MVT::v8i64,  MVT::v8i16, 6 },
1296     { ISD::ZERO_EXTEND, MVT::v8i64,  MVT::v8i16, 6 },
1297     { ISD::SIGN_EXTEND, MVT::v16i16, MVT::v16i8, 2 },
1298     { ISD::ZERO_EXTEND, MVT::v16i16, MVT::v16i8, 2 },
1299     { ISD::SIGN_EXTEND, MVT::v16i32, MVT::v16i8, 6 },
1300     { ISD::ZERO_EXTEND, MVT::v16i32, MVT::v16i8, 6 },
1301 
1302     // LowerVectorINT_TO_FP:
1303     { ISD::SINT_TO_FP, MVT::v2f32, MVT::v2i32, 1 },
1304     { ISD::SINT_TO_FP, MVT::v4f32, MVT::v4i32, 1 },
1305     { ISD::SINT_TO_FP, MVT::v2f64, MVT::v2i64, 1 },
1306     { ISD::UINT_TO_FP, MVT::v2f32, MVT::v2i32, 1 },
1307     { ISD::UINT_TO_FP, MVT::v4f32, MVT::v4i32, 1 },
1308     { ISD::UINT_TO_FP, MVT::v2f64, MVT::v2i64, 1 },
1309 
1310     // Complex: to v2f32
1311     { ISD::SINT_TO_FP, MVT::v2f32, MVT::v2i8,  3 },
1312     { ISD::SINT_TO_FP, MVT::v2f32, MVT::v2i16, 3 },
1313     { ISD::SINT_TO_FP, MVT::v2f32, MVT::v2i64, 2 },
1314     { ISD::UINT_TO_FP, MVT::v2f32, MVT::v2i8,  3 },
1315     { ISD::UINT_TO_FP, MVT::v2f32, MVT::v2i16, 3 },
1316     { ISD::UINT_TO_FP, MVT::v2f32, MVT::v2i64, 2 },
1317 
1318     // Complex: to v4f32
1319     { ISD::SINT_TO_FP, MVT::v4f32, MVT::v4i8,  4 },
1320     { ISD::SINT_TO_FP, MVT::v4f32, MVT::v4i16, 2 },
1321     { ISD::UINT_TO_FP, MVT::v4f32, MVT::v4i8,  3 },
1322     { ISD::UINT_TO_FP, MVT::v4f32, MVT::v4i16, 2 },
1323 
1324     // Complex: to v8f32
1325     { ISD::SINT_TO_FP, MVT::v8f32, MVT::v8i8,  10 },
1326     { ISD::SINT_TO_FP, MVT::v8f32, MVT::v8i16, 4 },
1327     { ISD::UINT_TO_FP, MVT::v8f32, MVT::v8i8,  10 },
1328     { ISD::UINT_TO_FP, MVT::v8f32, MVT::v8i16, 4 },
1329 
1330     // Complex: to v16f32
1331     { ISD::SINT_TO_FP, MVT::v16f32, MVT::v16i8, 21 },
1332     { ISD::UINT_TO_FP, MVT::v16f32, MVT::v16i8, 21 },
1333 
1334     // Complex: to v2f64
1335     { ISD::SINT_TO_FP, MVT::v2f64, MVT::v2i8,  4 },
1336     { ISD::SINT_TO_FP, MVT::v2f64, MVT::v2i16, 4 },
1337     { ISD::SINT_TO_FP, MVT::v2f64, MVT::v2i32, 2 },
1338     { ISD::UINT_TO_FP, MVT::v2f64, MVT::v2i8,  4 },
1339     { ISD::UINT_TO_FP, MVT::v2f64, MVT::v2i16, 4 },
1340     { ISD::UINT_TO_FP, MVT::v2f64, MVT::v2i32, 2 },
1341 
1342 
1343     // LowerVectorFP_TO_INT
1344     { ISD::FP_TO_SINT, MVT::v2i32, MVT::v2f32, 1 },
1345     { ISD::FP_TO_SINT, MVT::v4i32, MVT::v4f32, 1 },
1346     { ISD::FP_TO_SINT, MVT::v2i64, MVT::v2f64, 1 },
1347     { ISD::FP_TO_UINT, MVT::v2i32, MVT::v2f32, 1 },
1348     { ISD::FP_TO_UINT, MVT::v4i32, MVT::v4f32, 1 },
1349     { ISD::FP_TO_UINT, MVT::v2i64, MVT::v2f64, 1 },
1350 
1351     // Complex, from v2f32: legal type is v2i32 (no cost) or v2i64 (1 ext).
1352     { ISD::FP_TO_SINT, MVT::v2i64, MVT::v2f32, 2 },
1353     { ISD::FP_TO_SINT, MVT::v2i16, MVT::v2f32, 1 },
1354     { ISD::FP_TO_SINT, MVT::v2i8,  MVT::v2f32, 1 },
1355     { ISD::FP_TO_UINT, MVT::v2i64, MVT::v2f32, 2 },
1356     { ISD::FP_TO_UINT, MVT::v2i16, MVT::v2f32, 1 },
1357     { ISD::FP_TO_UINT, MVT::v2i8,  MVT::v2f32, 1 },
1358 
1359     // Complex, from v4f32: legal type is v4i16, 1 narrowing => ~2
1360     { ISD::FP_TO_SINT, MVT::v4i16, MVT::v4f32, 2 },
1361     { ISD::FP_TO_SINT, MVT::v4i8,  MVT::v4f32, 2 },
1362     { ISD::FP_TO_UINT, MVT::v4i16, MVT::v4f32, 2 },
1363     { ISD::FP_TO_UINT, MVT::v4i8,  MVT::v4f32, 2 },
1364 
1365     // Complex, from nxv2f32.
1366     { ISD::FP_TO_SINT, MVT::nxv2i64, MVT::nxv2f32, 1 },
1367     { ISD::FP_TO_SINT, MVT::nxv2i32, MVT::nxv2f32, 1 },
1368     { ISD::FP_TO_SINT, MVT::nxv2i16, MVT::nxv2f32, 1 },
1369     { ISD::FP_TO_SINT, MVT::nxv2i8,  MVT::nxv2f32, 1 },
1370     { ISD::FP_TO_UINT, MVT::nxv2i64, MVT::nxv2f32, 1 },
1371     { ISD::FP_TO_UINT, MVT::nxv2i32, MVT::nxv2f32, 1 },
1372     { ISD::FP_TO_UINT, MVT::nxv2i16, MVT::nxv2f32, 1 },
1373     { ISD::FP_TO_UINT, MVT::nxv2i8,  MVT::nxv2f32, 1 },
1374 
1375     // Complex, from v2f64: legal type is v2i32, 1 narrowing => ~2.
1376     { ISD::FP_TO_SINT, MVT::v2i32, MVT::v2f64, 2 },
1377     { ISD::FP_TO_SINT, MVT::v2i16, MVT::v2f64, 2 },
1378     { ISD::FP_TO_SINT, MVT::v2i8,  MVT::v2f64, 2 },
1379     { ISD::FP_TO_UINT, MVT::v2i32, MVT::v2f64, 2 },
1380     { ISD::FP_TO_UINT, MVT::v2i16, MVT::v2f64, 2 },
1381     { ISD::FP_TO_UINT, MVT::v2i8,  MVT::v2f64, 2 },
1382 
1383     // Complex, from nxv2f64.
1384     { ISD::FP_TO_SINT, MVT::nxv2i64, MVT::nxv2f64, 1 },
1385     { ISD::FP_TO_SINT, MVT::nxv2i32, MVT::nxv2f64, 1 },
1386     { ISD::FP_TO_SINT, MVT::nxv2i16, MVT::nxv2f64, 1 },
1387     { ISD::FP_TO_SINT, MVT::nxv2i8,  MVT::nxv2f64, 1 },
1388     { ISD::FP_TO_UINT, MVT::nxv2i64, MVT::nxv2f64, 1 },
1389     { ISD::FP_TO_UINT, MVT::nxv2i32, MVT::nxv2f64, 1 },
1390     { ISD::FP_TO_UINT, MVT::nxv2i16, MVT::nxv2f64, 1 },
1391     { ISD::FP_TO_UINT, MVT::nxv2i8,  MVT::nxv2f64, 1 },
1392 
1393     // Complex, from nxv4f32.
1394     { ISD::FP_TO_SINT, MVT::nxv4i64, MVT::nxv4f32, 4 },
1395     { ISD::FP_TO_SINT, MVT::nxv4i32, MVT::nxv4f32, 1 },
1396     { ISD::FP_TO_SINT, MVT::nxv4i16, MVT::nxv4f32, 1 },
1397     { ISD::FP_TO_SINT, MVT::nxv4i8,  MVT::nxv4f32, 1 },
1398     { ISD::FP_TO_UINT, MVT::nxv4i64, MVT::nxv4f32, 4 },
1399     { ISD::FP_TO_UINT, MVT::nxv4i32, MVT::nxv4f32, 1 },
1400     { ISD::FP_TO_UINT, MVT::nxv4i16, MVT::nxv4f32, 1 },
1401     { ISD::FP_TO_UINT, MVT::nxv4i8,  MVT::nxv4f32, 1 },
1402 
1403     // Complex, from nxv8f64. Illegal -> illegal conversions not required.
1404     { ISD::FP_TO_SINT, MVT::nxv8i16, MVT::nxv8f64, 7 },
1405     { ISD::FP_TO_SINT, MVT::nxv8i8,  MVT::nxv8f64, 7 },
1406     { ISD::FP_TO_UINT, MVT::nxv8i16, MVT::nxv8f64, 7 },
1407     { ISD::FP_TO_UINT, MVT::nxv8i8,  MVT::nxv8f64, 7 },
1408 
1409     // Complex, from nxv4f64. Illegal -> illegal conversions not required.
1410     { ISD::FP_TO_SINT, MVT::nxv4i32, MVT::nxv4f64, 3 },
1411     { ISD::FP_TO_SINT, MVT::nxv4i16, MVT::nxv4f64, 3 },
1412     { ISD::FP_TO_SINT, MVT::nxv4i8,  MVT::nxv4f64, 3 },
1413     { ISD::FP_TO_UINT, MVT::nxv4i32, MVT::nxv4f64, 3 },
1414     { ISD::FP_TO_UINT, MVT::nxv4i16, MVT::nxv4f64, 3 },
1415     { ISD::FP_TO_UINT, MVT::nxv4i8,  MVT::nxv4f64, 3 },
1416 
1417     // Complex, from nxv8f32. Illegal -> illegal conversions not required.
1418     { ISD::FP_TO_SINT, MVT::nxv8i16, MVT::nxv8f32, 3 },
1419     { ISD::FP_TO_SINT, MVT::nxv8i8,  MVT::nxv8f32, 3 },
1420     { ISD::FP_TO_UINT, MVT::nxv8i16, MVT::nxv8f32, 3 },
1421     { ISD::FP_TO_UINT, MVT::nxv8i8,  MVT::nxv8f32, 3 },
1422 
1423     // Complex, from nxv8f16.
1424     { ISD::FP_TO_SINT, MVT::nxv8i64, MVT::nxv8f16, 10 },
1425     { ISD::FP_TO_SINT, MVT::nxv8i32, MVT::nxv8f16, 4 },
1426     { ISD::FP_TO_SINT, MVT::nxv8i16, MVT::nxv8f16, 1 },
1427     { ISD::FP_TO_SINT, MVT::nxv8i8,  MVT::nxv8f16, 1 },
1428     { ISD::FP_TO_UINT, MVT::nxv8i64, MVT::nxv8f16, 10 },
1429     { ISD::FP_TO_UINT, MVT::nxv8i32, MVT::nxv8f16, 4 },
1430     { ISD::FP_TO_UINT, MVT::nxv8i16, MVT::nxv8f16, 1 },
1431     { ISD::FP_TO_UINT, MVT::nxv8i8,  MVT::nxv8f16, 1 },
1432 
1433     // Complex, from nxv4f16.
1434     { ISD::FP_TO_SINT, MVT::nxv4i64, MVT::nxv4f16, 4 },
1435     { ISD::FP_TO_SINT, MVT::nxv4i32, MVT::nxv4f16, 1 },
1436     { ISD::FP_TO_SINT, MVT::nxv4i16, MVT::nxv4f16, 1 },
1437     { ISD::FP_TO_SINT, MVT::nxv4i8,  MVT::nxv4f16, 1 },
1438     { ISD::FP_TO_UINT, MVT::nxv4i64, MVT::nxv4f16, 4 },
1439     { ISD::FP_TO_UINT, MVT::nxv4i32, MVT::nxv4f16, 1 },
1440     { ISD::FP_TO_UINT, MVT::nxv4i16, MVT::nxv4f16, 1 },
1441     { ISD::FP_TO_UINT, MVT::nxv4i8,  MVT::nxv4f16, 1 },
1442 
1443     // Complex, from nxv2f16.
1444     { ISD::FP_TO_SINT, MVT::nxv2i64, MVT::nxv2f16, 1 },
1445     { ISD::FP_TO_SINT, MVT::nxv2i32, MVT::nxv2f16, 1 },
1446     { ISD::FP_TO_SINT, MVT::nxv2i16, MVT::nxv2f16, 1 },
1447     { ISD::FP_TO_SINT, MVT::nxv2i8,  MVT::nxv2f16, 1 },
1448     { ISD::FP_TO_UINT, MVT::nxv2i64, MVT::nxv2f16, 1 },
1449     { ISD::FP_TO_UINT, MVT::nxv2i32, MVT::nxv2f16, 1 },
1450     { ISD::FP_TO_UINT, MVT::nxv2i16, MVT::nxv2f16, 1 },
1451     { ISD::FP_TO_UINT, MVT::nxv2i8,  MVT::nxv2f16, 1 },
1452 
1453     // Truncate from nxvmf32 to nxvmf16.
1454     { ISD::FP_ROUND, MVT::nxv2f16, MVT::nxv2f32, 1 },
1455     { ISD::FP_ROUND, MVT::nxv4f16, MVT::nxv4f32, 1 },
1456     { ISD::FP_ROUND, MVT::nxv8f16, MVT::nxv8f32, 3 },
1457 
1458     // Truncate from nxvmf64 to nxvmf16.
1459     { ISD::FP_ROUND, MVT::nxv2f16, MVT::nxv2f64, 1 },
1460     { ISD::FP_ROUND, MVT::nxv4f16, MVT::nxv4f64, 3 },
1461     { ISD::FP_ROUND, MVT::nxv8f16, MVT::nxv8f64, 7 },
1462 
1463     // Truncate from nxvmf64 to nxvmf32.
1464     { ISD::FP_ROUND, MVT::nxv2f32, MVT::nxv2f64, 1 },
1465     { ISD::FP_ROUND, MVT::nxv4f32, MVT::nxv4f64, 3 },
1466     { ISD::FP_ROUND, MVT::nxv8f32, MVT::nxv8f64, 6 },
1467 
1468     // Extend from nxvmf16 to nxvmf32.
1469     { ISD::FP_EXTEND, MVT::nxv2f32, MVT::nxv2f16, 1},
1470     { ISD::FP_EXTEND, MVT::nxv4f32, MVT::nxv4f16, 1},
1471     { ISD::FP_EXTEND, MVT::nxv8f32, MVT::nxv8f16, 2},
1472 
1473     // Extend from nxvmf16 to nxvmf64.
1474     { ISD::FP_EXTEND, MVT::nxv2f64, MVT::nxv2f16, 1},
1475     { ISD::FP_EXTEND, MVT::nxv4f64, MVT::nxv4f16, 2},
1476     { ISD::FP_EXTEND, MVT::nxv8f64, MVT::nxv8f16, 4},
1477 
1478     // Extend from nxvmf32 to nxvmf64.
1479     { ISD::FP_EXTEND, MVT::nxv2f64, MVT::nxv2f32, 1},
1480     { ISD::FP_EXTEND, MVT::nxv4f64, MVT::nxv4f32, 2},
1481     { ISD::FP_EXTEND, MVT::nxv8f64, MVT::nxv8f32, 6},
1482 
1483   };
1484 
1485   if (const auto *Entry = ConvertCostTableLookup(ConversionTbl, ISD,
1486                                                  DstTy.getSimpleVT(),
1487                                                  SrcTy.getSimpleVT()))
1488     return AdjustCost(Entry->Cost);
1489 
1490   return AdjustCost(
1491       BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I));
1492 }
1493 
1494 InstructionCost AArch64TTIImpl::getExtractWithExtendCost(unsigned Opcode,
1495                                                          Type *Dst,
1496                                                          VectorType *VecTy,
1497                                                          unsigned Index) {
1498 
1499   // Make sure we were given a valid extend opcode.
1500   assert((Opcode == Instruction::SExt || Opcode == Instruction::ZExt) &&
1501          "Invalid opcode");
1502 
1503   // We are extending an element we extract from a vector, so the source type
1504   // of the extend is the element type of the vector.
1505   auto *Src = VecTy->getElementType();
1506 
1507   // Sign- and zero-extends are for integer types only.
1508   assert(isa<IntegerType>(Dst) && isa<IntegerType>(Src) && "Invalid type");
1509 
1510   // Get the cost for the extract. We compute the cost (if any) for the extend
1511   // below.
1512   InstructionCost Cost =
1513       getVectorInstrCost(Instruction::ExtractElement, VecTy, Index);
1514 
1515   // Legalize the types.
1516   auto VecLT = TLI->getTypeLegalizationCost(DL, VecTy);
1517   auto DstVT = TLI->getValueType(DL, Dst);
1518   auto SrcVT = TLI->getValueType(DL, Src);
1519   TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
1520 
1521   // If the resulting type is still a vector and the destination type is legal,
1522   // we may get the extension for free. If not, get the default cost for the
1523   // extend.
1524   if (!VecLT.second.isVector() || !TLI->isTypeLegal(DstVT))
1525     return Cost + getCastInstrCost(Opcode, Dst, Src, TTI::CastContextHint::None,
1526                                    CostKind);
1527 
1528   // The destination type should be larger than the element type. If not, get
1529   // the default cost for the extend.
1530   if (DstVT.getFixedSizeInBits() < SrcVT.getFixedSizeInBits())
1531     return Cost + getCastInstrCost(Opcode, Dst, Src, TTI::CastContextHint::None,
1532                                    CostKind);
1533 
1534   switch (Opcode) {
1535   default:
1536     llvm_unreachable("Opcode should be either SExt or ZExt");
1537 
1538   // For sign-extends, we only need a smov, which performs the extension
1539   // automatically.
1540   case Instruction::SExt:
1541     return Cost;
1542 
1543   // For zero-extends, the extend is performed automatically by a umov unless
1544   // the destination type is i64 and the element type is i8 or i16.
1545   case Instruction::ZExt:
1546     if (DstVT.getSizeInBits() != 64u || SrcVT.getSizeInBits() == 32u)
1547       return Cost;
1548   }
1549 
1550   // If we are unable to perform the extend for free, get the default cost.
1551   return Cost + getCastInstrCost(Opcode, Dst, Src, TTI::CastContextHint::None,
1552                                  CostKind);
1553 }
1554 
1555 InstructionCost AArch64TTIImpl::getCFInstrCost(unsigned Opcode,
1556                                                TTI::TargetCostKind CostKind,
1557                                                const Instruction *I) {
1558   if (CostKind != TTI::TCK_RecipThroughput)
1559     return Opcode == Instruction::PHI ? 0 : 1;
1560   assert(CostKind == TTI::TCK_RecipThroughput && "unexpected CostKind");
1561   // Branches are assumed to be predicted.
1562   return 0;
1563 }
1564 
1565 InstructionCost AArch64TTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
1566                                                    unsigned Index) {
1567   assert(Val->isVectorTy() && "This must be a vector type");
1568 
1569   if (Index != -1U) {
1570     // Legalize the type.
1571     std::pair<InstructionCost, MVT> LT = TLI->getTypeLegalizationCost(DL, Val);
1572 
1573     // This type is legalized to a scalar type.
1574     if (!LT.second.isVector())
1575       return 0;
1576 
1577     // The type may be split. Normalize the index to the new type.
1578     unsigned Width = LT.second.getVectorNumElements();
1579     Index = Index % Width;
1580 
1581     // The element at index zero is already inside the vector.
1582     if (Index == 0)
1583       return 0;
1584   }
1585 
1586   // All other insert/extracts cost this much.
1587   return ST->getVectorInsertExtractBaseCost();
1588 }
1589 
1590 InstructionCost AArch64TTIImpl::getArithmeticInstrCost(
1591     unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
1592     TTI::OperandValueKind Opd1Info, TTI::OperandValueKind Opd2Info,
1593     TTI::OperandValueProperties Opd1PropInfo,
1594     TTI::OperandValueProperties Opd2PropInfo, ArrayRef<const Value *> Args,
1595     const Instruction *CxtI) {
1596   // TODO: Handle more cost kinds.
1597   if (CostKind != TTI::TCK_RecipThroughput)
1598     return BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Opd1Info,
1599                                          Opd2Info, Opd1PropInfo,
1600                                          Opd2PropInfo, Args, CxtI);
1601 
1602   // Legalize the type.
1603   std::pair<InstructionCost, MVT> LT = TLI->getTypeLegalizationCost(DL, Ty);
1604 
1605   // If the instruction is a widening instruction (e.g., uaddl, saddw, etc.),
1606   // add in the widening overhead specified by the sub-target. Since the
1607   // extends feeding widening instructions are performed automatically, they
1608   // aren't present in the generated code and have a zero cost. By adding a
1609   // widening overhead here, we attach the total cost of the combined operation
1610   // to the widening instruction.
1611   InstructionCost Cost = 0;
1612   if (isWideningInstruction(Ty, Opcode, Args))
1613     Cost += ST->getWideningBaseCost();
1614 
1615   int ISD = TLI->InstructionOpcodeToISD(Opcode);
1616 
1617   switch (ISD) {
1618   default:
1619     return Cost + BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Opd1Info,
1620                                                 Opd2Info,
1621                                                 Opd1PropInfo, Opd2PropInfo);
1622   case ISD::SDIV:
1623     if (Opd2Info == TargetTransformInfo::OK_UniformConstantValue &&
1624         Opd2PropInfo == TargetTransformInfo::OP_PowerOf2) {
1625       // On AArch64, scalar signed division by constants power-of-two are
1626       // normally expanded to the sequence ADD + CMP + SELECT + SRA.
1627       // The OperandValue properties many not be same as that of previous
1628       // operation; conservatively assume OP_None.
1629       Cost += getArithmeticInstrCost(Instruction::Add, Ty, CostKind,
1630                                      Opd1Info, Opd2Info,
1631                                      TargetTransformInfo::OP_None,
1632                                      TargetTransformInfo::OP_None);
1633       Cost += getArithmeticInstrCost(Instruction::Sub, Ty, CostKind,
1634                                      Opd1Info, Opd2Info,
1635                                      TargetTransformInfo::OP_None,
1636                                      TargetTransformInfo::OP_None);
1637       Cost += getArithmeticInstrCost(Instruction::Select, Ty, CostKind,
1638                                      Opd1Info, Opd2Info,
1639                                      TargetTransformInfo::OP_None,
1640                                      TargetTransformInfo::OP_None);
1641       Cost += getArithmeticInstrCost(Instruction::AShr, Ty, CostKind,
1642                                      Opd1Info, Opd2Info,
1643                                      TargetTransformInfo::OP_None,
1644                                      TargetTransformInfo::OP_None);
1645       return Cost;
1646     }
1647     LLVM_FALLTHROUGH;
1648   case ISD::UDIV:
1649     if (Opd2Info == TargetTransformInfo::OK_UniformConstantValue) {
1650       auto VT = TLI->getValueType(DL, Ty);
1651       if (TLI->isOperationLegalOrCustom(ISD::MULHU, VT)) {
1652         // Vector signed division by constant are expanded to the
1653         // sequence MULHS + ADD/SUB + SRA + SRL + ADD, and unsigned division
1654         // to MULHS + SUB + SRL + ADD + SRL.
1655         InstructionCost MulCost = getArithmeticInstrCost(
1656             Instruction::Mul, Ty, CostKind, Opd1Info, Opd2Info,
1657             TargetTransformInfo::OP_None, TargetTransformInfo::OP_None);
1658         InstructionCost AddCost = getArithmeticInstrCost(
1659             Instruction::Add, Ty, CostKind, Opd1Info, Opd2Info,
1660             TargetTransformInfo::OP_None, TargetTransformInfo::OP_None);
1661         InstructionCost ShrCost = getArithmeticInstrCost(
1662             Instruction::AShr, Ty, CostKind, Opd1Info, Opd2Info,
1663             TargetTransformInfo::OP_None, TargetTransformInfo::OP_None);
1664         return MulCost * 2 + AddCost * 2 + ShrCost * 2 + 1;
1665       }
1666     }
1667 
1668     Cost += BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Opd1Info,
1669                                           Opd2Info,
1670                                           Opd1PropInfo, Opd2PropInfo);
1671     if (Ty->isVectorTy()) {
1672       // On AArch64, vector divisions are not supported natively and are
1673       // expanded into scalar divisions of each pair of elements.
1674       Cost += getArithmeticInstrCost(Instruction::ExtractElement, Ty, CostKind,
1675                                      Opd1Info, Opd2Info, Opd1PropInfo,
1676                                      Opd2PropInfo);
1677       Cost += getArithmeticInstrCost(Instruction::InsertElement, Ty, CostKind,
1678                                      Opd1Info, Opd2Info, Opd1PropInfo,
1679                                      Opd2PropInfo);
1680       // TODO: if one of the arguments is scalar, then it's not necessary to
1681       // double the cost of handling the vector elements.
1682       Cost += Cost;
1683     }
1684     return Cost;
1685 
1686   case ISD::MUL:
1687     if (LT.second != MVT::v2i64)
1688       return (Cost + 1) * LT.first;
1689     // Since we do not have a MUL.2d instruction, a mul <2 x i64> is expensive
1690     // as elements are extracted from the vectors and the muls scalarized.
1691     // As getScalarizationOverhead is a bit too pessimistic, we estimate the
1692     // cost for a i64 vector directly here, which is:
1693     // - four i64 extracts,
1694     // - two i64 inserts, and
1695     // - two muls.
1696     // So, for a v2i64 with LT.First = 1 the cost is 8, and for a v4i64 with
1697     // LT.first = 2 the cost is 16.
1698     return LT.first * 8;
1699   case ISD::ADD:
1700   case ISD::XOR:
1701   case ISD::OR:
1702   case ISD::AND:
1703     // These nodes are marked as 'custom' for combining purposes only.
1704     // We know that they are legal. See LowerAdd in ISelLowering.
1705     return (Cost + 1) * LT.first;
1706 
1707   case ISD::FADD:
1708   case ISD::FSUB:
1709   case ISD::FMUL:
1710   case ISD::FDIV:
1711   case ISD::FNEG:
1712     // These nodes are marked as 'custom' just to lower them to SVE.
1713     // We know said lowering will incur no additional cost.
1714     if (!Ty->getScalarType()->isFP128Ty())
1715       return (Cost + 2) * LT.first;
1716 
1717     return Cost + BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Opd1Info,
1718                                                 Opd2Info,
1719                                                 Opd1PropInfo, Opd2PropInfo);
1720   }
1721 }
1722 
1723 InstructionCost AArch64TTIImpl::getAddressComputationCost(Type *Ty,
1724                                                           ScalarEvolution *SE,
1725                                                           const SCEV *Ptr) {
1726   // Address computations in vectorized code with non-consecutive addresses will
1727   // likely result in more instructions compared to scalar code where the
1728   // computation can more often be merged into the index mode. The resulting
1729   // extra micro-ops can significantly decrease throughput.
1730   unsigned NumVectorInstToHideOverhead = 10;
1731   int MaxMergeDistance = 64;
1732 
1733   if (Ty->isVectorTy() && SE &&
1734       !BaseT::isConstantStridedAccessLessThan(SE, Ptr, MaxMergeDistance + 1))
1735     return NumVectorInstToHideOverhead;
1736 
1737   // In many cases the address computation is not merged into the instruction
1738   // addressing mode.
1739   return 1;
1740 }
1741 
1742 InstructionCost AArch64TTIImpl::getCmpSelInstrCost(unsigned Opcode, Type *ValTy,
1743                                                    Type *CondTy,
1744                                                    CmpInst::Predicate VecPred,
1745                                                    TTI::TargetCostKind CostKind,
1746                                                    const Instruction *I) {
1747   // TODO: Handle other cost kinds.
1748   if (CostKind != TTI::TCK_RecipThroughput)
1749     return BaseT::getCmpSelInstrCost(Opcode, ValTy, CondTy, VecPred, CostKind,
1750                                      I);
1751 
1752   int ISD = TLI->InstructionOpcodeToISD(Opcode);
1753   // We don't lower some vector selects well that are wider than the register
1754   // width.
1755   if (isa<FixedVectorType>(ValTy) && ISD == ISD::SELECT) {
1756     // We would need this many instructions to hide the scalarization happening.
1757     const int AmortizationCost = 20;
1758 
1759     // If VecPred is not set, check if we can get a predicate from the context
1760     // instruction, if its type matches the requested ValTy.
1761     if (VecPred == CmpInst::BAD_ICMP_PREDICATE && I && I->getType() == ValTy) {
1762       CmpInst::Predicate CurrentPred;
1763       if (match(I, m_Select(m_Cmp(CurrentPred, m_Value(), m_Value()), m_Value(),
1764                             m_Value())))
1765         VecPred = CurrentPred;
1766     }
1767     // Check if we have a compare/select chain that can be lowered using CMxx &
1768     // BFI pair.
1769     if (CmpInst::isIntPredicate(VecPred)) {
1770       static const auto ValidMinMaxTys = {MVT::v8i8,  MVT::v16i8, MVT::v4i16,
1771                                           MVT::v8i16, MVT::v2i32, MVT::v4i32,
1772                                           MVT::v2i64};
1773       auto LT = TLI->getTypeLegalizationCost(DL, ValTy);
1774       if (any_of(ValidMinMaxTys, [&LT](MVT M) { return M == LT.second; }))
1775         return LT.first;
1776     }
1777 
1778     static const TypeConversionCostTblEntry
1779     VectorSelectTbl[] = {
1780       { ISD::SELECT, MVT::v16i1, MVT::v16i16, 16 },
1781       { ISD::SELECT, MVT::v8i1, MVT::v8i32, 8 },
1782       { ISD::SELECT, MVT::v16i1, MVT::v16i32, 16 },
1783       { ISD::SELECT, MVT::v4i1, MVT::v4i64, 4 * AmortizationCost },
1784       { ISD::SELECT, MVT::v8i1, MVT::v8i64, 8 * AmortizationCost },
1785       { ISD::SELECT, MVT::v16i1, MVT::v16i64, 16 * AmortizationCost }
1786     };
1787 
1788     EVT SelCondTy = TLI->getValueType(DL, CondTy);
1789     EVT SelValTy = TLI->getValueType(DL, ValTy);
1790     if (SelCondTy.isSimple() && SelValTy.isSimple()) {
1791       if (const auto *Entry = ConvertCostTableLookup(VectorSelectTbl, ISD,
1792                                                      SelCondTy.getSimpleVT(),
1793                                                      SelValTy.getSimpleVT()))
1794         return Entry->Cost;
1795     }
1796   }
1797   // The base case handles scalable vectors fine for now, since it treats the
1798   // cost as 1 * legalization cost.
1799   return BaseT::getCmpSelInstrCost(Opcode, ValTy, CondTy, VecPred, CostKind, I);
1800 }
1801 
1802 AArch64TTIImpl::TTI::MemCmpExpansionOptions
1803 AArch64TTIImpl::enableMemCmpExpansion(bool OptSize, bool IsZeroCmp) const {
1804   TTI::MemCmpExpansionOptions Options;
1805   if (ST->requiresStrictAlign()) {
1806     // TODO: Add cost modeling for strict align. Misaligned loads expand to
1807     // a bunch of instructions when strict align is enabled.
1808     return Options;
1809   }
1810   Options.AllowOverlappingLoads = true;
1811   Options.MaxNumLoads = TLI->getMaxExpandSizeMemcmp(OptSize);
1812   Options.NumLoadsPerBlock = Options.MaxNumLoads;
1813   // TODO: Though vector loads usually perform well on AArch64, in some targets
1814   // they may wake up the FP unit, which raises the power consumption.  Perhaps
1815   // they could be used with no holds barred (-O3).
1816   Options.LoadSizes = {8, 4, 2, 1};
1817   return Options;
1818 }
1819 
1820 InstructionCost
1821 AArch64TTIImpl::getMaskedMemoryOpCost(unsigned Opcode, Type *Src,
1822                                       Align Alignment, unsigned AddressSpace,
1823                                       TTI::TargetCostKind CostKind) {
1824   if (useNeonVector(Src))
1825     return BaseT::getMaskedMemoryOpCost(Opcode, Src, Alignment, AddressSpace,
1826                                         CostKind);
1827   auto LT = TLI->getTypeLegalizationCost(DL, Src);
1828   if (!LT.first.isValid())
1829     return InstructionCost::getInvalid();
1830 
1831   // The code-generator is currently not able to handle scalable vectors
1832   // of <vscale x 1 x eltty> yet, so return an invalid cost to avoid selecting
1833   // it. This change will be removed when code-generation for these types is
1834   // sufficiently reliable.
1835   if (cast<VectorType>(Src)->getElementCount() == ElementCount::getScalable(1))
1836     return InstructionCost::getInvalid();
1837 
1838   return LT.first * 2;
1839 }
1840 
1841 static unsigned getSVEGatherScatterOverhead(unsigned Opcode) {
1842   return Opcode == Instruction::Load ? SVEGatherOverhead : SVEScatterOverhead;
1843 }
1844 
1845 InstructionCost AArch64TTIImpl::getGatherScatterOpCost(
1846     unsigned Opcode, Type *DataTy, const Value *Ptr, bool VariableMask,
1847     Align Alignment, TTI::TargetCostKind CostKind, const Instruction *I) {
1848   if (useNeonVector(DataTy))
1849     return BaseT::getGatherScatterOpCost(Opcode, DataTy, Ptr, VariableMask,
1850                                          Alignment, CostKind, I);
1851   auto *VT = cast<VectorType>(DataTy);
1852   auto LT = TLI->getTypeLegalizationCost(DL, DataTy);
1853   if (!LT.first.isValid())
1854     return InstructionCost::getInvalid();
1855 
1856   // The code-generator is currently not able to handle scalable vectors
1857   // of <vscale x 1 x eltty> yet, so return an invalid cost to avoid selecting
1858   // it. This change will be removed when code-generation for these types is
1859   // sufficiently reliable.
1860   if (cast<VectorType>(DataTy)->getElementCount() ==
1861       ElementCount::getScalable(1))
1862     return InstructionCost::getInvalid();
1863 
1864   ElementCount LegalVF = LT.second.getVectorElementCount();
1865   InstructionCost MemOpCost =
1866       getMemoryOpCost(Opcode, VT->getElementType(), Alignment, 0, CostKind, I);
1867   // Add on an overhead cost for using gathers/scatters.
1868   // TODO: At the moment this is applied unilaterally for all CPUs, but at some
1869   // point we may want a per-CPU overhead.
1870   MemOpCost *= getSVEGatherScatterOverhead(Opcode);
1871   return LT.first * MemOpCost * getMaxNumElements(LegalVF);
1872 }
1873 
1874 bool AArch64TTIImpl::useNeonVector(const Type *Ty) const {
1875   return isa<FixedVectorType>(Ty) && !ST->useSVEForFixedLengthVectors();
1876 }
1877 
1878 InstructionCost AArch64TTIImpl::getMemoryOpCost(unsigned Opcode, Type *Ty,
1879                                                 MaybeAlign Alignment,
1880                                                 unsigned AddressSpace,
1881                                                 TTI::TargetCostKind CostKind,
1882                                                 const Instruction *I) {
1883   EVT VT = TLI->getValueType(DL, Ty, true);
1884   // Type legalization can't handle structs
1885   if (VT == MVT::Other)
1886     return BaseT::getMemoryOpCost(Opcode, Ty, Alignment, AddressSpace,
1887                                   CostKind);
1888 
1889   auto LT = TLI->getTypeLegalizationCost(DL, Ty);
1890   if (!LT.first.isValid())
1891     return InstructionCost::getInvalid();
1892 
1893   // The code-generator is currently not able to handle scalable vectors
1894   // of <vscale x 1 x eltty> yet, so return an invalid cost to avoid selecting
1895   // it. This change will be removed when code-generation for these types is
1896   // sufficiently reliable.
1897   if (auto *VTy = dyn_cast<ScalableVectorType>(Ty))
1898     if (VTy->getElementCount() == ElementCount::getScalable(1))
1899       return InstructionCost::getInvalid();
1900 
1901   // TODO: consider latency as well for TCK_SizeAndLatency.
1902   if (CostKind == TTI::TCK_CodeSize || CostKind == TTI::TCK_SizeAndLatency)
1903     return LT.first;
1904 
1905   if (CostKind != TTI::TCK_RecipThroughput)
1906     return 1;
1907 
1908   if (ST->isMisaligned128StoreSlow() && Opcode == Instruction::Store &&
1909       LT.second.is128BitVector() && (!Alignment || *Alignment < Align(16))) {
1910     // Unaligned stores are extremely inefficient. We don't split all
1911     // unaligned 128-bit stores because the negative impact that has shown in
1912     // practice on inlined block copy code.
1913     // We make such stores expensive so that we will only vectorize if there
1914     // are 6 other instructions getting vectorized.
1915     const int AmortizationCost = 6;
1916 
1917     return LT.first * 2 * AmortizationCost;
1918   }
1919 
1920   // Check truncating stores and extending loads.
1921   if (useNeonVector(Ty) &&
1922       Ty->getScalarSizeInBits() != LT.second.getScalarSizeInBits()) {
1923     // v4i8 types are lowered to scalar a load/store and sshll/xtn.
1924     if (VT == MVT::v4i8)
1925       return 2;
1926     // Otherwise we need to scalarize.
1927     return cast<FixedVectorType>(Ty)->getNumElements() * 2;
1928   }
1929 
1930   return LT.first;
1931 }
1932 
1933 InstructionCost AArch64TTIImpl::getInterleavedMemoryOpCost(
1934     unsigned Opcode, Type *VecTy, unsigned Factor, ArrayRef<unsigned> Indices,
1935     Align Alignment, unsigned AddressSpace, TTI::TargetCostKind CostKind,
1936     bool UseMaskForCond, bool UseMaskForGaps) {
1937   assert(Factor >= 2 && "Invalid interleave factor");
1938   auto *VecVTy = cast<FixedVectorType>(VecTy);
1939 
1940   if (!UseMaskForCond && !UseMaskForGaps &&
1941       Factor <= TLI->getMaxSupportedInterleaveFactor()) {
1942     unsigned NumElts = VecVTy->getNumElements();
1943     auto *SubVecTy =
1944         FixedVectorType::get(VecTy->getScalarType(), NumElts / Factor);
1945 
1946     // ldN/stN only support legal vector types of size 64 or 128 in bits.
1947     // Accesses having vector types that are a multiple of 128 bits can be
1948     // matched to more than one ldN/stN instruction.
1949     bool UseScalable;
1950     if (NumElts % Factor == 0 &&
1951         TLI->isLegalInterleavedAccessType(SubVecTy, DL, UseScalable))
1952       return Factor * TLI->getNumInterleavedAccesses(SubVecTy, DL, UseScalable);
1953   }
1954 
1955   return BaseT::getInterleavedMemoryOpCost(Opcode, VecTy, Factor, Indices,
1956                                            Alignment, AddressSpace, CostKind,
1957                                            UseMaskForCond, UseMaskForGaps);
1958 }
1959 
1960 InstructionCost
1961 AArch64TTIImpl::getCostOfKeepingLiveOverCall(ArrayRef<Type *> Tys) {
1962   InstructionCost Cost = 0;
1963   TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
1964   for (auto *I : Tys) {
1965     if (!I->isVectorTy())
1966       continue;
1967     if (I->getScalarSizeInBits() * cast<FixedVectorType>(I)->getNumElements() ==
1968         128)
1969       Cost += getMemoryOpCost(Instruction::Store, I, Align(128), 0, CostKind) +
1970               getMemoryOpCost(Instruction::Load, I, Align(128), 0, CostKind);
1971   }
1972   return Cost;
1973 }
1974 
1975 unsigned AArch64TTIImpl::getMaxInterleaveFactor(unsigned VF) {
1976   return ST->getMaxInterleaveFactor();
1977 }
1978 
1979 // For Falkor, we want to avoid having too many strided loads in a loop since
1980 // that can exhaust the HW prefetcher resources.  We adjust the unroller
1981 // MaxCount preference below to attempt to ensure unrolling doesn't create too
1982 // many strided loads.
1983 static void
1984 getFalkorUnrollingPreferences(Loop *L, ScalarEvolution &SE,
1985                               TargetTransformInfo::UnrollingPreferences &UP) {
1986   enum { MaxStridedLoads = 7 };
1987   auto countStridedLoads = [](Loop *L, ScalarEvolution &SE) {
1988     int StridedLoads = 0;
1989     // FIXME? We could make this more precise by looking at the CFG and
1990     // e.g. not counting loads in each side of an if-then-else diamond.
1991     for (const auto BB : L->blocks()) {
1992       for (auto &I : *BB) {
1993         LoadInst *LMemI = dyn_cast<LoadInst>(&I);
1994         if (!LMemI)
1995           continue;
1996 
1997         Value *PtrValue = LMemI->getPointerOperand();
1998         if (L->isLoopInvariant(PtrValue))
1999           continue;
2000 
2001         const SCEV *LSCEV = SE.getSCEV(PtrValue);
2002         const SCEVAddRecExpr *LSCEVAddRec = dyn_cast<SCEVAddRecExpr>(LSCEV);
2003         if (!LSCEVAddRec || !LSCEVAddRec->isAffine())
2004           continue;
2005 
2006         // FIXME? We could take pairing of unrolled load copies into account
2007         // by looking at the AddRec, but we would probably have to limit this
2008         // to loops with no stores or other memory optimization barriers.
2009         ++StridedLoads;
2010         // We've seen enough strided loads that seeing more won't make a
2011         // difference.
2012         if (StridedLoads > MaxStridedLoads / 2)
2013           return StridedLoads;
2014       }
2015     }
2016     return StridedLoads;
2017   };
2018 
2019   int StridedLoads = countStridedLoads(L, SE);
2020   LLVM_DEBUG(dbgs() << "falkor-hwpf: detected " << StridedLoads
2021                     << " strided loads\n");
2022   // Pick the largest power of 2 unroll count that won't result in too many
2023   // strided loads.
2024   if (StridedLoads) {
2025     UP.MaxCount = 1 << Log2_32(MaxStridedLoads / StridedLoads);
2026     LLVM_DEBUG(dbgs() << "falkor-hwpf: setting unroll MaxCount to "
2027                       << UP.MaxCount << '\n');
2028   }
2029 }
2030 
2031 void AArch64TTIImpl::getUnrollingPreferences(Loop *L, ScalarEvolution &SE,
2032                                              TTI::UnrollingPreferences &UP,
2033                                              OptimizationRemarkEmitter *ORE) {
2034   // Enable partial unrolling and runtime unrolling.
2035   BaseT::getUnrollingPreferences(L, SE, UP, ORE);
2036 
2037   UP.UpperBound = true;
2038 
2039   // For inner loop, it is more likely to be a hot one, and the runtime check
2040   // can be promoted out from LICM pass, so the overhead is less, let's try
2041   // a larger threshold to unroll more loops.
2042   if (L->getLoopDepth() > 1)
2043     UP.PartialThreshold *= 2;
2044 
2045   // Disable partial & runtime unrolling on -Os.
2046   UP.PartialOptSizeThreshold = 0;
2047 
2048   if (ST->getProcFamily() == AArch64Subtarget::Falkor &&
2049       EnableFalkorHWPFUnrollFix)
2050     getFalkorUnrollingPreferences(L, SE, UP);
2051 
2052   // Scan the loop: don't unroll loops with calls as this could prevent
2053   // inlining. Don't unroll vector loops either, as they don't benefit much from
2054   // unrolling.
2055   for (auto *BB : L->getBlocks()) {
2056     for (auto &I : *BB) {
2057       // Don't unroll vectorised loop.
2058       if (I.getType()->isVectorTy())
2059         return;
2060 
2061       if (isa<CallInst>(I) || isa<InvokeInst>(I)) {
2062         if (const Function *F = cast<CallBase>(I).getCalledFunction()) {
2063           if (!isLoweredToCall(F))
2064             continue;
2065         }
2066         return;
2067       }
2068     }
2069   }
2070 
2071   // Enable runtime unrolling for in-order models
2072   // If mcpu is omitted, getProcFamily() returns AArch64Subtarget::Others, so by
2073   // checking for that case, we can ensure that the default behaviour is
2074   // unchanged
2075   if (ST->getProcFamily() != AArch64Subtarget::Others &&
2076       !ST->getSchedModel().isOutOfOrder()) {
2077     UP.Runtime = true;
2078     UP.Partial = true;
2079     UP.UnrollRemainder = true;
2080     UP.DefaultUnrollRuntimeCount = 4;
2081 
2082     UP.UnrollAndJam = true;
2083     UP.UnrollAndJamInnerLoopThreshold = 60;
2084   }
2085 }
2086 
2087 void AArch64TTIImpl::getPeelingPreferences(Loop *L, ScalarEvolution &SE,
2088                                            TTI::PeelingPreferences &PP) {
2089   BaseT::getPeelingPreferences(L, SE, PP);
2090 }
2091 
2092 Value *AArch64TTIImpl::getOrCreateResultFromMemIntrinsic(IntrinsicInst *Inst,
2093                                                          Type *ExpectedType) {
2094   switch (Inst->getIntrinsicID()) {
2095   default:
2096     return nullptr;
2097   case Intrinsic::aarch64_neon_st2:
2098   case Intrinsic::aarch64_neon_st3:
2099   case Intrinsic::aarch64_neon_st4: {
2100     // Create a struct type
2101     StructType *ST = dyn_cast<StructType>(ExpectedType);
2102     if (!ST)
2103       return nullptr;
2104     unsigned NumElts = Inst->arg_size() - 1;
2105     if (ST->getNumElements() != NumElts)
2106       return nullptr;
2107     for (unsigned i = 0, e = NumElts; i != e; ++i) {
2108       if (Inst->getArgOperand(i)->getType() != ST->getElementType(i))
2109         return nullptr;
2110     }
2111     Value *Res = UndefValue::get(ExpectedType);
2112     IRBuilder<> Builder(Inst);
2113     for (unsigned i = 0, e = NumElts; i != e; ++i) {
2114       Value *L = Inst->getArgOperand(i);
2115       Res = Builder.CreateInsertValue(Res, L, i);
2116     }
2117     return Res;
2118   }
2119   case Intrinsic::aarch64_neon_ld2:
2120   case Intrinsic::aarch64_neon_ld3:
2121   case Intrinsic::aarch64_neon_ld4:
2122     if (Inst->getType() == ExpectedType)
2123       return Inst;
2124     return nullptr;
2125   }
2126 }
2127 
2128 bool AArch64TTIImpl::getTgtMemIntrinsic(IntrinsicInst *Inst,
2129                                         MemIntrinsicInfo &Info) {
2130   switch (Inst->getIntrinsicID()) {
2131   default:
2132     break;
2133   case Intrinsic::aarch64_neon_ld2:
2134   case Intrinsic::aarch64_neon_ld3:
2135   case Intrinsic::aarch64_neon_ld4:
2136     Info.ReadMem = true;
2137     Info.WriteMem = false;
2138     Info.PtrVal = Inst->getArgOperand(0);
2139     break;
2140   case Intrinsic::aarch64_neon_st2:
2141   case Intrinsic::aarch64_neon_st3:
2142   case Intrinsic::aarch64_neon_st4:
2143     Info.ReadMem = false;
2144     Info.WriteMem = true;
2145     Info.PtrVal = Inst->getArgOperand(Inst->arg_size() - 1);
2146     break;
2147   }
2148 
2149   switch (Inst->getIntrinsicID()) {
2150   default:
2151     return false;
2152   case Intrinsic::aarch64_neon_ld2:
2153   case Intrinsic::aarch64_neon_st2:
2154     Info.MatchingId = VECTOR_LDST_TWO_ELEMENTS;
2155     break;
2156   case Intrinsic::aarch64_neon_ld3:
2157   case Intrinsic::aarch64_neon_st3:
2158     Info.MatchingId = VECTOR_LDST_THREE_ELEMENTS;
2159     break;
2160   case Intrinsic::aarch64_neon_ld4:
2161   case Intrinsic::aarch64_neon_st4:
2162     Info.MatchingId = VECTOR_LDST_FOUR_ELEMENTS;
2163     break;
2164   }
2165   return true;
2166 }
2167 
2168 /// See if \p I should be considered for address type promotion. We check if \p
2169 /// I is a sext with right type and used in memory accesses. If it used in a
2170 /// "complex" getelementptr, we allow it to be promoted without finding other
2171 /// sext instructions that sign extended the same initial value. A getelementptr
2172 /// is considered as "complex" if it has more than 2 operands.
2173 bool AArch64TTIImpl::shouldConsiderAddressTypePromotion(
2174     const Instruction &I, bool &AllowPromotionWithoutCommonHeader) {
2175   bool Considerable = false;
2176   AllowPromotionWithoutCommonHeader = false;
2177   if (!isa<SExtInst>(&I))
2178     return false;
2179   Type *ConsideredSExtType =
2180       Type::getInt64Ty(I.getParent()->getParent()->getContext());
2181   if (I.getType() != ConsideredSExtType)
2182     return false;
2183   // See if the sext is the one with the right type and used in at least one
2184   // GetElementPtrInst.
2185   for (const User *U : I.users()) {
2186     if (const GetElementPtrInst *GEPInst = dyn_cast<GetElementPtrInst>(U)) {
2187       Considerable = true;
2188       // A getelementptr is considered as "complex" if it has more than 2
2189       // operands. We will promote a SExt used in such complex GEP as we
2190       // expect some computation to be merged if they are done on 64 bits.
2191       if (GEPInst->getNumOperands() > 2) {
2192         AllowPromotionWithoutCommonHeader = true;
2193         break;
2194       }
2195     }
2196   }
2197   return Considerable;
2198 }
2199 
2200 bool AArch64TTIImpl::isLegalToVectorizeReduction(
2201     const RecurrenceDescriptor &RdxDesc, ElementCount VF) const {
2202   if (!VF.isScalable())
2203     return true;
2204 
2205   Type *Ty = RdxDesc.getRecurrenceType();
2206   if (Ty->isBFloatTy() || !isElementTypeLegalForScalableVector(Ty))
2207     return false;
2208 
2209   switch (RdxDesc.getRecurrenceKind()) {
2210   case RecurKind::Add:
2211   case RecurKind::FAdd:
2212   case RecurKind::And:
2213   case RecurKind::Or:
2214   case RecurKind::Xor:
2215   case RecurKind::SMin:
2216   case RecurKind::SMax:
2217   case RecurKind::UMin:
2218   case RecurKind::UMax:
2219   case RecurKind::FMin:
2220   case RecurKind::FMax:
2221   case RecurKind::SelectICmp:
2222   case RecurKind::SelectFCmp:
2223   case RecurKind::FMulAdd:
2224     return true;
2225   default:
2226     return false;
2227   }
2228 }
2229 
2230 InstructionCost
2231 AArch64TTIImpl::getMinMaxReductionCost(VectorType *Ty, VectorType *CondTy,
2232                                        bool IsUnsigned,
2233                                        TTI::TargetCostKind CostKind) {
2234   std::pair<InstructionCost, MVT> LT = TLI->getTypeLegalizationCost(DL, Ty);
2235 
2236   if (LT.second.getScalarType() == MVT::f16 && !ST->hasFullFP16())
2237     return BaseT::getMinMaxReductionCost(Ty, CondTy, IsUnsigned, CostKind);
2238 
2239   assert((isa<ScalableVectorType>(Ty) == isa<ScalableVectorType>(CondTy)) &&
2240          "Both vector needs to be equally scalable");
2241 
2242   InstructionCost LegalizationCost = 0;
2243   if (LT.first > 1) {
2244     Type *LegalVTy = EVT(LT.second).getTypeForEVT(Ty->getContext());
2245     unsigned MinMaxOpcode =
2246         Ty->isFPOrFPVectorTy()
2247             ? Intrinsic::maxnum
2248             : (IsUnsigned ? Intrinsic::umin : Intrinsic::smin);
2249     IntrinsicCostAttributes Attrs(MinMaxOpcode, LegalVTy, {LegalVTy, LegalVTy});
2250     LegalizationCost = getIntrinsicInstrCost(Attrs, CostKind) * (LT.first - 1);
2251   }
2252 
2253   return LegalizationCost + /*Cost of horizontal reduction*/ 2;
2254 }
2255 
2256 InstructionCost AArch64TTIImpl::getArithmeticReductionCostSVE(
2257     unsigned Opcode, VectorType *ValTy, TTI::TargetCostKind CostKind) {
2258   std::pair<InstructionCost, MVT> LT = TLI->getTypeLegalizationCost(DL, ValTy);
2259   InstructionCost LegalizationCost = 0;
2260   if (LT.first > 1) {
2261     Type *LegalVTy = EVT(LT.second).getTypeForEVT(ValTy->getContext());
2262     LegalizationCost = getArithmeticInstrCost(Opcode, LegalVTy, CostKind);
2263     LegalizationCost *= LT.first - 1;
2264   }
2265 
2266   int ISD = TLI->InstructionOpcodeToISD(Opcode);
2267   assert(ISD && "Invalid opcode");
2268   // Add the final reduction cost for the legal horizontal reduction
2269   switch (ISD) {
2270   case ISD::ADD:
2271   case ISD::AND:
2272   case ISD::OR:
2273   case ISD::XOR:
2274   case ISD::FADD:
2275     return LegalizationCost + 2;
2276   default:
2277     return InstructionCost::getInvalid();
2278   }
2279 }
2280 
2281 InstructionCost
2282 AArch64TTIImpl::getArithmeticReductionCost(unsigned Opcode, VectorType *ValTy,
2283                                            Optional<FastMathFlags> FMF,
2284                                            TTI::TargetCostKind CostKind) {
2285   if (TTI::requiresOrderedReduction(FMF)) {
2286     if (auto *FixedVTy = dyn_cast<FixedVectorType>(ValTy)) {
2287       InstructionCost BaseCost =
2288           BaseT::getArithmeticReductionCost(Opcode, ValTy, FMF, CostKind);
2289       // Add on extra cost to reflect the extra overhead on some CPUs. We still
2290       // end up vectorizing for more computationally intensive loops.
2291       return BaseCost + FixedVTy->getNumElements();
2292     }
2293 
2294     if (Opcode != Instruction::FAdd)
2295       return InstructionCost::getInvalid();
2296 
2297     auto *VTy = cast<ScalableVectorType>(ValTy);
2298     InstructionCost Cost =
2299         getArithmeticInstrCost(Opcode, VTy->getScalarType(), CostKind);
2300     Cost *= getMaxNumElements(VTy->getElementCount());
2301     return Cost;
2302   }
2303 
2304   if (isa<ScalableVectorType>(ValTy))
2305     return getArithmeticReductionCostSVE(Opcode, ValTy, CostKind);
2306 
2307   std::pair<InstructionCost, MVT> LT = TLI->getTypeLegalizationCost(DL, ValTy);
2308   MVT MTy = LT.second;
2309   int ISD = TLI->InstructionOpcodeToISD(Opcode);
2310   assert(ISD && "Invalid opcode");
2311 
2312   // Horizontal adds can use the 'addv' instruction. We model the cost of these
2313   // instructions as twice a normal vector add, plus 1 for each legalization
2314   // step (LT.first). This is the only arithmetic vector reduction operation for
2315   // which we have an instruction.
2316   // OR, XOR and AND costs should match the codegen from:
2317   // OR: llvm/test/CodeGen/AArch64/reduce-or.ll
2318   // XOR: llvm/test/CodeGen/AArch64/reduce-xor.ll
2319   // AND: llvm/test/CodeGen/AArch64/reduce-and.ll
2320   static const CostTblEntry CostTblNoPairwise[]{
2321       {ISD::ADD, MVT::v8i8,   2},
2322       {ISD::ADD, MVT::v16i8,  2},
2323       {ISD::ADD, MVT::v4i16,  2},
2324       {ISD::ADD, MVT::v8i16,  2},
2325       {ISD::ADD, MVT::v4i32,  2},
2326       {ISD::OR,  MVT::v8i8,  15},
2327       {ISD::OR,  MVT::v16i8, 17},
2328       {ISD::OR,  MVT::v4i16,  7},
2329       {ISD::OR,  MVT::v8i16,  9},
2330       {ISD::OR,  MVT::v2i32,  3},
2331       {ISD::OR,  MVT::v4i32,  5},
2332       {ISD::OR,  MVT::v2i64,  3},
2333       {ISD::XOR, MVT::v8i8,  15},
2334       {ISD::XOR, MVT::v16i8, 17},
2335       {ISD::XOR, MVT::v4i16,  7},
2336       {ISD::XOR, MVT::v8i16,  9},
2337       {ISD::XOR, MVT::v2i32,  3},
2338       {ISD::XOR, MVT::v4i32,  5},
2339       {ISD::XOR, MVT::v2i64,  3},
2340       {ISD::AND, MVT::v8i8,  15},
2341       {ISD::AND, MVT::v16i8, 17},
2342       {ISD::AND, MVT::v4i16,  7},
2343       {ISD::AND, MVT::v8i16,  9},
2344       {ISD::AND, MVT::v2i32,  3},
2345       {ISD::AND, MVT::v4i32,  5},
2346       {ISD::AND, MVT::v2i64,  3},
2347   };
2348   switch (ISD) {
2349   default:
2350     break;
2351   case ISD::ADD:
2352     if (const auto *Entry = CostTableLookup(CostTblNoPairwise, ISD, MTy))
2353       return (LT.first - 1) + Entry->Cost;
2354     break;
2355   case ISD::XOR:
2356   case ISD::AND:
2357   case ISD::OR:
2358     const auto *Entry = CostTableLookup(CostTblNoPairwise, ISD, MTy);
2359     if (!Entry)
2360       break;
2361     auto *ValVTy = cast<FixedVectorType>(ValTy);
2362     if (!ValVTy->getElementType()->isIntegerTy(1) &&
2363         MTy.getVectorNumElements() <= ValVTy->getNumElements() &&
2364         isPowerOf2_32(ValVTy->getNumElements())) {
2365       InstructionCost ExtraCost = 0;
2366       if (LT.first != 1) {
2367         // Type needs to be split, so there is an extra cost of LT.first - 1
2368         // arithmetic ops.
2369         auto *Ty = FixedVectorType::get(ValTy->getElementType(),
2370                                         MTy.getVectorNumElements());
2371         ExtraCost = getArithmeticInstrCost(Opcode, Ty, CostKind);
2372         ExtraCost *= LT.first - 1;
2373       }
2374       return Entry->Cost + ExtraCost;
2375     }
2376     break;
2377   }
2378   return BaseT::getArithmeticReductionCost(Opcode, ValTy, FMF, CostKind);
2379 }
2380 
2381 InstructionCost AArch64TTIImpl::getSpliceCost(VectorType *Tp, int Index) {
2382   static const CostTblEntry ShuffleTbl[] = {
2383       { TTI::SK_Splice, MVT::nxv16i8,  1 },
2384       { TTI::SK_Splice, MVT::nxv8i16,  1 },
2385       { TTI::SK_Splice, MVT::nxv4i32,  1 },
2386       { TTI::SK_Splice, MVT::nxv2i64,  1 },
2387       { TTI::SK_Splice, MVT::nxv2f16,  1 },
2388       { TTI::SK_Splice, MVT::nxv4f16,  1 },
2389       { TTI::SK_Splice, MVT::nxv8f16,  1 },
2390       { TTI::SK_Splice, MVT::nxv2bf16, 1 },
2391       { TTI::SK_Splice, MVT::nxv4bf16, 1 },
2392       { TTI::SK_Splice, MVT::nxv8bf16, 1 },
2393       { TTI::SK_Splice, MVT::nxv2f32,  1 },
2394       { TTI::SK_Splice, MVT::nxv4f32,  1 },
2395       { TTI::SK_Splice, MVT::nxv2f64,  1 },
2396   };
2397 
2398   std::pair<InstructionCost, MVT> LT = TLI->getTypeLegalizationCost(DL, Tp);
2399   Type *LegalVTy = EVT(LT.second).getTypeForEVT(Tp->getContext());
2400   TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
2401   EVT PromotedVT = LT.second.getScalarType() == MVT::i1
2402                        ? TLI->getPromotedVTForPredicate(EVT(LT.second))
2403                        : LT.second;
2404   Type *PromotedVTy = EVT(PromotedVT).getTypeForEVT(Tp->getContext());
2405   InstructionCost LegalizationCost = 0;
2406   if (Index < 0) {
2407     LegalizationCost =
2408         getCmpSelInstrCost(Instruction::ICmp, PromotedVTy, PromotedVTy,
2409                            CmpInst::BAD_ICMP_PREDICATE, CostKind) +
2410         getCmpSelInstrCost(Instruction::Select, PromotedVTy, LegalVTy,
2411                            CmpInst::BAD_ICMP_PREDICATE, CostKind);
2412   }
2413 
2414   // Predicated splice are promoted when lowering. See AArch64ISelLowering.cpp
2415   // Cost performed on a promoted type.
2416   if (LT.second.getScalarType() == MVT::i1) {
2417     LegalizationCost +=
2418         getCastInstrCost(Instruction::ZExt, PromotedVTy, LegalVTy,
2419                          TTI::CastContextHint::None, CostKind) +
2420         getCastInstrCost(Instruction::Trunc, LegalVTy, PromotedVTy,
2421                          TTI::CastContextHint::None, CostKind);
2422   }
2423   const auto *Entry =
2424       CostTableLookup(ShuffleTbl, TTI::SK_Splice, PromotedVT.getSimpleVT());
2425   assert(Entry && "Illegal Type for Splice");
2426   LegalizationCost += Entry->Cost;
2427   return LegalizationCost * LT.first;
2428 }
2429 
2430 InstructionCost AArch64TTIImpl::getShuffleCost(TTI::ShuffleKind Kind,
2431                                                VectorType *Tp,
2432                                                ArrayRef<int> Mask, int Index,
2433                                                VectorType *SubTp) {
2434   Kind = improveShuffleKindFromMask(Kind, Mask);
2435   if (Kind == TTI::SK_Broadcast || Kind == TTI::SK_Transpose ||
2436       Kind == TTI::SK_Select || Kind == TTI::SK_PermuteSingleSrc ||
2437       Kind == TTI::SK_Reverse) {
2438     static const CostTblEntry ShuffleTbl[] = {
2439       // Broadcast shuffle kinds can be performed with 'dup'.
2440       { TTI::SK_Broadcast, MVT::v8i8,  1 },
2441       { TTI::SK_Broadcast, MVT::v16i8, 1 },
2442       { TTI::SK_Broadcast, MVT::v4i16, 1 },
2443       { TTI::SK_Broadcast, MVT::v8i16, 1 },
2444       { TTI::SK_Broadcast, MVT::v2i32, 1 },
2445       { TTI::SK_Broadcast, MVT::v4i32, 1 },
2446       { TTI::SK_Broadcast, MVT::v2i64, 1 },
2447       { TTI::SK_Broadcast, MVT::v2f32, 1 },
2448       { TTI::SK_Broadcast, MVT::v4f32, 1 },
2449       { TTI::SK_Broadcast, MVT::v2f64, 1 },
2450       // Transpose shuffle kinds can be performed with 'trn1/trn2' and
2451       // 'zip1/zip2' instructions.
2452       { TTI::SK_Transpose, MVT::v8i8,  1 },
2453       { TTI::SK_Transpose, MVT::v16i8, 1 },
2454       { TTI::SK_Transpose, MVT::v4i16, 1 },
2455       { TTI::SK_Transpose, MVT::v8i16, 1 },
2456       { TTI::SK_Transpose, MVT::v2i32, 1 },
2457       { TTI::SK_Transpose, MVT::v4i32, 1 },
2458       { TTI::SK_Transpose, MVT::v2i64, 1 },
2459       { TTI::SK_Transpose, MVT::v2f32, 1 },
2460       { TTI::SK_Transpose, MVT::v4f32, 1 },
2461       { TTI::SK_Transpose, MVT::v2f64, 1 },
2462       // Select shuffle kinds.
2463       // TODO: handle vXi8/vXi16.
2464       { TTI::SK_Select, MVT::v2i32, 1 }, // mov.
2465       { TTI::SK_Select, MVT::v4i32, 2 }, // rev+trn (or similar).
2466       { TTI::SK_Select, MVT::v2i64, 1 }, // mov.
2467       { TTI::SK_Select, MVT::v2f32, 1 }, // mov.
2468       { TTI::SK_Select, MVT::v4f32, 2 }, // rev+trn (or similar).
2469       { TTI::SK_Select, MVT::v2f64, 1 }, // mov.
2470       // PermuteSingleSrc shuffle kinds.
2471       { TTI::SK_PermuteSingleSrc, MVT::v2i32, 1 }, // mov.
2472       { TTI::SK_PermuteSingleSrc, MVT::v4i32, 3 }, // perfectshuffle worst case.
2473       { TTI::SK_PermuteSingleSrc, MVT::v2i64, 1 }, // mov.
2474       { TTI::SK_PermuteSingleSrc, MVT::v2f32, 1 }, // mov.
2475       { TTI::SK_PermuteSingleSrc, MVT::v4f32, 3 }, // perfectshuffle worst case.
2476       { TTI::SK_PermuteSingleSrc, MVT::v2f64, 1 }, // mov.
2477       { TTI::SK_PermuteSingleSrc, MVT::v4i16, 3 }, // perfectshuffle worst case.
2478       { TTI::SK_PermuteSingleSrc, MVT::v4f16, 3 }, // perfectshuffle worst case.
2479       { TTI::SK_PermuteSingleSrc, MVT::v4bf16, 3 }, // perfectshuffle worst case.
2480       { TTI::SK_PermuteSingleSrc, MVT::v8i16, 8 }, // constpool + load + tbl
2481       { TTI::SK_PermuteSingleSrc, MVT::v8f16, 8 }, // constpool + load + tbl
2482       { TTI::SK_PermuteSingleSrc, MVT::v8bf16, 8 }, // constpool + load + tbl
2483       { TTI::SK_PermuteSingleSrc, MVT::v8i8, 8 }, // constpool + load + tbl
2484       { TTI::SK_PermuteSingleSrc, MVT::v16i8, 8 }, // constpool + load + tbl
2485       // Reverse can be lowered with `rev`.
2486       { TTI::SK_Reverse, MVT::v2i32, 1 }, // mov.
2487       { TTI::SK_Reverse, MVT::v4i32, 2 }, // REV64; EXT
2488       { TTI::SK_Reverse, MVT::v2i64, 1 }, // mov.
2489       { TTI::SK_Reverse, MVT::v2f32, 1 }, // mov.
2490       { TTI::SK_Reverse, MVT::v4f32, 2 }, // REV64; EXT
2491       { TTI::SK_Reverse, MVT::v2f64, 1 }, // mov.
2492       // Broadcast shuffle kinds for scalable vectors
2493       { TTI::SK_Broadcast, MVT::nxv16i8,  1 },
2494       { TTI::SK_Broadcast, MVT::nxv8i16,  1 },
2495       { TTI::SK_Broadcast, MVT::nxv4i32,  1 },
2496       { TTI::SK_Broadcast, MVT::nxv2i64,  1 },
2497       { TTI::SK_Broadcast, MVT::nxv2f16,  1 },
2498       { TTI::SK_Broadcast, MVT::nxv4f16,  1 },
2499       { TTI::SK_Broadcast, MVT::nxv8f16,  1 },
2500       { TTI::SK_Broadcast, MVT::nxv2bf16, 1 },
2501       { TTI::SK_Broadcast, MVT::nxv4bf16, 1 },
2502       { TTI::SK_Broadcast, MVT::nxv8bf16, 1 },
2503       { TTI::SK_Broadcast, MVT::nxv2f32,  1 },
2504       { TTI::SK_Broadcast, MVT::nxv4f32,  1 },
2505       { TTI::SK_Broadcast, MVT::nxv2f64,  1 },
2506       { TTI::SK_Broadcast, MVT::nxv16i1,  1 },
2507       { TTI::SK_Broadcast, MVT::nxv8i1,   1 },
2508       { TTI::SK_Broadcast, MVT::nxv4i1,   1 },
2509       { TTI::SK_Broadcast, MVT::nxv2i1,   1 },
2510       // Handle the cases for vector.reverse with scalable vectors
2511       { TTI::SK_Reverse, MVT::nxv16i8,  1 },
2512       { TTI::SK_Reverse, MVT::nxv8i16,  1 },
2513       { TTI::SK_Reverse, MVT::nxv4i32,  1 },
2514       { TTI::SK_Reverse, MVT::nxv2i64,  1 },
2515       { TTI::SK_Reverse, MVT::nxv2f16,  1 },
2516       { TTI::SK_Reverse, MVT::nxv4f16,  1 },
2517       { TTI::SK_Reverse, MVT::nxv8f16,  1 },
2518       { TTI::SK_Reverse, MVT::nxv2bf16, 1 },
2519       { TTI::SK_Reverse, MVT::nxv4bf16, 1 },
2520       { TTI::SK_Reverse, MVT::nxv8bf16, 1 },
2521       { TTI::SK_Reverse, MVT::nxv2f32,  1 },
2522       { TTI::SK_Reverse, MVT::nxv4f32,  1 },
2523       { TTI::SK_Reverse, MVT::nxv2f64,  1 },
2524       { TTI::SK_Reverse, MVT::nxv16i1,  1 },
2525       { TTI::SK_Reverse, MVT::nxv8i1,   1 },
2526       { TTI::SK_Reverse, MVT::nxv4i1,   1 },
2527       { TTI::SK_Reverse, MVT::nxv2i1,   1 },
2528     };
2529     std::pair<InstructionCost, MVT> LT = TLI->getTypeLegalizationCost(DL, Tp);
2530     if (const auto *Entry = CostTableLookup(ShuffleTbl, Kind, LT.second))
2531       return LT.first * Entry->Cost;
2532   }
2533   if (Kind == TTI::SK_Splice && isa<ScalableVectorType>(Tp))
2534     return getSpliceCost(Tp, Index);
2535   return BaseT::getShuffleCost(Kind, Tp, Mask, Index, SubTp);
2536 }
2537