1 //===-- AArch64TargetTransformInfo.cpp - AArch64 specific TTI -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "AArch64TargetTransformInfo.h"
10 #include "AArch64ExpandImm.h"
11 #include "MCTargetDesc/AArch64AddressingModes.h"
12 #include "llvm/Analysis/LoopInfo.h"
13 #include "llvm/Analysis/TargetTransformInfo.h"
14 #include "llvm/CodeGen/BasicTTIImpl.h"
15 #include "llvm/CodeGen/CostTable.h"
16 #include "llvm/CodeGen/TargetLowering.h"
17 #include "llvm/IR/IntrinsicInst.h"
18 #include "llvm/IR/IntrinsicsAArch64.h"
19 #include "llvm/IR/PatternMatch.h"
20 #include "llvm/Support/Debug.h"
21 #include "llvm/Transforms/InstCombine/InstCombiner.h"
22 #include <algorithm>
23 using namespace llvm;
24 using namespace llvm::PatternMatch;
25 
26 #define DEBUG_TYPE "aarch64tti"
27 
28 static cl::opt<bool> EnableFalkorHWPFUnrollFix("enable-falkor-hwpf-unroll-fix",
29                                                cl::init(true), cl::Hidden);
30 
31 bool AArch64TTIImpl::areInlineCompatible(const Function *Caller,
32                                          const Function *Callee) const {
33   const TargetMachine &TM = getTLI()->getTargetMachine();
34 
35   const FeatureBitset &CallerBits =
36       TM.getSubtargetImpl(*Caller)->getFeatureBits();
37   const FeatureBitset &CalleeBits =
38       TM.getSubtargetImpl(*Callee)->getFeatureBits();
39 
40   // Inline a callee if its target-features are a subset of the callers
41   // target-features.
42   return (CallerBits & CalleeBits) == CalleeBits;
43 }
44 
45 /// Calculate the cost of materializing a 64-bit value. This helper
46 /// method might only calculate a fraction of a larger immediate. Therefore it
47 /// is valid to return a cost of ZERO.
48 InstructionCost AArch64TTIImpl::getIntImmCost(int64_t Val) {
49   // Check if the immediate can be encoded within an instruction.
50   if (Val == 0 || AArch64_AM::isLogicalImmediate(Val, 64))
51     return 0;
52 
53   if (Val < 0)
54     Val = ~Val;
55 
56   // Calculate how many moves we will need to materialize this constant.
57   SmallVector<AArch64_IMM::ImmInsnModel, 4> Insn;
58   AArch64_IMM::expandMOVImm(Val, 64, Insn);
59   return Insn.size();
60 }
61 
62 /// Calculate the cost of materializing the given constant.
63 InstructionCost AArch64TTIImpl::getIntImmCost(const APInt &Imm, Type *Ty,
64                                               TTI::TargetCostKind CostKind) {
65   assert(Ty->isIntegerTy());
66 
67   unsigned BitSize = Ty->getPrimitiveSizeInBits();
68   if (BitSize == 0)
69     return ~0U;
70 
71   // Sign-extend all constants to a multiple of 64-bit.
72   APInt ImmVal = Imm;
73   if (BitSize & 0x3f)
74     ImmVal = Imm.sext((BitSize + 63) & ~0x3fU);
75 
76   // Split the constant into 64-bit chunks and calculate the cost for each
77   // chunk.
78   InstructionCost Cost = 0;
79   for (unsigned ShiftVal = 0; ShiftVal < BitSize; ShiftVal += 64) {
80     APInt Tmp = ImmVal.ashr(ShiftVal).sextOrTrunc(64);
81     int64_t Val = Tmp.getSExtValue();
82     Cost += getIntImmCost(Val);
83   }
84   // We need at least one instruction to materialze the constant.
85   return std::max<InstructionCost>(1, Cost);
86 }
87 
88 InstructionCost AArch64TTIImpl::getIntImmCostInst(unsigned Opcode, unsigned Idx,
89                                                   const APInt &Imm, Type *Ty,
90                                                   TTI::TargetCostKind CostKind,
91                                                   Instruction *Inst) {
92   assert(Ty->isIntegerTy());
93 
94   unsigned BitSize = Ty->getPrimitiveSizeInBits();
95   // There is no cost model for constants with a bit size of 0. Return TCC_Free
96   // here, so that constant hoisting will ignore this constant.
97   if (BitSize == 0)
98     return TTI::TCC_Free;
99 
100   unsigned ImmIdx = ~0U;
101   switch (Opcode) {
102   default:
103     return TTI::TCC_Free;
104   case Instruction::GetElementPtr:
105     // Always hoist the base address of a GetElementPtr.
106     if (Idx == 0)
107       return 2 * TTI::TCC_Basic;
108     return TTI::TCC_Free;
109   case Instruction::Store:
110     ImmIdx = 0;
111     break;
112   case Instruction::Add:
113   case Instruction::Sub:
114   case Instruction::Mul:
115   case Instruction::UDiv:
116   case Instruction::SDiv:
117   case Instruction::URem:
118   case Instruction::SRem:
119   case Instruction::And:
120   case Instruction::Or:
121   case Instruction::Xor:
122   case Instruction::ICmp:
123     ImmIdx = 1;
124     break;
125   // Always return TCC_Free for the shift value of a shift instruction.
126   case Instruction::Shl:
127   case Instruction::LShr:
128   case Instruction::AShr:
129     if (Idx == 1)
130       return TTI::TCC_Free;
131     break;
132   case Instruction::Trunc:
133   case Instruction::ZExt:
134   case Instruction::SExt:
135   case Instruction::IntToPtr:
136   case Instruction::PtrToInt:
137   case Instruction::BitCast:
138   case Instruction::PHI:
139   case Instruction::Call:
140   case Instruction::Select:
141   case Instruction::Ret:
142   case Instruction::Load:
143     break;
144   }
145 
146   if (Idx == ImmIdx) {
147     int NumConstants = (BitSize + 63) / 64;
148     InstructionCost Cost = AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind);
149     return (Cost <= NumConstants * TTI::TCC_Basic)
150                ? static_cast<int>(TTI::TCC_Free)
151                : Cost;
152   }
153   return AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind);
154 }
155 
156 InstructionCost
157 AArch64TTIImpl::getIntImmCostIntrin(Intrinsic::ID IID, unsigned Idx,
158                                     const APInt &Imm, Type *Ty,
159                                     TTI::TargetCostKind CostKind) {
160   assert(Ty->isIntegerTy());
161 
162   unsigned BitSize = Ty->getPrimitiveSizeInBits();
163   // There is no cost model for constants with a bit size of 0. Return TCC_Free
164   // here, so that constant hoisting will ignore this constant.
165   if (BitSize == 0)
166     return TTI::TCC_Free;
167 
168   // Most (all?) AArch64 intrinsics do not support folding immediates into the
169   // selected instruction, so we compute the materialization cost for the
170   // immediate directly.
171   if (IID >= Intrinsic::aarch64_addg && IID <= Intrinsic::aarch64_udiv)
172     return AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind);
173 
174   switch (IID) {
175   default:
176     return TTI::TCC_Free;
177   case Intrinsic::sadd_with_overflow:
178   case Intrinsic::uadd_with_overflow:
179   case Intrinsic::ssub_with_overflow:
180   case Intrinsic::usub_with_overflow:
181   case Intrinsic::smul_with_overflow:
182   case Intrinsic::umul_with_overflow:
183     if (Idx == 1) {
184       int NumConstants = (BitSize + 63) / 64;
185       InstructionCost Cost = AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind);
186       return (Cost <= NumConstants * TTI::TCC_Basic)
187                  ? static_cast<int>(TTI::TCC_Free)
188                  : Cost;
189     }
190     break;
191   case Intrinsic::experimental_stackmap:
192     if ((Idx < 2) || (Imm.getBitWidth() <= 64 && isInt<64>(Imm.getSExtValue())))
193       return TTI::TCC_Free;
194     break;
195   case Intrinsic::experimental_patchpoint_void:
196   case Intrinsic::experimental_patchpoint_i64:
197     if ((Idx < 4) || (Imm.getBitWidth() <= 64 && isInt<64>(Imm.getSExtValue())))
198       return TTI::TCC_Free;
199     break;
200   case Intrinsic::experimental_gc_statepoint:
201     if ((Idx < 5) || (Imm.getBitWidth() <= 64 && isInt<64>(Imm.getSExtValue())))
202       return TTI::TCC_Free;
203     break;
204   }
205   return AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind);
206 }
207 
208 TargetTransformInfo::PopcntSupportKind
209 AArch64TTIImpl::getPopcntSupport(unsigned TyWidth) {
210   assert(isPowerOf2_32(TyWidth) && "Ty width must be power of 2");
211   if (TyWidth == 32 || TyWidth == 64)
212     return TTI::PSK_FastHardware;
213   // TODO: AArch64TargetLowering::LowerCTPOP() supports 128bit popcount.
214   return TTI::PSK_Software;
215 }
216 
217 InstructionCost
218 AArch64TTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
219                                       TTI::TargetCostKind CostKind) {
220   auto *RetTy = ICA.getReturnType();
221   switch (ICA.getID()) {
222   case Intrinsic::umin:
223   case Intrinsic::umax: {
224     auto LT = TLI->getTypeLegalizationCost(DL, RetTy);
225     // umin(x,y) -> sub(x,usubsat(x,y))
226     // umax(x,y) -> add(x,usubsat(y,x))
227     if (LT.second == MVT::v2i64)
228       return LT.first * 2;
229     LLVM_FALLTHROUGH;
230   }
231   case Intrinsic::smin:
232   case Intrinsic::smax: {
233     static const auto ValidMinMaxTys = {MVT::v8i8,  MVT::v16i8, MVT::v4i16,
234                                         MVT::v8i16, MVT::v2i32, MVT::v4i32};
235     auto LT = TLI->getTypeLegalizationCost(DL, RetTy);
236     if (any_of(ValidMinMaxTys, [&LT](MVT M) { return M == LT.second; }))
237       return LT.first;
238     break;
239   }
240   case Intrinsic::sadd_sat:
241   case Intrinsic::ssub_sat:
242   case Intrinsic::uadd_sat:
243   case Intrinsic::usub_sat: {
244     static const auto ValidSatTys = {MVT::v8i8,  MVT::v16i8, MVT::v4i16,
245                                      MVT::v8i16, MVT::v2i32, MVT::v4i32,
246                                      MVT::v2i64};
247     auto LT = TLI->getTypeLegalizationCost(DL, RetTy);
248     // This is a base cost of 1 for the vadd, plus 3 extract shifts if we
249     // need to extend the type, as it uses shr(qadd(shl, shl)).
250     unsigned Instrs =
251         LT.second.getScalarSizeInBits() == RetTy->getScalarSizeInBits() ? 1 : 4;
252     if (any_of(ValidSatTys, [&LT](MVT M) { return M == LT.second; }))
253       return LT.first * Instrs;
254     break;
255   }
256   case Intrinsic::abs: {
257     static const auto ValidAbsTys = {MVT::v8i8,  MVT::v16i8, MVT::v4i16,
258                                      MVT::v8i16, MVT::v2i32, MVT::v4i32,
259                                      MVT::v2i64};
260     auto LT = TLI->getTypeLegalizationCost(DL, RetTy);
261     if (any_of(ValidAbsTys, [&LT](MVT M) { return M == LT.second; }))
262       return LT.first;
263     break;
264   }
265   case Intrinsic::experimental_stepvector: {
266     InstructionCost Cost = 1; // Cost of the `index' instruction
267     auto LT = TLI->getTypeLegalizationCost(DL, RetTy);
268     // Legalisation of illegal vectors involves an `index' instruction plus
269     // (LT.first - 1) vector adds.
270     if (LT.first > 1) {
271       Type *LegalVTy = EVT(LT.second).getTypeForEVT(RetTy->getContext());
272       InstructionCost AddCost =
273           getArithmeticInstrCost(Instruction::Add, LegalVTy, CostKind);
274       Cost += AddCost * (LT.first - 1);
275     }
276     return Cost;
277   }
278   case Intrinsic::bitreverse: {
279     static const CostTblEntry BitreverseTbl[] = {
280         {Intrinsic::bitreverse, MVT::i32, 1},
281         {Intrinsic::bitreverse, MVT::i64, 1},
282         {Intrinsic::bitreverse, MVT::v8i8, 1},
283         {Intrinsic::bitreverse, MVT::v16i8, 1},
284         {Intrinsic::bitreverse, MVT::v4i16, 2},
285         {Intrinsic::bitreverse, MVT::v8i16, 2},
286         {Intrinsic::bitreverse, MVT::v2i32, 2},
287         {Intrinsic::bitreverse, MVT::v4i32, 2},
288         {Intrinsic::bitreverse, MVT::v1i64, 2},
289         {Intrinsic::bitreverse, MVT::v2i64, 2},
290     };
291     const auto LegalisationCost = TLI->getTypeLegalizationCost(DL, RetTy);
292     const auto *Entry =
293         CostTableLookup(BitreverseTbl, ICA.getID(), LegalisationCost.second);
294     // Cost Model is using the legal type(i32) that i8 and i16 will be converted
295     // to +1 so that we match the actual lowering cost
296     if (TLI->getValueType(DL, RetTy, true) == MVT::i8 ||
297         TLI->getValueType(DL, RetTy, true) == MVT::i16)
298       return LegalisationCost.first * Entry->Cost + 1;
299     if (Entry)
300       return LegalisationCost.first * Entry->Cost;
301     break;
302   }
303   case Intrinsic::ctpop: {
304     static const CostTblEntry CtpopCostTbl[] = {
305         {ISD::CTPOP, MVT::v2i64, 4},
306         {ISD::CTPOP, MVT::v4i32, 3},
307         {ISD::CTPOP, MVT::v8i16, 2},
308         {ISD::CTPOP, MVT::v16i8, 1},
309         {ISD::CTPOP, MVT::i64,   4},
310         {ISD::CTPOP, MVT::v2i32, 3},
311         {ISD::CTPOP, MVT::v4i16, 2},
312         {ISD::CTPOP, MVT::v8i8,  1},
313         {ISD::CTPOP, MVT::i32,   5},
314     };
315     auto LT = TLI->getTypeLegalizationCost(DL, RetTy);
316     MVT MTy = LT.second;
317     if (const auto *Entry = CostTableLookup(CtpopCostTbl, ISD::CTPOP, MTy)) {
318       // Extra cost of +1 when illegal vector types are legalized by promoting
319       // the integer type.
320       int ExtraCost = MTy.isVector() && MTy.getScalarSizeInBits() !=
321                                             RetTy->getScalarSizeInBits()
322                           ? 1
323                           : 0;
324       return LT.first * Entry->Cost + ExtraCost;
325     }
326     break;
327   }
328   default:
329     break;
330   }
331   return BaseT::getIntrinsicInstrCost(ICA, CostKind);
332 }
333 
334 /// The function will remove redundant reinterprets casting in the presence
335 /// of the control flow
336 static Optional<Instruction *> processPhiNode(InstCombiner &IC,
337                                               IntrinsicInst &II) {
338   SmallVector<Instruction *, 32> Worklist;
339   auto RequiredType = II.getType();
340 
341   auto *PN = dyn_cast<PHINode>(II.getArgOperand(0));
342   assert(PN && "Expected Phi Node!");
343 
344   // Don't create a new Phi unless we can remove the old one.
345   if (!PN->hasOneUse())
346     return None;
347 
348   for (Value *IncValPhi : PN->incoming_values()) {
349     auto *Reinterpret = dyn_cast<IntrinsicInst>(IncValPhi);
350     if (!Reinterpret ||
351         Reinterpret->getIntrinsicID() !=
352             Intrinsic::aarch64_sve_convert_to_svbool ||
353         RequiredType != Reinterpret->getArgOperand(0)->getType())
354       return None;
355   }
356 
357   // Create the new Phi
358   LLVMContext &Ctx = PN->getContext();
359   IRBuilder<> Builder(Ctx);
360   Builder.SetInsertPoint(PN);
361   PHINode *NPN = Builder.CreatePHI(RequiredType, PN->getNumIncomingValues());
362   Worklist.push_back(PN);
363 
364   for (unsigned I = 0; I < PN->getNumIncomingValues(); I++) {
365     auto *Reinterpret = cast<Instruction>(PN->getIncomingValue(I));
366     NPN->addIncoming(Reinterpret->getOperand(0), PN->getIncomingBlock(I));
367     Worklist.push_back(Reinterpret);
368   }
369 
370   // Cleanup Phi Node and reinterprets
371   return IC.replaceInstUsesWith(II, NPN);
372 }
373 
374 static Optional<Instruction *> instCombineConvertFromSVBool(InstCombiner &IC,
375                                                             IntrinsicInst &II) {
376   // If the reinterpret instruction operand is a PHI Node
377   if (isa<PHINode>(II.getArgOperand(0)))
378     return processPhiNode(IC, II);
379 
380   SmallVector<Instruction *, 32> CandidatesForRemoval;
381   Value *Cursor = II.getOperand(0), *EarliestReplacement = nullptr;
382 
383   const auto *IVTy = cast<VectorType>(II.getType());
384 
385   // Walk the chain of conversions.
386   while (Cursor) {
387     // If the type of the cursor has fewer lanes than the final result, zeroing
388     // must take place, which breaks the equivalence chain.
389     const auto *CursorVTy = cast<VectorType>(Cursor->getType());
390     if (CursorVTy->getElementCount().getKnownMinValue() <
391         IVTy->getElementCount().getKnownMinValue())
392       break;
393 
394     // If the cursor has the same type as I, it is a viable replacement.
395     if (Cursor->getType() == IVTy)
396       EarliestReplacement = Cursor;
397 
398     auto *IntrinsicCursor = dyn_cast<IntrinsicInst>(Cursor);
399 
400     // If this is not an SVE conversion intrinsic, this is the end of the chain.
401     if (!IntrinsicCursor || !(IntrinsicCursor->getIntrinsicID() ==
402                                   Intrinsic::aarch64_sve_convert_to_svbool ||
403                               IntrinsicCursor->getIntrinsicID() ==
404                                   Intrinsic::aarch64_sve_convert_from_svbool))
405       break;
406 
407     CandidatesForRemoval.insert(CandidatesForRemoval.begin(), IntrinsicCursor);
408     Cursor = IntrinsicCursor->getOperand(0);
409   }
410 
411   // If no viable replacement in the conversion chain was found, there is
412   // nothing to do.
413   if (!EarliestReplacement)
414     return None;
415 
416   return IC.replaceInstUsesWith(II, EarliestReplacement);
417 }
418 
419 static Optional<Instruction *> instCombineSVEDup(InstCombiner &IC,
420                                                  IntrinsicInst &II) {
421   IntrinsicInst *Pg = dyn_cast<IntrinsicInst>(II.getArgOperand(1));
422   if (!Pg)
423     return None;
424 
425   if (Pg->getIntrinsicID() != Intrinsic::aarch64_sve_ptrue)
426     return None;
427 
428   const auto PTruePattern =
429       cast<ConstantInt>(Pg->getOperand(0))->getZExtValue();
430   if (PTruePattern != AArch64SVEPredPattern::vl1)
431     return None;
432 
433   // The intrinsic is inserting into lane zero so use an insert instead.
434   auto *IdxTy = Type::getInt64Ty(II.getContext());
435   auto *Insert = InsertElementInst::Create(
436       II.getArgOperand(0), II.getArgOperand(2), ConstantInt::get(IdxTy, 0));
437   Insert->insertBefore(&II);
438   Insert->takeName(&II);
439 
440   return IC.replaceInstUsesWith(II, Insert);
441 }
442 
443 static Optional<Instruction *> instCombineSVECmpNE(InstCombiner &IC,
444                                                    IntrinsicInst &II) {
445   LLVMContext &Ctx = II.getContext();
446   IRBuilder<> Builder(Ctx);
447   Builder.SetInsertPoint(&II);
448 
449   // Check that the predicate is all active
450   auto *Pg = dyn_cast<IntrinsicInst>(II.getArgOperand(0));
451   if (!Pg || Pg->getIntrinsicID() != Intrinsic::aarch64_sve_ptrue)
452     return None;
453 
454   const auto PTruePattern =
455       cast<ConstantInt>(Pg->getOperand(0))->getZExtValue();
456   if (PTruePattern != AArch64SVEPredPattern::all)
457     return None;
458 
459   // Check that we have a compare of zero..
460   auto *DupX = dyn_cast<IntrinsicInst>(II.getArgOperand(2));
461   if (!DupX || DupX->getIntrinsicID() != Intrinsic::aarch64_sve_dup_x)
462     return None;
463 
464   auto *DupXArg = dyn_cast<ConstantInt>(DupX->getArgOperand(0));
465   if (!DupXArg || !DupXArg->isZero())
466     return None;
467 
468   // ..against a dupq
469   auto *DupQLane = dyn_cast<IntrinsicInst>(II.getArgOperand(1));
470   if (!DupQLane ||
471       DupQLane->getIntrinsicID() != Intrinsic::aarch64_sve_dupq_lane)
472     return None;
473 
474   // Where the dupq is a lane 0 replicate of a vector insert
475   if (!cast<ConstantInt>(DupQLane->getArgOperand(1))->isZero())
476     return None;
477 
478   auto *VecIns = dyn_cast<IntrinsicInst>(DupQLane->getArgOperand(0));
479   if (!VecIns ||
480       VecIns->getIntrinsicID() != Intrinsic::experimental_vector_insert)
481     return None;
482 
483   // Where the vector insert is a fixed constant vector insert into undef at
484   // index zero
485   if (!isa<UndefValue>(VecIns->getArgOperand(0)))
486     return None;
487 
488   if (!cast<ConstantInt>(VecIns->getArgOperand(2))->isZero())
489     return None;
490 
491   auto *ConstVec = dyn_cast<Constant>(VecIns->getArgOperand(1));
492   if (!ConstVec)
493     return None;
494 
495   auto *VecTy = dyn_cast<FixedVectorType>(ConstVec->getType());
496   auto *OutTy = dyn_cast<ScalableVectorType>(II.getType());
497   if (!VecTy || !OutTy || VecTy->getNumElements() != OutTy->getMinNumElements())
498     return None;
499 
500   unsigned NumElts = VecTy->getNumElements();
501   unsigned PredicateBits = 0;
502 
503   // Expand intrinsic operands to a 16-bit byte level predicate
504   for (unsigned I = 0; I < NumElts; ++I) {
505     auto *Arg = dyn_cast<ConstantInt>(ConstVec->getAggregateElement(I));
506     if (!Arg)
507       return None;
508     if (!Arg->isZero())
509       PredicateBits |= 1 << (I * (16 / NumElts));
510   }
511 
512   // If all bits are zero bail early with an empty predicate
513   if (PredicateBits == 0) {
514     auto *PFalse = Constant::getNullValue(II.getType());
515     PFalse->takeName(&II);
516     return IC.replaceInstUsesWith(II, PFalse);
517   }
518 
519   // Calculate largest predicate type used (where byte predicate is largest)
520   unsigned Mask = 8;
521   for (unsigned I = 0; I < 16; ++I)
522     if ((PredicateBits & (1 << I)) != 0)
523       Mask |= (I % 8);
524 
525   unsigned PredSize = Mask & -Mask;
526   auto *PredType = ScalableVectorType::get(
527       Type::getInt1Ty(Ctx), AArch64::SVEBitsPerBlock / (PredSize * 8));
528 
529   // Ensure all relevant bits are set
530   for (unsigned I = 0; I < 16; I += PredSize)
531     if ((PredicateBits & (1 << I)) == 0)
532       return None;
533 
534   auto *PTruePat =
535       ConstantInt::get(Type::getInt32Ty(Ctx), AArch64SVEPredPattern::all);
536   auto *PTrue = Builder.CreateIntrinsic(Intrinsic::aarch64_sve_ptrue,
537                                         {PredType}, {PTruePat});
538   auto *ConvertToSVBool = Builder.CreateIntrinsic(
539       Intrinsic::aarch64_sve_convert_to_svbool, {PredType}, {PTrue});
540   auto *ConvertFromSVBool =
541       Builder.CreateIntrinsic(Intrinsic::aarch64_sve_convert_from_svbool,
542                               {II.getType()}, {ConvertToSVBool});
543 
544   ConvertFromSVBool->takeName(&II);
545   return IC.replaceInstUsesWith(II, ConvertFromSVBool);
546 }
547 
548 static Optional<Instruction *> instCombineSVELast(InstCombiner &IC,
549                                                   IntrinsicInst &II) {
550   Value *Pg = II.getArgOperand(0);
551   Value *Vec = II.getArgOperand(1);
552   bool IsAfter = II.getIntrinsicID() == Intrinsic::aarch64_sve_lasta;
553 
554   auto *C = dyn_cast<Constant>(Pg);
555   if (IsAfter && C && C->isNullValue()) {
556     // The intrinsic is extracting lane 0 so use an extract instead.
557     auto *IdxTy = Type::getInt64Ty(II.getContext());
558     auto *Extract = ExtractElementInst::Create(Vec, ConstantInt::get(IdxTy, 0));
559     Extract->insertBefore(&II);
560     Extract->takeName(&II);
561     return IC.replaceInstUsesWith(II, Extract);
562   }
563 
564   auto *IntrPG = dyn_cast<IntrinsicInst>(Pg);
565   if (!IntrPG)
566     return None;
567 
568   if (IntrPG->getIntrinsicID() != Intrinsic::aarch64_sve_ptrue)
569     return None;
570 
571   const auto PTruePattern =
572       cast<ConstantInt>(IntrPG->getOperand(0))->getZExtValue();
573 
574   // Can the intrinsic's predicate be converted to a known constant index?
575   unsigned Idx;
576   switch (PTruePattern) {
577   default:
578     return None;
579   case AArch64SVEPredPattern::vl1:
580     Idx = 0;
581     break;
582   case AArch64SVEPredPattern::vl2:
583     Idx = 1;
584     break;
585   case AArch64SVEPredPattern::vl3:
586     Idx = 2;
587     break;
588   case AArch64SVEPredPattern::vl4:
589     Idx = 3;
590     break;
591   case AArch64SVEPredPattern::vl5:
592     Idx = 4;
593     break;
594   case AArch64SVEPredPattern::vl6:
595     Idx = 5;
596     break;
597   case AArch64SVEPredPattern::vl7:
598     Idx = 6;
599     break;
600   case AArch64SVEPredPattern::vl8:
601     Idx = 7;
602     break;
603   case AArch64SVEPredPattern::vl16:
604     Idx = 15;
605     break;
606   }
607 
608   // Increment the index if extracting the element after the last active
609   // predicate element.
610   if (IsAfter)
611     ++Idx;
612 
613   // Ignore extracts whose index is larger than the known minimum vector
614   // length. NOTE: This is an artificial constraint where we prefer to
615   // maintain what the user asked for until an alternative is proven faster.
616   auto *PgVTy = cast<ScalableVectorType>(Pg->getType());
617   if (Idx >= PgVTy->getMinNumElements())
618     return None;
619 
620   // The intrinsic is extracting a fixed lane so use an extract instead.
621   auto *IdxTy = Type::getInt64Ty(II.getContext());
622   auto *Extract = ExtractElementInst::Create(Vec, ConstantInt::get(IdxTy, Idx));
623   Extract->insertBefore(&II);
624   Extract->takeName(&II);
625   return IC.replaceInstUsesWith(II, Extract);
626 }
627 
628 static Optional<Instruction *> instCombineRDFFR(InstCombiner &IC,
629                                                 IntrinsicInst &II) {
630   LLVMContext &Ctx = II.getContext();
631   IRBuilder<> Builder(Ctx);
632   Builder.SetInsertPoint(&II);
633   // Replace rdffr with predicated rdffr.z intrinsic, so that optimizePTestInstr
634   // can work with RDFFR_PP for ptest elimination.
635   auto *AllPat =
636       ConstantInt::get(Type::getInt32Ty(Ctx), AArch64SVEPredPattern::all);
637   auto *PTrue = Builder.CreateIntrinsic(Intrinsic::aarch64_sve_ptrue,
638                                         {II.getType()}, {AllPat});
639   auto *RDFFR =
640       Builder.CreateIntrinsic(Intrinsic::aarch64_sve_rdffr_z, {}, {PTrue});
641   RDFFR->takeName(&II);
642   return IC.replaceInstUsesWith(II, RDFFR);
643 }
644 
645 static Optional<Instruction *>
646 instCombineSVECntElts(InstCombiner &IC, IntrinsicInst &II, unsigned NumElts) {
647   const auto Pattern = cast<ConstantInt>(II.getArgOperand(0))->getZExtValue();
648 
649   if (Pattern == AArch64SVEPredPattern::all) {
650     LLVMContext &Ctx = II.getContext();
651     IRBuilder<> Builder(Ctx);
652     Builder.SetInsertPoint(&II);
653 
654     Constant *StepVal = ConstantInt::get(II.getType(), NumElts);
655     auto *VScale = Builder.CreateVScale(StepVal);
656     VScale->takeName(&II);
657     return IC.replaceInstUsesWith(II, VScale);
658   }
659 
660   unsigned MinNumElts = 0;
661   switch (Pattern) {
662   default:
663     return None;
664   case AArch64SVEPredPattern::vl1:
665   case AArch64SVEPredPattern::vl2:
666   case AArch64SVEPredPattern::vl3:
667   case AArch64SVEPredPattern::vl4:
668   case AArch64SVEPredPattern::vl5:
669   case AArch64SVEPredPattern::vl6:
670   case AArch64SVEPredPattern::vl7:
671   case AArch64SVEPredPattern::vl8:
672     MinNumElts = Pattern;
673     break;
674   case AArch64SVEPredPattern::vl16:
675     MinNumElts = 16;
676     break;
677   }
678 
679   return NumElts >= MinNumElts
680              ? Optional<Instruction *>(IC.replaceInstUsesWith(
681                    II, ConstantInt::get(II.getType(), MinNumElts)))
682              : None;
683 }
684 
685 Optional<Instruction *>
686 AArch64TTIImpl::instCombineIntrinsic(InstCombiner &IC,
687                                      IntrinsicInst &II) const {
688   Intrinsic::ID IID = II.getIntrinsicID();
689   switch (IID) {
690   default:
691     break;
692   case Intrinsic::aarch64_sve_convert_from_svbool:
693     return instCombineConvertFromSVBool(IC, II);
694   case Intrinsic::aarch64_sve_dup:
695     return instCombineSVEDup(IC, II);
696   case Intrinsic::aarch64_sve_cmpne:
697   case Intrinsic::aarch64_sve_cmpne_wide:
698     return instCombineSVECmpNE(IC, II);
699   case Intrinsic::aarch64_sve_rdffr:
700     return instCombineRDFFR(IC, II);
701   case Intrinsic::aarch64_sve_lasta:
702   case Intrinsic::aarch64_sve_lastb:
703     return instCombineSVELast(IC, II);
704   case Intrinsic::aarch64_sve_cntd:
705     return instCombineSVECntElts(IC, II, 2);
706   case Intrinsic::aarch64_sve_cntw:
707     return instCombineSVECntElts(IC, II, 4);
708   case Intrinsic::aarch64_sve_cnth:
709     return instCombineSVECntElts(IC, II, 8);
710   case Intrinsic::aarch64_sve_cntb:
711     return instCombineSVECntElts(IC, II, 16);
712   }
713 
714   return None;
715 }
716 
717 bool AArch64TTIImpl::isWideningInstruction(Type *DstTy, unsigned Opcode,
718                                            ArrayRef<const Value *> Args) {
719 
720   // A helper that returns a vector type from the given type. The number of
721   // elements in type Ty determine the vector width.
722   auto toVectorTy = [&](Type *ArgTy) {
723     return VectorType::get(ArgTy->getScalarType(),
724                            cast<VectorType>(DstTy)->getElementCount());
725   };
726 
727   // Exit early if DstTy is not a vector type whose elements are at least
728   // 16-bits wide.
729   if (!DstTy->isVectorTy() || DstTy->getScalarSizeInBits() < 16)
730     return false;
731 
732   // Determine if the operation has a widening variant. We consider both the
733   // "long" (e.g., usubl) and "wide" (e.g., usubw) versions of the
734   // instructions.
735   //
736   // TODO: Add additional widening operations (e.g., mul, shl, etc.) once we
737   //       verify that their extending operands are eliminated during code
738   //       generation.
739   switch (Opcode) {
740   case Instruction::Add: // UADDL(2), SADDL(2), UADDW(2), SADDW(2).
741   case Instruction::Sub: // USUBL(2), SSUBL(2), USUBW(2), SSUBW(2).
742     break;
743   default:
744     return false;
745   }
746 
747   // To be a widening instruction (either the "wide" or "long" versions), the
748   // second operand must be a sign- or zero extend having a single user. We
749   // only consider extends having a single user because they may otherwise not
750   // be eliminated.
751   if (Args.size() != 2 ||
752       (!isa<SExtInst>(Args[1]) && !isa<ZExtInst>(Args[1])) ||
753       !Args[1]->hasOneUse())
754     return false;
755   auto *Extend = cast<CastInst>(Args[1]);
756 
757   // Legalize the destination type and ensure it can be used in a widening
758   // operation.
759   auto DstTyL = TLI->getTypeLegalizationCost(DL, DstTy);
760   unsigned DstElTySize = DstTyL.second.getScalarSizeInBits();
761   if (!DstTyL.second.isVector() || DstElTySize != DstTy->getScalarSizeInBits())
762     return false;
763 
764   // Legalize the source type and ensure it can be used in a widening
765   // operation.
766   auto *SrcTy = toVectorTy(Extend->getSrcTy());
767   auto SrcTyL = TLI->getTypeLegalizationCost(DL, SrcTy);
768   unsigned SrcElTySize = SrcTyL.second.getScalarSizeInBits();
769   if (!SrcTyL.second.isVector() || SrcElTySize != SrcTy->getScalarSizeInBits())
770     return false;
771 
772   // Get the total number of vector elements in the legalized types.
773   InstructionCost NumDstEls =
774       DstTyL.first * DstTyL.second.getVectorMinNumElements();
775   InstructionCost NumSrcEls =
776       SrcTyL.first * SrcTyL.second.getVectorMinNumElements();
777 
778   // Return true if the legalized types have the same number of vector elements
779   // and the destination element type size is twice that of the source type.
780   return NumDstEls == NumSrcEls && 2 * SrcElTySize == DstElTySize;
781 }
782 
783 InstructionCost AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
784                                                  Type *Src,
785                                                  TTI::CastContextHint CCH,
786                                                  TTI::TargetCostKind CostKind,
787                                                  const Instruction *I) {
788   int ISD = TLI->InstructionOpcodeToISD(Opcode);
789   assert(ISD && "Invalid opcode");
790 
791   // If the cast is observable, and it is used by a widening instruction (e.g.,
792   // uaddl, saddw, etc.), it may be free.
793   if (I && I->hasOneUse()) {
794     auto *SingleUser = cast<Instruction>(*I->user_begin());
795     SmallVector<const Value *, 4> Operands(SingleUser->operand_values());
796     if (isWideningInstruction(Dst, SingleUser->getOpcode(), Operands)) {
797       // If the cast is the second operand, it is free. We will generate either
798       // a "wide" or "long" version of the widening instruction.
799       if (I == SingleUser->getOperand(1))
800         return 0;
801       // If the cast is not the second operand, it will be free if it looks the
802       // same as the second operand. In this case, we will generate a "long"
803       // version of the widening instruction.
804       if (auto *Cast = dyn_cast<CastInst>(SingleUser->getOperand(1)))
805         if (I->getOpcode() == unsigned(Cast->getOpcode()) &&
806             cast<CastInst>(I)->getSrcTy() == Cast->getSrcTy())
807           return 0;
808     }
809   }
810 
811   // TODO: Allow non-throughput costs that aren't binary.
812   auto AdjustCost = [&CostKind](InstructionCost Cost) -> InstructionCost {
813     if (CostKind != TTI::TCK_RecipThroughput)
814       return Cost == 0 ? 0 : 1;
815     return Cost;
816   };
817 
818   EVT SrcTy = TLI->getValueType(DL, Src);
819   EVT DstTy = TLI->getValueType(DL, Dst);
820 
821   if (!SrcTy.isSimple() || !DstTy.isSimple())
822     return AdjustCost(
823         BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I));
824 
825   static const TypeConversionCostTblEntry
826   ConversionTbl[] = {
827     { ISD::TRUNCATE, MVT::v4i16, MVT::v4i32,  1 },
828     { ISD::TRUNCATE, MVT::v4i32, MVT::v4i64,  0 },
829     { ISD::TRUNCATE, MVT::v8i8,  MVT::v8i32,  3 },
830     { ISD::TRUNCATE, MVT::v16i8, MVT::v16i32, 6 },
831 
832     // Truncations on nxvmiN
833     { ISD::TRUNCATE, MVT::nxv2i1, MVT::nxv2i16, 1 },
834     { ISD::TRUNCATE, MVT::nxv2i1, MVT::nxv2i32, 1 },
835     { ISD::TRUNCATE, MVT::nxv2i1, MVT::nxv2i64, 1 },
836     { ISD::TRUNCATE, MVT::nxv4i1, MVT::nxv4i16, 1 },
837     { ISD::TRUNCATE, MVT::nxv4i1, MVT::nxv4i32, 1 },
838     { ISD::TRUNCATE, MVT::nxv4i1, MVT::nxv4i64, 2 },
839     { ISD::TRUNCATE, MVT::nxv8i1, MVT::nxv8i16, 1 },
840     { ISD::TRUNCATE, MVT::nxv8i1, MVT::nxv8i32, 3 },
841     { ISD::TRUNCATE, MVT::nxv8i1, MVT::nxv8i64, 5 },
842     { ISD::TRUNCATE, MVT::nxv16i1, MVT::nxv16i8, 1 },
843     { ISD::TRUNCATE, MVT::nxv2i16, MVT::nxv2i32, 1 },
844     { ISD::TRUNCATE, MVT::nxv2i32, MVT::nxv2i64, 1 },
845     { ISD::TRUNCATE, MVT::nxv4i16, MVT::nxv4i32, 1 },
846     { ISD::TRUNCATE, MVT::nxv4i32, MVT::nxv4i64, 2 },
847     { ISD::TRUNCATE, MVT::nxv8i16, MVT::nxv8i32, 3 },
848     { ISD::TRUNCATE, MVT::nxv8i32, MVT::nxv8i64, 6 },
849 
850     // The number of shll instructions for the extension.
851     { ISD::SIGN_EXTEND, MVT::v4i64,  MVT::v4i16, 3 },
852     { ISD::ZERO_EXTEND, MVT::v4i64,  MVT::v4i16, 3 },
853     { ISD::SIGN_EXTEND, MVT::v4i64,  MVT::v4i32, 2 },
854     { ISD::ZERO_EXTEND, MVT::v4i64,  MVT::v4i32, 2 },
855     { ISD::SIGN_EXTEND, MVT::v8i32,  MVT::v8i8,  3 },
856     { ISD::ZERO_EXTEND, MVT::v8i32,  MVT::v8i8,  3 },
857     { ISD::SIGN_EXTEND, MVT::v8i32,  MVT::v8i16, 2 },
858     { ISD::ZERO_EXTEND, MVT::v8i32,  MVT::v8i16, 2 },
859     { ISD::SIGN_EXTEND, MVT::v8i64,  MVT::v8i8,  7 },
860     { ISD::ZERO_EXTEND, MVT::v8i64,  MVT::v8i8,  7 },
861     { ISD::SIGN_EXTEND, MVT::v8i64,  MVT::v8i16, 6 },
862     { ISD::ZERO_EXTEND, MVT::v8i64,  MVT::v8i16, 6 },
863     { ISD::SIGN_EXTEND, MVT::v16i16, MVT::v16i8, 2 },
864     { ISD::ZERO_EXTEND, MVT::v16i16, MVT::v16i8, 2 },
865     { ISD::SIGN_EXTEND, MVT::v16i32, MVT::v16i8, 6 },
866     { ISD::ZERO_EXTEND, MVT::v16i32, MVT::v16i8, 6 },
867 
868     // LowerVectorINT_TO_FP:
869     { ISD::SINT_TO_FP, MVT::v2f32, MVT::v2i32, 1 },
870     { ISD::SINT_TO_FP, MVT::v4f32, MVT::v4i32, 1 },
871     { ISD::SINT_TO_FP, MVT::v2f64, MVT::v2i64, 1 },
872     { ISD::UINT_TO_FP, MVT::v2f32, MVT::v2i32, 1 },
873     { ISD::UINT_TO_FP, MVT::v4f32, MVT::v4i32, 1 },
874     { ISD::UINT_TO_FP, MVT::v2f64, MVT::v2i64, 1 },
875 
876     // Complex: to v2f32
877     { ISD::SINT_TO_FP, MVT::v2f32, MVT::v2i8,  3 },
878     { ISD::SINT_TO_FP, MVT::v2f32, MVT::v2i16, 3 },
879     { ISD::SINT_TO_FP, MVT::v2f32, MVT::v2i64, 2 },
880     { ISD::UINT_TO_FP, MVT::v2f32, MVT::v2i8,  3 },
881     { ISD::UINT_TO_FP, MVT::v2f32, MVT::v2i16, 3 },
882     { ISD::UINT_TO_FP, MVT::v2f32, MVT::v2i64, 2 },
883 
884     // Complex: to v4f32
885     { ISD::SINT_TO_FP, MVT::v4f32, MVT::v4i8,  4 },
886     { ISD::SINT_TO_FP, MVT::v4f32, MVT::v4i16, 2 },
887     { ISD::UINT_TO_FP, MVT::v4f32, MVT::v4i8,  3 },
888     { ISD::UINT_TO_FP, MVT::v4f32, MVT::v4i16, 2 },
889 
890     // Complex: to v8f32
891     { ISD::SINT_TO_FP, MVT::v8f32, MVT::v8i8,  10 },
892     { ISD::SINT_TO_FP, MVT::v8f32, MVT::v8i16, 4 },
893     { ISD::UINT_TO_FP, MVT::v8f32, MVT::v8i8,  10 },
894     { ISD::UINT_TO_FP, MVT::v8f32, MVT::v8i16, 4 },
895 
896     // Complex: to v16f32
897     { ISD::SINT_TO_FP, MVT::v16f32, MVT::v16i8, 21 },
898     { ISD::UINT_TO_FP, MVT::v16f32, MVT::v16i8, 21 },
899 
900     // Complex: to v2f64
901     { ISD::SINT_TO_FP, MVT::v2f64, MVT::v2i8,  4 },
902     { ISD::SINT_TO_FP, MVT::v2f64, MVT::v2i16, 4 },
903     { ISD::SINT_TO_FP, MVT::v2f64, MVT::v2i32, 2 },
904     { ISD::UINT_TO_FP, MVT::v2f64, MVT::v2i8,  4 },
905     { ISD::UINT_TO_FP, MVT::v2f64, MVT::v2i16, 4 },
906     { ISD::UINT_TO_FP, MVT::v2f64, MVT::v2i32, 2 },
907 
908 
909     // LowerVectorFP_TO_INT
910     { ISD::FP_TO_SINT, MVT::v2i32, MVT::v2f32, 1 },
911     { ISD::FP_TO_SINT, MVT::v4i32, MVT::v4f32, 1 },
912     { ISD::FP_TO_SINT, MVT::v2i64, MVT::v2f64, 1 },
913     { ISD::FP_TO_UINT, MVT::v2i32, MVT::v2f32, 1 },
914     { ISD::FP_TO_UINT, MVT::v4i32, MVT::v4f32, 1 },
915     { ISD::FP_TO_UINT, MVT::v2i64, MVT::v2f64, 1 },
916 
917     // Complex, from v2f32: legal type is v2i32 (no cost) or v2i64 (1 ext).
918     { ISD::FP_TO_SINT, MVT::v2i64, MVT::v2f32, 2 },
919     { ISD::FP_TO_SINT, MVT::v2i16, MVT::v2f32, 1 },
920     { ISD::FP_TO_SINT, MVT::v2i8,  MVT::v2f32, 1 },
921     { ISD::FP_TO_UINT, MVT::v2i64, MVT::v2f32, 2 },
922     { ISD::FP_TO_UINT, MVT::v2i16, MVT::v2f32, 1 },
923     { ISD::FP_TO_UINT, MVT::v2i8,  MVT::v2f32, 1 },
924 
925     // Complex, from v4f32: legal type is v4i16, 1 narrowing => ~2
926     { ISD::FP_TO_SINT, MVT::v4i16, MVT::v4f32, 2 },
927     { ISD::FP_TO_SINT, MVT::v4i8,  MVT::v4f32, 2 },
928     { ISD::FP_TO_UINT, MVT::v4i16, MVT::v4f32, 2 },
929     { ISD::FP_TO_UINT, MVT::v4i8,  MVT::v4f32, 2 },
930 
931     // Complex, from nxv2f32.
932     { ISD::FP_TO_SINT, MVT::nxv2i64, MVT::nxv2f32, 1 },
933     { ISD::FP_TO_SINT, MVT::nxv2i32, MVT::nxv2f32, 1 },
934     { ISD::FP_TO_SINT, MVT::nxv2i16, MVT::nxv2f32, 1 },
935     { ISD::FP_TO_SINT, MVT::nxv2i8,  MVT::nxv2f32, 1 },
936     { ISD::FP_TO_UINT, MVT::nxv2i64, MVT::nxv2f32, 1 },
937     { ISD::FP_TO_UINT, MVT::nxv2i32, MVT::nxv2f32, 1 },
938     { ISD::FP_TO_UINT, MVT::nxv2i16, MVT::nxv2f32, 1 },
939     { ISD::FP_TO_UINT, MVT::nxv2i8,  MVT::nxv2f32, 1 },
940 
941     // Complex, from v2f64: legal type is v2i32, 1 narrowing => ~2.
942     { ISD::FP_TO_SINT, MVT::v2i32, MVT::v2f64, 2 },
943     { ISD::FP_TO_SINT, MVT::v2i16, MVT::v2f64, 2 },
944     { ISD::FP_TO_SINT, MVT::v2i8,  MVT::v2f64, 2 },
945     { ISD::FP_TO_UINT, MVT::v2i32, MVT::v2f64, 2 },
946     { ISD::FP_TO_UINT, MVT::v2i16, MVT::v2f64, 2 },
947     { ISD::FP_TO_UINT, MVT::v2i8,  MVT::v2f64, 2 },
948 
949     // Complex, from nxv2f64.
950     { ISD::FP_TO_SINT, MVT::nxv2i64, MVT::nxv2f64, 1 },
951     { ISD::FP_TO_SINT, MVT::nxv2i32, MVT::nxv2f64, 1 },
952     { ISD::FP_TO_SINT, MVT::nxv2i16, MVT::nxv2f64, 1 },
953     { ISD::FP_TO_SINT, MVT::nxv2i8,  MVT::nxv2f64, 1 },
954     { ISD::FP_TO_UINT, MVT::nxv2i64, MVT::nxv2f64, 1 },
955     { ISD::FP_TO_UINT, MVT::nxv2i32, MVT::nxv2f64, 1 },
956     { ISD::FP_TO_UINT, MVT::nxv2i16, MVT::nxv2f64, 1 },
957     { ISD::FP_TO_UINT, MVT::nxv2i8,  MVT::nxv2f64, 1 },
958 
959     // Complex, from nxv4f32.
960     { ISD::FP_TO_SINT, MVT::nxv4i64, MVT::nxv4f32, 4 },
961     { ISD::FP_TO_SINT, MVT::nxv4i32, MVT::nxv4f32, 1 },
962     { ISD::FP_TO_SINT, MVT::nxv4i16, MVT::nxv4f32, 1 },
963     { ISD::FP_TO_SINT, MVT::nxv4i8,  MVT::nxv4f32, 1 },
964     { ISD::FP_TO_UINT, MVT::nxv4i64, MVT::nxv4f32, 4 },
965     { ISD::FP_TO_UINT, MVT::nxv4i32, MVT::nxv4f32, 1 },
966     { ISD::FP_TO_UINT, MVT::nxv4i16, MVT::nxv4f32, 1 },
967     { ISD::FP_TO_UINT, MVT::nxv4i8,  MVT::nxv4f32, 1 },
968 
969     // Complex, from nxv8f64. Illegal -> illegal conversions not required.
970     { ISD::FP_TO_SINT, MVT::nxv8i16, MVT::nxv8f64, 7 },
971     { ISD::FP_TO_SINT, MVT::nxv8i8,  MVT::nxv8f64, 7 },
972     { ISD::FP_TO_UINT, MVT::nxv8i16, MVT::nxv8f64, 7 },
973     { ISD::FP_TO_UINT, MVT::nxv8i8,  MVT::nxv8f64, 7 },
974 
975     // Complex, from nxv4f64. Illegal -> illegal conversions not required.
976     { ISD::FP_TO_SINT, MVT::nxv4i32, MVT::nxv4f64, 3 },
977     { ISD::FP_TO_SINT, MVT::nxv4i16, MVT::nxv4f64, 3 },
978     { ISD::FP_TO_SINT, MVT::nxv4i8,  MVT::nxv4f64, 3 },
979     { ISD::FP_TO_UINT, MVT::nxv4i32, MVT::nxv4f64, 3 },
980     { ISD::FP_TO_UINT, MVT::nxv4i16, MVT::nxv4f64, 3 },
981     { ISD::FP_TO_UINT, MVT::nxv4i8,  MVT::nxv4f64, 3 },
982 
983     // Complex, from nxv8f32. Illegal -> illegal conversions not required.
984     { ISD::FP_TO_SINT, MVT::nxv8i16, MVT::nxv8f32, 3 },
985     { ISD::FP_TO_SINT, MVT::nxv8i8,  MVT::nxv8f32, 3 },
986     { ISD::FP_TO_UINT, MVT::nxv8i16, MVT::nxv8f32, 3 },
987     { ISD::FP_TO_UINT, MVT::nxv8i8,  MVT::nxv8f32, 3 },
988 
989     // Complex, from nxv8f16.
990     { ISD::FP_TO_SINT, MVT::nxv8i64, MVT::nxv8f16, 10 },
991     { ISD::FP_TO_SINT, MVT::nxv8i32, MVT::nxv8f16, 4 },
992     { ISD::FP_TO_SINT, MVT::nxv8i16, MVT::nxv8f16, 1 },
993     { ISD::FP_TO_SINT, MVT::nxv8i8,  MVT::nxv8f16, 1 },
994     { ISD::FP_TO_UINT, MVT::nxv8i64, MVT::nxv8f16, 10 },
995     { ISD::FP_TO_UINT, MVT::nxv8i32, MVT::nxv8f16, 4 },
996     { ISD::FP_TO_UINT, MVT::nxv8i16, MVT::nxv8f16, 1 },
997     { ISD::FP_TO_UINT, MVT::nxv8i8,  MVT::nxv8f16, 1 },
998 
999     // Complex, from nxv4f16.
1000     { ISD::FP_TO_SINT, MVT::nxv4i64, MVT::nxv4f16, 4 },
1001     { ISD::FP_TO_SINT, MVT::nxv4i32, MVT::nxv4f16, 1 },
1002     { ISD::FP_TO_SINT, MVT::nxv4i16, MVT::nxv4f16, 1 },
1003     { ISD::FP_TO_SINT, MVT::nxv4i8,  MVT::nxv4f16, 1 },
1004     { ISD::FP_TO_UINT, MVT::nxv4i64, MVT::nxv4f16, 4 },
1005     { ISD::FP_TO_UINT, MVT::nxv4i32, MVT::nxv4f16, 1 },
1006     { ISD::FP_TO_UINT, MVT::nxv4i16, MVT::nxv4f16, 1 },
1007     { ISD::FP_TO_UINT, MVT::nxv4i8,  MVT::nxv4f16, 1 },
1008 
1009     // Complex, from nxv2f16.
1010     { ISD::FP_TO_SINT, MVT::nxv2i64, MVT::nxv2f16, 1 },
1011     { ISD::FP_TO_SINT, MVT::nxv2i32, MVT::nxv2f16, 1 },
1012     { ISD::FP_TO_SINT, MVT::nxv2i16, MVT::nxv2f16, 1 },
1013     { ISD::FP_TO_SINT, MVT::nxv2i8,  MVT::nxv2f16, 1 },
1014     { ISD::FP_TO_UINT, MVT::nxv2i64, MVT::nxv2f16, 1 },
1015     { ISD::FP_TO_UINT, MVT::nxv2i32, MVT::nxv2f16, 1 },
1016     { ISD::FP_TO_UINT, MVT::nxv2i16, MVT::nxv2f16, 1 },
1017     { ISD::FP_TO_UINT, MVT::nxv2i8,  MVT::nxv2f16, 1 },
1018 
1019     // Truncate from nxvmf32 to nxvmf16.
1020     { ISD::FP_ROUND, MVT::nxv2f16, MVT::nxv2f32, 1 },
1021     { ISD::FP_ROUND, MVT::nxv4f16, MVT::nxv4f32, 1 },
1022     { ISD::FP_ROUND, MVT::nxv8f16, MVT::nxv8f32, 3 },
1023 
1024     // Truncate from nxvmf64 to nxvmf16.
1025     { ISD::FP_ROUND, MVT::nxv2f16, MVT::nxv2f64, 1 },
1026     { ISD::FP_ROUND, MVT::nxv4f16, MVT::nxv4f64, 3 },
1027     { ISD::FP_ROUND, MVT::nxv8f16, MVT::nxv8f64, 7 },
1028 
1029     // Truncate from nxvmf64 to nxvmf32.
1030     { ISD::FP_ROUND, MVT::nxv2f32, MVT::nxv2f64, 1 },
1031     { ISD::FP_ROUND, MVT::nxv4f32, MVT::nxv4f64, 3 },
1032     { ISD::FP_ROUND, MVT::nxv8f32, MVT::nxv8f64, 6 },
1033 
1034     // Extend from nxvmf16 to nxvmf32.
1035     { ISD::FP_EXTEND, MVT::nxv2f32, MVT::nxv2f16, 1},
1036     { ISD::FP_EXTEND, MVT::nxv4f32, MVT::nxv4f16, 1},
1037     { ISD::FP_EXTEND, MVT::nxv8f32, MVT::nxv8f16, 2},
1038 
1039     // Extend from nxvmf16 to nxvmf64.
1040     { ISD::FP_EXTEND, MVT::nxv2f64, MVT::nxv2f16, 1},
1041     { ISD::FP_EXTEND, MVT::nxv4f64, MVT::nxv4f16, 2},
1042     { ISD::FP_EXTEND, MVT::nxv8f64, MVT::nxv8f16, 4},
1043 
1044     // Extend from nxvmf32 to nxvmf64.
1045     { ISD::FP_EXTEND, MVT::nxv2f64, MVT::nxv2f32, 1},
1046     { ISD::FP_EXTEND, MVT::nxv4f64, MVT::nxv4f32, 2},
1047     { ISD::FP_EXTEND, MVT::nxv8f64, MVT::nxv8f32, 6},
1048 
1049   };
1050 
1051   if (const auto *Entry = ConvertCostTableLookup(ConversionTbl, ISD,
1052                                                  DstTy.getSimpleVT(),
1053                                                  SrcTy.getSimpleVT()))
1054     return AdjustCost(Entry->Cost);
1055 
1056   return AdjustCost(
1057       BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I));
1058 }
1059 
1060 InstructionCost AArch64TTIImpl::getExtractWithExtendCost(unsigned Opcode,
1061                                                          Type *Dst,
1062                                                          VectorType *VecTy,
1063                                                          unsigned Index) {
1064 
1065   // Make sure we were given a valid extend opcode.
1066   assert((Opcode == Instruction::SExt || Opcode == Instruction::ZExt) &&
1067          "Invalid opcode");
1068 
1069   // We are extending an element we extract from a vector, so the source type
1070   // of the extend is the element type of the vector.
1071   auto *Src = VecTy->getElementType();
1072 
1073   // Sign- and zero-extends are for integer types only.
1074   assert(isa<IntegerType>(Dst) && isa<IntegerType>(Src) && "Invalid type");
1075 
1076   // Get the cost for the extract. We compute the cost (if any) for the extend
1077   // below.
1078   InstructionCost Cost =
1079       getVectorInstrCost(Instruction::ExtractElement, VecTy, Index);
1080 
1081   // Legalize the types.
1082   auto VecLT = TLI->getTypeLegalizationCost(DL, VecTy);
1083   auto DstVT = TLI->getValueType(DL, Dst);
1084   auto SrcVT = TLI->getValueType(DL, Src);
1085   TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
1086 
1087   // If the resulting type is still a vector and the destination type is legal,
1088   // we may get the extension for free. If not, get the default cost for the
1089   // extend.
1090   if (!VecLT.second.isVector() || !TLI->isTypeLegal(DstVT))
1091     return Cost + getCastInstrCost(Opcode, Dst, Src, TTI::CastContextHint::None,
1092                                    CostKind);
1093 
1094   // The destination type should be larger than the element type. If not, get
1095   // the default cost for the extend.
1096   if (DstVT.getFixedSizeInBits() < SrcVT.getFixedSizeInBits())
1097     return Cost + getCastInstrCost(Opcode, Dst, Src, TTI::CastContextHint::None,
1098                                    CostKind);
1099 
1100   switch (Opcode) {
1101   default:
1102     llvm_unreachable("Opcode should be either SExt or ZExt");
1103 
1104   // For sign-extends, we only need a smov, which performs the extension
1105   // automatically.
1106   case Instruction::SExt:
1107     return Cost;
1108 
1109   // For zero-extends, the extend is performed automatically by a umov unless
1110   // the destination type is i64 and the element type is i8 or i16.
1111   case Instruction::ZExt:
1112     if (DstVT.getSizeInBits() != 64u || SrcVT.getSizeInBits() == 32u)
1113       return Cost;
1114   }
1115 
1116   // If we are unable to perform the extend for free, get the default cost.
1117   return Cost + getCastInstrCost(Opcode, Dst, Src, TTI::CastContextHint::None,
1118                                  CostKind);
1119 }
1120 
1121 InstructionCost AArch64TTIImpl::getCFInstrCost(unsigned Opcode,
1122                                                TTI::TargetCostKind CostKind,
1123                                                const Instruction *I) {
1124   if (CostKind != TTI::TCK_RecipThroughput)
1125     return Opcode == Instruction::PHI ? 0 : 1;
1126   assert(CostKind == TTI::TCK_RecipThroughput && "unexpected CostKind");
1127   // Branches are assumed to be predicted.
1128   return 0;
1129 }
1130 
1131 InstructionCost AArch64TTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
1132                                                    unsigned Index) {
1133   assert(Val->isVectorTy() && "This must be a vector type");
1134 
1135   if (Index != -1U) {
1136     // Legalize the type.
1137     std::pair<InstructionCost, MVT> LT = TLI->getTypeLegalizationCost(DL, Val);
1138 
1139     // This type is legalized to a scalar type.
1140     if (!LT.second.isVector())
1141       return 0;
1142 
1143     // The type may be split. Normalize the index to the new type.
1144     unsigned Width = LT.second.getVectorNumElements();
1145     Index = Index % Width;
1146 
1147     // The element at index zero is already inside the vector.
1148     if (Index == 0)
1149       return 0;
1150   }
1151 
1152   // All other insert/extracts cost this much.
1153   return ST->getVectorInsertExtractBaseCost();
1154 }
1155 
1156 InstructionCost AArch64TTIImpl::getArithmeticInstrCost(
1157     unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
1158     TTI::OperandValueKind Opd1Info, TTI::OperandValueKind Opd2Info,
1159     TTI::OperandValueProperties Opd1PropInfo,
1160     TTI::OperandValueProperties Opd2PropInfo, ArrayRef<const Value *> Args,
1161     const Instruction *CxtI) {
1162   // TODO: Handle more cost kinds.
1163   if (CostKind != TTI::TCK_RecipThroughput)
1164     return BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Opd1Info,
1165                                          Opd2Info, Opd1PropInfo,
1166                                          Opd2PropInfo, Args, CxtI);
1167 
1168   // Legalize the type.
1169   std::pair<InstructionCost, MVT> LT = TLI->getTypeLegalizationCost(DL, Ty);
1170 
1171   // If the instruction is a widening instruction (e.g., uaddl, saddw, etc.),
1172   // add in the widening overhead specified by the sub-target. Since the
1173   // extends feeding widening instructions are performed automatically, they
1174   // aren't present in the generated code and have a zero cost. By adding a
1175   // widening overhead here, we attach the total cost of the combined operation
1176   // to the widening instruction.
1177   InstructionCost Cost = 0;
1178   if (isWideningInstruction(Ty, Opcode, Args))
1179     Cost += ST->getWideningBaseCost();
1180 
1181   int ISD = TLI->InstructionOpcodeToISD(Opcode);
1182 
1183   switch (ISD) {
1184   default:
1185     return Cost + BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Opd1Info,
1186                                                 Opd2Info,
1187                                                 Opd1PropInfo, Opd2PropInfo);
1188   case ISD::SDIV:
1189     if (Opd2Info == TargetTransformInfo::OK_UniformConstantValue &&
1190         Opd2PropInfo == TargetTransformInfo::OP_PowerOf2) {
1191       // On AArch64, scalar signed division by constants power-of-two are
1192       // normally expanded to the sequence ADD + CMP + SELECT + SRA.
1193       // The OperandValue properties many not be same as that of previous
1194       // operation; conservatively assume OP_None.
1195       Cost += getArithmeticInstrCost(Instruction::Add, Ty, CostKind,
1196                                      Opd1Info, Opd2Info,
1197                                      TargetTransformInfo::OP_None,
1198                                      TargetTransformInfo::OP_None);
1199       Cost += getArithmeticInstrCost(Instruction::Sub, Ty, CostKind,
1200                                      Opd1Info, Opd2Info,
1201                                      TargetTransformInfo::OP_None,
1202                                      TargetTransformInfo::OP_None);
1203       Cost += getArithmeticInstrCost(Instruction::Select, Ty, CostKind,
1204                                      Opd1Info, Opd2Info,
1205                                      TargetTransformInfo::OP_None,
1206                                      TargetTransformInfo::OP_None);
1207       Cost += getArithmeticInstrCost(Instruction::AShr, Ty, CostKind,
1208                                      Opd1Info, Opd2Info,
1209                                      TargetTransformInfo::OP_None,
1210                                      TargetTransformInfo::OP_None);
1211       return Cost;
1212     }
1213     LLVM_FALLTHROUGH;
1214   case ISD::UDIV:
1215     if (Opd2Info == TargetTransformInfo::OK_UniformConstantValue) {
1216       auto VT = TLI->getValueType(DL, Ty);
1217       if (TLI->isOperationLegalOrCustom(ISD::MULHU, VT)) {
1218         // Vector signed division by constant are expanded to the
1219         // sequence MULHS + ADD/SUB + SRA + SRL + ADD, and unsigned division
1220         // to MULHS + SUB + SRL + ADD + SRL.
1221         InstructionCost MulCost = getArithmeticInstrCost(
1222             Instruction::Mul, Ty, CostKind, Opd1Info, Opd2Info,
1223             TargetTransformInfo::OP_None, TargetTransformInfo::OP_None);
1224         InstructionCost AddCost = getArithmeticInstrCost(
1225             Instruction::Add, Ty, CostKind, Opd1Info, Opd2Info,
1226             TargetTransformInfo::OP_None, TargetTransformInfo::OP_None);
1227         InstructionCost ShrCost = getArithmeticInstrCost(
1228             Instruction::AShr, Ty, CostKind, Opd1Info, Opd2Info,
1229             TargetTransformInfo::OP_None, TargetTransformInfo::OP_None);
1230         return MulCost * 2 + AddCost * 2 + ShrCost * 2 + 1;
1231       }
1232     }
1233 
1234     Cost += BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Opd1Info,
1235                                           Opd2Info,
1236                                           Opd1PropInfo, Opd2PropInfo);
1237     if (Ty->isVectorTy()) {
1238       // On AArch64, vector divisions are not supported natively and are
1239       // expanded into scalar divisions of each pair of elements.
1240       Cost += getArithmeticInstrCost(Instruction::ExtractElement, Ty, CostKind,
1241                                      Opd1Info, Opd2Info, Opd1PropInfo,
1242                                      Opd2PropInfo);
1243       Cost += getArithmeticInstrCost(Instruction::InsertElement, Ty, CostKind,
1244                                      Opd1Info, Opd2Info, Opd1PropInfo,
1245                                      Opd2PropInfo);
1246       // TODO: if one of the arguments is scalar, then it's not necessary to
1247       // double the cost of handling the vector elements.
1248       Cost += Cost;
1249     }
1250     return Cost;
1251 
1252   case ISD::MUL:
1253     if (LT.second != MVT::v2i64)
1254       return (Cost + 1) * LT.first;
1255     // Since we do not have a MUL.2d instruction, a mul <2 x i64> is expensive
1256     // as elements are extracted from the vectors and the muls scalarized.
1257     // As getScalarizationOverhead is a bit too pessimistic, we estimate the
1258     // cost for a i64 vector directly here, which is:
1259     // - four i64 extracts,
1260     // - two i64 inserts, and
1261     // - two muls.
1262     // So, for a v2i64 with LT.First = 1 the cost is 8, and for a v4i64 with
1263     // LT.first = 2 the cost is 16.
1264     return LT.first * 8;
1265   case ISD::ADD:
1266   case ISD::XOR:
1267   case ISD::OR:
1268   case ISD::AND:
1269     // These nodes are marked as 'custom' for combining purposes only.
1270     // We know that they are legal. See LowerAdd in ISelLowering.
1271     return (Cost + 1) * LT.first;
1272 
1273   case ISD::FADD:
1274     // These nodes are marked as 'custom' just to lower them to SVE.
1275     // We know said lowering will incur no additional cost.
1276     if (isa<FixedVectorType>(Ty) && !Ty->getScalarType()->isFP128Ty())
1277       return (Cost + 2) * LT.first;
1278 
1279     return Cost + BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Opd1Info,
1280                                                 Opd2Info,
1281                                                 Opd1PropInfo, Opd2PropInfo);
1282   }
1283 }
1284 
1285 InstructionCost AArch64TTIImpl::getAddressComputationCost(Type *Ty,
1286                                                           ScalarEvolution *SE,
1287                                                           const SCEV *Ptr) {
1288   // Address computations in vectorized code with non-consecutive addresses will
1289   // likely result in more instructions compared to scalar code where the
1290   // computation can more often be merged into the index mode. The resulting
1291   // extra micro-ops can significantly decrease throughput.
1292   unsigned NumVectorInstToHideOverhead = 10;
1293   int MaxMergeDistance = 64;
1294 
1295   if (Ty->isVectorTy() && SE &&
1296       !BaseT::isConstantStridedAccessLessThan(SE, Ptr, MaxMergeDistance + 1))
1297     return NumVectorInstToHideOverhead;
1298 
1299   // In many cases the address computation is not merged into the instruction
1300   // addressing mode.
1301   return 1;
1302 }
1303 
1304 InstructionCost AArch64TTIImpl::getCmpSelInstrCost(unsigned Opcode, Type *ValTy,
1305                                                    Type *CondTy,
1306                                                    CmpInst::Predicate VecPred,
1307                                                    TTI::TargetCostKind CostKind,
1308                                                    const Instruction *I) {
1309   // TODO: Handle other cost kinds.
1310   if (CostKind != TTI::TCK_RecipThroughput)
1311     return BaseT::getCmpSelInstrCost(Opcode, ValTy, CondTy, VecPred, CostKind,
1312                                      I);
1313 
1314   int ISD = TLI->InstructionOpcodeToISD(Opcode);
1315   // We don't lower some vector selects well that are wider than the register
1316   // width.
1317   if (isa<FixedVectorType>(ValTy) && ISD == ISD::SELECT) {
1318     // We would need this many instructions to hide the scalarization happening.
1319     const int AmortizationCost = 20;
1320 
1321     // If VecPred is not set, check if we can get a predicate from the context
1322     // instruction, if its type matches the requested ValTy.
1323     if (VecPred == CmpInst::BAD_ICMP_PREDICATE && I && I->getType() == ValTy) {
1324       CmpInst::Predicate CurrentPred;
1325       if (match(I, m_Select(m_Cmp(CurrentPred, m_Value(), m_Value()), m_Value(),
1326                             m_Value())))
1327         VecPred = CurrentPred;
1328     }
1329     // Check if we have a compare/select chain that can be lowered using CMxx &
1330     // BFI pair.
1331     if (CmpInst::isIntPredicate(VecPred)) {
1332       static const auto ValidMinMaxTys = {MVT::v8i8,  MVT::v16i8, MVT::v4i16,
1333                                           MVT::v8i16, MVT::v2i32, MVT::v4i32,
1334                                           MVT::v2i64};
1335       auto LT = TLI->getTypeLegalizationCost(DL, ValTy);
1336       if (any_of(ValidMinMaxTys, [&LT](MVT M) { return M == LT.second; }))
1337         return LT.first;
1338     }
1339 
1340     static const TypeConversionCostTblEntry
1341     VectorSelectTbl[] = {
1342       { ISD::SELECT, MVT::v16i1, MVT::v16i16, 16 },
1343       { ISD::SELECT, MVT::v8i1, MVT::v8i32, 8 },
1344       { ISD::SELECT, MVT::v16i1, MVT::v16i32, 16 },
1345       { ISD::SELECT, MVT::v4i1, MVT::v4i64, 4 * AmortizationCost },
1346       { ISD::SELECT, MVT::v8i1, MVT::v8i64, 8 * AmortizationCost },
1347       { ISD::SELECT, MVT::v16i1, MVT::v16i64, 16 * AmortizationCost }
1348     };
1349 
1350     EVT SelCondTy = TLI->getValueType(DL, CondTy);
1351     EVT SelValTy = TLI->getValueType(DL, ValTy);
1352     if (SelCondTy.isSimple() && SelValTy.isSimple()) {
1353       if (const auto *Entry = ConvertCostTableLookup(VectorSelectTbl, ISD,
1354                                                      SelCondTy.getSimpleVT(),
1355                                                      SelValTy.getSimpleVT()))
1356         return Entry->Cost;
1357     }
1358   }
1359   // The base case handles scalable vectors fine for now, since it treats the
1360   // cost as 1 * legalization cost.
1361   return BaseT::getCmpSelInstrCost(Opcode, ValTy, CondTy, VecPred, CostKind, I);
1362 }
1363 
1364 AArch64TTIImpl::TTI::MemCmpExpansionOptions
1365 AArch64TTIImpl::enableMemCmpExpansion(bool OptSize, bool IsZeroCmp) const {
1366   TTI::MemCmpExpansionOptions Options;
1367   if (ST->requiresStrictAlign()) {
1368     // TODO: Add cost modeling for strict align. Misaligned loads expand to
1369     // a bunch of instructions when strict align is enabled.
1370     return Options;
1371   }
1372   Options.AllowOverlappingLoads = true;
1373   Options.MaxNumLoads = TLI->getMaxExpandSizeMemcmp(OptSize);
1374   Options.NumLoadsPerBlock = Options.MaxNumLoads;
1375   // TODO: Though vector loads usually perform well on AArch64, in some targets
1376   // they may wake up the FP unit, which raises the power consumption.  Perhaps
1377   // they could be used with no holds barred (-O3).
1378   Options.LoadSizes = {8, 4, 2, 1};
1379   return Options;
1380 }
1381 
1382 InstructionCost
1383 AArch64TTIImpl::getMaskedMemoryOpCost(unsigned Opcode, Type *Src,
1384                                       Align Alignment, unsigned AddressSpace,
1385                                       TTI::TargetCostKind CostKind) {
1386   if (!isa<ScalableVectorType>(Src))
1387     return BaseT::getMaskedMemoryOpCost(Opcode, Src, Alignment, AddressSpace,
1388                                         CostKind);
1389   auto LT = TLI->getTypeLegalizationCost(DL, Src);
1390   if (!LT.first.isValid())
1391     return InstructionCost::getInvalid();
1392   return LT.first * 2;
1393 }
1394 
1395 InstructionCost AArch64TTIImpl::getGatherScatterOpCost(
1396     unsigned Opcode, Type *DataTy, const Value *Ptr, bool VariableMask,
1397     Align Alignment, TTI::TargetCostKind CostKind, const Instruction *I) {
1398 
1399   if (!isa<ScalableVectorType>(DataTy))
1400     return BaseT::getGatherScatterOpCost(Opcode, DataTy, Ptr, VariableMask,
1401                                          Alignment, CostKind, I);
1402   auto *VT = cast<VectorType>(DataTy);
1403   auto LT = TLI->getTypeLegalizationCost(DL, DataTy);
1404   if (!LT.first.isValid())
1405     return InstructionCost::getInvalid();
1406 
1407   ElementCount LegalVF = LT.second.getVectorElementCount();
1408   Optional<unsigned> MaxNumVScale = getMaxVScale();
1409   assert(MaxNumVScale && "Expected valid max vscale value");
1410 
1411   InstructionCost MemOpCost =
1412       getMemoryOpCost(Opcode, VT->getElementType(), Alignment, 0, CostKind, I);
1413   unsigned MaxNumElementsPerGather =
1414       MaxNumVScale.getValue() * LegalVF.getKnownMinValue();
1415   return LT.first * MaxNumElementsPerGather * MemOpCost;
1416 }
1417 
1418 bool AArch64TTIImpl::useNeonVector(const Type *Ty) const {
1419   return isa<FixedVectorType>(Ty) && !ST->useSVEForFixedLengthVectors();
1420 }
1421 
1422 InstructionCost AArch64TTIImpl::getMemoryOpCost(unsigned Opcode, Type *Ty,
1423                                                 MaybeAlign Alignment,
1424                                                 unsigned AddressSpace,
1425                                                 TTI::TargetCostKind CostKind,
1426                                                 const Instruction *I) {
1427   EVT VT = TLI->getValueType(DL, Ty, true);
1428   // Type legalization can't handle structs
1429   if (VT == MVT::Other)
1430     return BaseT::getMemoryOpCost(Opcode, Ty, Alignment, AddressSpace,
1431                                   CostKind);
1432 
1433   auto LT = TLI->getTypeLegalizationCost(DL, Ty);
1434   if (!LT.first.isValid())
1435     return InstructionCost::getInvalid();
1436 
1437   // TODO: consider latency as well for TCK_SizeAndLatency.
1438   if (CostKind == TTI::TCK_CodeSize || CostKind == TTI::TCK_SizeAndLatency)
1439     return LT.first;
1440 
1441   if (CostKind != TTI::TCK_RecipThroughput)
1442     return 1;
1443 
1444   if (ST->isMisaligned128StoreSlow() && Opcode == Instruction::Store &&
1445       LT.second.is128BitVector() && (!Alignment || *Alignment < Align(16))) {
1446     // Unaligned stores are extremely inefficient. We don't split all
1447     // unaligned 128-bit stores because the negative impact that has shown in
1448     // practice on inlined block copy code.
1449     // We make such stores expensive so that we will only vectorize if there
1450     // are 6 other instructions getting vectorized.
1451     const int AmortizationCost = 6;
1452 
1453     return LT.first * 2 * AmortizationCost;
1454   }
1455 
1456   // Check truncating stores and extending loads.
1457   if (useNeonVector(Ty) &&
1458       Ty->getScalarSizeInBits() != LT.second.getScalarSizeInBits()) {
1459     // v4i8 types are lowered to scalar a load/store and sshll/xtn.
1460     if (VT == MVT::v4i8)
1461       return 2;
1462     // Otherwise we need to scalarize.
1463     return cast<FixedVectorType>(Ty)->getNumElements() * 2;
1464   }
1465 
1466   return LT.first;
1467 }
1468 
1469 InstructionCost AArch64TTIImpl::getInterleavedMemoryOpCost(
1470     unsigned Opcode, Type *VecTy, unsigned Factor, ArrayRef<unsigned> Indices,
1471     Align Alignment, unsigned AddressSpace, TTI::TargetCostKind CostKind,
1472     bool UseMaskForCond, bool UseMaskForGaps) {
1473   assert(Factor >= 2 && "Invalid interleave factor");
1474   auto *VecVTy = cast<FixedVectorType>(VecTy);
1475 
1476   if (!UseMaskForCond && !UseMaskForGaps &&
1477       Factor <= TLI->getMaxSupportedInterleaveFactor()) {
1478     unsigned NumElts = VecVTy->getNumElements();
1479     auto *SubVecTy =
1480         FixedVectorType::get(VecTy->getScalarType(), NumElts / Factor);
1481 
1482     // ldN/stN only support legal vector types of size 64 or 128 in bits.
1483     // Accesses having vector types that are a multiple of 128 bits can be
1484     // matched to more than one ldN/stN instruction.
1485     if (NumElts % Factor == 0 &&
1486         TLI->isLegalInterleavedAccessType(SubVecTy, DL))
1487       return Factor * TLI->getNumInterleavedAccesses(SubVecTy, DL);
1488   }
1489 
1490   return BaseT::getInterleavedMemoryOpCost(Opcode, VecTy, Factor, Indices,
1491                                            Alignment, AddressSpace, CostKind,
1492                                            UseMaskForCond, UseMaskForGaps);
1493 }
1494 
1495 InstructionCost
1496 AArch64TTIImpl::getCostOfKeepingLiveOverCall(ArrayRef<Type *> Tys) {
1497   InstructionCost Cost = 0;
1498   TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
1499   for (auto *I : Tys) {
1500     if (!I->isVectorTy())
1501       continue;
1502     if (I->getScalarSizeInBits() * cast<FixedVectorType>(I)->getNumElements() ==
1503         128)
1504       Cost += getMemoryOpCost(Instruction::Store, I, Align(128), 0, CostKind) +
1505               getMemoryOpCost(Instruction::Load, I, Align(128), 0, CostKind);
1506   }
1507   return Cost;
1508 }
1509 
1510 unsigned AArch64TTIImpl::getMaxInterleaveFactor(unsigned VF) {
1511   return ST->getMaxInterleaveFactor();
1512 }
1513 
1514 // For Falkor, we want to avoid having too many strided loads in a loop since
1515 // that can exhaust the HW prefetcher resources.  We adjust the unroller
1516 // MaxCount preference below to attempt to ensure unrolling doesn't create too
1517 // many strided loads.
1518 static void
1519 getFalkorUnrollingPreferences(Loop *L, ScalarEvolution &SE,
1520                               TargetTransformInfo::UnrollingPreferences &UP) {
1521   enum { MaxStridedLoads = 7 };
1522   auto countStridedLoads = [](Loop *L, ScalarEvolution &SE) {
1523     int StridedLoads = 0;
1524     // FIXME? We could make this more precise by looking at the CFG and
1525     // e.g. not counting loads in each side of an if-then-else diamond.
1526     for (const auto BB : L->blocks()) {
1527       for (auto &I : *BB) {
1528         LoadInst *LMemI = dyn_cast<LoadInst>(&I);
1529         if (!LMemI)
1530           continue;
1531 
1532         Value *PtrValue = LMemI->getPointerOperand();
1533         if (L->isLoopInvariant(PtrValue))
1534           continue;
1535 
1536         const SCEV *LSCEV = SE.getSCEV(PtrValue);
1537         const SCEVAddRecExpr *LSCEVAddRec = dyn_cast<SCEVAddRecExpr>(LSCEV);
1538         if (!LSCEVAddRec || !LSCEVAddRec->isAffine())
1539           continue;
1540 
1541         // FIXME? We could take pairing of unrolled load copies into account
1542         // by looking at the AddRec, but we would probably have to limit this
1543         // to loops with no stores or other memory optimization barriers.
1544         ++StridedLoads;
1545         // We've seen enough strided loads that seeing more won't make a
1546         // difference.
1547         if (StridedLoads > MaxStridedLoads / 2)
1548           return StridedLoads;
1549       }
1550     }
1551     return StridedLoads;
1552   };
1553 
1554   int StridedLoads = countStridedLoads(L, SE);
1555   LLVM_DEBUG(dbgs() << "falkor-hwpf: detected " << StridedLoads
1556                     << " strided loads\n");
1557   // Pick the largest power of 2 unroll count that won't result in too many
1558   // strided loads.
1559   if (StridedLoads) {
1560     UP.MaxCount = 1 << Log2_32(MaxStridedLoads / StridedLoads);
1561     LLVM_DEBUG(dbgs() << "falkor-hwpf: setting unroll MaxCount to "
1562                       << UP.MaxCount << '\n');
1563   }
1564 }
1565 
1566 void AArch64TTIImpl::getUnrollingPreferences(Loop *L, ScalarEvolution &SE,
1567                                              TTI::UnrollingPreferences &UP) {
1568   // Enable partial unrolling and runtime unrolling.
1569   BaseT::getUnrollingPreferences(L, SE, UP);
1570 
1571   // For inner loop, it is more likely to be a hot one, and the runtime check
1572   // can be promoted out from LICM pass, so the overhead is less, let's try
1573   // a larger threshold to unroll more loops.
1574   if (L->getLoopDepth() > 1)
1575     UP.PartialThreshold *= 2;
1576 
1577   // Disable partial & runtime unrolling on -Os.
1578   UP.PartialOptSizeThreshold = 0;
1579 
1580   if (ST->getProcFamily() == AArch64Subtarget::Falkor &&
1581       EnableFalkorHWPFUnrollFix)
1582     getFalkorUnrollingPreferences(L, SE, UP);
1583 
1584   // Scan the loop: don't unroll loops with calls as this could prevent
1585   // inlining. Don't unroll vector loops either, as they don't benefit much from
1586   // unrolling.
1587   for (auto *BB : L->getBlocks()) {
1588     for (auto &I : *BB) {
1589       // Don't unroll vectorised loop.
1590       if (I.getType()->isVectorTy())
1591         return;
1592 
1593       if (isa<CallInst>(I) || isa<InvokeInst>(I)) {
1594         if (const Function *F = cast<CallBase>(I).getCalledFunction()) {
1595           if (!isLoweredToCall(F))
1596             continue;
1597         }
1598         return;
1599       }
1600     }
1601   }
1602 
1603   // Enable runtime unrolling for in-order models
1604   // If mcpu is omitted, getProcFamily() returns AArch64Subtarget::Others, so by
1605   // checking for that case, we can ensure that the default behaviour is
1606   // unchanged
1607   if (ST->getProcFamily() != AArch64Subtarget::Others &&
1608       !ST->getSchedModel().isOutOfOrder()) {
1609     UP.Runtime = true;
1610     UP.Partial = true;
1611     UP.UpperBound = true;
1612     UP.UnrollRemainder = true;
1613     UP.DefaultUnrollRuntimeCount = 4;
1614 
1615     UP.UnrollAndJam = true;
1616     UP.UnrollAndJamInnerLoopThreshold = 60;
1617   }
1618 }
1619 
1620 void AArch64TTIImpl::getPeelingPreferences(Loop *L, ScalarEvolution &SE,
1621                                            TTI::PeelingPreferences &PP) {
1622   BaseT::getPeelingPreferences(L, SE, PP);
1623 }
1624 
1625 Value *AArch64TTIImpl::getOrCreateResultFromMemIntrinsic(IntrinsicInst *Inst,
1626                                                          Type *ExpectedType) {
1627   switch (Inst->getIntrinsicID()) {
1628   default:
1629     return nullptr;
1630   case Intrinsic::aarch64_neon_st2:
1631   case Intrinsic::aarch64_neon_st3:
1632   case Intrinsic::aarch64_neon_st4: {
1633     // Create a struct type
1634     StructType *ST = dyn_cast<StructType>(ExpectedType);
1635     if (!ST)
1636       return nullptr;
1637     unsigned NumElts = Inst->getNumArgOperands() - 1;
1638     if (ST->getNumElements() != NumElts)
1639       return nullptr;
1640     for (unsigned i = 0, e = NumElts; i != e; ++i) {
1641       if (Inst->getArgOperand(i)->getType() != ST->getElementType(i))
1642         return nullptr;
1643     }
1644     Value *Res = UndefValue::get(ExpectedType);
1645     IRBuilder<> Builder(Inst);
1646     for (unsigned i = 0, e = NumElts; i != e; ++i) {
1647       Value *L = Inst->getArgOperand(i);
1648       Res = Builder.CreateInsertValue(Res, L, i);
1649     }
1650     return Res;
1651   }
1652   case Intrinsic::aarch64_neon_ld2:
1653   case Intrinsic::aarch64_neon_ld3:
1654   case Intrinsic::aarch64_neon_ld4:
1655     if (Inst->getType() == ExpectedType)
1656       return Inst;
1657     return nullptr;
1658   }
1659 }
1660 
1661 bool AArch64TTIImpl::getTgtMemIntrinsic(IntrinsicInst *Inst,
1662                                         MemIntrinsicInfo &Info) {
1663   switch (Inst->getIntrinsicID()) {
1664   default:
1665     break;
1666   case Intrinsic::aarch64_neon_ld2:
1667   case Intrinsic::aarch64_neon_ld3:
1668   case Intrinsic::aarch64_neon_ld4:
1669     Info.ReadMem = true;
1670     Info.WriteMem = false;
1671     Info.PtrVal = Inst->getArgOperand(0);
1672     break;
1673   case Intrinsic::aarch64_neon_st2:
1674   case Intrinsic::aarch64_neon_st3:
1675   case Intrinsic::aarch64_neon_st4:
1676     Info.ReadMem = false;
1677     Info.WriteMem = true;
1678     Info.PtrVal = Inst->getArgOperand(Inst->getNumArgOperands() - 1);
1679     break;
1680   }
1681 
1682   switch (Inst->getIntrinsicID()) {
1683   default:
1684     return false;
1685   case Intrinsic::aarch64_neon_ld2:
1686   case Intrinsic::aarch64_neon_st2:
1687     Info.MatchingId = VECTOR_LDST_TWO_ELEMENTS;
1688     break;
1689   case Intrinsic::aarch64_neon_ld3:
1690   case Intrinsic::aarch64_neon_st3:
1691     Info.MatchingId = VECTOR_LDST_THREE_ELEMENTS;
1692     break;
1693   case Intrinsic::aarch64_neon_ld4:
1694   case Intrinsic::aarch64_neon_st4:
1695     Info.MatchingId = VECTOR_LDST_FOUR_ELEMENTS;
1696     break;
1697   }
1698   return true;
1699 }
1700 
1701 /// See if \p I should be considered for address type promotion. We check if \p
1702 /// I is a sext with right type and used in memory accesses. If it used in a
1703 /// "complex" getelementptr, we allow it to be promoted without finding other
1704 /// sext instructions that sign extended the same initial value. A getelementptr
1705 /// is considered as "complex" if it has more than 2 operands.
1706 bool AArch64TTIImpl::shouldConsiderAddressTypePromotion(
1707     const Instruction &I, bool &AllowPromotionWithoutCommonHeader) {
1708   bool Considerable = false;
1709   AllowPromotionWithoutCommonHeader = false;
1710   if (!isa<SExtInst>(&I))
1711     return false;
1712   Type *ConsideredSExtType =
1713       Type::getInt64Ty(I.getParent()->getParent()->getContext());
1714   if (I.getType() != ConsideredSExtType)
1715     return false;
1716   // See if the sext is the one with the right type and used in at least one
1717   // GetElementPtrInst.
1718   for (const User *U : I.users()) {
1719     if (const GetElementPtrInst *GEPInst = dyn_cast<GetElementPtrInst>(U)) {
1720       Considerable = true;
1721       // A getelementptr is considered as "complex" if it has more than 2
1722       // operands. We will promote a SExt used in such complex GEP as we
1723       // expect some computation to be merged if they are done on 64 bits.
1724       if (GEPInst->getNumOperands() > 2) {
1725         AllowPromotionWithoutCommonHeader = true;
1726         break;
1727       }
1728     }
1729   }
1730   return Considerable;
1731 }
1732 
1733 bool AArch64TTIImpl::isLegalToVectorizeReduction(
1734     const RecurrenceDescriptor &RdxDesc, ElementCount VF) const {
1735   if (!VF.isScalable())
1736     return true;
1737 
1738   Type *Ty = RdxDesc.getRecurrenceType();
1739   if (Ty->isBFloatTy() || !isElementTypeLegalForScalableVector(Ty))
1740     return false;
1741 
1742   switch (RdxDesc.getRecurrenceKind()) {
1743   case RecurKind::Add:
1744   case RecurKind::FAdd:
1745   case RecurKind::And:
1746   case RecurKind::Or:
1747   case RecurKind::Xor:
1748   case RecurKind::SMin:
1749   case RecurKind::SMax:
1750   case RecurKind::UMin:
1751   case RecurKind::UMax:
1752   case RecurKind::FMin:
1753   case RecurKind::FMax:
1754     return true;
1755   default:
1756     return false;
1757   }
1758 }
1759 
1760 InstructionCost
1761 AArch64TTIImpl::getMinMaxReductionCost(VectorType *Ty, VectorType *CondTy,
1762                                        bool IsUnsigned,
1763                                        TTI::TargetCostKind CostKind) {
1764   if (!isa<ScalableVectorType>(Ty))
1765     return BaseT::getMinMaxReductionCost(Ty, CondTy, IsUnsigned, CostKind);
1766   assert((isa<ScalableVectorType>(Ty) && isa<ScalableVectorType>(CondTy)) &&
1767          "Both vector needs to be scalable");
1768 
1769   std::pair<InstructionCost, MVT> LT = TLI->getTypeLegalizationCost(DL, Ty);
1770   InstructionCost LegalizationCost = 0;
1771   if (LT.first > 1) {
1772     Type *LegalVTy = EVT(LT.second).getTypeForEVT(Ty->getContext());
1773     unsigned CmpOpcode =
1774         Ty->isFPOrFPVectorTy() ? Instruction::FCmp : Instruction::ICmp;
1775     LegalizationCost =
1776         getCmpSelInstrCost(CmpOpcode, LegalVTy, LegalVTy,
1777                            CmpInst::BAD_ICMP_PREDICATE, CostKind) +
1778         getCmpSelInstrCost(Instruction::Select, LegalVTy, LegalVTy,
1779                            CmpInst::BAD_ICMP_PREDICATE, CostKind);
1780     LegalizationCost *= LT.first - 1;
1781   }
1782 
1783   return LegalizationCost + /*Cost of horizontal reduction*/ 2;
1784 }
1785 
1786 InstructionCost AArch64TTIImpl::getArithmeticReductionCostSVE(
1787     unsigned Opcode, VectorType *ValTy, TTI::TargetCostKind CostKind) {
1788   std::pair<InstructionCost, MVT> LT = TLI->getTypeLegalizationCost(DL, ValTy);
1789   InstructionCost LegalizationCost = 0;
1790   if (LT.first > 1) {
1791     Type *LegalVTy = EVT(LT.second).getTypeForEVT(ValTy->getContext());
1792     LegalizationCost = getArithmeticInstrCost(Opcode, LegalVTy, CostKind);
1793     LegalizationCost *= LT.first - 1;
1794   }
1795 
1796   int ISD = TLI->InstructionOpcodeToISD(Opcode);
1797   assert(ISD && "Invalid opcode");
1798   // Add the final reduction cost for the legal horizontal reduction
1799   switch (ISD) {
1800   case ISD::ADD:
1801   case ISD::AND:
1802   case ISD::OR:
1803   case ISD::XOR:
1804   case ISD::FADD:
1805     return LegalizationCost + 2;
1806   default:
1807     return InstructionCost::getInvalid();
1808   }
1809 }
1810 
1811 InstructionCost
1812 AArch64TTIImpl::getArithmeticReductionCost(unsigned Opcode, VectorType *ValTy,
1813                                            TTI::TargetCostKind CostKind) {
1814   if (isa<ScalableVectorType>(ValTy))
1815     return getArithmeticReductionCostSVE(Opcode, ValTy, CostKind);
1816 
1817   std::pair<InstructionCost, MVT> LT = TLI->getTypeLegalizationCost(DL, ValTy);
1818   MVT MTy = LT.second;
1819   int ISD = TLI->InstructionOpcodeToISD(Opcode);
1820   assert(ISD && "Invalid opcode");
1821 
1822   // Horizontal adds can use the 'addv' instruction. We model the cost of these
1823   // instructions as normal vector adds. This is the only arithmetic vector
1824   // reduction operation for which we have an instruction.
1825   // OR, XOR and AND costs should match the codegen from:
1826   // OR: llvm/test/CodeGen/AArch64/reduce-or.ll
1827   // XOR: llvm/test/CodeGen/AArch64/reduce-xor.ll
1828   // AND: llvm/test/CodeGen/AArch64/reduce-and.ll
1829   static const CostTblEntry CostTblNoPairwise[]{
1830       {ISD::ADD, MVT::v8i8,   1},
1831       {ISD::ADD, MVT::v16i8,  1},
1832       {ISD::ADD, MVT::v4i16,  1},
1833       {ISD::ADD, MVT::v8i16,  1},
1834       {ISD::ADD, MVT::v4i32,  1},
1835       {ISD::OR,  MVT::v8i8,  15},
1836       {ISD::OR,  MVT::v16i8, 17},
1837       {ISD::OR,  MVT::v4i16,  7},
1838       {ISD::OR,  MVT::v8i16,  9},
1839       {ISD::OR,  MVT::v2i32,  3},
1840       {ISD::OR,  MVT::v4i32,  5},
1841       {ISD::OR,  MVT::v2i64,  3},
1842       {ISD::XOR, MVT::v8i8,  15},
1843       {ISD::XOR, MVT::v16i8, 17},
1844       {ISD::XOR, MVT::v4i16,  7},
1845       {ISD::XOR, MVT::v8i16,  9},
1846       {ISD::XOR, MVT::v2i32,  3},
1847       {ISD::XOR, MVT::v4i32,  5},
1848       {ISD::XOR, MVT::v2i64,  3},
1849       {ISD::AND, MVT::v8i8,  15},
1850       {ISD::AND, MVT::v16i8, 17},
1851       {ISD::AND, MVT::v4i16,  7},
1852       {ISD::AND, MVT::v8i16,  9},
1853       {ISD::AND, MVT::v2i32,  3},
1854       {ISD::AND, MVT::v4i32,  5},
1855       {ISD::AND, MVT::v2i64,  3},
1856   };
1857   switch (ISD) {
1858   default:
1859     break;
1860   case ISD::ADD:
1861     if (const auto *Entry = CostTableLookup(CostTblNoPairwise, ISD, MTy))
1862       return LT.first * Entry->Cost;
1863     break;
1864   case ISD::XOR:
1865   case ISD::AND:
1866   case ISD::OR:
1867     const auto *Entry = CostTableLookup(CostTblNoPairwise, ISD, MTy);
1868     if (!Entry)
1869       break;
1870     auto *ValVTy = cast<FixedVectorType>(ValTy);
1871     if (!ValVTy->getElementType()->isIntegerTy(1) &&
1872         MTy.getVectorNumElements() <= ValVTy->getNumElements() &&
1873         isPowerOf2_32(ValVTy->getNumElements())) {
1874       InstructionCost ExtraCost = 0;
1875       if (LT.first != 1) {
1876         // Type needs to be split, so there is an extra cost of LT.first - 1
1877         // arithmetic ops.
1878         auto *Ty = FixedVectorType::get(ValTy->getElementType(),
1879                                         MTy.getVectorNumElements());
1880         ExtraCost = getArithmeticInstrCost(Opcode, Ty, CostKind);
1881         ExtraCost *= LT.first - 1;
1882       }
1883       return Entry->Cost + ExtraCost;
1884     }
1885     break;
1886   }
1887   return BaseT::getArithmeticReductionCost(Opcode, ValTy, CostKind);
1888 }
1889 
1890 InstructionCost AArch64TTIImpl::getSpliceCost(VectorType *Tp, int Index) {
1891   static const CostTblEntry ShuffleTbl[] = {
1892       { TTI::SK_Splice, MVT::nxv16i8,  1 },
1893       { TTI::SK_Splice, MVT::nxv8i16,  1 },
1894       { TTI::SK_Splice, MVT::nxv4i32,  1 },
1895       { TTI::SK_Splice, MVT::nxv2i64,  1 },
1896       { TTI::SK_Splice, MVT::nxv2f16,  1 },
1897       { TTI::SK_Splice, MVT::nxv4f16,  1 },
1898       { TTI::SK_Splice, MVT::nxv8f16,  1 },
1899       { TTI::SK_Splice, MVT::nxv2bf16, 1 },
1900       { TTI::SK_Splice, MVT::nxv4bf16, 1 },
1901       { TTI::SK_Splice, MVT::nxv8bf16, 1 },
1902       { TTI::SK_Splice, MVT::nxv2f32,  1 },
1903       { TTI::SK_Splice, MVT::nxv4f32,  1 },
1904       { TTI::SK_Splice, MVT::nxv2f64,  1 },
1905   };
1906 
1907   std::pair<InstructionCost, MVT> LT = TLI->getTypeLegalizationCost(DL, Tp);
1908   Type *LegalVTy = EVT(LT.second).getTypeForEVT(Tp->getContext());
1909   TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
1910   EVT PromotedVT = LT.second.getScalarType() == MVT::i1
1911                        ? TLI->getPromotedVTForPredicate(EVT(LT.second))
1912                        : LT.second;
1913   Type *PromotedVTy = EVT(PromotedVT).getTypeForEVT(Tp->getContext());
1914   InstructionCost LegalizationCost = 0;
1915   if (Index < 0) {
1916     LegalizationCost =
1917         getCmpSelInstrCost(Instruction::ICmp, PromotedVTy, PromotedVTy,
1918                            CmpInst::BAD_ICMP_PREDICATE, CostKind) +
1919         getCmpSelInstrCost(Instruction::Select, PromotedVTy, LegalVTy,
1920                            CmpInst::BAD_ICMP_PREDICATE, CostKind);
1921   }
1922 
1923   // Predicated splice are promoted when lowering. See AArch64ISelLowering.cpp
1924   // Cost performed on a promoted type.
1925   if (LT.second.getScalarType() == MVT::i1) {
1926     LegalizationCost +=
1927         getCastInstrCost(Instruction::ZExt, PromotedVTy, LegalVTy,
1928                          TTI::CastContextHint::None, CostKind) +
1929         getCastInstrCost(Instruction::Trunc, LegalVTy, PromotedVTy,
1930                          TTI::CastContextHint::None, CostKind);
1931   }
1932   const auto *Entry =
1933       CostTableLookup(ShuffleTbl, TTI::SK_Splice, PromotedVT.getSimpleVT());
1934   assert(Entry && "Illegal Type for Splice");
1935   LegalizationCost += Entry->Cost;
1936   return LegalizationCost * LT.first;
1937 }
1938 
1939 InstructionCost AArch64TTIImpl::getShuffleCost(TTI::ShuffleKind Kind,
1940                                                VectorType *Tp,
1941                                                ArrayRef<int> Mask, int Index,
1942                                                VectorType *SubTp) {
1943   Kind = improveShuffleKindFromMask(Kind, Mask);
1944   if (Kind == TTI::SK_Broadcast || Kind == TTI::SK_Transpose ||
1945       Kind == TTI::SK_Select || Kind == TTI::SK_PermuteSingleSrc ||
1946       Kind == TTI::SK_Reverse) {
1947     static const CostTblEntry ShuffleTbl[] = {
1948       // Broadcast shuffle kinds can be performed with 'dup'.
1949       { TTI::SK_Broadcast, MVT::v8i8,  1 },
1950       { TTI::SK_Broadcast, MVT::v16i8, 1 },
1951       { TTI::SK_Broadcast, MVT::v4i16, 1 },
1952       { TTI::SK_Broadcast, MVT::v8i16, 1 },
1953       { TTI::SK_Broadcast, MVT::v2i32, 1 },
1954       { TTI::SK_Broadcast, MVT::v4i32, 1 },
1955       { TTI::SK_Broadcast, MVT::v2i64, 1 },
1956       { TTI::SK_Broadcast, MVT::v2f32, 1 },
1957       { TTI::SK_Broadcast, MVT::v4f32, 1 },
1958       { TTI::SK_Broadcast, MVT::v2f64, 1 },
1959       // Transpose shuffle kinds can be performed with 'trn1/trn2' and
1960       // 'zip1/zip2' instructions.
1961       { TTI::SK_Transpose, MVT::v8i8,  1 },
1962       { TTI::SK_Transpose, MVT::v16i8, 1 },
1963       { TTI::SK_Transpose, MVT::v4i16, 1 },
1964       { TTI::SK_Transpose, MVT::v8i16, 1 },
1965       { TTI::SK_Transpose, MVT::v2i32, 1 },
1966       { TTI::SK_Transpose, MVT::v4i32, 1 },
1967       { TTI::SK_Transpose, MVT::v2i64, 1 },
1968       { TTI::SK_Transpose, MVT::v2f32, 1 },
1969       { TTI::SK_Transpose, MVT::v4f32, 1 },
1970       { TTI::SK_Transpose, MVT::v2f64, 1 },
1971       // Select shuffle kinds.
1972       // TODO: handle vXi8/vXi16.
1973       { TTI::SK_Select, MVT::v2i32, 1 }, // mov.
1974       { TTI::SK_Select, MVT::v4i32, 2 }, // rev+trn (or similar).
1975       { TTI::SK_Select, MVT::v2i64, 1 }, // mov.
1976       { TTI::SK_Select, MVT::v2f32, 1 }, // mov.
1977       { TTI::SK_Select, MVT::v4f32, 2 }, // rev+trn (or similar).
1978       { TTI::SK_Select, MVT::v2f64, 1 }, // mov.
1979       // PermuteSingleSrc shuffle kinds.
1980       // TODO: handle vXi8/vXi16.
1981       { TTI::SK_PermuteSingleSrc, MVT::v2i32, 1 }, // mov.
1982       { TTI::SK_PermuteSingleSrc, MVT::v4i32, 3 }, // perfectshuffle worst case.
1983       { TTI::SK_PermuteSingleSrc, MVT::v2i64, 1 }, // mov.
1984       { TTI::SK_PermuteSingleSrc, MVT::v2f32, 1 }, // mov.
1985       { TTI::SK_PermuteSingleSrc, MVT::v4f32, 3 }, // perfectshuffle worst case.
1986       { TTI::SK_PermuteSingleSrc, MVT::v2f64, 1 }, // mov.
1987       // Reverse can be lowered with `rev`.
1988       { TTI::SK_Reverse, MVT::v2i32, 1 }, // mov.
1989       { TTI::SK_Reverse, MVT::v4i32, 2 }, // REV64; EXT
1990       { TTI::SK_Reverse, MVT::v2i64, 1 }, // mov.
1991       { TTI::SK_Reverse, MVT::v2f32, 1 }, // mov.
1992       { TTI::SK_Reverse, MVT::v4f32, 2 }, // REV64; EXT
1993       { TTI::SK_Reverse, MVT::v2f64, 1 }, // mov.
1994       // Broadcast shuffle kinds for scalable vectors
1995       { TTI::SK_Broadcast, MVT::nxv16i8,  1 },
1996       { TTI::SK_Broadcast, MVT::nxv8i16,  1 },
1997       { TTI::SK_Broadcast, MVT::nxv4i32,  1 },
1998       { TTI::SK_Broadcast, MVT::nxv2i64,  1 },
1999       { TTI::SK_Broadcast, MVT::nxv2f16,  1 },
2000       { TTI::SK_Broadcast, MVT::nxv4f16,  1 },
2001       { TTI::SK_Broadcast, MVT::nxv8f16,  1 },
2002       { TTI::SK_Broadcast, MVT::nxv2bf16, 1 },
2003       { TTI::SK_Broadcast, MVT::nxv4bf16, 1 },
2004       { TTI::SK_Broadcast, MVT::nxv8bf16, 1 },
2005       { TTI::SK_Broadcast, MVT::nxv2f32,  1 },
2006       { TTI::SK_Broadcast, MVT::nxv4f32,  1 },
2007       { TTI::SK_Broadcast, MVT::nxv2f64,  1 },
2008       { TTI::SK_Broadcast, MVT::nxv16i1,  1 },
2009       { TTI::SK_Broadcast, MVT::nxv8i1,   1 },
2010       { TTI::SK_Broadcast, MVT::nxv4i1,   1 },
2011       { TTI::SK_Broadcast, MVT::nxv2i1,   1 },
2012       // Handle the cases for vector.reverse with scalable vectors
2013       { TTI::SK_Reverse, MVT::nxv16i8,  1 },
2014       { TTI::SK_Reverse, MVT::nxv8i16,  1 },
2015       { TTI::SK_Reverse, MVT::nxv4i32,  1 },
2016       { TTI::SK_Reverse, MVT::nxv2i64,  1 },
2017       { TTI::SK_Reverse, MVT::nxv2f16,  1 },
2018       { TTI::SK_Reverse, MVT::nxv4f16,  1 },
2019       { TTI::SK_Reverse, MVT::nxv8f16,  1 },
2020       { TTI::SK_Reverse, MVT::nxv2bf16, 1 },
2021       { TTI::SK_Reverse, MVT::nxv4bf16, 1 },
2022       { TTI::SK_Reverse, MVT::nxv8bf16, 1 },
2023       { TTI::SK_Reverse, MVT::nxv2f32,  1 },
2024       { TTI::SK_Reverse, MVT::nxv4f32,  1 },
2025       { TTI::SK_Reverse, MVT::nxv2f64,  1 },
2026       { TTI::SK_Reverse, MVT::nxv16i1,  1 },
2027       { TTI::SK_Reverse, MVT::nxv8i1,   1 },
2028       { TTI::SK_Reverse, MVT::nxv4i1,   1 },
2029       { TTI::SK_Reverse, MVT::nxv2i1,   1 },
2030     };
2031     std::pair<InstructionCost, MVT> LT = TLI->getTypeLegalizationCost(DL, Tp);
2032     if (const auto *Entry = CostTableLookup(ShuffleTbl, Kind, LT.second))
2033       return LT.first * Entry->Cost;
2034   }
2035   if (Kind == TTI::SK_Splice && isa<ScalableVectorType>(Tp))
2036     return getSpliceCost(Tp, Index);
2037   return BaseT::getShuffleCost(Kind, Tp, Mask, Index, SubTp);
2038 }
2039