1 //===-- AArch64TargetTransformInfo.cpp - AArch64 specific TTI -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "AArch64TargetTransformInfo.h"
10 #include "AArch64ExpandImm.h"
11 #include "MCTargetDesc/AArch64AddressingModes.h"
12 #include "llvm/Analysis/LoopInfo.h"
13 #include "llvm/Analysis/TargetTransformInfo.h"
14 #include "llvm/CodeGen/BasicTTIImpl.h"
15 #include "llvm/CodeGen/CostTable.h"
16 #include "llvm/CodeGen/TargetLowering.h"
17 #include "llvm/IR/Intrinsics.h"
18 #include "llvm/IR/IntrinsicInst.h"
19 #include "llvm/IR/IntrinsicsAArch64.h"
20 #include "llvm/IR/PatternMatch.h"
21 #include "llvm/Support/Debug.h"
22 #include "llvm/Transforms/InstCombine/InstCombiner.h"
23 #include <algorithm>
24 using namespace llvm;
25 using namespace llvm::PatternMatch;
26 
27 #define DEBUG_TYPE "aarch64tti"
28 
29 static cl::opt<bool> EnableFalkorHWPFUnrollFix("enable-falkor-hwpf-unroll-fix",
30                                                cl::init(true), cl::Hidden);
31 
32 bool AArch64TTIImpl::areInlineCompatible(const Function *Caller,
33                                          const Function *Callee) const {
34   const TargetMachine &TM = getTLI()->getTargetMachine();
35 
36   const FeatureBitset &CallerBits =
37       TM.getSubtargetImpl(*Caller)->getFeatureBits();
38   const FeatureBitset &CalleeBits =
39       TM.getSubtargetImpl(*Callee)->getFeatureBits();
40 
41   // Inline a callee if its target-features are a subset of the callers
42   // target-features.
43   return (CallerBits & CalleeBits) == CalleeBits;
44 }
45 
46 /// Calculate the cost of materializing a 64-bit value. This helper
47 /// method might only calculate a fraction of a larger immediate. Therefore it
48 /// is valid to return a cost of ZERO.
49 InstructionCost AArch64TTIImpl::getIntImmCost(int64_t Val) {
50   // Check if the immediate can be encoded within an instruction.
51   if (Val == 0 || AArch64_AM::isLogicalImmediate(Val, 64))
52     return 0;
53 
54   if (Val < 0)
55     Val = ~Val;
56 
57   // Calculate how many moves we will need to materialize this constant.
58   SmallVector<AArch64_IMM::ImmInsnModel, 4> Insn;
59   AArch64_IMM::expandMOVImm(Val, 64, Insn);
60   return Insn.size();
61 }
62 
63 /// Calculate the cost of materializing the given constant.
64 InstructionCost AArch64TTIImpl::getIntImmCost(const APInt &Imm, Type *Ty,
65                                               TTI::TargetCostKind CostKind) {
66   assert(Ty->isIntegerTy());
67 
68   unsigned BitSize = Ty->getPrimitiveSizeInBits();
69   if (BitSize == 0)
70     return ~0U;
71 
72   // Sign-extend all constants to a multiple of 64-bit.
73   APInt ImmVal = Imm;
74   if (BitSize & 0x3f)
75     ImmVal = Imm.sext((BitSize + 63) & ~0x3fU);
76 
77   // Split the constant into 64-bit chunks and calculate the cost for each
78   // chunk.
79   InstructionCost Cost = 0;
80   for (unsigned ShiftVal = 0; ShiftVal < BitSize; ShiftVal += 64) {
81     APInt Tmp = ImmVal.ashr(ShiftVal).sextOrTrunc(64);
82     int64_t Val = Tmp.getSExtValue();
83     Cost += getIntImmCost(Val);
84   }
85   // We need at least one instruction to materialze the constant.
86   return std::max<InstructionCost>(1, Cost);
87 }
88 
89 InstructionCost AArch64TTIImpl::getIntImmCostInst(unsigned Opcode, unsigned Idx,
90                                                   const APInt &Imm, Type *Ty,
91                                                   TTI::TargetCostKind CostKind,
92                                                   Instruction *Inst) {
93   assert(Ty->isIntegerTy());
94 
95   unsigned BitSize = Ty->getPrimitiveSizeInBits();
96   // There is no cost model for constants with a bit size of 0. Return TCC_Free
97   // here, so that constant hoisting will ignore this constant.
98   if (BitSize == 0)
99     return TTI::TCC_Free;
100 
101   unsigned ImmIdx = ~0U;
102   switch (Opcode) {
103   default:
104     return TTI::TCC_Free;
105   case Instruction::GetElementPtr:
106     // Always hoist the base address of a GetElementPtr.
107     if (Idx == 0)
108       return 2 * TTI::TCC_Basic;
109     return TTI::TCC_Free;
110   case Instruction::Store:
111     ImmIdx = 0;
112     break;
113   case Instruction::Add:
114   case Instruction::Sub:
115   case Instruction::Mul:
116   case Instruction::UDiv:
117   case Instruction::SDiv:
118   case Instruction::URem:
119   case Instruction::SRem:
120   case Instruction::And:
121   case Instruction::Or:
122   case Instruction::Xor:
123   case Instruction::ICmp:
124     ImmIdx = 1;
125     break;
126   // Always return TCC_Free for the shift value of a shift instruction.
127   case Instruction::Shl:
128   case Instruction::LShr:
129   case Instruction::AShr:
130     if (Idx == 1)
131       return TTI::TCC_Free;
132     break;
133   case Instruction::Trunc:
134   case Instruction::ZExt:
135   case Instruction::SExt:
136   case Instruction::IntToPtr:
137   case Instruction::PtrToInt:
138   case Instruction::BitCast:
139   case Instruction::PHI:
140   case Instruction::Call:
141   case Instruction::Select:
142   case Instruction::Ret:
143   case Instruction::Load:
144     break;
145   }
146 
147   if (Idx == ImmIdx) {
148     int NumConstants = (BitSize + 63) / 64;
149     InstructionCost Cost = AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind);
150     return (Cost <= NumConstants * TTI::TCC_Basic)
151                ? static_cast<int>(TTI::TCC_Free)
152                : Cost;
153   }
154   return AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind);
155 }
156 
157 InstructionCost
158 AArch64TTIImpl::getIntImmCostIntrin(Intrinsic::ID IID, unsigned Idx,
159                                     const APInt &Imm, Type *Ty,
160                                     TTI::TargetCostKind CostKind) {
161   assert(Ty->isIntegerTy());
162 
163   unsigned BitSize = Ty->getPrimitiveSizeInBits();
164   // There is no cost model for constants with a bit size of 0. Return TCC_Free
165   // here, so that constant hoisting will ignore this constant.
166   if (BitSize == 0)
167     return TTI::TCC_Free;
168 
169   // Most (all?) AArch64 intrinsics do not support folding immediates into the
170   // selected instruction, so we compute the materialization cost for the
171   // immediate directly.
172   if (IID >= Intrinsic::aarch64_addg && IID <= Intrinsic::aarch64_udiv)
173     return AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind);
174 
175   switch (IID) {
176   default:
177     return TTI::TCC_Free;
178   case Intrinsic::sadd_with_overflow:
179   case Intrinsic::uadd_with_overflow:
180   case Intrinsic::ssub_with_overflow:
181   case Intrinsic::usub_with_overflow:
182   case Intrinsic::smul_with_overflow:
183   case Intrinsic::umul_with_overflow:
184     if (Idx == 1) {
185       int NumConstants = (BitSize + 63) / 64;
186       InstructionCost Cost = AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind);
187       return (Cost <= NumConstants * TTI::TCC_Basic)
188                  ? static_cast<int>(TTI::TCC_Free)
189                  : Cost;
190     }
191     break;
192   case Intrinsic::experimental_stackmap:
193     if ((Idx < 2) || (Imm.getBitWidth() <= 64 && isInt<64>(Imm.getSExtValue())))
194       return TTI::TCC_Free;
195     break;
196   case Intrinsic::experimental_patchpoint_void:
197   case Intrinsic::experimental_patchpoint_i64:
198     if ((Idx < 4) || (Imm.getBitWidth() <= 64 && isInt<64>(Imm.getSExtValue())))
199       return TTI::TCC_Free;
200     break;
201   case Intrinsic::experimental_gc_statepoint:
202     if ((Idx < 5) || (Imm.getBitWidth() <= 64 && isInt<64>(Imm.getSExtValue())))
203       return TTI::TCC_Free;
204     break;
205   }
206   return AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind);
207 }
208 
209 TargetTransformInfo::PopcntSupportKind
210 AArch64TTIImpl::getPopcntSupport(unsigned TyWidth) {
211   assert(isPowerOf2_32(TyWidth) && "Ty width must be power of 2");
212   if (TyWidth == 32 || TyWidth == 64)
213     return TTI::PSK_FastHardware;
214   // TODO: AArch64TargetLowering::LowerCTPOP() supports 128bit popcount.
215   return TTI::PSK_Software;
216 }
217 
218 InstructionCost
219 AArch64TTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
220                                       TTI::TargetCostKind CostKind) {
221   auto *RetTy = ICA.getReturnType();
222   switch (ICA.getID()) {
223   case Intrinsic::umin:
224   case Intrinsic::umax:
225   case Intrinsic::smin:
226   case Intrinsic::smax: {
227     static const auto ValidMinMaxTys = {MVT::v8i8,  MVT::v16i8, MVT::v4i16,
228                                         MVT::v8i16, MVT::v2i32, MVT::v4i32};
229     auto LT = TLI->getTypeLegalizationCost(DL, RetTy);
230     // v2i64 types get converted to cmp+bif hence the cost of 2
231     if (LT.second == MVT::v2i64)
232       return LT.first * 2;
233     if (any_of(ValidMinMaxTys, [&LT](MVT M) { return M == LT.second; }))
234       return LT.first;
235     break;
236   }
237   case Intrinsic::sadd_sat:
238   case Intrinsic::ssub_sat:
239   case Intrinsic::uadd_sat:
240   case Intrinsic::usub_sat: {
241     static const auto ValidSatTys = {MVT::v8i8,  MVT::v16i8, MVT::v4i16,
242                                      MVT::v8i16, MVT::v2i32, MVT::v4i32,
243                                      MVT::v2i64};
244     auto LT = TLI->getTypeLegalizationCost(DL, RetTy);
245     // This is a base cost of 1 for the vadd, plus 3 extract shifts if we
246     // need to extend the type, as it uses shr(qadd(shl, shl)).
247     unsigned Instrs =
248         LT.second.getScalarSizeInBits() == RetTy->getScalarSizeInBits() ? 1 : 4;
249     if (any_of(ValidSatTys, [&LT](MVT M) { return M == LT.second; }))
250       return LT.first * Instrs;
251     break;
252   }
253   case Intrinsic::abs: {
254     static const auto ValidAbsTys = {MVT::v8i8,  MVT::v16i8, MVT::v4i16,
255                                      MVT::v8i16, MVT::v2i32, MVT::v4i32,
256                                      MVT::v2i64};
257     auto LT = TLI->getTypeLegalizationCost(DL, RetTy);
258     if (any_of(ValidAbsTys, [&LT](MVT M) { return M == LT.second; }))
259       return LT.first;
260     break;
261   }
262   case Intrinsic::experimental_stepvector: {
263     InstructionCost Cost = 1; // Cost of the `index' instruction
264     auto LT = TLI->getTypeLegalizationCost(DL, RetTy);
265     // Legalisation of illegal vectors involves an `index' instruction plus
266     // (LT.first - 1) vector adds.
267     if (LT.first > 1) {
268       Type *LegalVTy = EVT(LT.second).getTypeForEVT(RetTy->getContext());
269       InstructionCost AddCost =
270           getArithmeticInstrCost(Instruction::Add, LegalVTy, CostKind);
271       Cost += AddCost * (LT.first - 1);
272     }
273     return Cost;
274   }
275   case Intrinsic::bitreverse: {
276     static const CostTblEntry BitreverseTbl[] = {
277         {Intrinsic::bitreverse, MVT::i32, 1},
278         {Intrinsic::bitreverse, MVT::i64, 1},
279         {Intrinsic::bitreverse, MVT::v8i8, 1},
280         {Intrinsic::bitreverse, MVT::v16i8, 1},
281         {Intrinsic::bitreverse, MVT::v4i16, 2},
282         {Intrinsic::bitreverse, MVT::v8i16, 2},
283         {Intrinsic::bitreverse, MVT::v2i32, 2},
284         {Intrinsic::bitreverse, MVT::v4i32, 2},
285         {Intrinsic::bitreverse, MVT::v1i64, 2},
286         {Intrinsic::bitreverse, MVT::v2i64, 2},
287     };
288     const auto LegalisationCost = TLI->getTypeLegalizationCost(DL, RetTy);
289     const auto *Entry =
290         CostTableLookup(BitreverseTbl, ICA.getID(), LegalisationCost.second);
291     // Cost Model is using the legal type(i32) that i8 and i16 will be converted
292     // to +1 so that we match the actual lowering cost
293     if (TLI->getValueType(DL, RetTy, true) == MVT::i8 ||
294         TLI->getValueType(DL, RetTy, true) == MVT::i16)
295       return LegalisationCost.first * Entry->Cost + 1;
296     if (Entry)
297       return LegalisationCost.first * Entry->Cost;
298     break;
299   }
300   case Intrinsic::ctpop: {
301     static const CostTblEntry CtpopCostTbl[] = {
302         {ISD::CTPOP, MVT::v2i64, 4},
303         {ISD::CTPOP, MVT::v4i32, 3},
304         {ISD::CTPOP, MVT::v8i16, 2},
305         {ISD::CTPOP, MVT::v16i8, 1},
306         {ISD::CTPOP, MVT::i64,   4},
307         {ISD::CTPOP, MVT::v2i32, 3},
308         {ISD::CTPOP, MVT::v4i16, 2},
309         {ISD::CTPOP, MVT::v8i8,  1},
310         {ISD::CTPOP, MVT::i32,   5},
311     };
312     auto LT = TLI->getTypeLegalizationCost(DL, RetTy);
313     MVT MTy = LT.second;
314     if (const auto *Entry = CostTableLookup(CtpopCostTbl, ISD::CTPOP, MTy)) {
315       // Extra cost of +1 when illegal vector types are legalized by promoting
316       // the integer type.
317       int ExtraCost = MTy.isVector() && MTy.getScalarSizeInBits() !=
318                                             RetTy->getScalarSizeInBits()
319                           ? 1
320                           : 0;
321       return LT.first * Entry->Cost + ExtraCost;
322     }
323     break;
324   }
325   default:
326     break;
327   }
328   return BaseT::getIntrinsicInstrCost(ICA, CostKind);
329 }
330 
331 /// The function will remove redundant reinterprets casting in the presence
332 /// of the control flow
333 static Optional<Instruction *> processPhiNode(InstCombiner &IC,
334                                               IntrinsicInst &II) {
335   SmallVector<Instruction *, 32> Worklist;
336   auto RequiredType = II.getType();
337 
338   auto *PN = dyn_cast<PHINode>(II.getArgOperand(0));
339   assert(PN && "Expected Phi Node!");
340 
341   // Don't create a new Phi unless we can remove the old one.
342   if (!PN->hasOneUse())
343     return None;
344 
345   for (Value *IncValPhi : PN->incoming_values()) {
346     auto *Reinterpret = dyn_cast<IntrinsicInst>(IncValPhi);
347     if (!Reinterpret ||
348         Reinterpret->getIntrinsicID() !=
349             Intrinsic::aarch64_sve_convert_to_svbool ||
350         RequiredType != Reinterpret->getArgOperand(0)->getType())
351       return None;
352   }
353 
354   // Create the new Phi
355   LLVMContext &Ctx = PN->getContext();
356   IRBuilder<> Builder(Ctx);
357   Builder.SetInsertPoint(PN);
358   PHINode *NPN = Builder.CreatePHI(RequiredType, PN->getNumIncomingValues());
359   Worklist.push_back(PN);
360 
361   for (unsigned I = 0; I < PN->getNumIncomingValues(); I++) {
362     auto *Reinterpret = cast<Instruction>(PN->getIncomingValue(I));
363     NPN->addIncoming(Reinterpret->getOperand(0), PN->getIncomingBlock(I));
364     Worklist.push_back(Reinterpret);
365   }
366 
367   // Cleanup Phi Node and reinterprets
368   return IC.replaceInstUsesWith(II, NPN);
369 }
370 
371 static Optional<Instruction *> instCombineConvertFromSVBool(InstCombiner &IC,
372                                                             IntrinsicInst &II) {
373   // If the reinterpret instruction operand is a PHI Node
374   if (isa<PHINode>(II.getArgOperand(0)))
375     return processPhiNode(IC, II);
376 
377   SmallVector<Instruction *, 32> CandidatesForRemoval;
378   Value *Cursor = II.getOperand(0), *EarliestReplacement = nullptr;
379 
380   const auto *IVTy = cast<VectorType>(II.getType());
381 
382   // Walk the chain of conversions.
383   while (Cursor) {
384     // If the type of the cursor has fewer lanes than the final result, zeroing
385     // must take place, which breaks the equivalence chain.
386     const auto *CursorVTy = cast<VectorType>(Cursor->getType());
387     if (CursorVTy->getElementCount().getKnownMinValue() <
388         IVTy->getElementCount().getKnownMinValue())
389       break;
390 
391     // If the cursor has the same type as I, it is a viable replacement.
392     if (Cursor->getType() == IVTy)
393       EarliestReplacement = Cursor;
394 
395     auto *IntrinsicCursor = dyn_cast<IntrinsicInst>(Cursor);
396 
397     // If this is not an SVE conversion intrinsic, this is the end of the chain.
398     if (!IntrinsicCursor || !(IntrinsicCursor->getIntrinsicID() ==
399                                   Intrinsic::aarch64_sve_convert_to_svbool ||
400                               IntrinsicCursor->getIntrinsicID() ==
401                                   Intrinsic::aarch64_sve_convert_from_svbool))
402       break;
403 
404     CandidatesForRemoval.insert(CandidatesForRemoval.begin(), IntrinsicCursor);
405     Cursor = IntrinsicCursor->getOperand(0);
406   }
407 
408   // If no viable replacement in the conversion chain was found, there is
409   // nothing to do.
410   if (!EarliestReplacement)
411     return None;
412 
413   return IC.replaceInstUsesWith(II, EarliestReplacement);
414 }
415 
416 static Optional<Instruction *> instCombineSVEDup(InstCombiner &IC,
417                                                  IntrinsicInst &II) {
418   IntrinsicInst *Pg = dyn_cast<IntrinsicInst>(II.getArgOperand(1));
419   if (!Pg)
420     return None;
421 
422   if (Pg->getIntrinsicID() != Intrinsic::aarch64_sve_ptrue)
423     return None;
424 
425   const auto PTruePattern =
426       cast<ConstantInt>(Pg->getOperand(0))->getZExtValue();
427   if (PTruePattern != AArch64SVEPredPattern::vl1)
428     return None;
429 
430   // The intrinsic is inserting into lane zero so use an insert instead.
431   auto *IdxTy = Type::getInt64Ty(II.getContext());
432   auto *Insert = InsertElementInst::Create(
433       II.getArgOperand(0), II.getArgOperand(2), ConstantInt::get(IdxTy, 0));
434   Insert->insertBefore(&II);
435   Insert->takeName(&II);
436 
437   return IC.replaceInstUsesWith(II, Insert);
438 }
439 
440 static Optional<Instruction *> instCombineSVECmpNE(InstCombiner &IC,
441                                                    IntrinsicInst &II) {
442   LLVMContext &Ctx = II.getContext();
443   IRBuilder<> Builder(Ctx);
444   Builder.SetInsertPoint(&II);
445 
446   // Check that the predicate is all active
447   auto *Pg = dyn_cast<IntrinsicInst>(II.getArgOperand(0));
448   if (!Pg || Pg->getIntrinsicID() != Intrinsic::aarch64_sve_ptrue)
449     return None;
450 
451   const auto PTruePattern =
452       cast<ConstantInt>(Pg->getOperand(0))->getZExtValue();
453   if (PTruePattern != AArch64SVEPredPattern::all)
454     return None;
455 
456   // Check that we have a compare of zero..
457   auto *DupX = dyn_cast<IntrinsicInst>(II.getArgOperand(2));
458   if (!DupX || DupX->getIntrinsicID() != Intrinsic::aarch64_sve_dup_x)
459     return None;
460 
461   auto *DupXArg = dyn_cast<ConstantInt>(DupX->getArgOperand(0));
462   if (!DupXArg || !DupXArg->isZero())
463     return None;
464 
465   // ..against a dupq
466   auto *DupQLane = dyn_cast<IntrinsicInst>(II.getArgOperand(1));
467   if (!DupQLane ||
468       DupQLane->getIntrinsicID() != Intrinsic::aarch64_sve_dupq_lane)
469     return None;
470 
471   // Where the dupq is a lane 0 replicate of a vector insert
472   if (!cast<ConstantInt>(DupQLane->getArgOperand(1))->isZero())
473     return None;
474 
475   auto *VecIns = dyn_cast<IntrinsicInst>(DupQLane->getArgOperand(0));
476   if (!VecIns ||
477       VecIns->getIntrinsicID() != Intrinsic::experimental_vector_insert)
478     return None;
479 
480   // Where the vector insert is a fixed constant vector insert into undef at
481   // index zero
482   if (!isa<UndefValue>(VecIns->getArgOperand(0)))
483     return None;
484 
485   if (!cast<ConstantInt>(VecIns->getArgOperand(2))->isZero())
486     return None;
487 
488   auto *ConstVec = dyn_cast<Constant>(VecIns->getArgOperand(1));
489   if (!ConstVec)
490     return None;
491 
492   auto *VecTy = dyn_cast<FixedVectorType>(ConstVec->getType());
493   auto *OutTy = dyn_cast<ScalableVectorType>(II.getType());
494   if (!VecTy || !OutTy || VecTy->getNumElements() != OutTy->getMinNumElements())
495     return None;
496 
497   unsigned NumElts = VecTy->getNumElements();
498   unsigned PredicateBits = 0;
499 
500   // Expand intrinsic operands to a 16-bit byte level predicate
501   for (unsigned I = 0; I < NumElts; ++I) {
502     auto *Arg = dyn_cast<ConstantInt>(ConstVec->getAggregateElement(I));
503     if (!Arg)
504       return None;
505     if (!Arg->isZero())
506       PredicateBits |= 1 << (I * (16 / NumElts));
507   }
508 
509   // If all bits are zero bail early with an empty predicate
510   if (PredicateBits == 0) {
511     auto *PFalse = Constant::getNullValue(II.getType());
512     PFalse->takeName(&II);
513     return IC.replaceInstUsesWith(II, PFalse);
514   }
515 
516   // Calculate largest predicate type used (where byte predicate is largest)
517   unsigned Mask = 8;
518   for (unsigned I = 0; I < 16; ++I)
519     if ((PredicateBits & (1 << I)) != 0)
520       Mask |= (I % 8);
521 
522   unsigned PredSize = Mask & -Mask;
523   auto *PredType = ScalableVectorType::get(
524       Type::getInt1Ty(Ctx), AArch64::SVEBitsPerBlock / (PredSize * 8));
525 
526   // Ensure all relevant bits are set
527   for (unsigned I = 0; I < 16; I += PredSize)
528     if ((PredicateBits & (1 << I)) == 0)
529       return None;
530 
531   auto *PTruePat =
532       ConstantInt::get(Type::getInt32Ty(Ctx), AArch64SVEPredPattern::all);
533   auto *PTrue = Builder.CreateIntrinsic(Intrinsic::aarch64_sve_ptrue,
534                                         {PredType}, {PTruePat});
535   auto *ConvertToSVBool = Builder.CreateIntrinsic(
536       Intrinsic::aarch64_sve_convert_to_svbool, {PredType}, {PTrue});
537   auto *ConvertFromSVBool =
538       Builder.CreateIntrinsic(Intrinsic::aarch64_sve_convert_from_svbool,
539                               {II.getType()}, {ConvertToSVBool});
540 
541   ConvertFromSVBool->takeName(&II);
542   return IC.replaceInstUsesWith(II, ConvertFromSVBool);
543 }
544 
545 static Optional<Instruction *> instCombineSVELast(InstCombiner &IC,
546                                                   IntrinsicInst &II) {
547   IRBuilder<> Builder(II.getContext());
548   Builder.SetInsertPoint(&II);
549   Value *Pg = II.getArgOperand(0);
550   Value *Vec = II.getArgOperand(1);
551   auto IntrinsicID = II.getIntrinsicID();
552   bool IsAfter = IntrinsicID == Intrinsic::aarch64_sve_lasta;
553 
554   // lastX(splat(X)) --> X
555   if (auto *SplatVal = getSplatValue(Vec))
556     return IC.replaceInstUsesWith(II, SplatVal);
557 
558   // If x and/or y is a splat value then:
559   // lastX (binop (x, y)) --> binop(lastX(x), lastX(y))
560   Value *LHS, *RHS;
561   if (match(Vec, m_OneUse(m_BinOp(m_Value(LHS), m_Value(RHS))))) {
562     if (isSplatValue(LHS) || isSplatValue(RHS)) {
563       auto *OldBinOp = cast<BinaryOperator>(Vec);
564       auto OpC = OldBinOp->getOpcode();
565       auto *NewLHS =
566           Builder.CreateIntrinsic(IntrinsicID, {Vec->getType()}, {Pg, LHS});
567       auto *NewRHS =
568           Builder.CreateIntrinsic(IntrinsicID, {Vec->getType()}, {Pg, RHS});
569       auto *NewBinOp = BinaryOperator::CreateWithCopiedFlags(
570           OpC, NewLHS, NewRHS, OldBinOp, OldBinOp->getName(), &II);
571       return IC.replaceInstUsesWith(II, NewBinOp);
572     }
573   }
574 
575   auto *C = dyn_cast<Constant>(Pg);
576   if (IsAfter && C && C->isNullValue()) {
577     // The intrinsic is extracting lane 0 so use an extract instead.
578     auto *IdxTy = Type::getInt64Ty(II.getContext());
579     auto *Extract = ExtractElementInst::Create(Vec, ConstantInt::get(IdxTy, 0));
580     Extract->insertBefore(&II);
581     Extract->takeName(&II);
582     return IC.replaceInstUsesWith(II, Extract);
583   }
584 
585   auto *IntrPG = dyn_cast<IntrinsicInst>(Pg);
586   if (!IntrPG)
587     return None;
588 
589   if (IntrPG->getIntrinsicID() != Intrinsic::aarch64_sve_ptrue)
590     return None;
591 
592   const auto PTruePattern =
593       cast<ConstantInt>(IntrPG->getOperand(0))->getZExtValue();
594 
595   // Can the intrinsic's predicate be converted to a known constant index?
596   unsigned MinNumElts = getNumElementsFromSVEPredPattern(PTruePattern);
597   if (!MinNumElts)
598     return None;
599 
600   unsigned Idx = MinNumElts - 1;
601   // Increment the index if extracting the element after the last active
602   // predicate element.
603   if (IsAfter)
604     ++Idx;
605 
606   // Ignore extracts whose index is larger than the known minimum vector
607   // length. NOTE: This is an artificial constraint where we prefer to
608   // maintain what the user asked for until an alternative is proven faster.
609   auto *PgVTy = cast<ScalableVectorType>(Pg->getType());
610   if (Idx >= PgVTy->getMinNumElements())
611     return None;
612 
613   // The intrinsic is extracting a fixed lane so use an extract instead.
614   auto *IdxTy = Type::getInt64Ty(II.getContext());
615   auto *Extract = ExtractElementInst::Create(Vec, ConstantInt::get(IdxTy, Idx));
616   Extract->insertBefore(&II);
617   Extract->takeName(&II);
618   return IC.replaceInstUsesWith(II, Extract);
619 }
620 
621 static Optional<Instruction *> instCombineRDFFR(InstCombiner &IC,
622                                                 IntrinsicInst &II) {
623   LLVMContext &Ctx = II.getContext();
624   IRBuilder<> Builder(Ctx);
625   Builder.SetInsertPoint(&II);
626   // Replace rdffr with predicated rdffr.z intrinsic, so that optimizePTestInstr
627   // can work with RDFFR_PP for ptest elimination.
628   auto *AllPat =
629       ConstantInt::get(Type::getInt32Ty(Ctx), AArch64SVEPredPattern::all);
630   auto *PTrue = Builder.CreateIntrinsic(Intrinsic::aarch64_sve_ptrue,
631                                         {II.getType()}, {AllPat});
632   auto *RDFFR =
633       Builder.CreateIntrinsic(Intrinsic::aarch64_sve_rdffr_z, {}, {PTrue});
634   RDFFR->takeName(&II);
635   return IC.replaceInstUsesWith(II, RDFFR);
636 }
637 
638 static Optional<Instruction *>
639 instCombineSVECntElts(InstCombiner &IC, IntrinsicInst &II, unsigned NumElts) {
640   const auto Pattern = cast<ConstantInt>(II.getArgOperand(0))->getZExtValue();
641 
642   if (Pattern == AArch64SVEPredPattern::all) {
643     LLVMContext &Ctx = II.getContext();
644     IRBuilder<> Builder(Ctx);
645     Builder.SetInsertPoint(&II);
646 
647     Constant *StepVal = ConstantInt::get(II.getType(), NumElts);
648     auto *VScale = Builder.CreateVScale(StepVal);
649     VScale->takeName(&II);
650     return IC.replaceInstUsesWith(II, VScale);
651   }
652 
653   unsigned MinNumElts = getNumElementsFromSVEPredPattern(Pattern);
654 
655   return MinNumElts && NumElts >= MinNumElts
656              ? Optional<Instruction *>(IC.replaceInstUsesWith(
657                    II, ConstantInt::get(II.getType(), MinNumElts)))
658              : None;
659 }
660 
661 static Optional<Instruction *> instCombineSVEPTest(InstCombiner &IC,
662                                                    IntrinsicInst &II) {
663   IntrinsicInst *Op1 = dyn_cast<IntrinsicInst>(II.getArgOperand(0));
664   IntrinsicInst *Op2 = dyn_cast<IntrinsicInst>(II.getArgOperand(1));
665 
666   if (Op1 && Op2 &&
667       Op1->getIntrinsicID() == Intrinsic::aarch64_sve_convert_to_svbool &&
668       Op2->getIntrinsicID() == Intrinsic::aarch64_sve_convert_to_svbool &&
669       Op1->getArgOperand(0)->getType() == Op2->getArgOperand(0)->getType()) {
670 
671     IRBuilder<> Builder(II.getContext());
672     Builder.SetInsertPoint(&II);
673 
674     Value *Ops[] = {Op1->getArgOperand(0), Op2->getArgOperand(0)};
675     Type *Tys[] = {Op1->getArgOperand(0)->getType()};
676 
677     auto *PTest = Builder.CreateIntrinsic(II.getIntrinsicID(), Tys, Ops);
678 
679     PTest->takeName(&II);
680     return IC.replaceInstUsesWith(II, PTest);
681   }
682 
683   return None;
684 }
685 
686 static Optional<Instruction *> instCombineSVEVectorMul(InstCombiner &IC,
687                                                        IntrinsicInst &II) {
688   auto *OpPredicate = II.getOperand(0);
689   auto *OpMultiplicand = II.getOperand(1);
690   auto *OpMultiplier = II.getOperand(2);
691 
692   IRBuilder<> Builder(II.getContext());
693   Builder.SetInsertPoint(&II);
694 
695   // Return true if a given instruction is an aarch64_sve_dup_x intrinsic call
696   // with a unit splat value, false otherwise.
697   auto IsUnitDupX = [](auto *I) {
698     auto *IntrI = dyn_cast<IntrinsicInst>(I);
699     if (!IntrI || IntrI->getIntrinsicID() != Intrinsic::aarch64_sve_dup_x)
700       return false;
701 
702     auto *SplatValue = IntrI->getOperand(0);
703     return match(SplatValue, m_FPOne()) || match(SplatValue, m_One());
704   };
705 
706   // Return true if a given instruction is an aarch64_sve_dup intrinsic call
707   // with a unit splat value, false otherwise.
708   auto IsUnitDup = [](auto *I) {
709     auto *IntrI = dyn_cast<IntrinsicInst>(I);
710     if (!IntrI || IntrI->getIntrinsicID() != Intrinsic::aarch64_sve_dup)
711       return false;
712 
713     auto *SplatValue = IntrI->getOperand(2);
714     return match(SplatValue, m_FPOne()) || match(SplatValue, m_One());
715   };
716 
717   // The OpMultiplier variable should always point to the dup (if any), so
718   // swap if necessary.
719   if (IsUnitDup(OpMultiplicand) || IsUnitDupX(OpMultiplicand))
720     std::swap(OpMultiplier, OpMultiplicand);
721 
722   if (IsUnitDupX(OpMultiplier)) {
723     // [f]mul pg (dupx 1) %n => %n
724     OpMultiplicand->takeName(&II);
725     return IC.replaceInstUsesWith(II, OpMultiplicand);
726   } else if (IsUnitDup(OpMultiplier)) {
727     // [f]mul pg (dup pg 1) %n => %n
728     auto *DupInst = cast<IntrinsicInst>(OpMultiplier);
729     auto *DupPg = DupInst->getOperand(1);
730     // TODO: this is naive. The optimization is still valid if DupPg
731     // 'encompasses' OpPredicate, not only if they're the same predicate.
732     if (OpPredicate == DupPg) {
733       OpMultiplicand->takeName(&II);
734       return IC.replaceInstUsesWith(II, OpMultiplicand);
735     }
736   }
737 
738   return None;
739 }
740 
741 static Optional<Instruction *> instCombineSVEUnpack(InstCombiner &IC,
742                                                     IntrinsicInst &II) {
743   IRBuilder<> Builder(II.getContext());
744   Builder.SetInsertPoint(&II);
745   Value *UnpackArg = II.getArgOperand(0);
746   auto *RetTy = cast<ScalableVectorType>(II.getType());
747   bool IsSigned = II.getIntrinsicID() == Intrinsic::aarch64_sve_sunpkhi ||
748                   II.getIntrinsicID() == Intrinsic::aarch64_sve_sunpklo;
749 
750   // Hi = uunpkhi(splat(X)) --> Hi = splat(extend(X))
751   // Lo = uunpklo(splat(X)) --> Lo = splat(extend(X))
752   if (auto *ScalarArg = getSplatValue(UnpackArg)) {
753     ScalarArg =
754         Builder.CreateIntCast(ScalarArg, RetTy->getScalarType(), IsSigned);
755     Value *NewVal =
756         Builder.CreateVectorSplat(RetTy->getElementCount(), ScalarArg);
757     NewVal->takeName(&II);
758     return IC.replaceInstUsesWith(II, NewVal);
759   }
760 
761   return None;
762 }
763 static Optional<Instruction *> instCombineSVETBL(InstCombiner &IC,
764                                                  IntrinsicInst &II) {
765   auto *OpVal = II.getOperand(0);
766   auto *OpIndices = II.getOperand(1);
767   VectorType *VTy = cast<VectorType>(II.getType());
768 
769   // Check whether OpIndices is an aarch64_sve_dup_x intrinsic call with
770   // constant splat value < minimal element count of result.
771   auto *DupXIntrI = dyn_cast<IntrinsicInst>(OpIndices);
772   if (!DupXIntrI || DupXIntrI->getIntrinsicID() != Intrinsic::aarch64_sve_dup_x)
773     return None;
774 
775   auto *SplatValue = dyn_cast<ConstantInt>(DupXIntrI->getOperand(0));
776   if (!SplatValue ||
777       SplatValue->getValue().uge(VTy->getElementCount().getKnownMinValue()))
778     return None;
779 
780   // Convert sve_tbl(OpVal sve_dup_x(SplatValue)) to
781   // splat_vector(extractelement(OpVal, SplatValue)) for further optimization.
782   IRBuilder<> Builder(II.getContext());
783   Builder.SetInsertPoint(&II);
784   auto *Extract = Builder.CreateExtractElement(OpVal, SplatValue);
785   auto *VectorSplat =
786       Builder.CreateVectorSplat(VTy->getElementCount(), Extract);
787 
788   VectorSplat->takeName(&II);
789   return IC.replaceInstUsesWith(II, VectorSplat);
790 }
791 
792 Optional<Instruction *>
793 AArch64TTIImpl::instCombineIntrinsic(InstCombiner &IC,
794                                      IntrinsicInst &II) const {
795   Intrinsic::ID IID = II.getIntrinsicID();
796   switch (IID) {
797   default:
798     break;
799   case Intrinsic::aarch64_sve_convert_from_svbool:
800     return instCombineConvertFromSVBool(IC, II);
801   case Intrinsic::aarch64_sve_dup:
802     return instCombineSVEDup(IC, II);
803   case Intrinsic::aarch64_sve_cmpne:
804   case Intrinsic::aarch64_sve_cmpne_wide:
805     return instCombineSVECmpNE(IC, II);
806   case Intrinsic::aarch64_sve_rdffr:
807     return instCombineRDFFR(IC, II);
808   case Intrinsic::aarch64_sve_lasta:
809   case Intrinsic::aarch64_sve_lastb:
810     return instCombineSVELast(IC, II);
811   case Intrinsic::aarch64_sve_cntd:
812     return instCombineSVECntElts(IC, II, 2);
813   case Intrinsic::aarch64_sve_cntw:
814     return instCombineSVECntElts(IC, II, 4);
815   case Intrinsic::aarch64_sve_cnth:
816     return instCombineSVECntElts(IC, II, 8);
817   case Intrinsic::aarch64_sve_cntb:
818     return instCombineSVECntElts(IC, II, 16);
819   case Intrinsic::aarch64_sve_ptest_any:
820   case Intrinsic::aarch64_sve_ptest_first:
821   case Intrinsic::aarch64_sve_ptest_last:
822     return instCombineSVEPTest(IC, II);
823   case Intrinsic::aarch64_sve_mul:
824   case Intrinsic::aarch64_sve_fmul:
825     return instCombineSVEVectorMul(IC, II);
826   case Intrinsic::aarch64_sve_tbl:
827     return instCombineSVETBL(IC, II);
828   case Intrinsic::aarch64_sve_uunpkhi:
829   case Intrinsic::aarch64_sve_uunpklo:
830   case Intrinsic::aarch64_sve_sunpkhi:
831   case Intrinsic::aarch64_sve_sunpklo:
832     return instCombineSVEUnpack(IC, II);
833   }
834 
835   return None;
836 }
837 
838 bool AArch64TTIImpl::isWideningInstruction(Type *DstTy, unsigned Opcode,
839                                            ArrayRef<const Value *> Args) {
840 
841   // A helper that returns a vector type from the given type. The number of
842   // elements in type Ty determine the vector width.
843   auto toVectorTy = [&](Type *ArgTy) {
844     return VectorType::get(ArgTy->getScalarType(),
845                            cast<VectorType>(DstTy)->getElementCount());
846   };
847 
848   // Exit early if DstTy is not a vector type whose elements are at least
849   // 16-bits wide.
850   if (!DstTy->isVectorTy() || DstTy->getScalarSizeInBits() < 16)
851     return false;
852 
853   // Determine if the operation has a widening variant. We consider both the
854   // "long" (e.g., usubl) and "wide" (e.g., usubw) versions of the
855   // instructions.
856   //
857   // TODO: Add additional widening operations (e.g., mul, shl, etc.) once we
858   //       verify that their extending operands are eliminated during code
859   //       generation.
860   switch (Opcode) {
861   case Instruction::Add: // UADDL(2), SADDL(2), UADDW(2), SADDW(2).
862   case Instruction::Sub: // USUBL(2), SSUBL(2), USUBW(2), SSUBW(2).
863     break;
864   default:
865     return false;
866   }
867 
868   // To be a widening instruction (either the "wide" or "long" versions), the
869   // second operand must be a sign- or zero extend having a single user. We
870   // only consider extends having a single user because they may otherwise not
871   // be eliminated.
872   if (Args.size() != 2 ||
873       (!isa<SExtInst>(Args[1]) && !isa<ZExtInst>(Args[1])) ||
874       !Args[1]->hasOneUse())
875     return false;
876   auto *Extend = cast<CastInst>(Args[1]);
877 
878   // Legalize the destination type and ensure it can be used in a widening
879   // operation.
880   auto DstTyL = TLI->getTypeLegalizationCost(DL, DstTy);
881   unsigned DstElTySize = DstTyL.second.getScalarSizeInBits();
882   if (!DstTyL.second.isVector() || DstElTySize != DstTy->getScalarSizeInBits())
883     return false;
884 
885   // Legalize the source type and ensure it can be used in a widening
886   // operation.
887   auto *SrcTy = toVectorTy(Extend->getSrcTy());
888   auto SrcTyL = TLI->getTypeLegalizationCost(DL, SrcTy);
889   unsigned SrcElTySize = SrcTyL.second.getScalarSizeInBits();
890   if (!SrcTyL.second.isVector() || SrcElTySize != SrcTy->getScalarSizeInBits())
891     return false;
892 
893   // Get the total number of vector elements in the legalized types.
894   InstructionCost NumDstEls =
895       DstTyL.first * DstTyL.second.getVectorMinNumElements();
896   InstructionCost NumSrcEls =
897       SrcTyL.first * SrcTyL.second.getVectorMinNumElements();
898 
899   // Return true if the legalized types have the same number of vector elements
900   // and the destination element type size is twice that of the source type.
901   return NumDstEls == NumSrcEls && 2 * SrcElTySize == DstElTySize;
902 }
903 
904 InstructionCost AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
905                                                  Type *Src,
906                                                  TTI::CastContextHint CCH,
907                                                  TTI::TargetCostKind CostKind,
908                                                  const Instruction *I) {
909   int ISD = TLI->InstructionOpcodeToISD(Opcode);
910   assert(ISD && "Invalid opcode");
911 
912   // If the cast is observable, and it is used by a widening instruction (e.g.,
913   // uaddl, saddw, etc.), it may be free.
914   if (I && I->hasOneUse()) {
915     auto *SingleUser = cast<Instruction>(*I->user_begin());
916     SmallVector<const Value *, 4> Operands(SingleUser->operand_values());
917     if (isWideningInstruction(Dst, SingleUser->getOpcode(), Operands)) {
918       // If the cast is the second operand, it is free. We will generate either
919       // a "wide" or "long" version of the widening instruction.
920       if (I == SingleUser->getOperand(1))
921         return 0;
922       // If the cast is not the second operand, it will be free if it looks the
923       // same as the second operand. In this case, we will generate a "long"
924       // version of the widening instruction.
925       if (auto *Cast = dyn_cast<CastInst>(SingleUser->getOperand(1)))
926         if (I->getOpcode() == unsigned(Cast->getOpcode()) &&
927             cast<CastInst>(I)->getSrcTy() == Cast->getSrcTy())
928           return 0;
929     }
930   }
931 
932   // TODO: Allow non-throughput costs that aren't binary.
933   auto AdjustCost = [&CostKind](InstructionCost Cost) -> InstructionCost {
934     if (CostKind != TTI::TCK_RecipThroughput)
935       return Cost == 0 ? 0 : 1;
936     return Cost;
937   };
938 
939   EVT SrcTy = TLI->getValueType(DL, Src);
940   EVT DstTy = TLI->getValueType(DL, Dst);
941 
942   if (!SrcTy.isSimple() || !DstTy.isSimple())
943     return AdjustCost(
944         BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I));
945 
946   static const TypeConversionCostTblEntry
947   ConversionTbl[] = {
948     { ISD::TRUNCATE, MVT::v4i16, MVT::v4i32,  1 },
949     { ISD::TRUNCATE, MVT::v4i32, MVT::v4i64,  0 },
950     { ISD::TRUNCATE, MVT::v8i8,  MVT::v8i32,  3 },
951     { ISD::TRUNCATE, MVT::v16i8, MVT::v16i32, 6 },
952 
953     // Truncations on nxvmiN
954     { ISD::TRUNCATE, MVT::nxv2i1, MVT::nxv2i16, 1 },
955     { ISD::TRUNCATE, MVT::nxv2i1, MVT::nxv2i32, 1 },
956     { ISD::TRUNCATE, MVT::nxv2i1, MVT::nxv2i64, 1 },
957     { ISD::TRUNCATE, MVT::nxv4i1, MVT::nxv4i16, 1 },
958     { ISD::TRUNCATE, MVT::nxv4i1, MVT::nxv4i32, 1 },
959     { ISD::TRUNCATE, MVT::nxv4i1, MVT::nxv4i64, 2 },
960     { ISD::TRUNCATE, MVT::nxv8i1, MVT::nxv8i16, 1 },
961     { ISD::TRUNCATE, MVT::nxv8i1, MVT::nxv8i32, 3 },
962     { ISD::TRUNCATE, MVT::nxv8i1, MVT::nxv8i64, 5 },
963     { ISD::TRUNCATE, MVT::nxv16i1, MVT::nxv16i8, 1 },
964     { ISD::TRUNCATE, MVT::nxv2i16, MVT::nxv2i32, 1 },
965     { ISD::TRUNCATE, MVT::nxv2i32, MVT::nxv2i64, 1 },
966     { ISD::TRUNCATE, MVT::nxv4i16, MVT::nxv4i32, 1 },
967     { ISD::TRUNCATE, MVT::nxv4i32, MVT::nxv4i64, 2 },
968     { ISD::TRUNCATE, MVT::nxv8i16, MVT::nxv8i32, 3 },
969     { ISD::TRUNCATE, MVT::nxv8i32, MVT::nxv8i64, 6 },
970 
971     // The number of shll instructions for the extension.
972     { ISD::SIGN_EXTEND, MVT::v4i64,  MVT::v4i16, 3 },
973     { ISD::ZERO_EXTEND, MVT::v4i64,  MVT::v4i16, 3 },
974     { ISD::SIGN_EXTEND, MVT::v4i64,  MVT::v4i32, 2 },
975     { ISD::ZERO_EXTEND, MVT::v4i64,  MVT::v4i32, 2 },
976     { ISD::SIGN_EXTEND, MVT::v8i32,  MVT::v8i8,  3 },
977     { ISD::ZERO_EXTEND, MVT::v8i32,  MVT::v8i8,  3 },
978     { ISD::SIGN_EXTEND, MVT::v8i32,  MVT::v8i16, 2 },
979     { ISD::ZERO_EXTEND, MVT::v8i32,  MVT::v8i16, 2 },
980     { ISD::SIGN_EXTEND, MVT::v8i64,  MVT::v8i8,  7 },
981     { ISD::ZERO_EXTEND, MVT::v8i64,  MVT::v8i8,  7 },
982     { ISD::SIGN_EXTEND, MVT::v8i64,  MVT::v8i16, 6 },
983     { ISD::ZERO_EXTEND, MVT::v8i64,  MVT::v8i16, 6 },
984     { ISD::SIGN_EXTEND, MVT::v16i16, MVT::v16i8, 2 },
985     { ISD::ZERO_EXTEND, MVT::v16i16, MVT::v16i8, 2 },
986     { ISD::SIGN_EXTEND, MVT::v16i32, MVT::v16i8, 6 },
987     { ISD::ZERO_EXTEND, MVT::v16i32, MVT::v16i8, 6 },
988 
989     // LowerVectorINT_TO_FP:
990     { ISD::SINT_TO_FP, MVT::v2f32, MVT::v2i32, 1 },
991     { ISD::SINT_TO_FP, MVT::v4f32, MVT::v4i32, 1 },
992     { ISD::SINT_TO_FP, MVT::v2f64, MVT::v2i64, 1 },
993     { ISD::UINT_TO_FP, MVT::v2f32, MVT::v2i32, 1 },
994     { ISD::UINT_TO_FP, MVT::v4f32, MVT::v4i32, 1 },
995     { ISD::UINT_TO_FP, MVT::v2f64, MVT::v2i64, 1 },
996 
997     // Complex: to v2f32
998     { ISD::SINT_TO_FP, MVT::v2f32, MVT::v2i8,  3 },
999     { ISD::SINT_TO_FP, MVT::v2f32, MVT::v2i16, 3 },
1000     { ISD::SINT_TO_FP, MVT::v2f32, MVT::v2i64, 2 },
1001     { ISD::UINT_TO_FP, MVT::v2f32, MVT::v2i8,  3 },
1002     { ISD::UINT_TO_FP, MVT::v2f32, MVT::v2i16, 3 },
1003     { ISD::UINT_TO_FP, MVT::v2f32, MVT::v2i64, 2 },
1004 
1005     // Complex: to v4f32
1006     { ISD::SINT_TO_FP, MVT::v4f32, MVT::v4i8,  4 },
1007     { ISD::SINT_TO_FP, MVT::v4f32, MVT::v4i16, 2 },
1008     { ISD::UINT_TO_FP, MVT::v4f32, MVT::v4i8,  3 },
1009     { ISD::UINT_TO_FP, MVT::v4f32, MVT::v4i16, 2 },
1010 
1011     // Complex: to v8f32
1012     { ISD::SINT_TO_FP, MVT::v8f32, MVT::v8i8,  10 },
1013     { ISD::SINT_TO_FP, MVT::v8f32, MVT::v8i16, 4 },
1014     { ISD::UINT_TO_FP, MVT::v8f32, MVT::v8i8,  10 },
1015     { ISD::UINT_TO_FP, MVT::v8f32, MVT::v8i16, 4 },
1016 
1017     // Complex: to v16f32
1018     { ISD::SINT_TO_FP, MVT::v16f32, MVT::v16i8, 21 },
1019     { ISD::UINT_TO_FP, MVT::v16f32, MVT::v16i8, 21 },
1020 
1021     // Complex: to v2f64
1022     { ISD::SINT_TO_FP, MVT::v2f64, MVT::v2i8,  4 },
1023     { ISD::SINT_TO_FP, MVT::v2f64, MVT::v2i16, 4 },
1024     { ISD::SINT_TO_FP, MVT::v2f64, MVT::v2i32, 2 },
1025     { ISD::UINT_TO_FP, MVT::v2f64, MVT::v2i8,  4 },
1026     { ISD::UINT_TO_FP, MVT::v2f64, MVT::v2i16, 4 },
1027     { ISD::UINT_TO_FP, MVT::v2f64, MVT::v2i32, 2 },
1028 
1029 
1030     // LowerVectorFP_TO_INT
1031     { ISD::FP_TO_SINT, MVT::v2i32, MVT::v2f32, 1 },
1032     { ISD::FP_TO_SINT, MVT::v4i32, MVT::v4f32, 1 },
1033     { ISD::FP_TO_SINT, MVT::v2i64, MVT::v2f64, 1 },
1034     { ISD::FP_TO_UINT, MVT::v2i32, MVT::v2f32, 1 },
1035     { ISD::FP_TO_UINT, MVT::v4i32, MVT::v4f32, 1 },
1036     { ISD::FP_TO_UINT, MVT::v2i64, MVT::v2f64, 1 },
1037 
1038     // Complex, from v2f32: legal type is v2i32 (no cost) or v2i64 (1 ext).
1039     { ISD::FP_TO_SINT, MVT::v2i64, MVT::v2f32, 2 },
1040     { ISD::FP_TO_SINT, MVT::v2i16, MVT::v2f32, 1 },
1041     { ISD::FP_TO_SINT, MVT::v2i8,  MVT::v2f32, 1 },
1042     { ISD::FP_TO_UINT, MVT::v2i64, MVT::v2f32, 2 },
1043     { ISD::FP_TO_UINT, MVT::v2i16, MVT::v2f32, 1 },
1044     { ISD::FP_TO_UINT, MVT::v2i8,  MVT::v2f32, 1 },
1045 
1046     // Complex, from v4f32: legal type is v4i16, 1 narrowing => ~2
1047     { ISD::FP_TO_SINT, MVT::v4i16, MVT::v4f32, 2 },
1048     { ISD::FP_TO_SINT, MVT::v4i8,  MVT::v4f32, 2 },
1049     { ISD::FP_TO_UINT, MVT::v4i16, MVT::v4f32, 2 },
1050     { ISD::FP_TO_UINT, MVT::v4i8,  MVT::v4f32, 2 },
1051 
1052     // Complex, from nxv2f32.
1053     { ISD::FP_TO_SINT, MVT::nxv2i64, MVT::nxv2f32, 1 },
1054     { ISD::FP_TO_SINT, MVT::nxv2i32, MVT::nxv2f32, 1 },
1055     { ISD::FP_TO_SINT, MVT::nxv2i16, MVT::nxv2f32, 1 },
1056     { ISD::FP_TO_SINT, MVT::nxv2i8,  MVT::nxv2f32, 1 },
1057     { ISD::FP_TO_UINT, MVT::nxv2i64, MVT::nxv2f32, 1 },
1058     { ISD::FP_TO_UINT, MVT::nxv2i32, MVT::nxv2f32, 1 },
1059     { ISD::FP_TO_UINT, MVT::nxv2i16, MVT::nxv2f32, 1 },
1060     { ISD::FP_TO_UINT, MVT::nxv2i8,  MVT::nxv2f32, 1 },
1061 
1062     // Complex, from v2f64: legal type is v2i32, 1 narrowing => ~2.
1063     { ISD::FP_TO_SINT, MVT::v2i32, MVT::v2f64, 2 },
1064     { ISD::FP_TO_SINT, MVT::v2i16, MVT::v2f64, 2 },
1065     { ISD::FP_TO_SINT, MVT::v2i8,  MVT::v2f64, 2 },
1066     { ISD::FP_TO_UINT, MVT::v2i32, MVT::v2f64, 2 },
1067     { ISD::FP_TO_UINT, MVT::v2i16, MVT::v2f64, 2 },
1068     { ISD::FP_TO_UINT, MVT::v2i8,  MVT::v2f64, 2 },
1069 
1070     // Complex, from nxv2f64.
1071     { ISD::FP_TO_SINT, MVT::nxv2i64, MVT::nxv2f64, 1 },
1072     { ISD::FP_TO_SINT, MVT::nxv2i32, MVT::nxv2f64, 1 },
1073     { ISD::FP_TO_SINT, MVT::nxv2i16, MVT::nxv2f64, 1 },
1074     { ISD::FP_TO_SINT, MVT::nxv2i8,  MVT::nxv2f64, 1 },
1075     { ISD::FP_TO_UINT, MVT::nxv2i64, MVT::nxv2f64, 1 },
1076     { ISD::FP_TO_UINT, MVT::nxv2i32, MVT::nxv2f64, 1 },
1077     { ISD::FP_TO_UINT, MVT::nxv2i16, MVT::nxv2f64, 1 },
1078     { ISD::FP_TO_UINT, MVT::nxv2i8,  MVT::nxv2f64, 1 },
1079 
1080     // Complex, from nxv4f32.
1081     { ISD::FP_TO_SINT, MVT::nxv4i64, MVT::nxv4f32, 4 },
1082     { ISD::FP_TO_SINT, MVT::nxv4i32, MVT::nxv4f32, 1 },
1083     { ISD::FP_TO_SINT, MVT::nxv4i16, MVT::nxv4f32, 1 },
1084     { ISD::FP_TO_SINT, MVT::nxv4i8,  MVT::nxv4f32, 1 },
1085     { ISD::FP_TO_UINT, MVT::nxv4i64, MVT::nxv4f32, 4 },
1086     { ISD::FP_TO_UINT, MVT::nxv4i32, MVT::nxv4f32, 1 },
1087     { ISD::FP_TO_UINT, MVT::nxv4i16, MVT::nxv4f32, 1 },
1088     { ISD::FP_TO_UINT, MVT::nxv4i8,  MVT::nxv4f32, 1 },
1089 
1090     // Complex, from nxv8f64. Illegal -> illegal conversions not required.
1091     { ISD::FP_TO_SINT, MVT::nxv8i16, MVT::nxv8f64, 7 },
1092     { ISD::FP_TO_SINT, MVT::nxv8i8,  MVT::nxv8f64, 7 },
1093     { ISD::FP_TO_UINT, MVT::nxv8i16, MVT::nxv8f64, 7 },
1094     { ISD::FP_TO_UINT, MVT::nxv8i8,  MVT::nxv8f64, 7 },
1095 
1096     // Complex, from nxv4f64. Illegal -> illegal conversions not required.
1097     { ISD::FP_TO_SINT, MVT::nxv4i32, MVT::nxv4f64, 3 },
1098     { ISD::FP_TO_SINT, MVT::nxv4i16, MVT::nxv4f64, 3 },
1099     { ISD::FP_TO_SINT, MVT::nxv4i8,  MVT::nxv4f64, 3 },
1100     { ISD::FP_TO_UINT, MVT::nxv4i32, MVT::nxv4f64, 3 },
1101     { ISD::FP_TO_UINT, MVT::nxv4i16, MVT::nxv4f64, 3 },
1102     { ISD::FP_TO_UINT, MVT::nxv4i8,  MVT::nxv4f64, 3 },
1103 
1104     // Complex, from nxv8f32. Illegal -> illegal conversions not required.
1105     { ISD::FP_TO_SINT, MVT::nxv8i16, MVT::nxv8f32, 3 },
1106     { ISD::FP_TO_SINT, MVT::nxv8i8,  MVT::nxv8f32, 3 },
1107     { ISD::FP_TO_UINT, MVT::nxv8i16, MVT::nxv8f32, 3 },
1108     { ISD::FP_TO_UINT, MVT::nxv8i8,  MVT::nxv8f32, 3 },
1109 
1110     // Complex, from nxv8f16.
1111     { ISD::FP_TO_SINT, MVT::nxv8i64, MVT::nxv8f16, 10 },
1112     { ISD::FP_TO_SINT, MVT::nxv8i32, MVT::nxv8f16, 4 },
1113     { ISD::FP_TO_SINT, MVT::nxv8i16, MVT::nxv8f16, 1 },
1114     { ISD::FP_TO_SINT, MVT::nxv8i8,  MVT::nxv8f16, 1 },
1115     { ISD::FP_TO_UINT, MVT::nxv8i64, MVT::nxv8f16, 10 },
1116     { ISD::FP_TO_UINT, MVT::nxv8i32, MVT::nxv8f16, 4 },
1117     { ISD::FP_TO_UINT, MVT::nxv8i16, MVT::nxv8f16, 1 },
1118     { ISD::FP_TO_UINT, MVT::nxv8i8,  MVT::nxv8f16, 1 },
1119 
1120     // Complex, from nxv4f16.
1121     { ISD::FP_TO_SINT, MVT::nxv4i64, MVT::nxv4f16, 4 },
1122     { ISD::FP_TO_SINT, MVT::nxv4i32, MVT::nxv4f16, 1 },
1123     { ISD::FP_TO_SINT, MVT::nxv4i16, MVT::nxv4f16, 1 },
1124     { ISD::FP_TO_SINT, MVT::nxv4i8,  MVT::nxv4f16, 1 },
1125     { ISD::FP_TO_UINT, MVT::nxv4i64, MVT::nxv4f16, 4 },
1126     { ISD::FP_TO_UINT, MVT::nxv4i32, MVT::nxv4f16, 1 },
1127     { ISD::FP_TO_UINT, MVT::nxv4i16, MVT::nxv4f16, 1 },
1128     { ISD::FP_TO_UINT, MVT::nxv4i8,  MVT::nxv4f16, 1 },
1129 
1130     // Complex, from nxv2f16.
1131     { ISD::FP_TO_SINT, MVT::nxv2i64, MVT::nxv2f16, 1 },
1132     { ISD::FP_TO_SINT, MVT::nxv2i32, MVT::nxv2f16, 1 },
1133     { ISD::FP_TO_SINT, MVT::nxv2i16, MVT::nxv2f16, 1 },
1134     { ISD::FP_TO_SINT, MVT::nxv2i8,  MVT::nxv2f16, 1 },
1135     { ISD::FP_TO_UINT, MVT::nxv2i64, MVT::nxv2f16, 1 },
1136     { ISD::FP_TO_UINT, MVT::nxv2i32, MVT::nxv2f16, 1 },
1137     { ISD::FP_TO_UINT, MVT::nxv2i16, MVT::nxv2f16, 1 },
1138     { ISD::FP_TO_UINT, MVT::nxv2i8,  MVT::nxv2f16, 1 },
1139 
1140     // Truncate from nxvmf32 to nxvmf16.
1141     { ISD::FP_ROUND, MVT::nxv2f16, MVT::nxv2f32, 1 },
1142     { ISD::FP_ROUND, MVT::nxv4f16, MVT::nxv4f32, 1 },
1143     { ISD::FP_ROUND, MVT::nxv8f16, MVT::nxv8f32, 3 },
1144 
1145     // Truncate from nxvmf64 to nxvmf16.
1146     { ISD::FP_ROUND, MVT::nxv2f16, MVT::nxv2f64, 1 },
1147     { ISD::FP_ROUND, MVT::nxv4f16, MVT::nxv4f64, 3 },
1148     { ISD::FP_ROUND, MVT::nxv8f16, MVT::nxv8f64, 7 },
1149 
1150     // Truncate from nxvmf64 to nxvmf32.
1151     { ISD::FP_ROUND, MVT::nxv2f32, MVT::nxv2f64, 1 },
1152     { ISD::FP_ROUND, MVT::nxv4f32, MVT::nxv4f64, 3 },
1153     { ISD::FP_ROUND, MVT::nxv8f32, MVT::nxv8f64, 6 },
1154 
1155     // Extend from nxvmf16 to nxvmf32.
1156     { ISD::FP_EXTEND, MVT::nxv2f32, MVT::nxv2f16, 1},
1157     { ISD::FP_EXTEND, MVT::nxv4f32, MVT::nxv4f16, 1},
1158     { ISD::FP_EXTEND, MVT::nxv8f32, MVT::nxv8f16, 2},
1159 
1160     // Extend from nxvmf16 to nxvmf64.
1161     { ISD::FP_EXTEND, MVT::nxv2f64, MVT::nxv2f16, 1},
1162     { ISD::FP_EXTEND, MVT::nxv4f64, MVT::nxv4f16, 2},
1163     { ISD::FP_EXTEND, MVT::nxv8f64, MVT::nxv8f16, 4},
1164 
1165     // Extend from nxvmf32 to nxvmf64.
1166     { ISD::FP_EXTEND, MVT::nxv2f64, MVT::nxv2f32, 1},
1167     { ISD::FP_EXTEND, MVT::nxv4f64, MVT::nxv4f32, 2},
1168     { ISD::FP_EXTEND, MVT::nxv8f64, MVT::nxv8f32, 6},
1169 
1170   };
1171 
1172   if (const auto *Entry = ConvertCostTableLookup(ConversionTbl, ISD,
1173                                                  DstTy.getSimpleVT(),
1174                                                  SrcTy.getSimpleVT()))
1175     return AdjustCost(Entry->Cost);
1176 
1177   return AdjustCost(
1178       BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I));
1179 }
1180 
1181 InstructionCost AArch64TTIImpl::getExtractWithExtendCost(unsigned Opcode,
1182                                                          Type *Dst,
1183                                                          VectorType *VecTy,
1184                                                          unsigned Index) {
1185 
1186   // Make sure we were given a valid extend opcode.
1187   assert((Opcode == Instruction::SExt || Opcode == Instruction::ZExt) &&
1188          "Invalid opcode");
1189 
1190   // We are extending an element we extract from a vector, so the source type
1191   // of the extend is the element type of the vector.
1192   auto *Src = VecTy->getElementType();
1193 
1194   // Sign- and zero-extends are for integer types only.
1195   assert(isa<IntegerType>(Dst) && isa<IntegerType>(Src) && "Invalid type");
1196 
1197   // Get the cost for the extract. We compute the cost (if any) for the extend
1198   // below.
1199   InstructionCost Cost =
1200       getVectorInstrCost(Instruction::ExtractElement, VecTy, Index);
1201 
1202   // Legalize the types.
1203   auto VecLT = TLI->getTypeLegalizationCost(DL, VecTy);
1204   auto DstVT = TLI->getValueType(DL, Dst);
1205   auto SrcVT = TLI->getValueType(DL, Src);
1206   TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
1207 
1208   // If the resulting type is still a vector and the destination type is legal,
1209   // we may get the extension for free. If not, get the default cost for the
1210   // extend.
1211   if (!VecLT.second.isVector() || !TLI->isTypeLegal(DstVT))
1212     return Cost + getCastInstrCost(Opcode, Dst, Src, TTI::CastContextHint::None,
1213                                    CostKind);
1214 
1215   // The destination type should be larger than the element type. If not, get
1216   // the default cost for the extend.
1217   if (DstVT.getFixedSizeInBits() < SrcVT.getFixedSizeInBits())
1218     return Cost + getCastInstrCost(Opcode, Dst, Src, TTI::CastContextHint::None,
1219                                    CostKind);
1220 
1221   switch (Opcode) {
1222   default:
1223     llvm_unreachable("Opcode should be either SExt or ZExt");
1224 
1225   // For sign-extends, we only need a smov, which performs the extension
1226   // automatically.
1227   case Instruction::SExt:
1228     return Cost;
1229 
1230   // For zero-extends, the extend is performed automatically by a umov unless
1231   // the destination type is i64 and the element type is i8 or i16.
1232   case Instruction::ZExt:
1233     if (DstVT.getSizeInBits() != 64u || SrcVT.getSizeInBits() == 32u)
1234       return Cost;
1235   }
1236 
1237   // If we are unable to perform the extend for free, get the default cost.
1238   return Cost + getCastInstrCost(Opcode, Dst, Src, TTI::CastContextHint::None,
1239                                  CostKind);
1240 }
1241 
1242 InstructionCost AArch64TTIImpl::getCFInstrCost(unsigned Opcode,
1243                                                TTI::TargetCostKind CostKind,
1244                                                const Instruction *I) {
1245   if (CostKind != TTI::TCK_RecipThroughput)
1246     return Opcode == Instruction::PHI ? 0 : 1;
1247   assert(CostKind == TTI::TCK_RecipThroughput && "unexpected CostKind");
1248   // Branches are assumed to be predicted.
1249   return 0;
1250 }
1251 
1252 InstructionCost AArch64TTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
1253                                                    unsigned Index) {
1254   assert(Val->isVectorTy() && "This must be a vector type");
1255 
1256   if (Index != -1U) {
1257     // Legalize the type.
1258     std::pair<InstructionCost, MVT> LT = TLI->getTypeLegalizationCost(DL, Val);
1259 
1260     // This type is legalized to a scalar type.
1261     if (!LT.second.isVector())
1262       return 0;
1263 
1264     // The type may be split. Normalize the index to the new type.
1265     unsigned Width = LT.second.getVectorNumElements();
1266     Index = Index % Width;
1267 
1268     // The element at index zero is already inside the vector.
1269     if (Index == 0)
1270       return 0;
1271   }
1272 
1273   // All other insert/extracts cost this much.
1274   return ST->getVectorInsertExtractBaseCost();
1275 }
1276 
1277 InstructionCost AArch64TTIImpl::getArithmeticInstrCost(
1278     unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
1279     TTI::OperandValueKind Opd1Info, TTI::OperandValueKind Opd2Info,
1280     TTI::OperandValueProperties Opd1PropInfo,
1281     TTI::OperandValueProperties Opd2PropInfo, ArrayRef<const Value *> Args,
1282     const Instruction *CxtI) {
1283   // TODO: Handle more cost kinds.
1284   if (CostKind != TTI::TCK_RecipThroughput)
1285     return BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Opd1Info,
1286                                          Opd2Info, Opd1PropInfo,
1287                                          Opd2PropInfo, Args, CxtI);
1288 
1289   // Legalize the type.
1290   std::pair<InstructionCost, MVT> LT = TLI->getTypeLegalizationCost(DL, Ty);
1291 
1292   // If the instruction is a widening instruction (e.g., uaddl, saddw, etc.),
1293   // add in the widening overhead specified by the sub-target. Since the
1294   // extends feeding widening instructions are performed automatically, they
1295   // aren't present in the generated code and have a zero cost. By adding a
1296   // widening overhead here, we attach the total cost of the combined operation
1297   // to the widening instruction.
1298   InstructionCost Cost = 0;
1299   if (isWideningInstruction(Ty, Opcode, Args))
1300     Cost += ST->getWideningBaseCost();
1301 
1302   int ISD = TLI->InstructionOpcodeToISD(Opcode);
1303 
1304   switch (ISD) {
1305   default:
1306     return Cost + BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Opd1Info,
1307                                                 Opd2Info,
1308                                                 Opd1PropInfo, Opd2PropInfo);
1309   case ISD::SDIV:
1310     if (Opd2Info == TargetTransformInfo::OK_UniformConstantValue &&
1311         Opd2PropInfo == TargetTransformInfo::OP_PowerOf2) {
1312       // On AArch64, scalar signed division by constants power-of-two are
1313       // normally expanded to the sequence ADD + CMP + SELECT + SRA.
1314       // The OperandValue properties many not be same as that of previous
1315       // operation; conservatively assume OP_None.
1316       Cost += getArithmeticInstrCost(Instruction::Add, Ty, CostKind,
1317                                      Opd1Info, Opd2Info,
1318                                      TargetTransformInfo::OP_None,
1319                                      TargetTransformInfo::OP_None);
1320       Cost += getArithmeticInstrCost(Instruction::Sub, Ty, CostKind,
1321                                      Opd1Info, Opd2Info,
1322                                      TargetTransformInfo::OP_None,
1323                                      TargetTransformInfo::OP_None);
1324       Cost += getArithmeticInstrCost(Instruction::Select, Ty, CostKind,
1325                                      Opd1Info, Opd2Info,
1326                                      TargetTransformInfo::OP_None,
1327                                      TargetTransformInfo::OP_None);
1328       Cost += getArithmeticInstrCost(Instruction::AShr, Ty, CostKind,
1329                                      Opd1Info, Opd2Info,
1330                                      TargetTransformInfo::OP_None,
1331                                      TargetTransformInfo::OP_None);
1332       return Cost;
1333     }
1334     LLVM_FALLTHROUGH;
1335   case ISD::UDIV:
1336     if (Opd2Info == TargetTransformInfo::OK_UniformConstantValue) {
1337       auto VT = TLI->getValueType(DL, Ty);
1338       if (TLI->isOperationLegalOrCustom(ISD::MULHU, VT)) {
1339         // Vector signed division by constant are expanded to the
1340         // sequence MULHS + ADD/SUB + SRA + SRL + ADD, and unsigned division
1341         // to MULHS + SUB + SRL + ADD + SRL.
1342         InstructionCost MulCost = getArithmeticInstrCost(
1343             Instruction::Mul, Ty, CostKind, Opd1Info, Opd2Info,
1344             TargetTransformInfo::OP_None, TargetTransformInfo::OP_None);
1345         InstructionCost AddCost = getArithmeticInstrCost(
1346             Instruction::Add, Ty, CostKind, Opd1Info, Opd2Info,
1347             TargetTransformInfo::OP_None, TargetTransformInfo::OP_None);
1348         InstructionCost ShrCost = getArithmeticInstrCost(
1349             Instruction::AShr, Ty, CostKind, Opd1Info, Opd2Info,
1350             TargetTransformInfo::OP_None, TargetTransformInfo::OP_None);
1351         return MulCost * 2 + AddCost * 2 + ShrCost * 2 + 1;
1352       }
1353     }
1354 
1355     Cost += BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Opd1Info,
1356                                           Opd2Info,
1357                                           Opd1PropInfo, Opd2PropInfo);
1358     if (Ty->isVectorTy()) {
1359       // On AArch64, vector divisions are not supported natively and are
1360       // expanded into scalar divisions of each pair of elements.
1361       Cost += getArithmeticInstrCost(Instruction::ExtractElement, Ty, CostKind,
1362                                      Opd1Info, Opd2Info, Opd1PropInfo,
1363                                      Opd2PropInfo);
1364       Cost += getArithmeticInstrCost(Instruction::InsertElement, Ty, CostKind,
1365                                      Opd1Info, Opd2Info, Opd1PropInfo,
1366                                      Opd2PropInfo);
1367       // TODO: if one of the arguments is scalar, then it's not necessary to
1368       // double the cost of handling the vector elements.
1369       Cost += Cost;
1370     }
1371     return Cost;
1372 
1373   case ISD::MUL:
1374     if (LT.second != MVT::v2i64)
1375       return (Cost + 1) * LT.first;
1376     // Since we do not have a MUL.2d instruction, a mul <2 x i64> is expensive
1377     // as elements are extracted from the vectors and the muls scalarized.
1378     // As getScalarizationOverhead is a bit too pessimistic, we estimate the
1379     // cost for a i64 vector directly here, which is:
1380     // - four i64 extracts,
1381     // - two i64 inserts, and
1382     // - two muls.
1383     // So, for a v2i64 with LT.First = 1 the cost is 8, and for a v4i64 with
1384     // LT.first = 2 the cost is 16.
1385     return LT.first * 8;
1386   case ISD::ADD:
1387   case ISD::XOR:
1388   case ISD::OR:
1389   case ISD::AND:
1390     // These nodes are marked as 'custom' for combining purposes only.
1391     // We know that they are legal. See LowerAdd in ISelLowering.
1392     return (Cost + 1) * LT.first;
1393 
1394   case ISD::FADD:
1395     // These nodes are marked as 'custom' just to lower them to SVE.
1396     // We know said lowering will incur no additional cost.
1397     if (isa<FixedVectorType>(Ty) && !Ty->getScalarType()->isFP128Ty())
1398       return (Cost + 2) * LT.first;
1399 
1400     return Cost + BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Opd1Info,
1401                                                 Opd2Info,
1402                                                 Opd1PropInfo, Opd2PropInfo);
1403   }
1404 }
1405 
1406 InstructionCost AArch64TTIImpl::getAddressComputationCost(Type *Ty,
1407                                                           ScalarEvolution *SE,
1408                                                           const SCEV *Ptr) {
1409   // Address computations in vectorized code with non-consecutive addresses will
1410   // likely result in more instructions compared to scalar code where the
1411   // computation can more often be merged into the index mode. The resulting
1412   // extra micro-ops can significantly decrease throughput.
1413   unsigned NumVectorInstToHideOverhead = 10;
1414   int MaxMergeDistance = 64;
1415 
1416   if (Ty->isVectorTy() && SE &&
1417       !BaseT::isConstantStridedAccessLessThan(SE, Ptr, MaxMergeDistance + 1))
1418     return NumVectorInstToHideOverhead;
1419 
1420   // In many cases the address computation is not merged into the instruction
1421   // addressing mode.
1422   return 1;
1423 }
1424 
1425 InstructionCost AArch64TTIImpl::getCmpSelInstrCost(unsigned Opcode, Type *ValTy,
1426                                                    Type *CondTy,
1427                                                    CmpInst::Predicate VecPred,
1428                                                    TTI::TargetCostKind CostKind,
1429                                                    const Instruction *I) {
1430   // TODO: Handle other cost kinds.
1431   if (CostKind != TTI::TCK_RecipThroughput)
1432     return BaseT::getCmpSelInstrCost(Opcode, ValTy, CondTy, VecPred, CostKind,
1433                                      I);
1434 
1435   int ISD = TLI->InstructionOpcodeToISD(Opcode);
1436   // We don't lower some vector selects well that are wider than the register
1437   // width.
1438   if (isa<FixedVectorType>(ValTy) && ISD == ISD::SELECT) {
1439     // We would need this many instructions to hide the scalarization happening.
1440     const int AmortizationCost = 20;
1441 
1442     // If VecPred is not set, check if we can get a predicate from the context
1443     // instruction, if its type matches the requested ValTy.
1444     if (VecPred == CmpInst::BAD_ICMP_PREDICATE && I && I->getType() == ValTy) {
1445       CmpInst::Predicate CurrentPred;
1446       if (match(I, m_Select(m_Cmp(CurrentPred, m_Value(), m_Value()), m_Value(),
1447                             m_Value())))
1448         VecPred = CurrentPred;
1449     }
1450     // Check if we have a compare/select chain that can be lowered using CMxx &
1451     // BFI pair.
1452     if (CmpInst::isIntPredicate(VecPred)) {
1453       static const auto ValidMinMaxTys = {MVT::v8i8,  MVT::v16i8, MVT::v4i16,
1454                                           MVT::v8i16, MVT::v2i32, MVT::v4i32,
1455                                           MVT::v2i64};
1456       auto LT = TLI->getTypeLegalizationCost(DL, ValTy);
1457       if (any_of(ValidMinMaxTys, [&LT](MVT M) { return M == LT.second; }))
1458         return LT.first;
1459     }
1460 
1461     static const TypeConversionCostTblEntry
1462     VectorSelectTbl[] = {
1463       { ISD::SELECT, MVT::v16i1, MVT::v16i16, 16 },
1464       { ISD::SELECT, MVT::v8i1, MVT::v8i32, 8 },
1465       { ISD::SELECT, MVT::v16i1, MVT::v16i32, 16 },
1466       { ISD::SELECT, MVT::v4i1, MVT::v4i64, 4 * AmortizationCost },
1467       { ISD::SELECT, MVT::v8i1, MVT::v8i64, 8 * AmortizationCost },
1468       { ISD::SELECT, MVT::v16i1, MVT::v16i64, 16 * AmortizationCost }
1469     };
1470 
1471     EVT SelCondTy = TLI->getValueType(DL, CondTy);
1472     EVT SelValTy = TLI->getValueType(DL, ValTy);
1473     if (SelCondTy.isSimple() && SelValTy.isSimple()) {
1474       if (const auto *Entry = ConvertCostTableLookup(VectorSelectTbl, ISD,
1475                                                      SelCondTy.getSimpleVT(),
1476                                                      SelValTy.getSimpleVT()))
1477         return Entry->Cost;
1478     }
1479   }
1480   // The base case handles scalable vectors fine for now, since it treats the
1481   // cost as 1 * legalization cost.
1482   return BaseT::getCmpSelInstrCost(Opcode, ValTy, CondTy, VecPred, CostKind, I);
1483 }
1484 
1485 AArch64TTIImpl::TTI::MemCmpExpansionOptions
1486 AArch64TTIImpl::enableMemCmpExpansion(bool OptSize, bool IsZeroCmp) const {
1487   TTI::MemCmpExpansionOptions Options;
1488   if (ST->requiresStrictAlign()) {
1489     // TODO: Add cost modeling for strict align. Misaligned loads expand to
1490     // a bunch of instructions when strict align is enabled.
1491     return Options;
1492   }
1493   Options.AllowOverlappingLoads = true;
1494   Options.MaxNumLoads = TLI->getMaxExpandSizeMemcmp(OptSize);
1495   Options.NumLoadsPerBlock = Options.MaxNumLoads;
1496   // TODO: Though vector loads usually perform well on AArch64, in some targets
1497   // they may wake up the FP unit, which raises the power consumption.  Perhaps
1498   // they could be used with no holds barred (-O3).
1499   Options.LoadSizes = {8, 4, 2, 1};
1500   return Options;
1501 }
1502 
1503 InstructionCost
1504 AArch64TTIImpl::getMaskedMemoryOpCost(unsigned Opcode, Type *Src,
1505                                       Align Alignment, unsigned AddressSpace,
1506                                       TTI::TargetCostKind CostKind) {
1507   if (useNeonVector(Src))
1508     return BaseT::getMaskedMemoryOpCost(Opcode, Src, Alignment, AddressSpace,
1509                                         CostKind);
1510   auto LT = TLI->getTypeLegalizationCost(DL, Src);
1511   if (!LT.first.isValid())
1512     return InstructionCost::getInvalid();
1513 
1514   // The code-generator is currently not able to handle scalable vectors
1515   // of <vscale x 1 x eltty> yet, so return an invalid cost to avoid selecting
1516   // it. This change will be removed when code-generation for these types is
1517   // sufficiently reliable.
1518   if (cast<VectorType>(Src)->getElementCount() == ElementCount::getScalable(1))
1519     return InstructionCost::getInvalid();
1520 
1521   return LT.first * 2;
1522 }
1523 
1524 InstructionCost AArch64TTIImpl::getGatherScatterOpCost(
1525     unsigned Opcode, Type *DataTy, const Value *Ptr, bool VariableMask,
1526     Align Alignment, TTI::TargetCostKind CostKind, const Instruction *I) {
1527   if (useNeonVector(DataTy))
1528     return BaseT::getGatherScatterOpCost(Opcode, DataTy, Ptr, VariableMask,
1529                                          Alignment, CostKind, I);
1530   auto *VT = cast<VectorType>(DataTy);
1531   auto LT = TLI->getTypeLegalizationCost(DL, DataTy);
1532   if (!LT.first.isValid())
1533     return InstructionCost::getInvalid();
1534 
1535   // The code-generator is currently not able to handle scalable vectors
1536   // of <vscale x 1 x eltty> yet, so return an invalid cost to avoid selecting
1537   // it. This change will be removed when code-generation for these types is
1538   // sufficiently reliable.
1539   if (cast<VectorType>(DataTy)->getElementCount() ==
1540       ElementCount::getScalable(1))
1541     return InstructionCost::getInvalid();
1542 
1543   ElementCount LegalVF = LT.second.getVectorElementCount();
1544   InstructionCost MemOpCost =
1545       getMemoryOpCost(Opcode, VT->getElementType(), Alignment, 0, CostKind, I);
1546   return LT.first * MemOpCost * getMaxNumElements(LegalVF, I->getFunction());
1547 }
1548 
1549 bool AArch64TTIImpl::useNeonVector(const Type *Ty) const {
1550   return isa<FixedVectorType>(Ty) && !ST->useSVEForFixedLengthVectors();
1551 }
1552 
1553 InstructionCost AArch64TTIImpl::getMemoryOpCost(unsigned Opcode, Type *Ty,
1554                                                 MaybeAlign Alignment,
1555                                                 unsigned AddressSpace,
1556                                                 TTI::TargetCostKind CostKind,
1557                                                 const Instruction *I) {
1558   EVT VT = TLI->getValueType(DL, Ty, true);
1559   // Type legalization can't handle structs
1560   if (VT == MVT::Other)
1561     return BaseT::getMemoryOpCost(Opcode, Ty, Alignment, AddressSpace,
1562                                   CostKind);
1563 
1564   auto LT = TLI->getTypeLegalizationCost(DL, Ty);
1565   if (!LT.first.isValid())
1566     return InstructionCost::getInvalid();
1567 
1568   // The code-generator is currently not able to handle scalable vectors
1569   // of <vscale x 1 x eltty> yet, so return an invalid cost to avoid selecting
1570   // it. This change will be removed when code-generation for these types is
1571   // sufficiently reliable.
1572   if (auto *VTy = dyn_cast<ScalableVectorType>(Ty))
1573     if (VTy->getElementCount() == ElementCount::getScalable(1))
1574       return InstructionCost::getInvalid();
1575 
1576   // TODO: consider latency as well for TCK_SizeAndLatency.
1577   if (CostKind == TTI::TCK_CodeSize || CostKind == TTI::TCK_SizeAndLatency)
1578     return LT.first;
1579 
1580   if (CostKind != TTI::TCK_RecipThroughput)
1581     return 1;
1582 
1583   if (ST->isMisaligned128StoreSlow() && Opcode == Instruction::Store &&
1584       LT.second.is128BitVector() && (!Alignment || *Alignment < Align(16))) {
1585     // Unaligned stores are extremely inefficient. We don't split all
1586     // unaligned 128-bit stores because the negative impact that has shown in
1587     // practice on inlined block copy code.
1588     // We make such stores expensive so that we will only vectorize if there
1589     // are 6 other instructions getting vectorized.
1590     const int AmortizationCost = 6;
1591 
1592     return LT.first * 2 * AmortizationCost;
1593   }
1594 
1595   // Check truncating stores and extending loads.
1596   if (useNeonVector(Ty) &&
1597       Ty->getScalarSizeInBits() != LT.second.getScalarSizeInBits()) {
1598     // v4i8 types are lowered to scalar a load/store and sshll/xtn.
1599     if (VT == MVT::v4i8)
1600       return 2;
1601     // Otherwise we need to scalarize.
1602     return cast<FixedVectorType>(Ty)->getNumElements() * 2;
1603   }
1604 
1605   return LT.first;
1606 }
1607 
1608 InstructionCost AArch64TTIImpl::getInterleavedMemoryOpCost(
1609     unsigned Opcode, Type *VecTy, unsigned Factor, ArrayRef<unsigned> Indices,
1610     Align Alignment, unsigned AddressSpace, TTI::TargetCostKind CostKind,
1611     bool UseMaskForCond, bool UseMaskForGaps) {
1612   assert(Factor >= 2 && "Invalid interleave factor");
1613   auto *VecVTy = cast<FixedVectorType>(VecTy);
1614 
1615   if (!UseMaskForCond && !UseMaskForGaps &&
1616       Factor <= TLI->getMaxSupportedInterleaveFactor()) {
1617     unsigned NumElts = VecVTy->getNumElements();
1618     auto *SubVecTy =
1619         FixedVectorType::get(VecTy->getScalarType(), NumElts / Factor);
1620 
1621     // ldN/stN only support legal vector types of size 64 or 128 in bits.
1622     // Accesses having vector types that are a multiple of 128 bits can be
1623     // matched to more than one ldN/stN instruction.
1624     if (NumElts % Factor == 0 &&
1625         TLI->isLegalInterleavedAccessType(SubVecTy, DL))
1626       return Factor * TLI->getNumInterleavedAccesses(SubVecTy, DL);
1627   }
1628 
1629   return BaseT::getInterleavedMemoryOpCost(Opcode, VecTy, Factor, Indices,
1630                                            Alignment, AddressSpace, CostKind,
1631                                            UseMaskForCond, UseMaskForGaps);
1632 }
1633 
1634 InstructionCost
1635 AArch64TTIImpl::getCostOfKeepingLiveOverCall(ArrayRef<Type *> Tys) {
1636   InstructionCost Cost = 0;
1637   TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
1638   for (auto *I : Tys) {
1639     if (!I->isVectorTy())
1640       continue;
1641     if (I->getScalarSizeInBits() * cast<FixedVectorType>(I)->getNumElements() ==
1642         128)
1643       Cost += getMemoryOpCost(Instruction::Store, I, Align(128), 0, CostKind) +
1644               getMemoryOpCost(Instruction::Load, I, Align(128), 0, CostKind);
1645   }
1646   return Cost;
1647 }
1648 
1649 unsigned AArch64TTIImpl::getMaxInterleaveFactor(unsigned VF) {
1650   return ST->getMaxInterleaveFactor();
1651 }
1652 
1653 // For Falkor, we want to avoid having too many strided loads in a loop since
1654 // that can exhaust the HW prefetcher resources.  We adjust the unroller
1655 // MaxCount preference below to attempt to ensure unrolling doesn't create too
1656 // many strided loads.
1657 static void
1658 getFalkorUnrollingPreferences(Loop *L, ScalarEvolution &SE,
1659                               TargetTransformInfo::UnrollingPreferences &UP) {
1660   enum { MaxStridedLoads = 7 };
1661   auto countStridedLoads = [](Loop *L, ScalarEvolution &SE) {
1662     int StridedLoads = 0;
1663     // FIXME? We could make this more precise by looking at the CFG and
1664     // e.g. not counting loads in each side of an if-then-else diamond.
1665     for (const auto BB : L->blocks()) {
1666       for (auto &I : *BB) {
1667         LoadInst *LMemI = dyn_cast<LoadInst>(&I);
1668         if (!LMemI)
1669           continue;
1670 
1671         Value *PtrValue = LMemI->getPointerOperand();
1672         if (L->isLoopInvariant(PtrValue))
1673           continue;
1674 
1675         const SCEV *LSCEV = SE.getSCEV(PtrValue);
1676         const SCEVAddRecExpr *LSCEVAddRec = dyn_cast<SCEVAddRecExpr>(LSCEV);
1677         if (!LSCEVAddRec || !LSCEVAddRec->isAffine())
1678           continue;
1679 
1680         // FIXME? We could take pairing of unrolled load copies into account
1681         // by looking at the AddRec, but we would probably have to limit this
1682         // to loops with no stores or other memory optimization barriers.
1683         ++StridedLoads;
1684         // We've seen enough strided loads that seeing more won't make a
1685         // difference.
1686         if (StridedLoads > MaxStridedLoads / 2)
1687           return StridedLoads;
1688       }
1689     }
1690     return StridedLoads;
1691   };
1692 
1693   int StridedLoads = countStridedLoads(L, SE);
1694   LLVM_DEBUG(dbgs() << "falkor-hwpf: detected " << StridedLoads
1695                     << " strided loads\n");
1696   // Pick the largest power of 2 unroll count that won't result in too many
1697   // strided loads.
1698   if (StridedLoads) {
1699     UP.MaxCount = 1 << Log2_32(MaxStridedLoads / StridedLoads);
1700     LLVM_DEBUG(dbgs() << "falkor-hwpf: setting unroll MaxCount to "
1701                       << UP.MaxCount << '\n');
1702   }
1703 }
1704 
1705 void AArch64TTIImpl::getUnrollingPreferences(Loop *L, ScalarEvolution &SE,
1706                                              TTI::UnrollingPreferences &UP,
1707                                              OptimizationRemarkEmitter *ORE) {
1708   // Enable partial unrolling and runtime unrolling.
1709   BaseT::getUnrollingPreferences(L, SE, UP, ORE);
1710 
1711   UP.UpperBound = true;
1712 
1713   // For inner loop, it is more likely to be a hot one, and the runtime check
1714   // can be promoted out from LICM pass, so the overhead is less, let's try
1715   // a larger threshold to unroll more loops.
1716   if (L->getLoopDepth() > 1)
1717     UP.PartialThreshold *= 2;
1718 
1719   // Disable partial & runtime unrolling on -Os.
1720   UP.PartialOptSizeThreshold = 0;
1721 
1722   if (ST->getProcFamily() == AArch64Subtarget::Falkor &&
1723       EnableFalkorHWPFUnrollFix)
1724     getFalkorUnrollingPreferences(L, SE, UP);
1725 
1726   // Scan the loop: don't unroll loops with calls as this could prevent
1727   // inlining. Don't unroll vector loops either, as they don't benefit much from
1728   // unrolling.
1729   for (auto *BB : L->getBlocks()) {
1730     for (auto &I : *BB) {
1731       // Don't unroll vectorised loop.
1732       if (I.getType()->isVectorTy())
1733         return;
1734 
1735       if (isa<CallInst>(I) || isa<InvokeInst>(I)) {
1736         if (const Function *F = cast<CallBase>(I).getCalledFunction()) {
1737           if (!isLoweredToCall(F))
1738             continue;
1739         }
1740         return;
1741       }
1742     }
1743   }
1744 
1745   // Enable runtime unrolling for in-order models
1746   // If mcpu is omitted, getProcFamily() returns AArch64Subtarget::Others, so by
1747   // checking for that case, we can ensure that the default behaviour is
1748   // unchanged
1749   if (ST->getProcFamily() != AArch64Subtarget::Others &&
1750       !ST->getSchedModel().isOutOfOrder()) {
1751     UP.Runtime = true;
1752     UP.Partial = true;
1753     UP.UnrollRemainder = true;
1754     UP.DefaultUnrollRuntimeCount = 4;
1755 
1756     UP.UnrollAndJam = true;
1757     UP.UnrollAndJamInnerLoopThreshold = 60;
1758   }
1759 }
1760 
1761 void AArch64TTIImpl::getPeelingPreferences(Loop *L, ScalarEvolution &SE,
1762                                            TTI::PeelingPreferences &PP) {
1763   BaseT::getPeelingPreferences(L, SE, PP);
1764 }
1765 
1766 Value *AArch64TTIImpl::getOrCreateResultFromMemIntrinsic(IntrinsicInst *Inst,
1767                                                          Type *ExpectedType) {
1768   switch (Inst->getIntrinsicID()) {
1769   default:
1770     return nullptr;
1771   case Intrinsic::aarch64_neon_st2:
1772   case Intrinsic::aarch64_neon_st3:
1773   case Intrinsic::aarch64_neon_st4: {
1774     // Create a struct type
1775     StructType *ST = dyn_cast<StructType>(ExpectedType);
1776     if (!ST)
1777       return nullptr;
1778     unsigned NumElts = Inst->getNumArgOperands() - 1;
1779     if (ST->getNumElements() != NumElts)
1780       return nullptr;
1781     for (unsigned i = 0, e = NumElts; i != e; ++i) {
1782       if (Inst->getArgOperand(i)->getType() != ST->getElementType(i))
1783         return nullptr;
1784     }
1785     Value *Res = UndefValue::get(ExpectedType);
1786     IRBuilder<> Builder(Inst);
1787     for (unsigned i = 0, e = NumElts; i != e; ++i) {
1788       Value *L = Inst->getArgOperand(i);
1789       Res = Builder.CreateInsertValue(Res, L, i);
1790     }
1791     return Res;
1792   }
1793   case Intrinsic::aarch64_neon_ld2:
1794   case Intrinsic::aarch64_neon_ld3:
1795   case Intrinsic::aarch64_neon_ld4:
1796     if (Inst->getType() == ExpectedType)
1797       return Inst;
1798     return nullptr;
1799   }
1800 }
1801 
1802 bool AArch64TTIImpl::getTgtMemIntrinsic(IntrinsicInst *Inst,
1803                                         MemIntrinsicInfo &Info) {
1804   switch (Inst->getIntrinsicID()) {
1805   default:
1806     break;
1807   case Intrinsic::aarch64_neon_ld2:
1808   case Intrinsic::aarch64_neon_ld3:
1809   case Intrinsic::aarch64_neon_ld4:
1810     Info.ReadMem = true;
1811     Info.WriteMem = false;
1812     Info.PtrVal = Inst->getArgOperand(0);
1813     break;
1814   case Intrinsic::aarch64_neon_st2:
1815   case Intrinsic::aarch64_neon_st3:
1816   case Intrinsic::aarch64_neon_st4:
1817     Info.ReadMem = false;
1818     Info.WriteMem = true;
1819     Info.PtrVal = Inst->getArgOperand(Inst->getNumArgOperands() - 1);
1820     break;
1821   }
1822 
1823   switch (Inst->getIntrinsicID()) {
1824   default:
1825     return false;
1826   case Intrinsic::aarch64_neon_ld2:
1827   case Intrinsic::aarch64_neon_st2:
1828     Info.MatchingId = VECTOR_LDST_TWO_ELEMENTS;
1829     break;
1830   case Intrinsic::aarch64_neon_ld3:
1831   case Intrinsic::aarch64_neon_st3:
1832     Info.MatchingId = VECTOR_LDST_THREE_ELEMENTS;
1833     break;
1834   case Intrinsic::aarch64_neon_ld4:
1835   case Intrinsic::aarch64_neon_st4:
1836     Info.MatchingId = VECTOR_LDST_FOUR_ELEMENTS;
1837     break;
1838   }
1839   return true;
1840 }
1841 
1842 /// See if \p I should be considered for address type promotion. We check if \p
1843 /// I is a sext with right type and used in memory accesses. If it used in a
1844 /// "complex" getelementptr, we allow it to be promoted without finding other
1845 /// sext instructions that sign extended the same initial value. A getelementptr
1846 /// is considered as "complex" if it has more than 2 operands.
1847 bool AArch64TTIImpl::shouldConsiderAddressTypePromotion(
1848     const Instruction &I, bool &AllowPromotionWithoutCommonHeader) {
1849   bool Considerable = false;
1850   AllowPromotionWithoutCommonHeader = false;
1851   if (!isa<SExtInst>(&I))
1852     return false;
1853   Type *ConsideredSExtType =
1854       Type::getInt64Ty(I.getParent()->getParent()->getContext());
1855   if (I.getType() != ConsideredSExtType)
1856     return false;
1857   // See if the sext is the one with the right type and used in at least one
1858   // GetElementPtrInst.
1859   for (const User *U : I.users()) {
1860     if (const GetElementPtrInst *GEPInst = dyn_cast<GetElementPtrInst>(U)) {
1861       Considerable = true;
1862       // A getelementptr is considered as "complex" if it has more than 2
1863       // operands. We will promote a SExt used in such complex GEP as we
1864       // expect some computation to be merged if they are done on 64 bits.
1865       if (GEPInst->getNumOperands() > 2) {
1866         AllowPromotionWithoutCommonHeader = true;
1867         break;
1868       }
1869     }
1870   }
1871   return Considerable;
1872 }
1873 
1874 bool AArch64TTIImpl::isLegalToVectorizeReduction(
1875     const RecurrenceDescriptor &RdxDesc, ElementCount VF) const {
1876   if (!VF.isScalable())
1877     return true;
1878 
1879   Type *Ty = RdxDesc.getRecurrenceType();
1880   if (Ty->isBFloatTy() || !isElementTypeLegalForScalableVector(Ty))
1881     return false;
1882 
1883   switch (RdxDesc.getRecurrenceKind()) {
1884   case RecurKind::Add:
1885   case RecurKind::FAdd:
1886   case RecurKind::And:
1887   case RecurKind::Or:
1888   case RecurKind::Xor:
1889   case RecurKind::SMin:
1890   case RecurKind::SMax:
1891   case RecurKind::UMin:
1892   case RecurKind::UMax:
1893   case RecurKind::FMin:
1894   case RecurKind::FMax:
1895     return true;
1896   default:
1897     return false;
1898   }
1899 }
1900 
1901 InstructionCost
1902 AArch64TTIImpl::getMinMaxReductionCost(VectorType *Ty, VectorType *CondTy,
1903                                        bool IsUnsigned,
1904                                        TTI::TargetCostKind CostKind) {
1905   std::pair<InstructionCost, MVT> LT = TLI->getTypeLegalizationCost(DL, Ty);
1906 
1907   if (LT.second.getScalarType() == MVT::f16 && !ST->hasFullFP16())
1908     return BaseT::getMinMaxReductionCost(Ty, CondTy, IsUnsigned, CostKind);
1909 
1910   assert((isa<ScalableVectorType>(Ty) == isa<ScalableVectorType>(CondTy)) &&
1911          "Both vector needs to be equally scalable");
1912 
1913   InstructionCost LegalizationCost = 0;
1914   if (LT.first > 1) {
1915     Type *LegalVTy = EVT(LT.second).getTypeForEVT(Ty->getContext());
1916     unsigned MinMaxOpcode =
1917         Ty->isFPOrFPVectorTy()
1918             ? Intrinsic::maxnum
1919             : (IsUnsigned ? Intrinsic::umin : Intrinsic::smin);
1920     IntrinsicCostAttributes Attrs(MinMaxOpcode, LegalVTy, {LegalVTy, LegalVTy});
1921     LegalizationCost = getIntrinsicInstrCost(Attrs, CostKind) * (LT.first - 1);
1922   }
1923 
1924   return LegalizationCost + /*Cost of horizontal reduction*/ 2;
1925 }
1926 
1927 InstructionCost AArch64TTIImpl::getArithmeticReductionCostSVE(
1928     unsigned Opcode, VectorType *ValTy, TTI::TargetCostKind CostKind) {
1929   std::pair<InstructionCost, MVT> LT = TLI->getTypeLegalizationCost(DL, ValTy);
1930   InstructionCost LegalizationCost = 0;
1931   if (LT.first > 1) {
1932     Type *LegalVTy = EVT(LT.second).getTypeForEVT(ValTy->getContext());
1933     LegalizationCost = getArithmeticInstrCost(Opcode, LegalVTy, CostKind);
1934     LegalizationCost *= LT.first - 1;
1935   }
1936 
1937   int ISD = TLI->InstructionOpcodeToISD(Opcode);
1938   assert(ISD && "Invalid opcode");
1939   // Add the final reduction cost for the legal horizontal reduction
1940   switch (ISD) {
1941   case ISD::ADD:
1942   case ISD::AND:
1943   case ISD::OR:
1944   case ISD::XOR:
1945   case ISD::FADD:
1946     return LegalizationCost + 2;
1947   default:
1948     return InstructionCost::getInvalid();
1949   }
1950 }
1951 
1952 InstructionCost
1953 AArch64TTIImpl::getArithmeticReductionCost(unsigned Opcode, VectorType *ValTy,
1954                                            Optional<FastMathFlags> FMF,
1955                                            TTI::TargetCostKind CostKind) {
1956   if (TTI::requiresOrderedReduction(FMF)) {
1957     if (auto *FixedVTy = dyn_cast<FixedVectorType>(ValTy)) {
1958       InstructionCost BaseCost =
1959           BaseT::getArithmeticReductionCost(Opcode, ValTy, FMF, CostKind);
1960       // Add on extra cost to reflect the extra overhead on some CPUs. We still
1961       // end up vectorizing for more computationally intensive loops.
1962       return BaseCost + FixedVTy->getNumElements();
1963     }
1964 
1965     if (Opcode != Instruction::FAdd)
1966       return InstructionCost::getInvalid();
1967 
1968     auto *VTy = cast<ScalableVectorType>(ValTy);
1969     InstructionCost Cost =
1970         getArithmeticInstrCost(Opcode, VTy->getScalarType(), CostKind);
1971     Cost *= getMaxNumElements(VTy->getElementCount());
1972     return Cost;
1973   }
1974 
1975   if (isa<ScalableVectorType>(ValTy))
1976     return getArithmeticReductionCostSVE(Opcode, ValTy, CostKind);
1977 
1978   std::pair<InstructionCost, MVT> LT = TLI->getTypeLegalizationCost(DL, ValTy);
1979   MVT MTy = LT.second;
1980   int ISD = TLI->InstructionOpcodeToISD(Opcode);
1981   assert(ISD && "Invalid opcode");
1982 
1983   // Horizontal adds can use the 'addv' instruction. We model the cost of these
1984   // instructions as twice a normal vector add, plus 1 for each legalization
1985   // step (LT.first). This is the only arithmetic vector reduction operation for
1986   // which we have an instruction.
1987   // OR, XOR and AND costs should match the codegen from:
1988   // OR: llvm/test/CodeGen/AArch64/reduce-or.ll
1989   // XOR: llvm/test/CodeGen/AArch64/reduce-xor.ll
1990   // AND: llvm/test/CodeGen/AArch64/reduce-and.ll
1991   static const CostTblEntry CostTblNoPairwise[]{
1992       {ISD::ADD, MVT::v8i8,   2},
1993       {ISD::ADD, MVT::v16i8,  2},
1994       {ISD::ADD, MVT::v4i16,  2},
1995       {ISD::ADD, MVT::v8i16,  2},
1996       {ISD::ADD, MVT::v4i32,  2},
1997       {ISD::OR,  MVT::v8i8,  15},
1998       {ISD::OR,  MVT::v16i8, 17},
1999       {ISD::OR,  MVT::v4i16,  7},
2000       {ISD::OR,  MVT::v8i16,  9},
2001       {ISD::OR,  MVT::v2i32,  3},
2002       {ISD::OR,  MVT::v4i32,  5},
2003       {ISD::OR,  MVT::v2i64,  3},
2004       {ISD::XOR, MVT::v8i8,  15},
2005       {ISD::XOR, MVT::v16i8, 17},
2006       {ISD::XOR, MVT::v4i16,  7},
2007       {ISD::XOR, MVT::v8i16,  9},
2008       {ISD::XOR, MVT::v2i32,  3},
2009       {ISD::XOR, MVT::v4i32,  5},
2010       {ISD::XOR, MVT::v2i64,  3},
2011       {ISD::AND, MVT::v8i8,  15},
2012       {ISD::AND, MVT::v16i8, 17},
2013       {ISD::AND, MVT::v4i16,  7},
2014       {ISD::AND, MVT::v8i16,  9},
2015       {ISD::AND, MVT::v2i32,  3},
2016       {ISD::AND, MVT::v4i32,  5},
2017       {ISD::AND, MVT::v2i64,  3},
2018   };
2019   switch (ISD) {
2020   default:
2021     break;
2022   case ISD::ADD:
2023     if (const auto *Entry = CostTableLookup(CostTblNoPairwise, ISD, MTy))
2024       return (LT.first - 1) + Entry->Cost;
2025     break;
2026   case ISD::XOR:
2027   case ISD::AND:
2028   case ISD::OR:
2029     const auto *Entry = CostTableLookup(CostTblNoPairwise, ISD, MTy);
2030     if (!Entry)
2031       break;
2032     auto *ValVTy = cast<FixedVectorType>(ValTy);
2033     if (!ValVTy->getElementType()->isIntegerTy(1) &&
2034         MTy.getVectorNumElements() <= ValVTy->getNumElements() &&
2035         isPowerOf2_32(ValVTy->getNumElements())) {
2036       InstructionCost ExtraCost = 0;
2037       if (LT.first != 1) {
2038         // Type needs to be split, so there is an extra cost of LT.first - 1
2039         // arithmetic ops.
2040         auto *Ty = FixedVectorType::get(ValTy->getElementType(),
2041                                         MTy.getVectorNumElements());
2042         ExtraCost = getArithmeticInstrCost(Opcode, Ty, CostKind);
2043         ExtraCost *= LT.first - 1;
2044       }
2045       return Entry->Cost + ExtraCost;
2046     }
2047     break;
2048   }
2049   return BaseT::getArithmeticReductionCost(Opcode, ValTy, FMF, CostKind);
2050 }
2051 
2052 InstructionCost AArch64TTIImpl::getSpliceCost(VectorType *Tp, int Index) {
2053   static const CostTblEntry ShuffleTbl[] = {
2054       { TTI::SK_Splice, MVT::nxv16i8,  1 },
2055       { TTI::SK_Splice, MVT::nxv8i16,  1 },
2056       { TTI::SK_Splice, MVT::nxv4i32,  1 },
2057       { TTI::SK_Splice, MVT::nxv2i64,  1 },
2058       { TTI::SK_Splice, MVT::nxv2f16,  1 },
2059       { TTI::SK_Splice, MVT::nxv4f16,  1 },
2060       { TTI::SK_Splice, MVT::nxv8f16,  1 },
2061       { TTI::SK_Splice, MVT::nxv2bf16, 1 },
2062       { TTI::SK_Splice, MVT::nxv4bf16, 1 },
2063       { TTI::SK_Splice, MVT::nxv8bf16, 1 },
2064       { TTI::SK_Splice, MVT::nxv2f32,  1 },
2065       { TTI::SK_Splice, MVT::nxv4f32,  1 },
2066       { TTI::SK_Splice, MVT::nxv2f64,  1 },
2067   };
2068 
2069   std::pair<InstructionCost, MVT> LT = TLI->getTypeLegalizationCost(DL, Tp);
2070   Type *LegalVTy = EVT(LT.second).getTypeForEVT(Tp->getContext());
2071   TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
2072   EVT PromotedVT = LT.second.getScalarType() == MVT::i1
2073                        ? TLI->getPromotedVTForPredicate(EVT(LT.second))
2074                        : LT.second;
2075   Type *PromotedVTy = EVT(PromotedVT).getTypeForEVT(Tp->getContext());
2076   InstructionCost LegalizationCost = 0;
2077   if (Index < 0) {
2078     LegalizationCost =
2079         getCmpSelInstrCost(Instruction::ICmp, PromotedVTy, PromotedVTy,
2080                            CmpInst::BAD_ICMP_PREDICATE, CostKind) +
2081         getCmpSelInstrCost(Instruction::Select, PromotedVTy, LegalVTy,
2082                            CmpInst::BAD_ICMP_PREDICATE, CostKind);
2083   }
2084 
2085   // Predicated splice are promoted when lowering. See AArch64ISelLowering.cpp
2086   // Cost performed on a promoted type.
2087   if (LT.second.getScalarType() == MVT::i1) {
2088     LegalizationCost +=
2089         getCastInstrCost(Instruction::ZExt, PromotedVTy, LegalVTy,
2090                          TTI::CastContextHint::None, CostKind) +
2091         getCastInstrCost(Instruction::Trunc, LegalVTy, PromotedVTy,
2092                          TTI::CastContextHint::None, CostKind);
2093   }
2094   const auto *Entry =
2095       CostTableLookup(ShuffleTbl, TTI::SK_Splice, PromotedVT.getSimpleVT());
2096   assert(Entry && "Illegal Type for Splice");
2097   LegalizationCost += Entry->Cost;
2098   return LegalizationCost * LT.first;
2099 }
2100 
2101 InstructionCost AArch64TTIImpl::getShuffleCost(TTI::ShuffleKind Kind,
2102                                                VectorType *Tp,
2103                                                ArrayRef<int> Mask, int Index,
2104                                                VectorType *SubTp) {
2105   Kind = improveShuffleKindFromMask(Kind, Mask);
2106   if (Kind == TTI::SK_Broadcast || Kind == TTI::SK_Transpose ||
2107       Kind == TTI::SK_Select || Kind == TTI::SK_PermuteSingleSrc ||
2108       Kind == TTI::SK_Reverse) {
2109     static const CostTblEntry ShuffleTbl[] = {
2110       // Broadcast shuffle kinds can be performed with 'dup'.
2111       { TTI::SK_Broadcast, MVT::v8i8,  1 },
2112       { TTI::SK_Broadcast, MVT::v16i8, 1 },
2113       { TTI::SK_Broadcast, MVT::v4i16, 1 },
2114       { TTI::SK_Broadcast, MVT::v8i16, 1 },
2115       { TTI::SK_Broadcast, MVT::v2i32, 1 },
2116       { TTI::SK_Broadcast, MVT::v4i32, 1 },
2117       { TTI::SK_Broadcast, MVT::v2i64, 1 },
2118       { TTI::SK_Broadcast, MVT::v2f32, 1 },
2119       { TTI::SK_Broadcast, MVT::v4f32, 1 },
2120       { TTI::SK_Broadcast, MVT::v2f64, 1 },
2121       // Transpose shuffle kinds can be performed with 'trn1/trn2' and
2122       // 'zip1/zip2' instructions.
2123       { TTI::SK_Transpose, MVT::v8i8,  1 },
2124       { TTI::SK_Transpose, MVT::v16i8, 1 },
2125       { TTI::SK_Transpose, MVT::v4i16, 1 },
2126       { TTI::SK_Transpose, MVT::v8i16, 1 },
2127       { TTI::SK_Transpose, MVT::v2i32, 1 },
2128       { TTI::SK_Transpose, MVT::v4i32, 1 },
2129       { TTI::SK_Transpose, MVT::v2i64, 1 },
2130       { TTI::SK_Transpose, MVT::v2f32, 1 },
2131       { TTI::SK_Transpose, MVT::v4f32, 1 },
2132       { TTI::SK_Transpose, MVT::v2f64, 1 },
2133       // Select shuffle kinds.
2134       // TODO: handle vXi8/vXi16.
2135       { TTI::SK_Select, MVT::v2i32, 1 }, // mov.
2136       { TTI::SK_Select, MVT::v4i32, 2 }, // rev+trn (or similar).
2137       { TTI::SK_Select, MVT::v2i64, 1 }, // mov.
2138       { TTI::SK_Select, MVT::v2f32, 1 }, // mov.
2139       { TTI::SK_Select, MVT::v4f32, 2 }, // rev+trn (or similar).
2140       { TTI::SK_Select, MVT::v2f64, 1 }, // mov.
2141       // PermuteSingleSrc shuffle kinds.
2142       { TTI::SK_PermuteSingleSrc, MVT::v2i32, 1 }, // mov.
2143       { TTI::SK_PermuteSingleSrc, MVT::v4i32, 3 }, // perfectshuffle worst case.
2144       { TTI::SK_PermuteSingleSrc, MVT::v2i64, 1 }, // mov.
2145       { TTI::SK_PermuteSingleSrc, MVT::v2f32, 1 }, // mov.
2146       { TTI::SK_PermuteSingleSrc, MVT::v4f32, 3 }, // perfectshuffle worst case.
2147       { TTI::SK_PermuteSingleSrc, MVT::v2f64, 1 }, // mov.
2148       { TTI::SK_PermuteSingleSrc, MVT::v4i16, 3 }, // perfectshuffle worst case.
2149       { TTI::SK_PermuteSingleSrc, MVT::v4f16, 3 }, // perfectshuffle worst case.
2150       { TTI::SK_PermuteSingleSrc, MVT::v4bf16, 3 }, // perfectshuffle worst case.
2151       { TTI::SK_PermuteSingleSrc, MVT::v8i16, 8 }, // constpool + load + tbl
2152       { TTI::SK_PermuteSingleSrc, MVT::v8f16, 8 }, // constpool + load + tbl
2153       { TTI::SK_PermuteSingleSrc, MVT::v8bf16, 8 }, // constpool + load + tbl
2154       { TTI::SK_PermuteSingleSrc, MVT::v8i8, 8 }, // constpool + load + tbl
2155       { TTI::SK_PermuteSingleSrc, MVT::v16i8, 8 }, // constpool + load + tbl
2156       // Reverse can be lowered with `rev`.
2157       { TTI::SK_Reverse, MVT::v2i32, 1 }, // mov.
2158       { TTI::SK_Reverse, MVT::v4i32, 2 }, // REV64; EXT
2159       { TTI::SK_Reverse, MVT::v2i64, 1 }, // mov.
2160       { TTI::SK_Reverse, MVT::v2f32, 1 }, // mov.
2161       { TTI::SK_Reverse, MVT::v4f32, 2 }, // REV64; EXT
2162       { TTI::SK_Reverse, MVT::v2f64, 1 }, // mov.
2163       // Broadcast shuffle kinds for scalable vectors
2164       { TTI::SK_Broadcast, MVT::nxv16i8,  1 },
2165       { TTI::SK_Broadcast, MVT::nxv8i16,  1 },
2166       { TTI::SK_Broadcast, MVT::nxv4i32,  1 },
2167       { TTI::SK_Broadcast, MVT::nxv2i64,  1 },
2168       { TTI::SK_Broadcast, MVT::nxv2f16,  1 },
2169       { TTI::SK_Broadcast, MVT::nxv4f16,  1 },
2170       { TTI::SK_Broadcast, MVT::nxv8f16,  1 },
2171       { TTI::SK_Broadcast, MVT::nxv2bf16, 1 },
2172       { TTI::SK_Broadcast, MVT::nxv4bf16, 1 },
2173       { TTI::SK_Broadcast, MVT::nxv8bf16, 1 },
2174       { TTI::SK_Broadcast, MVT::nxv2f32,  1 },
2175       { TTI::SK_Broadcast, MVT::nxv4f32,  1 },
2176       { TTI::SK_Broadcast, MVT::nxv2f64,  1 },
2177       { TTI::SK_Broadcast, MVT::nxv16i1,  1 },
2178       { TTI::SK_Broadcast, MVT::nxv8i1,   1 },
2179       { TTI::SK_Broadcast, MVT::nxv4i1,   1 },
2180       { TTI::SK_Broadcast, MVT::nxv2i1,   1 },
2181       // Handle the cases for vector.reverse with scalable vectors
2182       { TTI::SK_Reverse, MVT::nxv16i8,  1 },
2183       { TTI::SK_Reverse, MVT::nxv8i16,  1 },
2184       { TTI::SK_Reverse, MVT::nxv4i32,  1 },
2185       { TTI::SK_Reverse, MVT::nxv2i64,  1 },
2186       { TTI::SK_Reverse, MVT::nxv2f16,  1 },
2187       { TTI::SK_Reverse, MVT::nxv4f16,  1 },
2188       { TTI::SK_Reverse, MVT::nxv8f16,  1 },
2189       { TTI::SK_Reverse, MVT::nxv2bf16, 1 },
2190       { TTI::SK_Reverse, MVT::nxv4bf16, 1 },
2191       { TTI::SK_Reverse, MVT::nxv8bf16, 1 },
2192       { TTI::SK_Reverse, MVT::nxv2f32,  1 },
2193       { TTI::SK_Reverse, MVT::nxv4f32,  1 },
2194       { TTI::SK_Reverse, MVT::nxv2f64,  1 },
2195       { TTI::SK_Reverse, MVT::nxv16i1,  1 },
2196       { TTI::SK_Reverse, MVT::nxv8i1,   1 },
2197       { TTI::SK_Reverse, MVT::nxv4i1,   1 },
2198       { TTI::SK_Reverse, MVT::nxv2i1,   1 },
2199     };
2200     std::pair<InstructionCost, MVT> LT = TLI->getTypeLegalizationCost(DL, Tp);
2201     if (const auto *Entry = CostTableLookup(ShuffleTbl, Kind, LT.second))
2202       return LT.first * Entry->Cost;
2203   }
2204   if (Kind == TTI::SK_Splice && isa<ScalableVectorType>(Tp))
2205     return getSpliceCost(Tp, Index);
2206   return BaseT::getShuffleCost(Kind, Tp, Mask, Index, SubTp);
2207 }
2208