1 //===-- AArch64TargetTransformInfo.cpp - AArch64 specific TTI -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "AArch64TargetTransformInfo.h"
10 #include "AArch64ExpandImm.h"
11 #include "MCTargetDesc/AArch64AddressingModes.h"
12 #include "llvm/Analysis/LoopInfo.h"
13 #include "llvm/Analysis/TargetTransformInfo.h"
14 #include "llvm/CodeGen/BasicTTIImpl.h"
15 #include "llvm/CodeGen/CostTable.h"
16 #include "llvm/CodeGen/TargetLowering.h"
17 #include "llvm/IR/Intrinsics.h"
18 #include "llvm/IR/IntrinsicInst.h"
19 #include "llvm/IR/IntrinsicsAArch64.h"
20 #include "llvm/IR/PatternMatch.h"
21 #include "llvm/Support/Debug.h"
22 #include "llvm/Transforms/InstCombine/InstCombiner.h"
23 #include <algorithm>
24 using namespace llvm;
25 using namespace llvm::PatternMatch;
26 
27 #define DEBUG_TYPE "aarch64tti"
28 
29 static cl::opt<bool> EnableFalkorHWPFUnrollFix("enable-falkor-hwpf-unroll-fix",
30                                                cl::init(true), cl::Hidden);
31 
32 bool AArch64TTIImpl::areInlineCompatible(const Function *Caller,
33                                          const Function *Callee) const {
34   const TargetMachine &TM = getTLI()->getTargetMachine();
35 
36   const FeatureBitset &CallerBits =
37       TM.getSubtargetImpl(*Caller)->getFeatureBits();
38   const FeatureBitset &CalleeBits =
39       TM.getSubtargetImpl(*Callee)->getFeatureBits();
40 
41   // Inline a callee if its target-features are a subset of the callers
42   // target-features.
43   return (CallerBits & CalleeBits) == CalleeBits;
44 }
45 
46 /// Calculate the cost of materializing a 64-bit value. This helper
47 /// method might only calculate a fraction of a larger immediate. Therefore it
48 /// is valid to return a cost of ZERO.
49 InstructionCost AArch64TTIImpl::getIntImmCost(int64_t Val) {
50   // Check if the immediate can be encoded within an instruction.
51   if (Val == 0 || AArch64_AM::isLogicalImmediate(Val, 64))
52     return 0;
53 
54   if (Val < 0)
55     Val = ~Val;
56 
57   // Calculate how many moves we will need to materialize this constant.
58   SmallVector<AArch64_IMM::ImmInsnModel, 4> Insn;
59   AArch64_IMM::expandMOVImm(Val, 64, Insn);
60   return Insn.size();
61 }
62 
63 /// Calculate the cost of materializing the given constant.
64 InstructionCost AArch64TTIImpl::getIntImmCost(const APInt &Imm, Type *Ty,
65                                               TTI::TargetCostKind CostKind) {
66   assert(Ty->isIntegerTy());
67 
68   unsigned BitSize = Ty->getPrimitiveSizeInBits();
69   if (BitSize == 0)
70     return ~0U;
71 
72   // Sign-extend all constants to a multiple of 64-bit.
73   APInt ImmVal = Imm;
74   if (BitSize & 0x3f)
75     ImmVal = Imm.sext((BitSize + 63) & ~0x3fU);
76 
77   // Split the constant into 64-bit chunks and calculate the cost for each
78   // chunk.
79   InstructionCost Cost = 0;
80   for (unsigned ShiftVal = 0; ShiftVal < BitSize; ShiftVal += 64) {
81     APInt Tmp = ImmVal.ashr(ShiftVal).sextOrTrunc(64);
82     int64_t Val = Tmp.getSExtValue();
83     Cost += getIntImmCost(Val);
84   }
85   // We need at least one instruction to materialze the constant.
86   return std::max<InstructionCost>(1, Cost);
87 }
88 
89 InstructionCost AArch64TTIImpl::getIntImmCostInst(unsigned Opcode, unsigned Idx,
90                                                   const APInt &Imm, Type *Ty,
91                                                   TTI::TargetCostKind CostKind,
92                                                   Instruction *Inst) {
93   assert(Ty->isIntegerTy());
94 
95   unsigned BitSize = Ty->getPrimitiveSizeInBits();
96   // There is no cost model for constants with a bit size of 0. Return TCC_Free
97   // here, so that constant hoisting will ignore this constant.
98   if (BitSize == 0)
99     return TTI::TCC_Free;
100 
101   unsigned ImmIdx = ~0U;
102   switch (Opcode) {
103   default:
104     return TTI::TCC_Free;
105   case Instruction::GetElementPtr:
106     // Always hoist the base address of a GetElementPtr.
107     if (Idx == 0)
108       return 2 * TTI::TCC_Basic;
109     return TTI::TCC_Free;
110   case Instruction::Store:
111     ImmIdx = 0;
112     break;
113   case Instruction::Add:
114   case Instruction::Sub:
115   case Instruction::Mul:
116   case Instruction::UDiv:
117   case Instruction::SDiv:
118   case Instruction::URem:
119   case Instruction::SRem:
120   case Instruction::And:
121   case Instruction::Or:
122   case Instruction::Xor:
123   case Instruction::ICmp:
124     ImmIdx = 1;
125     break;
126   // Always return TCC_Free for the shift value of a shift instruction.
127   case Instruction::Shl:
128   case Instruction::LShr:
129   case Instruction::AShr:
130     if (Idx == 1)
131       return TTI::TCC_Free;
132     break;
133   case Instruction::Trunc:
134   case Instruction::ZExt:
135   case Instruction::SExt:
136   case Instruction::IntToPtr:
137   case Instruction::PtrToInt:
138   case Instruction::BitCast:
139   case Instruction::PHI:
140   case Instruction::Call:
141   case Instruction::Select:
142   case Instruction::Ret:
143   case Instruction::Load:
144     break;
145   }
146 
147   if (Idx == ImmIdx) {
148     int NumConstants = (BitSize + 63) / 64;
149     InstructionCost Cost = AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind);
150     return (Cost <= NumConstants * TTI::TCC_Basic)
151                ? static_cast<int>(TTI::TCC_Free)
152                : Cost;
153   }
154   return AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind);
155 }
156 
157 InstructionCost
158 AArch64TTIImpl::getIntImmCostIntrin(Intrinsic::ID IID, unsigned Idx,
159                                     const APInt &Imm, Type *Ty,
160                                     TTI::TargetCostKind CostKind) {
161   assert(Ty->isIntegerTy());
162 
163   unsigned BitSize = Ty->getPrimitiveSizeInBits();
164   // There is no cost model for constants with a bit size of 0. Return TCC_Free
165   // here, so that constant hoisting will ignore this constant.
166   if (BitSize == 0)
167     return TTI::TCC_Free;
168 
169   // Most (all?) AArch64 intrinsics do not support folding immediates into the
170   // selected instruction, so we compute the materialization cost for the
171   // immediate directly.
172   if (IID >= Intrinsic::aarch64_addg && IID <= Intrinsic::aarch64_udiv)
173     return AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind);
174 
175   switch (IID) {
176   default:
177     return TTI::TCC_Free;
178   case Intrinsic::sadd_with_overflow:
179   case Intrinsic::uadd_with_overflow:
180   case Intrinsic::ssub_with_overflow:
181   case Intrinsic::usub_with_overflow:
182   case Intrinsic::smul_with_overflow:
183   case Intrinsic::umul_with_overflow:
184     if (Idx == 1) {
185       int NumConstants = (BitSize + 63) / 64;
186       InstructionCost Cost = AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind);
187       return (Cost <= NumConstants * TTI::TCC_Basic)
188                  ? static_cast<int>(TTI::TCC_Free)
189                  : Cost;
190     }
191     break;
192   case Intrinsic::experimental_stackmap:
193     if ((Idx < 2) || (Imm.getBitWidth() <= 64 && isInt<64>(Imm.getSExtValue())))
194       return TTI::TCC_Free;
195     break;
196   case Intrinsic::experimental_patchpoint_void:
197   case Intrinsic::experimental_patchpoint_i64:
198     if ((Idx < 4) || (Imm.getBitWidth() <= 64 && isInt<64>(Imm.getSExtValue())))
199       return TTI::TCC_Free;
200     break;
201   case Intrinsic::experimental_gc_statepoint:
202     if ((Idx < 5) || (Imm.getBitWidth() <= 64 && isInt<64>(Imm.getSExtValue())))
203       return TTI::TCC_Free;
204     break;
205   }
206   return AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind);
207 }
208 
209 TargetTransformInfo::PopcntSupportKind
210 AArch64TTIImpl::getPopcntSupport(unsigned TyWidth) {
211   assert(isPowerOf2_32(TyWidth) && "Ty width must be power of 2");
212   if (TyWidth == 32 || TyWidth == 64)
213     return TTI::PSK_FastHardware;
214   // TODO: AArch64TargetLowering::LowerCTPOP() supports 128bit popcount.
215   return TTI::PSK_Software;
216 }
217 
218 InstructionCost
219 AArch64TTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
220                                       TTI::TargetCostKind CostKind) {
221   auto *RetTy = ICA.getReturnType();
222   switch (ICA.getID()) {
223   case Intrinsic::umin:
224   case Intrinsic::umax:
225   case Intrinsic::smin:
226   case Intrinsic::smax: {
227     static const auto ValidMinMaxTys = {MVT::v8i8,  MVT::v16i8, MVT::v4i16,
228                                         MVT::v8i16, MVT::v2i32, MVT::v4i32};
229     auto LT = TLI->getTypeLegalizationCost(DL, RetTy);
230     // v2i64 types get converted to cmp+bif hence the cost of 2
231     if (LT.second == MVT::v2i64)
232       return LT.first * 2;
233     if (any_of(ValidMinMaxTys, [&LT](MVT M) { return M == LT.second; }))
234       return LT.first;
235     break;
236   }
237   case Intrinsic::sadd_sat:
238   case Intrinsic::ssub_sat:
239   case Intrinsic::uadd_sat:
240   case Intrinsic::usub_sat: {
241     static const auto ValidSatTys = {MVT::v8i8,  MVT::v16i8, MVT::v4i16,
242                                      MVT::v8i16, MVT::v2i32, MVT::v4i32,
243                                      MVT::v2i64};
244     auto LT = TLI->getTypeLegalizationCost(DL, RetTy);
245     // This is a base cost of 1 for the vadd, plus 3 extract shifts if we
246     // need to extend the type, as it uses shr(qadd(shl, shl)).
247     unsigned Instrs =
248         LT.second.getScalarSizeInBits() == RetTy->getScalarSizeInBits() ? 1 : 4;
249     if (any_of(ValidSatTys, [&LT](MVT M) { return M == LT.second; }))
250       return LT.first * Instrs;
251     break;
252   }
253   case Intrinsic::abs: {
254     static const auto ValidAbsTys = {MVT::v8i8,  MVT::v16i8, MVT::v4i16,
255                                      MVT::v8i16, MVT::v2i32, MVT::v4i32,
256                                      MVT::v2i64};
257     auto LT = TLI->getTypeLegalizationCost(DL, RetTy);
258     if (any_of(ValidAbsTys, [&LT](MVT M) { return M == LT.second; }))
259       return LT.first;
260     break;
261   }
262   case Intrinsic::experimental_stepvector: {
263     InstructionCost Cost = 1; // Cost of the `index' instruction
264     auto LT = TLI->getTypeLegalizationCost(DL, RetTy);
265     // Legalisation of illegal vectors involves an `index' instruction plus
266     // (LT.first - 1) vector adds.
267     if (LT.first > 1) {
268       Type *LegalVTy = EVT(LT.second).getTypeForEVT(RetTy->getContext());
269       InstructionCost AddCost =
270           getArithmeticInstrCost(Instruction::Add, LegalVTy, CostKind);
271       Cost += AddCost * (LT.first - 1);
272     }
273     return Cost;
274   }
275   case Intrinsic::bitreverse: {
276     static const CostTblEntry BitreverseTbl[] = {
277         {Intrinsic::bitreverse, MVT::i32, 1},
278         {Intrinsic::bitreverse, MVT::i64, 1},
279         {Intrinsic::bitreverse, MVT::v8i8, 1},
280         {Intrinsic::bitreverse, MVT::v16i8, 1},
281         {Intrinsic::bitreverse, MVT::v4i16, 2},
282         {Intrinsic::bitreverse, MVT::v8i16, 2},
283         {Intrinsic::bitreverse, MVT::v2i32, 2},
284         {Intrinsic::bitreverse, MVT::v4i32, 2},
285         {Intrinsic::bitreverse, MVT::v1i64, 2},
286         {Intrinsic::bitreverse, MVT::v2i64, 2},
287     };
288     const auto LegalisationCost = TLI->getTypeLegalizationCost(DL, RetTy);
289     const auto *Entry =
290         CostTableLookup(BitreverseTbl, ICA.getID(), LegalisationCost.second);
291     // Cost Model is using the legal type(i32) that i8 and i16 will be converted
292     // to +1 so that we match the actual lowering cost
293     if (TLI->getValueType(DL, RetTy, true) == MVT::i8 ||
294         TLI->getValueType(DL, RetTy, true) == MVT::i16)
295       return LegalisationCost.first * Entry->Cost + 1;
296     if (Entry)
297       return LegalisationCost.first * Entry->Cost;
298     break;
299   }
300   case Intrinsic::ctpop: {
301     static const CostTblEntry CtpopCostTbl[] = {
302         {ISD::CTPOP, MVT::v2i64, 4},
303         {ISD::CTPOP, MVT::v4i32, 3},
304         {ISD::CTPOP, MVT::v8i16, 2},
305         {ISD::CTPOP, MVT::v16i8, 1},
306         {ISD::CTPOP, MVT::i64,   4},
307         {ISD::CTPOP, MVT::v2i32, 3},
308         {ISD::CTPOP, MVT::v4i16, 2},
309         {ISD::CTPOP, MVT::v8i8,  1},
310         {ISD::CTPOP, MVT::i32,   5},
311     };
312     auto LT = TLI->getTypeLegalizationCost(DL, RetTy);
313     MVT MTy = LT.second;
314     if (const auto *Entry = CostTableLookup(CtpopCostTbl, ISD::CTPOP, MTy)) {
315       // Extra cost of +1 when illegal vector types are legalized by promoting
316       // the integer type.
317       int ExtraCost = MTy.isVector() && MTy.getScalarSizeInBits() !=
318                                             RetTy->getScalarSizeInBits()
319                           ? 1
320                           : 0;
321       return LT.first * Entry->Cost + ExtraCost;
322     }
323     break;
324   }
325   default:
326     break;
327   }
328   return BaseT::getIntrinsicInstrCost(ICA, CostKind);
329 }
330 
331 /// The function will remove redundant reinterprets casting in the presence
332 /// of the control flow
333 static Optional<Instruction *> processPhiNode(InstCombiner &IC,
334                                               IntrinsicInst &II) {
335   SmallVector<Instruction *, 32> Worklist;
336   auto RequiredType = II.getType();
337 
338   auto *PN = dyn_cast<PHINode>(II.getArgOperand(0));
339   assert(PN && "Expected Phi Node!");
340 
341   // Don't create a new Phi unless we can remove the old one.
342   if (!PN->hasOneUse())
343     return None;
344 
345   for (Value *IncValPhi : PN->incoming_values()) {
346     auto *Reinterpret = dyn_cast<IntrinsicInst>(IncValPhi);
347     if (!Reinterpret ||
348         Reinterpret->getIntrinsicID() !=
349             Intrinsic::aarch64_sve_convert_to_svbool ||
350         RequiredType != Reinterpret->getArgOperand(0)->getType())
351       return None;
352   }
353 
354   // Create the new Phi
355   LLVMContext &Ctx = PN->getContext();
356   IRBuilder<> Builder(Ctx);
357   Builder.SetInsertPoint(PN);
358   PHINode *NPN = Builder.CreatePHI(RequiredType, PN->getNumIncomingValues());
359   Worklist.push_back(PN);
360 
361   for (unsigned I = 0; I < PN->getNumIncomingValues(); I++) {
362     auto *Reinterpret = cast<Instruction>(PN->getIncomingValue(I));
363     NPN->addIncoming(Reinterpret->getOperand(0), PN->getIncomingBlock(I));
364     Worklist.push_back(Reinterpret);
365   }
366 
367   // Cleanup Phi Node and reinterprets
368   return IC.replaceInstUsesWith(II, NPN);
369 }
370 
371 static Optional<Instruction *> instCombineConvertFromSVBool(InstCombiner &IC,
372                                                             IntrinsicInst &II) {
373   // If the reinterpret instruction operand is a PHI Node
374   if (isa<PHINode>(II.getArgOperand(0)))
375     return processPhiNode(IC, II);
376 
377   SmallVector<Instruction *, 32> CandidatesForRemoval;
378   Value *Cursor = II.getOperand(0), *EarliestReplacement = nullptr;
379 
380   const auto *IVTy = cast<VectorType>(II.getType());
381 
382   // Walk the chain of conversions.
383   while (Cursor) {
384     // If the type of the cursor has fewer lanes than the final result, zeroing
385     // must take place, which breaks the equivalence chain.
386     const auto *CursorVTy = cast<VectorType>(Cursor->getType());
387     if (CursorVTy->getElementCount().getKnownMinValue() <
388         IVTy->getElementCount().getKnownMinValue())
389       break;
390 
391     // If the cursor has the same type as I, it is a viable replacement.
392     if (Cursor->getType() == IVTy)
393       EarliestReplacement = Cursor;
394 
395     auto *IntrinsicCursor = dyn_cast<IntrinsicInst>(Cursor);
396 
397     // If this is not an SVE conversion intrinsic, this is the end of the chain.
398     if (!IntrinsicCursor || !(IntrinsicCursor->getIntrinsicID() ==
399                                   Intrinsic::aarch64_sve_convert_to_svbool ||
400                               IntrinsicCursor->getIntrinsicID() ==
401                                   Intrinsic::aarch64_sve_convert_from_svbool))
402       break;
403 
404     CandidatesForRemoval.insert(CandidatesForRemoval.begin(), IntrinsicCursor);
405     Cursor = IntrinsicCursor->getOperand(0);
406   }
407 
408   // If no viable replacement in the conversion chain was found, there is
409   // nothing to do.
410   if (!EarliestReplacement)
411     return None;
412 
413   return IC.replaceInstUsesWith(II, EarliestReplacement);
414 }
415 
416 static Optional<Instruction *> instCombineSVEDup(InstCombiner &IC,
417                                                  IntrinsicInst &II) {
418   IntrinsicInst *Pg = dyn_cast<IntrinsicInst>(II.getArgOperand(1));
419   if (!Pg)
420     return None;
421 
422   if (Pg->getIntrinsicID() != Intrinsic::aarch64_sve_ptrue)
423     return None;
424 
425   const auto PTruePattern =
426       cast<ConstantInt>(Pg->getOperand(0))->getZExtValue();
427   if (PTruePattern != AArch64SVEPredPattern::vl1)
428     return None;
429 
430   // The intrinsic is inserting into lane zero so use an insert instead.
431   auto *IdxTy = Type::getInt64Ty(II.getContext());
432   auto *Insert = InsertElementInst::Create(
433       II.getArgOperand(0), II.getArgOperand(2), ConstantInt::get(IdxTy, 0));
434   Insert->insertBefore(&II);
435   Insert->takeName(&II);
436 
437   return IC.replaceInstUsesWith(II, Insert);
438 }
439 
440 static Optional<Instruction *> instCombineSVECmpNE(InstCombiner &IC,
441                                                    IntrinsicInst &II) {
442   LLVMContext &Ctx = II.getContext();
443   IRBuilder<> Builder(Ctx);
444   Builder.SetInsertPoint(&II);
445 
446   // Check that the predicate is all active
447   auto *Pg = dyn_cast<IntrinsicInst>(II.getArgOperand(0));
448   if (!Pg || Pg->getIntrinsicID() != Intrinsic::aarch64_sve_ptrue)
449     return None;
450 
451   const auto PTruePattern =
452       cast<ConstantInt>(Pg->getOperand(0))->getZExtValue();
453   if (PTruePattern != AArch64SVEPredPattern::all)
454     return None;
455 
456   // Check that we have a compare of zero..
457   auto *DupX = dyn_cast<IntrinsicInst>(II.getArgOperand(2));
458   if (!DupX || DupX->getIntrinsicID() != Intrinsic::aarch64_sve_dup_x)
459     return None;
460 
461   auto *DupXArg = dyn_cast<ConstantInt>(DupX->getArgOperand(0));
462   if (!DupXArg || !DupXArg->isZero())
463     return None;
464 
465   // ..against a dupq
466   auto *DupQLane = dyn_cast<IntrinsicInst>(II.getArgOperand(1));
467   if (!DupQLane ||
468       DupQLane->getIntrinsicID() != Intrinsic::aarch64_sve_dupq_lane)
469     return None;
470 
471   // Where the dupq is a lane 0 replicate of a vector insert
472   if (!cast<ConstantInt>(DupQLane->getArgOperand(1))->isZero())
473     return None;
474 
475   auto *VecIns = dyn_cast<IntrinsicInst>(DupQLane->getArgOperand(0));
476   if (!VecIns ||
477       VecIns->getIntrinsicID() != Intrinsic::experimental_vector_insert)
478     return None;
479 
480   // Where the vector insert is a fixed constant vector insert into undef at
481   // index zero
482   if (!isa<UndefValue>(VecIns->getArgOperand(0)))
483     return None;
484 
485   if (!cast<ConstantInt>(VecIns->getArgOperand(2))->isZero())
486     return None;
487 
488   auto *ConstVec = dyn_cast<Constant>(VecIns->getArgOperand(1));
489   if (!ConstVec)
490     return None;
491 
492   auto *VecTy = dyn_cast<FixedVectorType>(ConstVec->getType());
493   auto *OutTy = dyn_cast<ScalableVectorType>(II.getType());
494   if (!VecTy || !OutTy || VecTy->getNumElements() != OutTy->getMinNumElements())
495     return None;
496 
497   unsigned NumElts = VecTy->getNumElements();
498   unsigned PredicateBits = 0;
499 
500   // Expand intrinsic operands to a 16-bit byte level predicate
501   for (unsigned I = 0; I < NumElts; ++I) {
502     auto *Arg = dyn_cast<ConstantInt>(ConstVec->getAggregateElement(I));
503     if (!Arg)
504       return None;
505     if (!Arg->isZero())
506       PredicateBits |= 1 << (I * (16 / NumElts));
507   }
508 
509   // If all bits are zero bail early with an empty predicate
510   if (PredicateBits == 0) {
511     auto *PFalse = Constant::getNullValue(II.getType());
512     PFalse->takeName(&II);
513     return IC.replaceInstUsesWith(II, PFalse);
514   }
515 
516   // Calculate largest predicate type used (where byte predicate is largest)
517   unsigned Mask = 8;
518   for (unsigned I = 0; I < 16; ++I)
519     if ((PredicateBits & (1 << I)) != 0)
520       Mask |= (I % 8);
521 
522   unsigned PredSize = Mask & -Mask;
523   auto *PredType = ScalableVectorType::get(
524       Type::getInt1Ty(Ctx), AArch64::SVEBitsPerBlock / (PredSize * 8));
525 
526   // Ensure all relevant bits are set
527   for (unsigned I = 0; I < 16; I += PredSize)
528     if ((PredicateBits & (1 << I)) == 0)
529       return None;
530 
531   auto *PTruePat =
532       ConstantInt::get(Type::getInt32Ty(Ctx), AArch64SVEPredPattern::all);
533   auto *PTrue = Builder.CreateIntrinsic(Intrinsic::aarch64_sve_ptrue,
534                                         {PredType}, {PTruePat});
535   auto *ConvertToSVBool = Builder.CreateIntrinsic(
536       Intrinsic::aarch64_sve_convert_to_svbool, {PredType}, {PTrue});
537   auto *ConvertFromSVBool =
538       Builder.CreateIntrinsic(Intrinsic::aarch64_sve_convert_from_svbool,
539                               {II.getType()}, {ConvertToSVBool});
540 
541   ConvertFromSVBool->takeName(&II);
542   return IC.replaceInstUsesWith(II, ConvertFromSVBool);
543 }
544 
545 static Optional<Instruction *> instCombineSVELast(InstCombiner &IC,
546                                                   IntrinsicInst &II) {
547   IRBuilder<> Builder(II.getContext());
548   Builder.SetInsertPoint(&II);
549   Value *Pg = II.getArgOperand(0);
550   Value *Vec = II.getArgOperand(1);
551   auto IntrinsicID = II.getIntrinsicID();
552   bool IsAfter = IntrinsicID == Intrinsic::aarch64_sve_lasta;
553 
554   // lastX(splat(X)) --> X
555   if (auto *SplatVal = getSplatValue(Vec))
556     return IC.replaceInstUsesWith(II, SplatVal);
557 
558   // If x and/or y is a splat value then:
559   // lastX (binop (x, y)) --> binop(lastX(x), lastX(y))
560   Value *LHS, *RHS;
561   if (match(Vec, m_OneUse(m_BinOp(m_Value(LHS), m_Value(RHS))))) {
562     if (isSplatValue(LHS) || isSplatValue(RHS)) {
563       auto *OldBinOp = cast<BinaryOperator>(Vec);
564       auto OpC = OldBinOp->getOpcode();
565       auto *NewLHS =
566           Builder.CreateIntrinsic(IntrinsicID, {Vec->getType()}, {Pg, LHS});
567       auto *NewRHS =
568           Builder.CreateIntrinsic(IntrinsicID, {Vec->getType()}, {Pg, RHS});
569       auto *NewBinOp = BinaryOperator::CreateWithCopiedFlags(
570           OpC, NewLHS, NewRHS, OldBinOp, OldBinOp->getName(), &II);
571       return IC.replaceInstUsesWith(II, NewBinOp);
572     }
573   }
574 
575   auto *C = dyn_cast<Constant>(Pg);
576   if (IsAfter && C && C->isNullValue()) {
577     // The intrinsic is extracting lane 0 so use an extract instead.
578     auto *IdxTy = Type::getInt64Ty(II.getContext());
579     auto *Extract = ExtractElementInst::Create(Vec, ConstantInt::get(IdxTy, 0));
580     Extract->insertBefore(&II);
581     Extract->takeName(&II);
582     return IC.replaceInstUsesWith(II, Extract);
583   }
584 
585   auto *IntrPG = dyn_cast<IntrinsicInst>(Pg);
586   if (!IntrPG)
587     return None;
588 
589   if (IntrPG->getIntrinsicID() != Intrinsic::aarch64_sve_ptrue)
590     return None;
591 
592   const auto PTruePattern =
593       cast<ConstantInt>(IntrPG->getOperand(0))->getZExtValue();
594 
595   // Can the intrinsic's predicate be converted to a known constant index?
596   unsigned Idx;
597   switch (PTruePattern) {
598   default:
599     return None;
600   case AArch64SVEPredPattern::vl1:
601     Idx = 0;
602     break;
603   case AArch64SVEPredPattern::vl2:
604     Idx = 1;
605     break;
606   case AArch64SVEPredPattern::vl3:
607     Idx = 2;
608     break;
609   case AArch64SVEPredPattern::vl4:
610     Idx = 3;
611     break;
612   case AArch64SVEPredPattern::vl5:
613     Idx = 4;
614     break;
615   case AArch64SVEPredPattern::vl6:
616     Idx = 5;
617     break;
618   case AArch64SVEPredPattern::vl7:
619     Idx = 6;
620     break;
621   case AArch64SVEPredPattern::vl8:
622     Idx = 7;
623     break;
624   case AArch64SVEPredPattern::vl16:
625     Idx = 15;
626     break;
627   }
628 
629   // Increment the index if extracting the element after the last active
630   // predicate element.
631   if (IsAfter)
632     ++Idx;
633 
634   // Ignore extracts whose index is larger than the known minimum vector
635   // length. NOTE: This is an artificial constraint where we prefer to
636   // maintain what the user asked for until an alternative is proven faster.
637   auto *PgVTy = cast<ScalableVectorType>(Pg->getType());
638   if (Idx >= PgVTy->getMinNumElements())
639     return None;
640 
641   // The intrinsic is extracting a fixed lane so use an extract instead.
642   auto *IdxTy = Type::getInt64Ty(II.getContext());
643   auto *Extract = ExtractElementInst::Create(Vec, ConstantInt::get(IdxTy, Idx));
644   Extract->insertBefore(&II);
645   Extract->takeName(&II);
646   return IC.replaceInstUsesWith(II, Extract);
647 }
648 
649 static Optional<Instruction *> instCombineRDFFR(InstCombiner &IC,
650                                                 IntrinsicInst &II) {
651   LLVMContext &Ctx = II.getContext();
652   IRBuilder<> Builder(Ctx);
653   Builder.SetInsertPoint(&II);
654   // Replace rdffr with predicated rdffr.z intrinsic, so that optimizePTestInstr
655   // can work with RDFFR_PP for ptest elimination.
656   auto *AllPat =
657       ConstantInt::get(Type::getInt32Ty(Ctx), AArch64SVEPredPattern::all);
658   auto *PTrue = Builder.CreateIntrinsic(Intrinsic::aarch64_sve_ptrue,
659                                         {II.getType()}, {AllPat});
660   auto *RDFFR =
661       Builder.CreateIntrinsic(Intrinsic::aarch64_sve_rdffr_z, {}, {PTrue});
662   RDFFR->takeName(&II);
663   return IC.replaceInstUsesWith(II, RDFFR);
664 }
665 
666 static Optional<Instruction *>
667 instCombineSVECntElts(InstCombiner &IC, IntrinsicInst &II, unsigned NumElts) {
668   const auto Pattern = cast<ConstantInt>(II.getArgOperand(0))->getZExtValue();
669 
670   if (Pattern == AArch64SVEPredPattern::all) {
671     LLVMContext &Ctx = II.getContext();
672     IRBuilder<> Builder(Ctx);
673     Builder.SetInsertPoint(&II);
674 
675     Constant *StepVal = ConstantInt::get(II.getType(), NumElts);
676     auto *VScale = Builder.CreateVScale(StepVal);
677     VScale->takeName(&II);
678     return IC.replaceInstUsesWith(II, VScale);
679   }
680 
681   unsigned MinNumElts = 0;
682   switch (Pattern) {
683   default:
684     return None;
685   case AArch64SVEPredPattern::vl1:
686   case AArch64SVEPredPattern::vl2:
687   case AArch64SVEPredPattern::vl3:
688   case AArch64SVEPredPattern::vl4:
689   case AArch64SVEPredPattern::vl5:
690   case AArch64SVEPredPattern::vl6:
691   case AArch64SVEPredPattern::vl7:
692   case AArch64SVEPredPattern::vl8:
693     MinNumElts = Pattern;
694     break;
695   case AArch64SVEPredPattern::vl16:
696     MinNumElts = 16;
697     break;
698   }
699 
700   return NumElts >= MinNumElts
701              ? Optional<Instruction *>(IC.replaceInstUsesWith(
702                    II, ConstantInt::get(II.getType(), MinNumElts)))
703              : None;
704 }
705 
706 static Optional<Instruction *> instCombineSVEPTest(InstCombiner &IC,
707                                                    IntrinsicInst &II) {
708   IntrinsicInst *Op1 = dyn_cast<IntrinsicInst>(II.getArgOperand(0));
709   IntrinsicInst *Op2 = dyn_cast<IntrinsicInst>(II.getArgOperand(1));
710 
711   if (Op1 && Op2 &&
712       Op1->getIntrinsicID() == Intrinsic::aarch64_sve_convert_to_svbool &&
713       Op2->getIntrinsicID() == Intrinsic::aarch64_sve_convert_to_svbool &&
714       Op1->getArgOperand(0)->getType() == Op2->getArgOperand(0)->getType()) {
715 
716     IRBuilder<> Builder(II.getContext());
717     Builder.SetInsertPoint(&II);
718 
719     Value *Ops[] = {Op1->getArgOperand(0), Op2->getArgOperand(0)};
720     Type *Tys[] = {Op1->getArgOperand(0)->getType()};
721 
722     auto *PTest = Builder.CreateIntrinsic(II.getIntrinsicID(), Tys, Ops);
723 
724     PTest->takeName(&II);
725     return IC.replaceInstUsesWith(II, PTest);
726   }
727 
728   return None;
729 }
730 
731 static Optional<Instruction *> instCombineSVEVectorMul(InstCombiner &IC,
732                                                        IntrinsicInst &II) {
733   auto *OpPredicate = II.getOperand(0);
734   auto *OpMultiplicand = II.getOperand(1);
735   auto *OpMultiplier = II.getOperand(2);
736 
737   IRBuilder<> Builder(II.getContext());
738   Builder.SetInsertPoint(&II);
739 
740   // Return true if a given instruction is an aarch64_sve_dup_x intrinsic call
741   // with a unit splat value, false otherwise.
742   auto IsUnitDupX = [](auto *I) {
743     auto *IntrI = dyn_cast<IntrinsicInst>(I);
744     if (!IntrI || IntrI->getIntrinsicID() != Intrinsic::aarch64_sve_dup_x)
745       return false;
746 
747     auto *SplatValue = IntrI->getOperand(0);
748     return match(SplatValue, m_FPOne()) || match(SplatValue, m_One());
749   };
750 
751   // Return true if a given instruction is an aarch64_sve_dup intrinsic call
752   // with a unit splat value, false otherwise.
753   auto IsUnitDup = [](auto *I) {
754     auto *IntrI = dyn_cast<IntrinsicInst>(I);
755     if (!IntrI || IntrI->getIntrinsicID() != Intrinsic::aarch64_sve_dup)
756       return false;
757 
758     auto *SplatValue = IntrI->getOperand(2);
759     return match(SplatValue, m_FPOne()) || match(SplatValue, m_One());
760   };
761 
762   // The OpMultiplier variable should always point to the dup (if any), so
763   // swap if necessary.
764   if (IsUnitDup(OpMultiplicand) || IsUnitDupX(OpMultiplicand))
765     std::swap(OpMultiplier, OpMultiplicand);
766 
767   if (IsUnitDupX(OpMultiplier)) {
768     // [f]mul pg (dupx 1) %n => %n
769     OpMultiplicand->takeName(&II);
770     return IC.replaceInstUsesWith(II, OpMultiplicand);
771   } else if (IsUnitDup(OpMultiplier)) {
772     // [f]mul pg (dup pg 1) %n => %n
773     auto *DupInst = cast<IntrinsicInst>(OpMultiplier);
774     auto *DupPg = DupInst->getOperand(1);
775     // TODO: this is naive. The optimization is still valid if DupPg
776     // 'encompasses' OpPredicate, not only if they're the same predicate.
777     if (OpPredicate == DupPg) {
778       OpMultiplicand->takeName(&II);
779       return IC.replaceInstUsesWith(II, OpMultiplicand);
780     }
781   }
782 
783   return None;
784 }
785 
786 static Optional<Instruction *> instCombineSVEUnpack(InstCombiner &IC,
787                                                     IntrinsicInst &II) {
788   IRBuilder<> Builder(II.getContext());
789   Builder.SetInsertPoint(&II);
790   Value *UnpackArg = II.getArgOperand(0);
791   auto *RetTy = cast<ScalableVectorType>(II.getType());
792   bool IsSigned = II.getIntrinsicID() == Intrinsic::aarch64_sve_sunpkhi ||
793                   II.getIntrinsicID() == Intrinsic::aarch64_sve_sunpklo;
794 
795   // Hi = uunpkhi(splat(X)) --> Hi = splat(extend(X))
796   // Lo = uunpklo(splat(X)) --> Lo = splat(extend(X))
797   if (auto *ScalarArg = getSplatValue(UnpackArg)) {
798     ScalarArg =
799         Builder.CreateIntCast(ScalarArg, RetTy->getScalarType(), IsSigned);
800     Value *NewVal =
801         Builder.CreateVectorSplat(RetTy->getElementCount(), ScalarArg);
802     NewVal->takeName(&II);
803     return IC.replaceInstUsesWith(II, NewVal);
804   }
805 
806   return None;
807 }
808 static Optional<Instruction *> instCombineSVETBL(InstCombiner &IC,
809                                                  IntrinsicInst &II) {
810   auto *OpVal = II.getOperand(0);
811   auto *OpIndices = II.getOperand(1);
812   VectorType *VTy = cast<VectorType>(II.getType());
813 
814   // Check whether OpIndices is an aarch64_sve_dup_x intrinsic call with
815   // constant splat value < minimal element count of result.
816   auto *DupXIntrI = dyn_cast<IntrinsicInst>(OpIndices);
817   if (!DupXIntrI || DupXIntrI->getIntrinsicID() != Intrinsic::aarch64_sve_dup_x)
818     return None;
819 
820   auto *SplatValue = dyn_cast<ConstantInt>(DupXIntrI->getOperand(0));
821   if (!SplatValue ||
822       SplatValue->getValue().uge(VTy->getElementCount().getKnownMinValue()))
823     return None;
824 
825   // Convert sve_tbl(OpVal sve_dup_x(SplatValue)) to
826   // splat_vector(extractelement(OpVal, SplatValue)) for further optimization.
827   IRBuilder<> Builder(II.getContext());
828   Builder.SetInsertPoint(&II);
829   auto *Extract = Builder.CreateExtractElement(OpVal, SplatValue);
830   auto *VectorSplat =
831       Builder.CreateVectorSplat(VTy->getElementCount(), Extract);
832 
833   VectorSplat->takeName(&II);
834   return IC.replaceInstUsesWith(II, VectorSplat);
835 }
836 
837 Optional<Instruction *>
838 AArch64TTIImpl::instCombineIntrinsic(InstCombiner &IC,
839                                      IntrinsicInst &II) const {
840   Intrinsic::ID IID = II.getIntrinsicID();
841   switch (IID) {
842   default:
843     break;
844   case Intrinsic::aarch64_sve_convert_from_svbool:
845     return instCombineConvertFromSVBool(IC, II);
846   case Intrinsic::aarch64_sve_dup:
847     return instCombineSVEDup(IC, II);
848   case Intrinsic::aarch64_sve_cmpne:
849   case Intrinsic::aarch64_sve_cmpne_wide:
850     return instCombineSVECmpNE(IC, II);
851   case Intrinsic::aarch64_sve_rdffr:
852     return instCombineRDFFR(IC, II);
853   case Intrinsic::aarch64_sve_lasta:
854   case Intrinsic::aarch64_sve_lastb:
855     return instCombineSVELast(IC, II);
856   case Intrinsic::aarch64_sve_cntd:
857     return instCombineSVECntElts(IC, II, 2);
858   case Intrinsic::aarch64_sve_cntw:
859     return instCombineSVECntElts(IC, II, 4);
860   case Intrinsic::aarch64_sve_cnth:
861     return instCombineSVECntElts(IC, II, 8);
862   case Intrinsic::aarch64_sve_cntb:
863     return instCombineSVECntElts(IC, II, 16);
864   case Intrinsic::aarch64_sve_ptest_any:
865   case Intrinsic::aarch64_sve_ptest_first:
866   case Intrinsic::aarch64_sve_ptest_last:
867     return instCombineSVEPTest(IC, II);
868   case Intrinsic::aarch64_sve_mul:
869   case Intrinsic::aarch64_sve_fmul:
870     return instCombineSVEVectorMul(IC, II);
871   case Intrinsic::aarch64_sve_tbl:
872     return instCombineSVETBL(IC, II);
873   case Intrinsic::aarch64_sve_uunpkhi:
874   case Intrinsic::aarch64_sve_uunpklo:
875   case Intrinsic::aarch64_sve_sunpkhi:
876   case Intrinsic::aarch64_sve_sunpklo:
877     return instCombineSVEUnpack(IC, II);
878   }
879 
880   return None;
881 }
882 
883 bool AArch64TTIImpl::isWideningInstruction(Type *DstTy, unsigned Opcode,
884                                            ArrayRef<const Value *> Args) {
885 
886   // A helper that returns a vector type from the given type. The number of
887   // elements in type Ty determine the vector width.
888   auto toVectorTy = [&](Type *ArgTy) {
889     return VectorType::get(ArgTy->getScalarType(),
890                            cast<VectorType>(DstTy)->getElementCount());
891   };
892 
893   // Exit early if DstTy is not a vector type whose elements are at least
894   // 16-bits wide.
895   if (!DstTy->isVectorTy() || DstTy->getScalarSizeInBits() < 16)
896     return false;
897 
898   // Determine if the operation has a widening variant. We consider both the
899   // "long" (e.g., usubl) and "wide" (e.g., usubw) versions of the
900   // instructions.
901   //
902   // TODO: Add additional widening operations (e.g., mul, shl, etc.) once we
903   //       verify that their extending operands are eliminated during code
904   //       generation.
905   switch (Opcode) {
906   case Instruction::Add: // UADDL(2), SADDL(2), UADDW(2), SADDW(2).
907   case Instruction::Sub: // USUBL(2), SSUBL(2), USUBW(2), SSUBW(2).
908     break;
909   default:
910     return false;
911   }
912 
913   // To be a widening instruction (either the "wide" or "long" versions), the
914   // second operand must be a sign- or zero extend having a single user. We
915   // only consider extends having a single user because they may otherwise not
916   // be eliminated.
917   if (Args.size() != 2 ||
918       (!isa<SExtInst>(Args[1]) && !isa<ZExtInst>(Args[1])) ||
919       !Args[1]->hasOneUse())
920     return false;
921   auto *Extend = cast<CastInst>(Args[1]);
922 
923   // Legalize the destination type and ensure it can be used in a widening
924   // operation.
925   auto DstTyL = TLI->getTypeLegalizationCost(DL, DstTy);
926   unsigned DstElTySize = DstTyL.second.getScalarSizeInBits();
927   if (!DstTyL.second.isVector() || DstElTySize != DstTy->getScalarSizeInBits())
928     return false;
929 
930   // Legalize the source type and ensure it can be used in a widening
931   // operation.
932   auto *SrcTy = toVectorTy(Extend->getSrcTy());
933   auto SrcTyL = TLI->getTypeLegalizationCost(DL, SrcTy);
934   unsigned SrcElTySize = SrcTyL.second.getScalarSizeInBits();
935   if (!SrcTyL.second.isVector() || SrcElTySize != SrcTy->getScalarSizeInBits())
936     return false;
937 
938   // Get the total number of vector elements in the legalized types.
939   InstructionCost NumDstEls =
940       DstTyL.first * DstTyL.second.getVectorMinNumElements();
941   InstructionCost NumSrcEls =
942       SrcTyL.first * SrcTyL.second.getVectorMinNumElements();
943 
944   // Return true if the legalized types have the same number of vector elements
945   // and the destination element type size is twice that of the source type.
946   return NumDstEls == NumSrcEls && 2 * SrcElTySize == DstElTySize;
947 }
948 
949 InstructionCost AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
950                                                  Type *Src,
951                                                  TTI::CastContextHint CCH,
952                                                  TTI::TargetCostKind CostKind,
953                                                  const Instruction *I) {
954   int ISD = TLI->InstructionOpcodeToISD(Opcode);
955   assert(ISD && "Invalid opcode");
956 
957   // If the cast is observable, and it is used by a widening instruction (e.g.,
958   // uaddl, saddw, etc.), it may be free.
959   if (I && I->hasOneUse()) {
960     auto *SingleUser = cast<Instruction>(*I->user_begin());
961     SmallVector<const Value *, 4> Operands(SingleUser->operand_values());
962     if (isWideningInstruction(Dst, SingleUser->getOpcode(), Operands)) {
963       // If the cast is the second operand, it is free. We will generate either
964       // a "wide" or "long" version of the widening instruction.
965       if (I == SingleUser->getOperand(1))
966         return 0;
967       // If the cast is not the second operand, it will be free if it looks the
968       // same as the second operand. In this case, we will generate a "long"
969       // version of the widening instruction.
970       if (auto *Cast = dyn_cast<CastInst>(SingleUser->getOperand(1)))
971         if (I->getOpcode() == unsigned(Cast->getOpcode()) &&
972             cast<CastInst>(I)->getSrcTy() == Cast->getSrcTy())
973           return 0;
974     }
975   }
976 
977   // TODO: Allow non-throughput costs that aren't binary.
978   auto AdjustCost = [&CostKind](InstructionCost Cost) -> InstructionCost {
979     if (CostKind != TTI::TCK_RecipThroughput)
980       return Cost == 0 ? 0 : 1;
981     return Cost;
982   };
983 
984   EVT SrcTy = TLI->getValueType(DL, Src);
985   EVT DstTy = TLI->getValueType(DL, Dst);
986 
987   if (!SrcTy.isSimple() || !DstTy.isSimple())
988     return AdjustCost(
989         BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I));
990 
991   static const TypeConversionCostTblEntry
992   ConversionTbl[] = {
993     { ISD::TRUNCATE, MVT::v4i16, MVT::v4i32,  1 },
994     { ISD::TRUNCATE, MVT::v4i32, MVT::v4i64,  0 },
995     { ISD::TRUNCATE, MVT::v8i8,  MVT::v8i32,  3 },
996     { ISD::TRUNCATE, MVT::v16i8, MVT::v16i32, 6 },
997 
998     // Truncations on nxvmiN
999     { ISD::TRUNCATE, MVT::nxv2i1, MVT::nxv2i16, 1 },
1000     { ISD::TRUNCATE, MVT::nxv2i1, MVT::nxv2i32, 1 },
1001     { ISD::TRUNCATE, MVT::nxv2i1, MVT::nxv2i64, 1 },
1002     { ISD::TRUNCATE, MVT::nxv4i1, MVT::nxv4i16, 1 },
1003     { ISD::TRUNCATE, MVT::nxv4i1, MVT::nxv4i32, 1 },
1004     { ISD::TRUNCATE, MVT::nxv4i1, MVT::nxv4i64, 2 },
1005     { ISD::TRUNCATE, MVT::nxv8i1, MVT::nxv8i16, 1 },
1006     { ISD::TRUNCATE, MVT::nxv8i1, MVT::nxv8i32, 3 },
1007     { ISD::TRUNCATE, MVT::nxv8i1, MVT::nxv8i64, 5 },
1008     { ISD::TRUNCATE, MVT::nxv16i1, MVT::nxv16i8, 1 },
1009     { ISD::TRUNCATE, MVT::nxv2i16, MVT::nxv2i32, 1 },
1010     { ISD::TRUNCATE, MVT::nxv2i32, MVT::nxv2i64, 1 },
1011     { ISD::TRUNCATE, MVT::nxv4i16, MVT::nxv4i32, 1 },
1012     { ISD::TRUNCATE, MVT::nxv4i32, MVT::nxv4i64, 2 },
1013     { ISD::TRUNCATE, MVT::nxv8i16, MVT::nxv8i32, 3 },
1014     { ISD::TRUNCATE, MVT::nxv8i32, MVT::nxv8i64, 6 },
1015 
1016     // The number of shll instructions for the extension.
1017     { ISD::SIGN_EXTEND, MVT::v4i64,  MVT::v4i16, 3 },
1018     { ISD::ZERO_EXTEND, MVT::v4i64,  MVT::v4i16, 3 },
1019     { ISD::SIGN_EXTEND, MVT::v4i64,  MVT::v4i32, 2 },
1020     { ISD::ZERO_EXTEND, MVT::v4i64,  MVT::v4i32, 2 },
1021     { ISD::SIGN_EXTEND, MVT::v8i32,  MVT::v8i8,  3 },
1022     { ISD::ZERO_EXTEND, MVT::v8i32,  MVT::v8i8,  3 },
1023     { ISD::SIGN_EXTEND, MVT::v8i32,  MVT::v8i16, 2 },
1024     { ISD::ZERO_EXTEND, MVT::v8i32,  MVT::v8i16, 2 },
1025     { ISD::SIGN_EXTEND, MVT::v8i64,  MVT::v8i8,  7 },
1026     { ISD::ZERO_EXTEND, MVT::v8i64,  MVT::v8i8,  7 },
1027     { ISD::SIGN_EXTEND, MVT::v8i64,  MVT::v8i16, 6 },
1028     { ISD::ZERO_EXTEND, MVT::v8i64,  MVT::v8i16, 6 },
1029     { ISD::SIGN_EXTEND, MVT::v16i16, MVT::v16i8, 2 },
1030     { ISD::ZERO_EXTEND, MVT::v16i16, MVT::v16i8, 2 },
1031     { ISD::SIGN_EXTEND, MVT::v16i32, MVT::v16i8, 6 },
1032     { ISD::ZERO_EXTEND, MVT::v16i32, MVT::v16i8, 6 },
1033 
1034     // LowerVectorINT_TO_FP:
1035     { ISD::SINT_TO_FP, MVT::v2f32, MVT::v2i32, 1 },
1036     { ISD::SINT_TO_FP, MVT::v4f32, MVT::v4i32, 1 },
1037     { ISD::SINT_TO_FP, MVT::v2f64, MVT::v2i64, 1 },
1038     { ISD::UINT_TO_FP, MVT::v2f32, MVT::v2i32, 1 },
1039     { ISD::UINT_TO_FP, MVT::v4f32, MVT::v4i32, 1 },
1040     { ISD::UINT_TO_FP, MVT::v2f64, MVT::v2i64, 1 },
1041 
1042     // Complex: to v2f32
1043     { ISD::SINT_TO_FP, MVT::v2f32, MVT::v2i8,  3 },
1044     { ISD::SINT_TO_FP, MVT::v2f32, MVT::v2i16, 3 },
1045     { ISD::SINT_TO_FP, MVT::v2f32, MVT::v2i64, 2 },
1046     { ISD::UINT_TO_FP, MVT::v2f32, MVT::v2i8,  3 },
1047     { ISD::UINT_TO_FP, MVT::v2f32, MVT::v2i16, 3 },
1048     { ISD::UINT_TO_FP, MVT::v2f32, MVT::v2i64, 2 },
1049 
1050     // Complex: to v4f32
1051     { ISD::SINT_TO_FP, MVT::v4f32, MVT::v4i8,  4 },
1052     { ISD::SINT_TO_FP, MVT::v4f32, MVT::v4i16, 2 },
1053     { ISD::UINT_TO_FP, MVT::v4f32, MVT::v4i8,  3 },
1054     { ISD::UINT_TO_FP, MVT::v4f32, MVT::v4i16, 2 },
1055 
1056     // Complex: to v8f32
1057     { ISD::SINT_TO_FP, MVT::v8f32, MVT::v8i8,  10 },
1058     { ISD::SINT_TO_FP, MVT::v8f32, MVT::v8i16, 4 },
1059     { ISD::UINT_TO_FP, MVT::v8f32, MVT::v8i8,  10 },
1060     { ISD::UINT_TO_FP, MVT::v8f32, MVT::v8i16, 4 },
1061 
1062     // Complex: to v16f32
1063     { ISD::SINT_TO_FP, MVT::v16f32, MVT::v16i8, 21 },
1064     { ISD::UINT_TO_FP, MVT::v16f32, MVT::v16i8, 21 },
1065 
1066     // Complex: to v2f64
1067     { ISD::SINT_TO_FP, MVT::v2f64, MVT::v2i8,  4 },
1068     { ISD::SINT_TO_FP, MVT::v2f64, MVT::v2i16, 4 },
1069     { ISD::SINT_TO_FP, MVT::v2f64, MVT::v2i32, 2 },
1070     { ISD::UINT_TO_FP, MVT::v2f64, MVT::v2i8,  4 },
1071     { ISD::UINT_TO_FP, MVT::v2f64, MVT::v2i16, 4 },
1072     { ISD::UINT_TO_FP, MVT::v2f64, MVT::v2i32, 2 },
1073 
1074 
1075     // LowerVectorFP_TO_INT
1076     { ISD::FP_TO_SINT, MVT::v2i32, MVT::v2f32, 1 },
1077     { ISD::FP_TO_SINT, MVT::v4i32, MVT::v4f32, 1 },
1078     { ISD::FP_TO_SINT, MVT::v2i64, MVT::v2f64, 1 },
1079     { ISD::FP_TO_UINT, MVT::v2i32, MVT::v2f32, 1 },
1080     { ISD::FP_TO_UINT, MVT::v4i32, MVT::v4f32, 1 },
1081     { ISD::FP_TO_UINT, MVT::v2i64, MVT::v2f64, 1 },
1082 
1083     // Complex, from v2f32: legal type is v2i32 (no cost) or v2i64 (1 ext).
1084     { ISD::FP_TO_SINT, MVT::v2i64, MVT::v2f32, 2 },
1085     { ISD::FP_TO_SINT, MVT::v2i16, MVT::v2f32, 1 },
1086     { ISD::FP_TO_SINT, MVT::v2i8,  MVT::v2f32, 1 },
1087     { ISD::FP_TO_UINT, MVT::v2i64, MVT::v2f32, 2 },
1088     { ISD::FP_TO_UINT, MVT::v2i16, MVT::v2f32, 1 },
1089     { ISD::FP_TO_UINT, MVT::v2i8,  MVT::v2f32, 1 },
1090 
1091     // Complex, from v4f32: legal type is v4i16, 1 narrowing => ~2
1092     { ISD::FP_TO_SINT, MVT::v4i16, MVT::v4f32, 2 },
1093     { ISD::FP_TO_SINT, MVT::v4i8,  MVT::v4f32, 2 },
1094     { ISD::FP_TO_UINT, MVT::v4i16, MVT::v4f32, 2 },
1095     { ISD::FP_TO_UINT, MVT::v4i8,  MVT::v4f32, 2 },
1096 
1097     // Complex, from nxv2f32.
1098     { ISD::FP_TO_SINT, MVT::nxv2i64, MVT::nxv2f32, 1 },
1099     { ISD::FP_TO_SINT, MVT::nxv2i32, MVT::nxv2f32, 1 },
1100     { ISD::FP_TO_SINT, MVT::nxv2i16, MVT::nxv2f32, 1 },
1101     { ISD::FP_TO_SINT, MVT::nxv2i8,  MVT::nxv2f32, 1 },
1102     { ISD::FP_TO_UINT, MVT::nxv2i64, MVT::nxv2f32, 1 },
1103     { ISD::FP_TO_UINT, MVT::nxv2i32, MVT::nxv2f32, 1 },
1104     { ISD::FP_TO_UINT, MVT::nxv2i16, MVT::nxv2f32, 1 },
1105     { ISD::FP_TO_UINT, MVT::nxv2i8,  MVT::nxv2f32, 1 },
1106 
1107     // Complex, from v2f64: legal type is v2i32, 1 narrowing => ~2.
1108     { ISD::FP_TO_SINT, MVT::v2i32, MVT::v2f64, 2 },
1109     { ISD::FP_TO_SINT, MVT::v2i16, MVT::v2f64, 2 },
1110     { ISD::FP_TO_SINT, MVT::v2i8,  MVT::v2f64, 2 },
1111     { ISD::FP_TO_UINT, MVT::v2i32, MVT::v2f64, 2 },
1112     { ISD::FP_TO_UINT, MVT::v2i16, MVT::v2f64, 2 },
1113     { ISD::FP_TO_UINT, MVT::v2i8,  MVT::v2f64, 2 },
1114 
1115     // Complex, from nxv2f64.
1116     { ISD::FP_TO_SINT, MVT::nxv2i64, MVT::nxv2f64, 1 },
1117     { ISD::FP_TO_SINT, MVT::nxv2i32, MVT::nxv2f64, 1 },
1118     { ISD::FP_TO_SINT, MVT::nxv2i16, MVT::nxv2f64, 1 },
1119     { ISD::FP_TO_SINT, MVT::nxv2i8,  MVT::nxv2f64, 1 },
1120     { ISD::FP_TO_UINT, MVT::nxv2i64, MVT::nxv2f64, 1 },
1121     { ISD::FP_TO_UINT, MVT::nxv2i32, MVT::nxv2f64, 1 },
1122     { ISD::FP_TO_UINT, MVT::nxv2i16, MVT::nxv2f64, 1 },
1123     { ISD::FP_TO_UINT, MVT::nxv2i8,  MVT::nxv2f64, 1 },
1124 
1125     // Complex, from nxv4f32.
1126     { ISD::FP_TO_SINT, MVT::nxv4i64, MVT::nxv4f32, 4 },
1127     { ISD::FP_TO_SINT, MVT::nxv4i32, MVT::nxv4f32, 1 },
1128     { ISD::FP_TO_SINT, MVT::nxv4i16, MVT::nxv4f32, 1 },
1129     { ISD::FP_TO_SINT, MVT::nxv4i8,  MVT::nxv4f32, 1 },
1130     { ISD::FP_TO_UINT, MVT::nxv4i64, MVT::nxv4f32, 4 },
1131     { ISD::FP_TO_UINT, MVT::nxv4i32, MVT::nxv4f32, 1 },
1132     { ISD::FP_TO_UINT, MVT::nxv4i16, MVT::nxv4f32, 1 },
1133     { ISD::FP_TO_UINT, MVT::nxv4i8,  MVT::nxv4f32, 1 },
1134 
1135     // Complex, from nxv8f64. Illegal -> illegal conversions not required.
1136     { ISD::FP_TO_SINT, MVT::nxv8i16, MVT::nxv8f64, 7 },
1137     { ISD::FP_TO_SINT, MVT::nxv8i8,  MVT::nxv8f64, 7 },
1138     { ISD::FP_TO_UINT, MVT::nxv8i16, MVT::nxv8f64, 7 },
1139     { ISD::FP_TO_UINT, MVT::nxv8i8,  MVT::nxv8f64, 7 },
1140 
1141     // Complex, from nxv4f64. Illegal -> illegal conversions not required.
1142     { ISD::FP_TO_SINT, MVT::nxv4i32, MVT::nxv4f64, 3 },
1143     { ISD::FP_TO_SINT, MVT::nxv4i16, MVT::nxv4f64, 3 },
1144     { ISD::FP_TO_SINT, MVT::nxv4i8,  MVT::nxv4f64, 3 },
1145     { ISD::FP_TO_UINT, MVT::nxv4i32, MVT::nxv4f64, 3 },
1146     { ISD::FP_TO_UINT, MVT::nxv4i16, MVT::nxv4f64, 3 },
1147     { ISD::FP_TO_UINT, MVT::nxv4i8,  MVT::nxv4f64, 3 },
1148 
1149     // Complex, from nxv8f32. Illegal -> illegal conversions not required.
1150     { ISD::FP_TO_SINT, MVT::nxv8i16, MVT::nxv8f32, 3 },
1151     { ISD::FP_TO_SINT, MVT::nxv8i8,  MVT::nxv8f32, 3 },
1152     { ISD::FP_TO_UINT, MVT::nxv8i16, MVT::nxv8f32, 3 },
1153     { ISD::FP_TO_UINT, MVT::nxv8i8,  MVT::nxv8f32, 3 },
1154 
1155     // Complex, from nxv8f16.
1156     { ISD::FP_TO_SINT, MVT::nxv8i64, MVT::nxv8f16, 10 },
1157     { ISD::FP_TO_SINT, MVT::nxv8i32, MVT::nxv8f16, 4 },
1158     { ISD::FP_TO_SINT, MVT::nxv8i16, MVT::nxv8f16, 1 },
1159     { ISD::FP_TO_SINT, MVT::nxv8i8,  MVT::nxv8f16, 1 },
1160     { ISD::FP_TO_UINT, MVT::nxv8i64, MVT::nxv8f16, 10 },
1161     { ISD::FP_TO_UINT, MVT::nxv8i32, MVT::nxv8f16, 4 },
1162     { ISD::FP_TO_UINT, MVT::nxv8i16, MVT::nxv8f16, 1 },
1163     { ISD::FP_TO_UINT, MVT::nxv8i8,  MVT::nxv8f16, 1 },
1164 
1165     // Complex, from nxv4f16.
1166     { ISD::FP_TO_SINT, MVT::nxv4i64, MVT::nxv4f16, 4 },
1167     { ISD::FP_TO_SINT, MVT::nxv4i32, MVT::nxv4f16, 1 },
1168     { ISD::FP_TO_SINT, MVT::nxv4i16, MVT::nxv4f16, 1 },
1169     { ISD::FP_TO_SINT, MVT::nxv4i8,  MVT::nxv4f16, 1 },
1170     { ISD::FP_TO_UINT, MVT::nxv4i64, MVT::nxv4f16, 4 },
1171     { ISD::FP_TO_UINT, MVT::nxv4i32, MVT::nxv4f16, 1 },
1172     { ISD::FP_TO_UINT, MVT::nxv4i16, MVT::nxv4f16, 1 },
1173     { ISD::FP_TO_UINT, MVT::nxv4i8,  MVT::nxv4f16, 1 },
1174 
1175     // Complex, from nxv2f16.
1176     { ISD::FP_TO_SINT, MVT::nxv2i64, MVT::nxv2f16, 1 },
1177     { ISD::FP_TO_SINT, MVT::nxv2i32, MVT::nxv2f16, 1 },
1178     { ISD::FP_TO_SINT, MVT::nxv2i16, MVT::nxv2f16, 1 },
1179     { ISD::FP_TO_SINT, MVT::nxv2i8,  MVT::nxv2f16, 1 },
1180     { ISD::FP_TO_UINT, MVT::nxv2i64, MVT::nxv2f16, 1 },
1181     { ISD::FP_TO_UINT, MVT::nxv2i32, MVT::nxv2f16, 1 },
1182     { ISD::FP_TO_UINT, MVT::nxv2i16, MVT::nxv2f16, 1 },
1183     { ISD::FP_TO_UINT, MVT::nxv2i8,  MVT::nxv2f16, 1 },
1184 
1185     // Truncate from nxvmf32 to nxvmf16.
1186     { ISD::FP_ROUND, MVT::nxv2f16, MVT::nxv2f32, 1 },
1187     { ISD::FP_ROUND, MVT::nxv4f16, MVT::nxv4f32, 1 },
1188     { ISD::FP_ROUND, MVT::nxv8f16, MVT::nxv8f32, 3 },
1189 
1190     // Truncate from nxvmf64 to nxvmf16.
1191     { ISD::FP_ROUND, MVT::nxv2f16, MVT::nxv2f64, 1 },
1192     { ISD::FP_ROUND, MVT::nxv4f16, MVT::nxv4f64, 3 },
1193     { ISD::FP_ROUND, MVT::nxv8f16, MVT::nxv8f64, 7 },
1194 
1195     // Truncate from nxvmf64 to nxvmf32.
1196     { ISD::FP_ROUND, MVT::nxv2f32, MVT::nxv2f64, 1 },
1197     { ISD::FP_ROUND, MVT::nxv4f32, MVT::nxv4f64, 3 },
1198     { ISD::FP_ROUND, MVT::nxv8f32, MVT::nxv8f64, 6 },
1199 
1200     // Extend from nxvmf16 to nxvmf32.
1201     { ISD::FP_EXTEND, MVT::nxv2f32, MVT::nxv2f16, 1},
1202     { ISD::FP_EXTEND, MVT::nxv4f32, MVT::nxv4f16, 1},
1203     { ISD::FP_EXTEND, MVT::nxv8f32, MVT::nxv8f16, 2},
1204 
1205     // Extend from nxvmf16 to nxvmf64.
1206     { ISD::FP_EXTEND, MVT::nxv2f64, MVT::nxv2f16, 1},
1207     { ISD::FP_EXTEND, MVT::nxv4f64, MVT::nxv4f16, 2},
1208     { ISD::FP_EXTEND, MVT::nxv8f64, MVT::nxv8f16, 4},
1209 
1210     // Extend from nxvmf32 to nxvmf64.
1211     { ISD::FP_EXTEND, MVT::nxv2f64, MVT::nxv2f32, 1},
1212     { ISD::FP_EXTEND, MVT::nxv4f64, MVT::nxv4f32, 2},
1213     { ISD::FP_EXTEND, MVT::nxv8f64, MVT::nxv8f32, 6},
1214 
1215   };
1216 
1217   if (const auto *Entry = ConvertCostTableLookup(ConversionTbl, ISD,
1218                                                  DstTy.getSimpleVT(),
1219                                                  SrcTy.getSimpleVT()))
1220     return AdjustCost(Entry->Cost);
1221 
1222   return AdjustCost(
1223       BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I));
1224 }
1225 
1226 InstructionCost AArch64TTIImpl::getExtractWithExtendCost(unsigned Opcode,
1227                                                          Type *Dst,
1228                                                          VectorType *VecTy,
1229                                                          unsigned Index) {
1230 
1231   // Make sure we were given a valid extend opcode.
1232   assert((Opcode == Instruction::SExt || Opcode == Instruction::ZExt) &&
1233          "Invalid opcode");
1234 
1235   // We are extending an element we extract from a vector, so the source type
1236   // of the extend is the element type of the vector.
1237   auto *Src = VecTy->getElementType();
1238 
1239   // Sign- and zero-extends are for integer types only.
1240   assert(isa<IntegerType>(Dst) && isa<IntegerType>(Src) && "Invalid type");
1241 
1242   // Get the cost for the extract. We compute the cost (if any) for the extend
1243   // below.
1244   InstructionCost Cost =
1245       getVectorInstrCost(Instruction::ExtractElement, VecTy, Index);
1246 
1247   // Legalize the types.
1248   auto VecLT = TLI->getTypeLegalizationCost(DL, VecTy);
1249   auto DstVT = TLI->getValueType(DL, Dst);
1250   auto SrcVT = TLI->getValueType(DL, Src);
1251   TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
1252 
1253   // If the resulting type is still a vector and the destination type is legal,
1254   // we may get the extension for free. If not, get the default cost for the
1255   // extend.
1256   if (!VecLT.second.isVector() || !TLI->isTypeLegal(DstVT))
1257     return Cost + getCastInstrCost(Opcode, Dst, Src, TTI::CastContextHint::None,
1258                                    CostKind);
1259 
1260   // The destination type should be larger than the element type. If not, get
1261   // the default cost for the extend.
1262   if (DstVT.getFixedSizeInBits() < SrcVT.getFixedSizeInBits())
1263     return Cost + getCastInstrCost(Opcode, Dst, Src, TTI::CastContextHint::None,
1264                                    CostKind);
1265 
1266   switch (Opcode) {
1267   default:
1268     llvm_unreachable("Opcode should be either SExt or ZExt");
1269 
1270   // For sign-extends, we only need a smov, which performs the extension
1271   // automatically.
1272   case Instruction::SExt:
1273     return Cost;
1274 
1275   // For zero-extends, the extend is performed automatically by a umov unless
1276   // the destination type is i64 and the element type is i8 or i16.
1277   case Instruction::ZExt:
1278     if (DstVT.getSizeInBits() != 64u || SrcVT.getSizeInBits() == 32u)
1279       return Cost;
1280   }
1281 
1282   // If we are unable to perform the extend for free, get the default cost.
1283   return Cost + getCastInstrCost(Opcode, Dst, Src, TTI::CastContextHint::None,
1284                                  CostKind);
1285 }
1286 
1287 InstructionCost AArch64TTIImpl::getCFInstrCost(unsigned Opcode,
1288                                                TTI::TargetCostKind CostKind,
1289                                                const Instruction *I) {
1290   if (CostKind != TTI::TCK_RecipThroughput)
1291     return Opcode == Instruction::PHI ? 0 : 1;
1292   assert(CostKind == TTI::TCK_RecipThroughput && "unexpected CostKind");
1293   // Branches are assumed to be predicted.
1294   return 0;
1295 }
1296 
1297 InstructionCost AArch64TTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
1298                                                    unsigned Index) {
1299   assert(Val->isVectorTy() && "This must be a vector type");
1300 
1301   if (Index != -1U) {
1302     // Legalize the type.
1303     std::pair<InstructionCost, MVT> LT = TLI->getTypeLegalizationCost(DL, Val);
1304 
1305     // This type is legalized to a scalar type.
1306     if (!LT.second.isVector())
1307       return 0;
1308 
1309     // The type may be split. Normalize the index to the new type.
1310     unsigned Width = LT.second.getVectorNumElements();
1311     Index = Index % Width;
1312 
1313     // The element at index zero is already inside the vector.
1314     if (Index == 0)
1315       return 0;
1316   }
1317 
1318   // All other insert/extracts cost this much.
1319   return ST->getVectorInsertExtractBaseCost();
1320 }
1321 
1322 InstructionCost AArch64TTIImpl::getArithmeticInstrCost(
1323     unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
1324     TTI::OperandValueKind Opd1Info, TTI::OperandValueKind Opd2Info,
1325     TTI::OperandValueProperties Opd1PropInfo,
1326     TTI::OperandValueProperties Opd2PropInfo, ArrayRef<const Value *> Args,
1327     const Instruction *CxtI) {
1328   // TODO: Handle more cost kinds.
1329   if (CostKind != TTI::TCK_RecipThroughput)
1330     return BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Opd1Info,
1331                                          Opd2Info, Opd1PropInfo,
1332                                          Opd2PropInfo, Args, CxtI);
1333 
1334   // Legalize the type.
1335   std::pair<InstructionCost, MVT> LT = TLI->getTypeLegalizationCost(DL, Ty);
1336 
1337   // If the instruction is a widening instruction (e.g., uaddl, saddw, etc.),
1338   // add in the widening overhead specified by the sub-target. Since the
1339   // extends feeding widening instructions are performed automatically, they
1340   // aren't present in the generated code and have a zero cost. By adding a
1341   // widening overhead here, we attach the total cost of the combined operation
1342   // to the widening instruction.
1343   InstructionCost Cost = 0;
1344   if (isWideningInstruction(Ty, Opcode, Args))
1345     Cost += ST->getWideningBaseCost();
1346 
1347   int ISD = TLI->InstructionOpcodeToISD(Opcode);
1348 
1349   switch (ISD) {
1350   default:
1351     return Cost + BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Opd1Info,
1352                                                 Opd2Info,
1353                                                 Opd1PropInfo, Opd2PropInfo);
1354   case ISD::SDIV:
1355     if (Opd2Info == TargetTransformInfo::OK_UniformConstantValue &&
1356         Opd2PropInfo == TargetTransformInfo::OP_PowerOf2) {
1357       // On AArch64, scalar signed division by constants power-of-two are
1358       // normally expanded to the sequence ADD + CMP + SELECT + SRA.
1359       // The OperandValue properties many not be same as that of previous
1360       // operation; conservatively assume OP_None.
1361       Cost += getArithmeticInstrCost(Instruction::Add, Ty, CostKind,
1362                                      Opd1Info, Opd2Info,
1363                                      TargetTransformInfo::OP_None,
1364                                      TargetTransformInfo::OP_None);
1365       Cost += getArithmeticInstrCost(Instruction::Sub, Ty, CostKind,
1366                                      Opd1Info, Opd2Info,
1367                                      TargetTransformInfo::OP_None,
1368                                      TargetTransformInfo::OP_None);
1369       Cost += getArithmeticInstrCost(Instruction::Select, Ty, CostKind,
1370                                      Opd1Info, Opd2Info,
1371                                      TargetTransformInfo::OP_None,
1372                                      TargetTransformInfo::OP_None);
1373       Cost += getArithmeticInstrCost(Instruction::AShr, Ty, CostKind,
1374                                      Opd1Info, Opd2Info,
1375                                      TargetTransformInfo::OP_None,
1376                                      TargetTransformInfo::OP_None);
1377       return Cost;
1378     }
1379     LLVM_FALLTHROUGH;
1380   case ISD::UDIV:
1381     if (Opd2Info == TargetTransformInfo::OK_UniformConstantValue) {
1382       auto VT = TLI->getValueType(DL, Ty);
1383       if (TLI->isOperationLegalOrCustom(ISD::MULHU, VT)) {
1384         // Vector signed division by constant are expanded to the
1385         // sequence MULHS + ADD/SUB + SRA + SRL + ADD, and unsigned division
1386         // to MULHS + SUB + SRL + ADD + SRL.
1387         InstructionCost MulCost = getArithmeticInstrCost(
1388             Instruction::Mul, Ty, CostKind, Opd1Info, Opd2Info,
1389             TargetTransformInfo::OP_None, TargetTransformInfo::OP_None);
1390         InstructionCost AddCost = getArithmeticInstrCost(
1391             Instruction::Add, Ty, CostKind, Opd1Info, Opd2Info,
1392             TargetTransformInfo::OP_None, TargetTransformInfo::OP_None);
1393         InstructionCost ShrCost = getArithmeticInstrCost(
1394             Instruction::AShr, Ty, CostKind, Opd1Info, Opd2Info,
1395             TargetTransformInfo::OP_None, TargetTransformInfo::OP_None);
1396         return MulCost * 2 + AddCost * 2 + ShrCost * 2 + 1;
1397       }
1398     }
1399 
1400     Cost += BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Opd1Info,
1401                                           Opd2Info,
1402                                           Opd1PropInfo, Opd2PropInfo);
1403     if (Ty->isVectorTy()) {
1404       // On AArch64, vector divisions are not supported natively and are
1405       // expanded into scalar divisions of each pair of elements.
1406       Cost += getArithmeticInstrCost(Instruction::ExtractElement, Ty, CostKind,
1407                                      Opd1Info, Opd2Info, Opd1PropInfo,
1408                                      Opd2PropInfo);
1409       Cost += getArithmeticInstrCost(Instruction::InsertElement, Ty, CostKind,
1410                                      Opd1Info, Opd2Info, Opd1PropInfo,
1411                                      Opd2PropInfo);
1412       // TODO: if one of the arguments is scalar, then it's not necessary to
1413       // double the cost of handling the vector elements.
1414       Cost += Cost;
1415     }
1416     return Cost;
1417 
1418   case ISD::MUL:
1419     if (LT.second != MVT::v2i64)
1420       return (Cost + 1) * LT.first;
1421     // Since we do not have a MUL.2d instruction, a mul <2 x i64> is expensive
1422     // as elements are extracted from the vectors and the muls scalarized.
1423     // As getScalarizationOverhead is a bit too pessimistic, we estimate the
1424     // cost for a i64 vector directly here, which is:
1425     // - four i64 extracts,
1426     // - two i64 inserts, and
1427     // - two muls.
1428     // So, for a v2i64 with LT.First = 1 the cost is 8, and for a v4i64 with
1429     // LT.first = 2 the cost is 16.
1430     return LT.first * 8;
1431   case ISD::ADD:
1432   case ISD::XOR:
1433   case ISD::OR:
1434   case ISD::AND:
1435     // These nodes are marked as 'custom' for combining purposes only.
1436     // We know that they are legal. See LowerAdd in ISelLowering.
1437     return (Cost + 1) * LT.first;
1438 
1439   case ISD::FADD:
1440     // These nodes are marked as 'custom' just to lower them to SVE.
1441     // We know said lowering will incur no additional cost.
1442     if (isa<FixedVectorType>(Ty) && !Ty->getScalarType()->isFP128Ty())
1443       return (Cost + 2) * LT.first;
1444 
1445     return Cost + BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Opd1Info,
1446                                                 Opd2Info,
1447                                                 Opd1PropInfo, Opd2PropInfo);
1448   }
1449 }
1450 
1451 InstructionCost AArch64TTIImpl::getAddressComputationCost(Type *Ty,
1452                                                           ScalarEvolution *SE,
1453                                                           const SCEV *Ptr) {
1454   // Address computations in vectorized code with non-consecutive addresses will
1455   // likely result in more instructions compared to scalar code where the
1456   // computation can more often be merged into the index mode. The resulting
1457   // extra micro-ops can significantly decrease throughput.
1458   unsigned NumVectorInstToHideOverhead = 10;
1459   int MaxMergeDistance = 64;
1460 
1461   if (Ty->isVectorTy() && SE &&
1462       !BaseT::isConstantStridedAccessLessThan(SE, Ptr, MaxMergeDistance + 1))
1463     return NumVectorInstToHideOverhead;
1464 
1465   // In many cases the address computation is not merged into the instruction
1466   // addressing mode.
1467   return 1;
1468 }
1469 
1470 InstructionCost AArch64TTIImpl::getCmpSelInstrCost(unsigned Opcode, Type *ValTy,
1471                                                    Type *CondTy,
1472                                                    CmpInst::Predicate VecPred,
1473                                                    TTI::TargetCostKind CostKind,
1474                                                    const Instruction *I) {
1475   // TODO: Handle other cost kinds.
1476   if (CostKind != TTI::TCK_RecipThroughput)
1477     return BaseT::getCmpSelInstrCost(Opcode, ValTy, CondTy, VecPred, CostKind,
1478                                      I);
1479 
1480   int ISD = TLI->InstructionOpcodeToISD(Opcode);
1481   // We don't lower some vector selects well that are wider than the register
1482   // width.
1483   if (isa<FixedVectorType>(ValTy) && ISD == ISD::SELECT) {
1484     // We would need this many instructions to hide the scalarization happening.
1485     const int AmortizationCost = 20;
1486 
1487     // If VecPred is not set, check if we can get a predicate from the context
1488     // instruction, if its type matches the requested ValTy.
1489     if (VecPred == CmpInst::BAD_ICMP_PREDICATE && I && I->getType() == ValTy) {
1490       CmpInst::Predicate CurrentPred;
1491       if (match(I, m_Select(m_Cmp(CurrentPred, m_Value(), m_Value()), m_Value(),
1492                             m_Value())))
1493         VecPred = CurrentPred;
1494     }
1495     // Check if we have a compare/select chain that can be lowered using CMxx &
1496     // BFI pair.
1497     if (CmpInst::isIntPredicate(VecPred)) {
1498       static const auto ValidMinMaxTys = {MVT::v8i8,  MVT::v16i8, MVT::v4i16,
1499                                           MVT::v8i16, MVT::v2i32, MVT::v4i32,
1500                                           MVT::v2i64};
1501       auto LT = TLI->getTypeLegalizationCost(DL, ValTy);
1502       if (any_of(ValidMinMaxTys, [&LT](MVT M) { return M == LT.second; }))
1503         return LT.first;
1504     }
1505 
1506     static const TypeConversionCostTblEntry
1507     VectorSelectTbl[] = {
1508       { ISD::SELECT, MVT::v16i1, MVT::v16i16, 16 },
1509       { ISD::SELECT, MVT::v8i1, MVT::v8i32, 8 },
1510       { ISD::SELECT, MVT::v16i1, MVT::v16i32, 16 },
1511       { ISD::SELECT, MVT::v4i1, MVT::v4i64, 4 * AmortizationCost },
1512       { ISD::SELECT, MVT::v8i1, MVT::v8i64, 8 * AmortizationCost },
1513       { ISD::SELECT, MVT::v16i1, MVT::v16i64, 16 * AmortizationCost }
1514     };
1515 
1516     EVT SelCondTy = TLI->getValueType(DL, CondTy);
1517     EVT SelValTy = TLI->getValueType(DL, ValTy);
1518     if (SelCondTy.isSimple() && SelValTy.isSimple()) {
1519       if (const auto *Entry = ConvertCostTableLookup(VectorSelectTbl, ISD,
1520                                                      SelCondTy.getSimpleVT(),
1521                                                      SelValTy.getSimpleVT()))
1522         return Entry->Cost;
1523     }
1524   }
1525   // The base case handles scalable vectors fine for now, since it treats the
1526   // cost as 1 * legalization cost.
1527   return BaseT::getCmpSelInstrCost(Opcode, ValTy, CondTy, VecPred, CostKind, I);
1528 }
1529 
1530 AArch64TTIImpl::TTI::MemCmpExpansionOptions
1531 AArch64TTIImpl::enableMemCmpExpansion(bool OptSize, bool IsZeroCmp) const {
1532   TTI::MemCmpExpansionOptions Options;
1533   if (ST->requiresStrictAlign()) {
1534     // TODO: Add cost modeling for strict align. Misaligned loads expand to
1535     // a bunch of instructions when strict align is enabled.
1536     return Options;
1537   }
1538   Options.AllowOverlappingLoads = true;
1539   Options.MaxNumLoads = TLI->getMaxExpandSizeMemcmp(OptSize);
1540   Options.NumLoadsPerBlock = Options.MaxNumLoads;
1541   // TODO: Though vector loads usually perform well on AArch64, in some targets
1542   // they may wake up the FP unit, which raises the power consumption.  Perhaps
1543   // they could be used with no holds barred (-O3).
1544   Options.LoadSizes = {8, 4, 2, 1};
1545   return Options;
1546 }
1547 
1548 InstructionCost
1549 AArch64TTIImpl::getMaskedMemoryOpCost(unsigned Opcode, Type *Src,
1550                                       Align Alignment, unsigned AddressSpace,
1551                                       TTI::TargetCostKind CostKind) {
1552   if (useNeonVector(Src))
1553     return BaseT::getMaskedMemoryOpCost(Opcode, Src, Alignment, AddressSpace,
1554                                         CostKind);
1555   auto LT = TLI->getTypeLegalizationCost(DL, Src);
1556   if (!LT.first.isValid())
1557     return InstructionCost::getInvalid();
1558 
1559   // The code-generator is currently not able to handle scalable vectors
1560   // of <vscale x 1 x eltty> yet, so return an invalid cost to avoid selecting
1561   // it. This change will be removed when code-generation for these types is
1562   // sufficiently reliable.
1563   if (cast<VectorType>(Src)->getElementCount() == ElementCount::getScalable(1))
1564     return InstructionCost::getInvalid();
1565 
1566   return LT.first * 2;
1567 }
1568 
1569 InstructionCost AArch64TTIImpl::getGatherScatterOpCost(
1570     unsigned Opcode, Type *DataTy, const Value *Ptr, bool VariableMask,
1571     Align Alignment, TTI::TargetCostKind CostKind, const Instruction *I) {
1572   if (useNeonVector(DataTy))
1573     return BaseT::getGatherScatterOpCost(Opcode, DataTy, Ptr, VariableMask,
1574                                          Alignment, CostKind, I);
1575   auto *VT = cast<VectorType>(DataTy);
1576   auto LT = TLI->getTypeLegalizationCost(DL, DataTy);
1577   if (!LT.first.isValid())
1578     return InstructionCost::getInvalid();
1579 
1580   // The code-generator is currently not able to handle scalable vectors
1581   // of <vscale x 1 x eltty> yet, so return an invalid cost to avoid selecting
1582   // it. This change will be removed when code-generation for these types is
1583   // sufficiently reliable.
1584   if (cast<VectorType>(DataTy)->getElementCount() ==
1585       ElementCount::getScalable(1))
1586     return InstructionCost::getInvalid();
1587 
1588   ElementCount LegalVF = LT.second.getVectorElementCount();
1589   InstructionCost MemOpCost =
1590       getMemoryOpCost(Opcode, VT->getElementType(), Alignment, 0, CostKind, I);
1591   return LT.first * MemOpCost * getMaxNumElements(LegalVF, I->getFunction());
1592 }
1593 
1594 bool AArch64TTIImpl::useNeonVector(const Type *Ty) const {
1595   return isa<FixedVectorType>(Ty) && !ST->useSVEForFixedLengthVectors();
1596 }
1597 
1598 InstructionCost AArch64TTIImpl::getMemoryOpCost(unsigned Opcode, Type *Ty,
1599                                                 MaybeAlign Alignment,
1600                                                 unsigned AddressSpace,
1601                                                 TTI::TargetCostKind CostKind,
1602                                                 const Instruction *I) {
1603   EVT VT = TLI->getValueType(DL, Ty, true);
1604   // Type legalization can't handle structs
1605   if (VT == MVT::Other)
1606     return BaseT::getMemoryOpCost(Opcode, Ty, Alignment, AddressSpace,
1607                                   CostKind);
1608 
1609   auto LT = TLI->getTypeLegalizationCost(DL, Ty);
1610   if (!LT.first.isValid())
1611     return InstructionCost::getInvalid();
1612 
1613   // The code-generator is currently not able to handle scalable vectors
1614   // of <vscale x 1 x eltty> yet, so return an invalid cost to avoid selecting
1615   // it. This change will be removed when code-generation for these types is
1616   // sufficiently reliable.
1617   if (auto *VTy = dyn_cast<ScalableVectorType>(Ty))
1618     if (VTy->getElementCount() == ElementCount::getScalable(1))
1619       return InstructionCost::getInvalid();
1620 
1621   // TODO: consider latency as well for TCK_SizeAndLatency.
1622   if (CostKind == TTI::TCK_CodeSize || CostKind == TTI::TCK_SizeAndLatency)
1623     return LT.first;
1624 
1625   if (CostKind != TTI::TCK_RecipThroughput)
1626     return 1;
1627 
1628   if (ST->isMisaligned128StoreSlow() && Opcode == Instruction::Store &&
1629       LT.second.is128BitVector() && (!Alignment || *Alignment < Align(16))) {
1630     // Unaligned stores are extremely inefficient. We don't split all
1631     // unaligned 128-bit stores because the negative impact that has shown in
1632     // practice on inlined block copy code.
1633     // We make such stores expensive so that we will only vectorize if there
1634     // are 6 other instructions getting vectorized.
1635     const int AmortizationCost = 6;
1636 
1637     return LT.first * 2 * AmortizationCost;
1638   }
1639 
1640   // Check truncating stores and extending loads.
1641   if (useNeonVector(Ty) &&
1642       Ty->getScalarSizeInBits() != LT.second.getScalarSizeInBits()) {
1643     // v4i8 types are lowered to scalar a load/store and sshll/xtn.
1644     if (VT == MVT::v4i8)
1645       return 2;
1646     // Otherwise we need to scalarize.
1647     return cast<FixedVectorType>(Ty)->getNumElements() * 2;
1648   }
1649 
1650   return LT.first;
1651 }
1652 
1653 InstructionCost AArch64TTIImpl::getInterleavedMemoryOpCost(
1654     unsigned Opcode, Type *VecTy, unsigned Factor, ArrayRef<unsigned> Indices,
1655     Align Alignment, unsigned AddressSpace, TTI::TargetCostKind CostKind,
1656     bool UseMaskForCond, bool UseMaskForGaps) {
1657   assert(Factor >= 2 && "Invalid interleave factor");
1658   auto *VecVTy = cast<FixedVectorType>(VecTy);
1659 
1660   if (!UseMaskForCond && !UseMaskForGaps &&
1661       Factor <= TLI->getMaxSupportedInterleaveFactor()) {
1662     unsigned NumElts = VecVTy->getNumElements();
1663     auto *SubVecTy =
1664         FixedVectorType::get(VecTy->getScalarType(), NumElts / Factor);
1665 
1666     // ldN/stN only support legal vector types of size 64 or 128 in bits.
1667     // Accesses having vector types that are a multiple of 128 bits can be
1668     // matched to more than one ldN/stN instruction.
1669     if (NumElts % Factor == 0 &&
1670         TLI->isLegalInterleavedAccessType(SubVecTy, DL))
1671       return Factor * TLI->getNumInterleavedAccesses(SubVecTy, DL);
1672   }
1673 
1674   return BaseT::getInterleavedMemoryOpCost(Opcode, VecTy, Factor, Indices,
1675                                            Alignment, AddressSpace, CostKind,
1676                                            UseMaskForCond, UseMaskForGaps);
1677 }
1678 
1679 InstructionCost
1680 AArch64TTIImpl::getCostOfKeepingLiveOverCall(ArrayRef<Type *> Tys) {
1681   InstructionCost Cost = 0;
1682   TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
1683   for (auto *I : Tys) {
1684     if (!I->isVectorTy())
1685       continue;
1686     if (I->getScalarSizeInBits() * cast<FixedVectorType>(I)->getNumElements() ==
1687         128)
1688       Cost += getMemoryOpCost(Instruction::Store, I, Align(128), 0, CostKind) +
1689               getMemoryOpCost(Instruction::Load, I, Align(128), 0, CostKind);
1690   }
1691   return Cost;
1692 }
1693 
1694 unsigned AArch64TTIImpl::getMaxInterleaveFactor(unsigned VF) {
1695   return ST->getMaxInterleaveFactor();
1696 }
1697 
1698 // For Falkor, we want to avoid having too many strided loads in a loop since
1699 // that can exhaust the HW prefetcher resources.  We adjust the unroller
1700 // MaxCount preference below to attempt to ensure unrolling doesn't create too
1701 // many strided loads.
1702 static void
1703 getFalkorUnrollingPreferences(Loop *L, ScalarEvolution &SE,
1704                               TargetTransformInfo::UnrollingPreferences &UP) {
1705   enum { MaxStridedLoads = 7 };
1706   auto countStridedLoads = [](Loop *L, ScalarEvolution &SE) {
1707     int StridedLoads = 0;
1708     // FIXME? We could make this more precise by looking at the CFG and
1709     // e.g. not counting loads in each side of an if-then-else diamond.
1710     for (const auto BB : L->blocks()) {
1711       for (auto &I : *BB) {
1712         LoadInst *LMemI = dyn_cast<LoadInst>(&I);
1713         if (!LMemI)
1714           continue;
1715 
1716         Value *PtrValue = LMemI->getPointerOperand();
1717         if (L->isLoopInvariant(PtrValue))
1718           continue;
1719 
1720         const SCEV *LSCEV = SE.getSCEV(PtrValue);
1721         const SCEVAddRecExpr *LSCEVAddRec = dyn_cast<SCEVAddRecExpr>(LSCEV);
1722         if (!LSCEVAddRec || !LSCEVAddRec->isAffine())
1723           continue;
1724 
1725         // FIXME? We could take pairing of unrolled load copies into account
1726         // by looking at the AddRec, but we would probably have to limit this
1727         // to loops with no stores or other memory optimization barriers.
1728         ++StridedLoads;
1729         // We've seen enough strided loads that seeing more won't make a
1730         // difference.
1731         if (StridedLoads > MaxStridedLoads / 2)
1732           return StridedLoads;
1733       }
1734     }
1735     return StridedLoads;
1736   };
1737 
1738   int StridedLoads = countStridedLoads(L, SE);
1739   LLVM_DEBUG(dbgs() << "falkor-hwpf: detected " << StridedLoads
1740                     << " strided loads\n");
1741   // Pick the largest power of 2 unroll count that won't result in too many
1742   // strided loads.
1743   if (StridedLoads) {
1744     UP.MaxCount = 1 << Log2_32(MaxStridedLoads / StridedLoads);
1745     LLVM_DEBUG(dbgs() << "falkor-hwpf: setting unroll MaxCount to "
1746                       << UP.MaxCount << '\n');
1747   }
1748 }
1749 
1750 void AArch64TTIImpl::getUnrollingPreferences(Loop *L, ScalarEvolution &SE,
1751                                              TTI::UnrollingPreferences &UP,
1752                                              OptimizationRemarkEmitter *ORE) {
1753   // Enable partial unrolling and runtime unrolling.
1754   BaseT::getUnrollingPreferences(L, SE, UP, ORE);
1755 
1756   UP.UpperBound = true;
1757 
1758   // For inner loop, it is more likely to be a hot one, and the runtime check
1759   // can be promoted out from LICM pass, so the overhead is less, let's try
1760   // a larger threshold to unroll more loops.
1761   if (L->getLoopDepth() > 1)
1762     UP.PartialThreshold *= 2;
1763 
1764   // Disable partial & runtime unrolling on -Os.
1765   UP.PartialOptSizeThreshold = 0;
1766 
1767   if (ST->getProcFamily() == AArch64Subtarget::Falkor &&
1768       EnableFalkorHWPFUnrollFix)
1769     getFalkorUnrollingPreferences(L, SE, UP);
1770 
1771   // Scan the loop: don't unroll loops with calls as this could prevent
1772   // inlining. Don't unroll vector loops either, as they don't benefit much from
1773   // unrolling.
1774   for (auto *BB : L->getBlocks()) {
1775     for (auto &I : *BB) {
1776       // Don't unroll vectorised loop.
1777       if (I.getType()->isVectorTy())
1778         return;
1779 
1780       if (isa<CallInst>(I) || isa<InvokeInst>(I)) {
1781         if (const Function *F = cast<CallBase>(I).getCalledFunction()) {
1782           if (!isLoweredToCall(F))
1783             continue;
1784         }
1785         return;
1786       }
1787     }
1788   }
1789 
1790   // Enable runtime unrolling for in-order models
1791   // If mcpu is omitted, getProcFamily() returns AArch64Subtarget::Others, so by
1792   // checking for that case, we can ensure that the default behaviour is
1793   // unchanged
1794   if (ST->getProcFamily() != AArch64Subtarget::Others &&
1795       !ST->getSchedModel().isOutOfOrder()) {
1796     UP.Runtime = true;
1797     UP.Partial = true;
1798     UP.UnrollRemainder = true;
1799     UP.DefaultUnrollRuntimeCount = 4;
1800 
1801     UP.UnrollAndJam = true;
1802     UP.UnrollAndJamInnerLoopThreshold = 60;
1803   }
1804 }
1805 
1806 void AArch64TTIImpl::getPeelingPreferences(Loop *L, ScalarEvolution &SE,
1807                                            TTI::PeelingPreferences &PP) {
1808   BaseT::getPeelingPreferences(L, SE, PP);
1809 }
1810 
1811 Value *AArch64TTIImpl::getOrCreateResultFromMemIntrinsic(IntrinsicInst *Inst,
1812                                                          Type *ExpectedType) {
1813   switch (Inst->getIntrinsicID()) {
1814   default:
1815     return nullptr;
1816   case Intrinsic::aarch64_neon_st2:
1817   case Intrinsic::aarch64_neon_st3:
1818   case Intrinsic::aarch64_neon_st4: {
1819     // Create a struct type
1820     StructType *ST = dyn_cast<StructType>(ExpectedType);
1821     if (!ST)
1822       return nullptr;
1823     unsigned NumElts = Inst->getNumArgOperands() - 1;
1824     if (ST->getNumElements() != NumElts)
1825       return nullptr;
1826     for (unsigned i = 0, e = NumElts; i != e; ++i) {
1827       if (Inst->getArgOperand(i)->getType() != ST->getElementType(i))
1828         return nullptr;
1829     }
1830     Value *Res = UndefValue::get(ExpectedType);
1831     IRBuilder<> Builder(Inst);
1832     for (unsigned i = 0, e = NumElts; i != e; ++i) {
1833       Value *L = Inst->getArgOperand(i);
1834       Res = Builder.CreateInsertValue(Res, L, i);
1835     }
1836     return Res;
1837   }
1838   case Intrinsic::aarch64_neon_ld2:
1839   case Intrinsic::aarch64_neon_ld3:
1840   case Intrinsic::aarch64_neon_ld4:
1841     if (Inst->getType() == ExpectedType)
1842       return Inst;
1843     return nullptr;
1844   }
1845 }
1846 
1847 bool AArch64TTIImpl::getTgtMemIntrinsic(IntrinsicInst *Inst,
1848                                         MemIntrinsicInfo &Info) {
1849   switch (Inst->getIntrinsicID()) {
1850   default:
1851     break;
1852   case Intrinsic::aarch64_neon_ld2:
1853   case Intrinsic::aarch64_neon_ld3:
1854   case Intrinsic::aarch64_neon_ld4:
1855     Info.ReadMem = true;
1856     Info.WriteMem = false;
1857     Info.PtrVal = Inst->getArgOperand(0);
1858     break;
1859   case Intrinsic::aarch64_neon_st2:
1860   case Intrinsic::aarch64_neon_st3:
1861   case Intrinsic::aarch64_neon_st4:
1862     Info.ReadMem = false;
1863     Info.WriteMem = true;
1864     Info.PtrVal = Inst->getArgOperand(Inst->getNumArgOperands() - 1);
1865     break;
1866   }
1867 
1868   switch (Inst->getIntrinsicID()) {
1869   default:
1870     return false;
1871   case Intrinsic::aarch64_neon_ld2:
1872   case Intrinsic::aarch64_neon_st2:
1873     Info.MatchingId = VECTOR_LDST_TWO_ELEMENTS;
1874     break;
1875   case Intrinsic::aarch64_neon_ld3:
1876   case Intrinsic::aarch64_neon_st3:
1877     Info.MatchingId = VECTOR_LDST_THREE_ELEMENTS;
1878     break;
1879   case Intrinsic::aarch64_neon_ld4:
1880   case Intrinsic::aarch64_neon_st4:
1881     Info.MatchingId = VECTOR_LDST_FOUR_ELEMENTS;
1882     break;
1883   }
1884   return true;
1885 }
1886 
1887 /// See if \p I should be considered for address type promotion. We check if \p
1888 /// I is a sext with right type and used in memory accesses. If it used in a
1889 /// "complex" getelementptr, we allow it to be promoted without finding other
1890 /// sext instructions that sign extended the same initial value. A getelementptr
1891 /// is considered as "complex" if it has more than 2 operands.
1892 bool AArch64TTIImpl::shouldConsiderAddressTypePromotion(
1893     const Instruction &I, bool &AllowPromotionWithoutCommonHeader) {
1894   bool Considerable = false;
1895   AllowPromotionWithoutCommonHeader = false;
1896   if (!isa<SExtInst>(&I))
1897     return false;
1898   Type *ConsideredSExtType =
1899       Type::getInt64Ty(I.getParent()->getParent()->getContext());
1900   if (I.getType() != ConsideredSExtType)
1901     return false;
1902   // See if the sext is the one with the right type and used in at least one
1903   // GetElementPtrInst.
1904   for (const User *U : I.users()) {
1905     if (const GetElementPtrInst *GEPInst = dyn_cast<GetElementPtrInst>(U)) {
1906       Considerable = true;
1907       // A getelementptr is considered as "complex" if it has more than 2
1908       // operands. We will promote a SExt used in such complex GEP as we
1909       // expect some computation to be merged if they are done on 64 bits.
1910       if (GEPInst->getNumOperands() > 2) {
1911         AllowPromotionWithoutCommonHeader = true;
1912         break;
1913       }
1914     }
1915   }
1916   return Considerable;
1917 }
1918 
1919 bool AArch64TTIImpl::isLegalToVectorizeReduction(
1920     const RecurrenceDescriptor &RdxDesc, ElementCount VF) const {
1921   if (!VF.isScalable())
1922     return true;
1923 
1924   Type *Ty = RdxDesc.getRecurrenceType();
1925   if (Ty->isBFloatTy() || !isElementTypeLegalForScalableVector(Ty))
1926     return false;
1927 
1928   switch (RdxDesc.getRecurrenceKind()) {
1929   case RecurKind::Add:
1930   case RecurKind::FAdd:
1931   case RecurKind::And:
1932   case RecurKind::Or:
1933   case RecurKind::Xor:
1934   case RecurKind::SMin:
1935   case RecurKind::SMax:
1936   case RecurKind::UMin:
1937   case RecurKind::UMax:
1938   case RecurKind::FMin:
1939   case RecurKind::FMax:
1940     return true;
1941   default:
1942     return false;
1943   }
1944 }
1945 
1946 InstructionCost
1947 AArch64TTIImpl::getMinMaxReductionCost(VectorType *Ty, VectorType *CondTy,
1948                                        bool IsUnsigned,
1949                                        TTI::TargetCostKind CostKind) {
1950   std::pair<InstructionCost, MVT> LT = TLI->getTypeLegalizationCost(DL, Ty);
1951 
1952   if (LT.second.getScalarType() == MVT::f16 && !ST->hasFullFP16())
1953     return BaseT::getMinMaxReductionCost(Ty, CondTy, IsUnsigned, CostKind);
1954 
1955   assert((isa<ScalableVectorType>(Ty) == isa<ScalableVectorType>(CondTy)) &&
1956          "Both vector needs to be equally scalable");
1957 
1958   InstructionCost LegalizationCost = 0;
1959   if (LT.first > 1) {
1960     Type *LegalVTy = EVT(LT.second).getTypeForEVT(Ty->getContext());
1961     unsigned MinMaxOpcode =
1962         Ty->isFPOrFPVectorTy()
1963             ? Intrinsic::maxnum
1964             : (IsUnsigned ? Intrinsic::umin : Intrinsic::smin);
1965     IntrinsicCostAttributes Attrs(MinMaxOpcode, LegalVTy, {LegalVTy, LegalVTy});
1966     LegalizationCost = getIntrinsicInstrCost(Attrs, CostKind) * (LT.first - 1);
1967   }
1968 
1969   return LegalizationCost + /*Cost of horizontal reduction*/ 2;
1970 }
1971 
1972 InstructionCost AArch64TTIImpl::getArithmeticReductionCostSVE(
1973     unsigned Opcode, VectorType *ValTy, TTI::TargetCostKind CostKind) {
1974   std::pair<InstructionCost, MVT> LT = TLI->getTypeLegalizationCost(DL, ValTy);
1975   InstructionCost LegalizationCost = 0;
1976   if (LT.first > 1) {
1977     Type *LegalVTy = EVT(LT.second).getTypeForEVT(ValTy->getContext());
1978     LegalizationCost = getArithmeticInstrCost(Opcode, LegalVTy, CostKind);
1979     LegalizationCost *= LT.first - 1;
1980   }
1981 
1982   int ISD = TLI->InstructionOpcodeToISD(Opcode);
1983   assert(ISD && "Invalid opcode");
1984   // Add the final reduction cost for the legal horizontal reduction
1985   switch (ISD) {
1986   case ISD::ADD:
1987   case ISD::AND:
1988   case ISD::OR:
1989   case ISD::XOR:
1990   case ISD::FADD:
1991     return LegalizationCost + 2;
1992   default:
1993     return InstructionCost::getInvalid();
1994   }
1995 }
1996 
1997 InstructionCost
1998 AArch64TTIImpl::getArithmeticReductionCost(unsigned Opcode, VectorType *ValTy,
1999                                            Optional<FastMathFlags> FMF,
2000                                            TTI::TargetCostKind CostKind) {
2001   if (TTI::requiresOrderedReduction(FMF)) {
2002     if (auto *FixedVTy = dyn_cast<FixedVectorType>(ValTy)) {
2003       InstructionCost BaseCost =
2004           BaseT::getArithmeticReductionCost(Opcode, ValTy, FMF, CostKind);
2005       // Add on extra cost to reflect the extra overhead on some CPUs. We still
2006       // end up vectorizing for more computationally intensive loops.
2007       return BaseCost + FixedVTy->getNumElements();
2008     }
2009 
2010     if (Opcode != Instruction::FAdd)
2011       return InstructionCost::getInvalid();
2012 
2013     auto *VTy = cast<ScalableVectorType>(ValTy);
2014     InstructionCost Cost =
2015         getArithmeticInstrCost(Opcode, VTy->getScalarType(), CostKind);
2016     Cost *= getMaxNumElements(VTy->getElementCount());
2017     return Cost;
2018   }
2019 
2020   if (isa<ScalableVectorType>(ValTy))
2021     return getArithmeticReductionCostSVE(Opcode, ValTy, CostKind);
2022 
2023   std::pair<InstructionCost, MVT> LT = TLI->getTypeLegalizationCost(DL, ValTy);
2024   MVT MTy = LT.second;
2025   int ISD = TLI->InstructionOpcodeToISD(Opcode);
2026   assert(ISD && "Invalid opcode");
2027 
2028   // Horizontal adds can use the 'addv' instruction. We model the cost of these
2029   // instructions as twice a normal vector add, plus 1 for each legalization
2030   // step (LT.first). This is the only arithmetic vector reduction operation for
2031   // which we have an instruction.
2032   // OR, XOR and AND costs should match the codegen from:
2033   // OR: llvm/test/CodeGen/AArch64/reduce-or.ll
2034   // XOR: llvm/test/CodeGen/AArch64/reduce-xor.ll
2035   // AND: llvm/test/CodeGen/AArch64/reduce-and.ll
2036   static const CostTblEntry CostTblNoPairwise[]{
2037       {ISD::ADD, MVT::v8i8,   2},
2038       {ISD::ADD, MVT::v16i8,  2},
2039       {ISD::ADD, MVT::v4i16,  2},
2040       {ISD::ADD, MVT::v8i16,  2},
2041       {ISD::ADD, MVT::v4i32,  2},
2042       {ISD::OR,  MVT::v8i8,  15},
2043       {ISD::OR,  MVT::v16i8, 17},
2044       {ISD::OR,  MVT::v4i16,  7},
2045       {ISD::OR,  MVT::v8i16,  9},
2046       {ISD::OR,  MVT::v2i32,  3},
2047       {ISD::OR,  MVT::v4i32,  5},
2048       {ISD::OR,  MVT::v2i64,  3},
2049       {ISD::XOR, MVT::v8i8,  15},
2050       {ISD::XOR, MVT::v16i8, 17},
2051       {ISD::XOR, MVT::v4i16,  7},
2052       {ISD::XOR, MVT::v8i16,  9},
2053       {ISD::XOR, MVT::v2i32,  3},
2054       {ISD::XOR, MVT::v4i32,  5},
2055       {ISD::XOR, MVT::v2i64,  3},
2056       {ISD::AND, MVT::v8i8,  15},
2057       {ISD::AND, MVT::v16i8, 17},
2058       {ISD::AND, MVT::v4i16,  7},
2059       {ISD::AND, MVT::v8i16,  9},
2060       {ISD::AND, MVT::v2i32,  3},
2061       {ISD::AND, MVT::v4i32,  5},
2062       {ISD::AND, MVT::v2i64,  3},
2063   };
2064   switch (ISD) {
2065   default:
2066     break;
2067   case ISD::ADD:
2068     if (const auto *Entry = CostTableLookup(CostTblNoPairwise, ISD, MTy))
2069       return (LT.first - 1) + Entry->Cost;
2070     break;
2071   case ISD::XOR:
2072   case ISD::AND:
2073   case ISD::OR:
2074     const auto *Entry = CostTableLookup(CostTblNoPairwise, ISD, MTy);
2075     if (!Entry)
2076       break;
2077     auto *ValVTy = cast<FixedVectorType>(ValTy);
2078     if (!ValVTy->getElementType()->isIntegerTy(1) &&
2079         MTy.getVectorNumElements() <= ValVTy->getNumElements() &&
2080         isPowerOf2_32(ValVTy->getNumElements())) {
2081       InstructionCost ExtraCost = 0;
2082       if (LT.first != 1) {
2083         // Type needs to be split, so there is an extra cost of LT.first - 1
2084         // arithmetic ops.
2085         auto *Ty = FixedVectorType::get(ValTy->getElementType(),
2086                                         MTy.getVectorNumElements());
2087         ExtraCost = getArithmeticInstrCost(Opcode, Ty, CostKind);
2088         ExtraCost *= LT.first - 1;
2089       }
2090       return Entry->Cost + ExtraCost;
2091     }
2092     break;
2093   }
2094   return BaseT::getArithmeticReductionCost(Opcode, ValTy, FMF, CostKind);
2095 }
2096 
2097 InstructionCost AArch64TTIImpl::getSpliceCost(VectorType *Tp, int Index) {
2098   static const CostTblEntry ShuffleTbl[] = {
2099       { TTI::SK_Splice, MVT::nxv16i8,  1 },
2100       { TTI::SK_Splice, MVT::nxv8i16,  1 },
2101       { TTI::SK_Splice, MVT::nxv4i32,  1 },
2102       { TTI::SK_Splice, MVT::nxv2i64,  1 },
2103       { TTI::SK_Splice, MVT::nxv2f16,  1 },
2104       { TTI::SK_Splice, MVT::nxv4f16,  1 },
2105       { TTI::SK_Splice, MVT::nxv8f16,  1 },
2106       { TTI::SK_Splice, MVT::nxv2bf16, 1 },
2107       { TTI::SK_Splice, MVT::nxv4bf16, 1 },
2108       { TTI::SK_Splice, MVT::nxv8bf16, 1 },
2109       { TTI::SK_Splice, MVT::nxv2f32,  1 },
2110       { TTI::SK_Splice, MVT::nxv4f32,  1 },
2111       { TTI::SK_Splice, MVT::nxv2f64,  1 },
2112   };
2113 
2114   std::pair<InstructionCost, MVT> LT = TLI->getTypeLegalizationCost(DL, Tp);
2115   Type *LegalVTy = EVT(LT.second).getTypeForEVT(Tp->getContext());
2116   TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
2117   EVT PromotedVT = LT.second.getScalarType() == MVT::i1
2118                        ? TLI->getPromotedVTForPredicate(EVT(LT.second))
2119                        : LT.second;
2120   Type *PromotedVTy = EVT(PromotedVT).getTypeForEVT(Tp->getContext());
2121   InstructionCost LegalizationCost = 0;
2122   if (Index < 0) {
2123     LegalizationCost =
2124         getCmpSelInstrCost(Instruction::ICmp, PromotedVTy, PromotedVTy,
2125                            CmpInst::BAD_ICMP_PREDICATE, CostKind) +
2126         getCmpSelInstrCost(Instruction::Select, PromotedVTy, LegalVTy,
2127                            CmpInst::BAD_ICMP_PREDICATE, CostKind);
2128   }
2129 
2130   // Predicated splice are promoted when lowering. See AArch64ISelLowering.cpp
2131   // Cost performed on a promoted type.
2132   if (LT.second.getScalarType() == MVT::i1) {
2133     LegalizationCost +=
2134         getCastInstrCost(Instruction::ZExt, PromotedVTy, LegalVTy,
2135                          TTI::CastContextHint::None, CostKind) +
2136         getCastInstrCost(Instruction::Trunc, LegalVTy, PromotedVTy,
2137                          TTI::CastContextHint::None, CostKind);
2138   }
2139   const auto *Entry =
2140       CostTableLookup(ShuffleTbl, TTI::SK_Splice, PromotedVT.getSimpleVT());
2141   assert(Entry && "Illegal Type for Splice");
2142   LegalizationCost += Entry->Cost;
2143   return LegalizationCost * LT.first;
2144 }
2145 
2146 InstructionCost AArch64TTIImpl::getShuffleCost(TTI::ShuffleKind Kind,
2147                                                VectorType *Tp,
2148                                                ArrayRef<int> Mask, int Index,
2149                                                VectorType *SubTp) {
2150   Kind = improveShuffleKindFromMask(Kind, Mask);
2151   if (Kind == TTI::SK_Broadcast || Kind == TTI::SK_Transpose ||
2152       Kind == TTI::SK_Select || Kind == TTI::SK_PermuteSingleSrc ||
2153       Kind == TTI::SK_Reverse) {
2154     static const CostTblEntry ShuffleTbl[] = {
2155       // Broadcast shuffle kinds can be performed with 'dup'.
2156       { TTI::SK_Broadcast, MVT::v8i8,  1 },
2157       { TTI::SK_Broadcast, MVT::v16i8, 1 },
2158       { TTI::SK_Broadcast, MVT::v4i16, 1 },
2159       { TTI::SK_Broadcast, MVT::v8i16, 1 },
2160       { TTI::SK_Broadcast, MVT::v2i32, 1 },
2161       { TTI::SK_Broadcast, MVT::v4i32, 1 },
2162       { TTI::SK_Broadcast, MVT::v2i64, 1 },
2163       { TTI::SK_Broadcast, MVT::v2f32, 1 },
2164       { TTI::SK_Broadcast, MVT::v4f32, 1 },
2165       { TTI::SK_Broadcast, MVT::v2f64, 1 },
2166       // Transpose shuffle kinds can be performed with 'trn1/trn2' and
2167       // 'zip1/zip2' instructions.
2168       { TTI::SK_Transpose, MVT::v8i8,  1 },
2169       { TTI::SK_Transpose, MVT::v16i8, 1 },
2170       { TTI::SK_Transpose, MVT::v4i16, 1 },
2171       { TTI::SK_Transpose, MVT::v8i16, 1 },
2172       { TTI::SK_Transpose, MVT::v2i32, 1 },
2173       { TTI::SK_Transpose, MVT::v4i32, 1 },
2174       { TTI::SK_Transpose, MVT::v2i64, 1 },
2175       { TTI::SK_Transpose, MVT::v2f32, 1 },
2176       { TTI::SK_Transpose, MVT::v4f32, 1 },
2177       { TTI::SK_Transpose, MVT::v2f64, 1 },
2178       // Select shuffle kinds.
2179       // TODO: handle vXi8/vXi16.
2180       { TTI::SK_Select, MVT::v2i32, 1 }, // mov.
2181       { TTI::SK_Select, MVT::v4i32, 2 }, // rev+trn (or similar).
2182       { TTI::SK_Select, MVT::v2i64, 1 }, // mov.
2183       { TTI::SK_Select, MVT::v2f32, 1 }, // mov.
2184       { TTI::SK_Select, MVT::v4f32, 2 }, // rev+trn (or similar).
2185       { TTI::SK_Select, MVT::v2f64, 1 }, // mov.
2186       // PermuteSingleSrc shuffle kinds.
2187       { TTI::SK_PermuteSingleSrc, MVT::v2i32, 1 }, // mov.
2188       { TTI::SK_PermuteSingleSrc, MVT::v4i32, 3 }, // perfectshuffle worst case.
2189       { TTI::SK_PermuteSingleSrc, MVT::v2i64, 1 }, // mov.
2190       { TTI::SK_PermuteSingleSrc, MVT::v2f32, 1 }, // mov.
2191       { TTI::SK_PermuteSingleSrc, MVT::v4f32, 3 }, // perfectshuffle worst case.
2192       { TTI::SK_PermuteSingleSrc, MVT::v2f64, 1 }, // mov.
2193       { TTI::SK_PermuteSingleSrc, MVT::v4i16, 3 }, // perfectshuffle worst case.
2194       { TTI::SK_PermuteSingleSrc, MVT::v4f16, 3 }, // perfectshuffle worst case.
2195       { TTI::SK_PermuteSingleSrc, MVT::v4bf16, 3 }, // perfectshuffle worst case.
2196       { TTI::SK_PermuteSingleSrc, MVT::v8i16, 8 }, // constpool + load + tbl
2197       { TTI::SK_PermuteSingleSrc, MVT::v8f16, 8 }, // constpool + load + tbl
2198       { TTI::SK_PermuteSingleSrc, MVT::v8bf16, 8 }, // constpool + load + tbl
2199       { TTI::SK_PermuteSingleSrc, MVT::v8i8, 8 }, // constpool + load + tbl
2200       { TTI::SK_PermuteSingleSrc, MVT::v16i8, 8 }, // constpool + load + tbl
2201       // Reverse can be lowered with `rev`.
2202       { TTI::SK_Reverse, MVT::v2i32, 1 }, // mov.
2203       { TTI::SK_Reverse, MVT::v4i32, 2 }, // REV64; EXT
2204       { TTI::SK_Reverse, MVT::v2i64, 1 }, // mov.
2205       { TTI::SK_Reverse, MVT::v2f32, 1 }, // mov.
2206       { TTI::SK_Reverse, MVT::v4f32, 2 }, // REV64; EXT
2207       { TTI::SK_Reverse, MVT::v2f64, 1 }, // mov.
2208       // Broadcast shuffle kinds for scalable vectors
2209       { TTI::SK_Broadcast, MVT::nxv16i8,  1 },
2210       { TTI::SK_Broadcast, MVT::nxv8i16,  1 },
2211       { TTI::SK_Broadcast, MVT::nxv4i32,  1 },
2212       { TTI::SK_Broadcast, MVT::nxv2i64,  1 },
2213       { TTI::SK_Broadcast, MVT::nxv2f16,  1 },
2214       { TTI::SK_Broadcast, MVT::nxv4f16,  1 },
2215       { TTI::SK_Broadcast, MVT::nxv8f16,  1 },
2216       { TTI::SK_Broadcast, MVT::nxv2bf16, 1 },
2217       { TTI::SK_Broadcast, MVT::nxv4bf16, 1 },
2218       { TTI::SK_Broadcast, MVT::nxv8bf16, 1 },
2219       { TTI::SK_Broadcast, MVT::nxv2f32,  1 },
2220       { TTI::SK_Broadcast, MVT::nxv4f32,  1 },
2221       { TTI::SK_Broadcast, MVT::nxv2f64,  1 },
2222       { TTI::SK_Broadcast, MVT::nxv16i1,  1 },
2223       { TTI::SK_Broadcast, MVT::nxv8i1,   1 },
2224       { TTI::SK_Broadcast, MVT::nxv4i1,   1 },
2225       { TTI::SK_Broadcast, MVT::nxv2i1,   1 },
2226       // Handle the cases for vector.reverse with scalable vectors
2227       { TTI::SK_Reverse, MVT::nxv16i8,  1 },
2228       { TTI::SK_Reverse, MVT::nxv8i16,  1 },
2229       { TTI::SK_Reverse, MVT::nxv4i32,  1 },
2230       { TTI::SK_Reverse, MVT::nxv2i64,  1 },
2231       { TTI::SK_Reverse, MVT::nxv2f16,  1 },
2232       { TTI::SK_Reverse, MVT::nxv4f16,  1 },
2233       { TTI::SK_Reverse, MVT::nxv8f16,  1 },
2234       { TTI::SK_Reverse, MVT::nxv2bf16, 1 },
2235       { TTI::SK_Reverse, MVT::nxv4bf16, 1 },
2236       { TTI::SK_Reverse, MVT::nxv8bf16, 1 },
2237       { TTI::SK_Reverse, MVT::nxv2f32,  1 },
2238       { TTI::SK_Reverse, MVT::nxv4f32,  1 },
2239       { TTI::SK_Reverse, MVT::nxv2f64,  1 },
2240       { TTI::SK_Reverse, MVT::nxv16i1,  1 },
2241       { TTI::SK_Reverse, MVT::nxv8i1,   1 },
2242       { TTI::SK_Reverse, MVT::nxv4i1,   1 },
2243       { TTI::SK_Reverse, MVT::nxv2i1,   1 },
2244     };
2245     std::pair<InstructionCost, MVT> LT = TLI->getTypeLegalizationCost(DL, Tp);
2246     if (const auto *Entry = CostTableLookup(ShuffleTbl, Kind, LT.second))
2247       return LT.first * Entry->Cost;
2248   }
2249   if (Kind == TTI::SK_Splice && isa<ScalableVectorType>(Tp))
2250     return getSpliceCost(Tp, Index);
2251   return BaseT::getShuffleCost(Kind, Tp, Mask, Index, SubTp);
2252 }
2253