1 //===----------- VectorUtils.cpp - Vectorizer utility functions -----------===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // This file defines vectorizer utilities. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "llvm/Analysis/VectorUtils.h" 15 #include "llvm/ADT/EquivalenceClasses.h" 16 #include "llvm/Analysis/DemandedBits.h" 17 #include "llvm/Analysis/LoopInfo.h" 18 #include "llvm/Analysis/ScalarEvolution.h" 19 #include "llvm/Analysis/ScalarEvolutionExpressions.h" 20 #include "llvm/Analysis/TargetTransformInfo.h" 21 #include "llvm/Analysis/ValueTracking.h" 22 #include "llvm/IR/Constants.h" 23 #include "llvm/IR/GetElementPtrTypeIterator.h" 24 #include "llvm/IR/IRBuilder.h" 25 #include "llvm/IR/PatternMatch.h" 26 #include "llvm/IR/Value.h" 27 28 using namespace llvm; 29 using namespace llvm::PatternMatch; 30 31 /// \brief Identify if the intrinsic is trivially vectorizable. 32 /// This method returns true if the intrinsic's argument types are all 33 /// scalars for the scalar form of the intrinsic and all vectors for 34 /// the vector form of the intrinsic. 35 bool llvm::isTriviallyVectorizable(Intrinsic::ID ID) { 36 switch (ID) { 37 case Intrinsic::sqrt: 38 case Intrinsic::sin: 39 case Intrinsic::cos: 40 case Intrinsic::exp: 41 case Intrinsic::exp2: 42 case Intrinsic::log: 43 case Intrinsic::log10: 44 case Intrinsic::log2: 45 case Intrinsic::fabs: 46 case Intrinsic::minnum: 47 case Intrinsic::maxnum: 48 case Intrinsic::copysign: 49 case Intrinsic::floor: 50 case Intrinsic::ceil: 51 case Intrinsic::trunc: 52 case Intrinsic::rint: 53 case Intrinsic::nearbyint: 54 case Intrinsic::round: 55 case Intrinsic::bswap: 56 case Intrinsic::bitreverse: 57 case Intrinsic::ctpop: 58 case Intrinsic::pow: 59 case Intrinsic::fma: 60 case Intrinsic::fmuladd: 61 case Intrinsic::ctlz: 62 case Intrinsic::cttz: 63 case Intrinsic::powi: 64 return true; 65 default: 66 return false; 67 } 68 } 69 70 /// \brief Identifies if the intrinsic has a scalar operand. It check for 71 /// ctlz,cttz and powi special intrinsics whose argument is scalar. 72 bool llvm::hasVectorInstrinsicScalarOpd(Intrinsic::ID ID, 73 unsigned ScalarOpdIdx) { 74 switch (ID) { 75 case Intrinsic::ctlz: 76 case Intrinsic::cttz: 77 case Intrinsic::powi: 78 return (ScalarOpdIdx == 1); 79 default: 80 return false; 81 } 82 } 83 84 /// \brief Returns intrinsic ID for call. 85 /// For the input call instruction it finds mapping intrinsic and returns 86 /// its ID, in case it does not found it return not_intrinsic. 87 Intrinsic::ID llvm::getVectorIntrinsicIDForCall(const CallInst *CI, 88 const TargetLibraryInfo *TLI) { 89 Intrinsic::ID ID = getIntrinsicForCallSite(CI, TLI); 90 if (ID == Intrinsic::not_intrinsic) 91 return Intrinsic::not_intrinsic; 92 93 if (isTriviallyVectorizable(ID) || ID == Intrinsic::lifetime_start || 94 ID == Intrinsic::lifetime_end || ID == Intrinsic::assume || 95 ID == Intrinsic::sideeffect) 96 return ID; 97 return Intrinsic::not_intrinsic; 98 } 99 100 /// \brief Find the operand of the GEP that should be checked for consecutive 101 /// stores. This ignores trailing indices that have no effect on the final 102 /// pointer. 103 unsigned llvm::getGEPInductionOperand(const GetElementPtrInst *Gep) { 104 const DataLayout &DL = Gep->getModule()->getDataLayout(); 105 unsigned LastOperand = Gep->getNumOperands() - 1; 106 unsigned GEPAllocSize = DL.getTypeAllocSize(Gep->getResultElementType()); 107 108 // Walk backwards and try to peel off zeros. 109 while (LastOperand > 1 && match(Gep->getOperand(LastOperand), m_Zero())) { 110 // Find the type we're currently indexing into. 111 gep_type_iterator GEPTI = gep_type_begin(Gep); 112 std::advance(GEPTI, LastOperand - 2); 113 114 // If it's a type with the same allocation size as the result of the GEP we 115 // can peel off the zero index. 116 if (DL.getTypeAllocSize(GEPTI.getIndexedType()) != GEPAllocSize) 117 break; 118 --LastOperand; 119 } 120 121 return LastOperand; 122 } 123 124 /// \brief If the argument is a GEP, then returns the operand identified by 125 /// getGEPInductionOperand. However, if there is some other non-loop-invariant 126 /// operand, it returns that instead. 127 Value *llvm::stripGetElementPtr(Value *Ptr, ScalarEvolution *SE, Loop *Lp) { 128 GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr); 129 if (!GEP) 130 return Ptr; 131 132 unsigned InductionOperand = getGEPInductionOperand(GEP); 133 134 // Check that all of the gep indices are uniform except for our induction 135 // operand. 136 for (unsigned i = 0, e = GEP->getNumOperands(); i != e; ++i) 137 if (i != InductionOperand && 138 !SE->isLoopInvariant(SE->getSCEV(GEP->getOperand(i)), Lp)) 139 return Ptr; 140 return GEP->getOperand(InductionOperand); 141 } 142 143 /// \brief If a value has only one user that is a CastInst, return it. 144 Value *llvm::getUniqueCastUse(Value *Ptr, Loop *Lp, Type *Ty) { 145 Value *UniqueCast = nullptr; 146 for (User *U : Ptr->users()) { 147 CastInst *CI = dyn_cast<CastInst>(U); 148 if (CI && CI->getType() == Ty) { 149 if (!UniqueCast) 150 UniqueCast = CI; 151 else 152 return nullptr; 153 } 154 } 155 return UniqueCast; 156 } 157 158 /// \brief Get the stride of a pointer access in a loop. Looks for symbolic 159 /// strides "a[i*stride]". Returns the symbolic stride, or null otherwise. 160 Value *llvm::getStrideFromPointer(Value *Ptr, ScalarEvolution *SE, Loop *Lp) { 161 auto *PtrTy = dyn_cast<PointerType>(Ptr->getType()); 162 if (!PtrTy || PtrTy->isAggregateType()) 163 return nullptr; 164 165 // Try to remove a gep instruction to make the pointer (actually index at this 166 // point) easier analyzable. If OrigPtr is equal to Ptr we are analzying the 167 // pointer, otherwise, we are analyzing the index. 168 Value *OrigPtr = Ptr; 169 170 // The size of the pointer access. 171 int64_t PtrAccessSize = 1; 172 173 Ptr = stripGetElementPtr(Ptr, SE, Lp); 174 const SCEV *V = SE->getSCEV(Ptr); 175 176 if (Ptr != OrigPtr) 177 // Strip off casts. 178 while (const SCEVCastExpr *C = dyn_cast<SCEVCastExpr>(V)) 179 V = C->getOperand(); 180 181 const SCEVAddRecExpr *S = dyn_cast<SCEVAddRecExpr>(V); 182 if (!S) 183 return nullptr; 184 185 V = S->getStepRecurrence(*SE); 186 if (!V) 187 return nullptr; 188 189 // Strip off the size of access multiplication if we are still analyzing the 190 // pointer. 191 if (OrigPtr == Ptr) { 192 if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(V)) { 193 if (M->getOperand(0)->getSCEVType() != scConstant) 194 return nullptr; 195 196 const APInt &APStepVal = cast<SCEVConstant>(M->getOperand(0))->getAPInt(); 197 198 // Huge step value - give up. 199 if (APStepVal.getBitWidth() > 64) 200 return nullptr; 201 202 int64_t StepVal = APStepVal.getSExtValue(); 203 if (PtrAccessSize != StepVal) 204 return nullptr; 205 V = M->getOperand(1); 206 } 207 } 208 209 // Strip off casts. 210 Type *StripedOffRecurrenceCast = nullptr; 211 if (const SCEVCastExpr *C = dyn_cast<SCEVCastExpr>(V)) { 212 StripedOffRecurrenceCast = C->getType(); 213 V = C->getOperand(); 214 } 215 216 // Look for the loop invariant symbolic value. 217 const SCEVUnknown *U = dyn_cast<SCEVUnknown>(V); 218 if (!U) 219 return nullptr; 220 221 Value *Stride = U->getValue(); 222 if (!Lp->isLoopInvariant(Stride)) 223 return nullptr; 224 225 // If we have stripped off the recurrence cast we have to make sure that we 226 // return the value that is used in this loop so that we can replace it later. 227 if (StripedOffRecurrenceCast) 228 Stride = getUniqueCastUse(Stride, Lp, StripedOffRecurrenceCast); 229 230 return Stride; 231 } 232 233 /// \brief Given a vector and an element number, see if the scalar value is 234 /// already around as a register, for example if it were inserted then extracted 235 /// from the vector. 236 Value *llvm::findScalarElement(Value *V, unsigned EltNo) { 237 assert(V->getType()->isVectorTy() && "Not looking at a vector?"); 238 VectorType *VTy = cast<VectorType>(V->getType()); 239 unsigned Width = VTy->getNumElements(); 240 if (EltNo >= Width) // Out of range access. 241 return UndefValue::get(VTy->getElementType()); 242 243 if (Constant *C = dyn_cast<Constant>(V)) 244 return C->getAggregateElement(EltNo); 245 246 if (InsertElementInst *III = dyn_cast<InsertElementInst>(V)) { 247 // If this is an insert to a variable element, we don't know what it is. 248 if (!isa<ConstantInt>(III->getOperand(2))) 249 return nullptr; 250 unsigned IIElt = cast<ConstantInt>(III->getOperand(2))->getZExtValue(); 251 252 // If this is an insert to the element we are looking for, return the 253 // inserted value. 254 if (EltNo == IIElt) 255 return III->getOperand(1); 256 257 // Otherwise, the insertelement doesn't modify the value, recurse on its 258 // vector input. 259 return findScalarElement(III->getOperand(0), EltNo); 260 } 261 262 if (ShuffleVectorInst *SVI = dyn_cast<ShuffleVectorInst>(V)) { 263 unsigned LHSWidth = SVI->getOperand(0)->getType()->getVectorNumElements(); 264 int InEl = SVI->getMaskValue(EltNo); 265 if (InEl < 0) 266 return UndefValue::get(VTy->getElementType()); 267 if (InEl < (int)LHSWidth) 268 return findScalarElement(SVI->getOperand(0), InEl); 269 return findScalarElement(SVI->getOperand(1), InEl - LHSWidth); 270 } 271 272 // Extract a value from a vector add operation with a constant zero. 273 Value *Val = nullptr; Constant *Con = nullptr; 274 if (match(V, m_Add(m_Value(Val), m_Constant(Con)))) 275 if (Constant *Elt = Con->getAggregateElement(EltNo)) 276 if (Elt->isNullValue()) 277 return findScalarElement(Val, EltNo); 278 279 // Otherwise, we don't know. 280 return nullptr; 281 } 282 283 /// \brief Get splat value if the input is a splat vector or return nullptr. 284 /// This function is not fully general. It checks only 2 cases: 285 /// the input value is (1) a splat constants vector or (2) a sequence 286 /// of instructions that broadcast a single value into a vector. 287 /// 288 const llvm::Value *llvm::getSplatValue(const Value *V) { 289 290 if (auto *C = dyn_cast<Constant>(V)) 291 if (isa<VectorType>(V->getType())) 292 return C->getSplatValue(); 293 294 auto *ShuffleInst = dyn_cast<ShuffleVectorInst>(V); 295 if (!ShuffleInst) 296 return nullptr; 297 // All-zero (or undef) shuffle mask elements. 298 for (int MaskElt : ShuffleInst->getShuffleMask()) 299 if (MaskElt != 0 && MaskElt != -1) 300 return nullptr; 301 // The first shuffle source is 'insertelement' with index 0. 302 auto *InsertEltInst = 303 dyn_cast<InsertElementInst>(ShuffleInst->getOperand(0)); 304 if (!InsertEltInst || !isa<ConstantInt>(InsertEltInst->getOperand(2)) || 305 !cast<ConstantInt>(InsertEltInst->getOperand(2))->isZero()) 306 return nullptr; 307 308 return InsertEltInst->getOperand(1); 309 } 310 311 MapVector<Instruction *, uint64_t> 312 llvm::computeMinimumValueSizes(ArrayRef<BasicBlock *> Blocks, DemandedBits &DB, 313 const TargetTransformInfo *TTI) { 314 315 // DemandedBits will give us every value's live-out bits. But we want 316 // to ensure no extra casts would need to be inserted, so every DAG 317 // of connected values must have the same minimum bitwidth. 318 EquivalenceClasses<Value *> ECs; 319 SmallVector<Value *, 16> Worklist; 320 SmallPtrSet<Value *, 4> Roots; 321 SmallPtrSet<Value *, 16> Visited; 322 DenseMap<Value *, uint64_t> DBits; 323 SmallPtrSet<Instruction *, 4> InstructionSet; 324 MapVector<Instruction *, uint64_t> MinBWs; 325 326 // Determine the roots. We work bottom-up, from truncs or icmps. 327 bool SeenExtFromIllegalType = false; 328 for (auto *BB : Blocks) 329 for (auto &I : *BB) { 330 InstructionSet.insert(&I); 331 332 if (TTI && (isa<ZExtInst>(&I) || isa<SExtInst>(&I)) && 333 !TTI->isTypeLegal(I.getOperand(0)->getType())) 334 SeenExtFromIllegalType = true; 335 336 // Only deal with non-vector integers up to 64-bits wide. 337 if ((isa<TruncInst>(&I) || isa<ICmpInst>(&I)) && 338 !I.getType()->isVectorTy() && 339 I.getOperand(0)->getType()->getScalarSizeInBits() <= 64) { 340 // Don't make work for ourselves. If we know the loaded type is legal, 341 // don't add it to the worklist. 342 if (TTI && isa<TruncInst>(&I) && TTI->isTypeLegal(I.getType())) 343 continue; 344 345 Worklist.push_back(&I); 346 Roots.insert(&I); 347 } 348 } 349 // Early exit. 350 if (Worklist.empty() || (TTI && !SeenExtFromIllegalType)) 351 return MinBWs; 352 353 // Now proceed breadth-first, unioning values together. 354 while (!Worklist.empty()) { 355 Value *Val = Worklist.pop_back_val(); 356 Value *Leader = ECs.getOrInsertLeaderValue(Val); 357 358 if (Visited.count(Val)) 359 continue; 360 Visited.insert(Val); 361 362 // Non-instructions terminate a chain successfully. 363 if (!isa<Instruction>(Val)) 364 continue; 365 Instruction *I = cast<Instruction>(Val); 366 367 // If we encounter a type that is larger than 64 bits, we can't represent 368 // it so bail out. 369 if (DB.getDemandedBits(I).getBitWidth() > 64) 370 return MapVector<Instruction *, uint64_t>(); 371 372 uint64_t V = DB.getDemandedBits(I).getZExtValue(); 373 DBits[Leader] |= V; 374 DBits[I] = V; 375 376 // Casts, loads and instructions outside of our range terminate a chain 377 // successfully. 378 if (isa<SExtInst>(I) || isa<ZExtInst>(I) || isa<LoadInst>(I) || 379 !InstructionSet.count(I)) 380 continue; 381 382 // Unsafe casts terminate a chain unsuccessfully. We can't do anything 383 // useful with bitcasts, ptrtoints or inttoptrs and it'd be unsafe to 384 // transform anything that relies on them. 385 if (isa<BitCastInst>(I) || isa<PtrToIntInst>(I) || isa<IntToPtrInst>(I) || 386 !I->getType()->isIntegerTy()) { 387 DBits[Leader] |= ~0ULL; 388 continue; 389 } 390 391 // We don't modify the types of PHIs. Reductions will already have been 392 // truncated if possible, and inductions' sizes will have been chosen by 393 // indvars. 394 if (isa<PHINode>(I)) 395 continue; 396 397 if (DBits[Leader] == ~0ULL) 398 // All bits demanded, no point continuing. 399 continue; 400 401 for (Value *O : cast<User>(I)->operands()) { 402 ECs.unionSets(Leader, O); 403 Worklist.push_back(O); 404 } 405 } 406 407 // Now we've discovered all values, walk them to see if there are 408 // any users we didn't see. If there are, we can't optimize that 409 // chain. 410 for (auto &I : DBits) 411 for (auto *U : I.first->users()) 412 if (U->getType()->isIntegerTy() && DBits.count(U) == 0) 413 DBits[ECs.getOrInsertLeaderValue(I.first)] |= ~0ULL; 414 415 for (auto I = ECs.begin(), E = ECs.end(); I != E; ++I) { 416 uint64_t LeaderDemandedBits = 0; 417 for (auto MI = ECs.member_begin(I), ME = ECs.member_end(); MI != ME; ++MI) 418 LeaderDemandedBits |= DBits[*MI]; 419 420 uint64_t MinBW = (sizeof(LeaderDemandedBits) * 8) - 421 llvm::countLeadingZeros(LeaderDemandedBits); 422 // Round up to a power of 2 423 if (!isPowerOf2_64((uint64_t)MinBW)) 424 MinBW = NextPowerOf2(MinBW); 425 426 // We don't modify the types of PHIs. Reductions will already have been 427 // truncated if possible, and inductions' sizes will have been chosen by 428 // indvars. 429 // If we are required to shrink a PHI, abandon this entire equivalence class. 430 bool Abort = false; 431 for (auto MI = ECs.member_begin(I), ME = ECs.member_end(); MI != ME; ++MI) 432 if (isa<PHINode>(*MI) && MinBW < (*MI)->getType()->getScalarSizeInBits()) { 433 Abort = true; 434 break; 435 } 436 if (Abort) 437 continue; 438 439 for (auto MI = ECs.member_begin(I), ME = ECs.member_end(); MI != ME; ++MI) { 440 if (!isa<Instruction>(*MI)) 441 continue; 442 Type *Ty = (*MI)->getType(); 443 if (Roots.count(*MI)) 444 Ty = cast<Instruction>(*MI)->getOperand(0)->getType(); 445 if (MinBW < Ty->getScalarSizeInBits()) 446 MinBWs[cast<Instruction>(*MI)] = MinBW; 447 } 448 } 449 450 return MinBWs; 451 } 452 453 /// \returns \p I after propagating metadata from \p VL. 454 Instruction *llvm::propagateMetadata(Instruction *Inst, ArrayRef<Value *> VL) { 455 Instruction *I0 = cast<Instruction>(VL[0]); 456 SmallVector<std::pair<unsigned, MDNode *>, 4> Metadata; 457 I0->getAllMetadataOtherThanDebugLoc(Metadata); 458 459 for (auto Kind : 460 {LLVMContext::MD_tbaa, LLVMContext::MD_alias_scope, 461 LLVMContext::MD_noalias, LLVMContext::MD_fpmath, 462 LLVMContext::MD_nontemporal, LLVMContext::MD_invariant_load}) { 463 MDNode *MD = I0->getMetadata(Kind); 464 465 for (int J = 1, E = VL.size(); MD && J != E; ++J) { 466 const Instruction *IJ = cast<Instruction>(VL[J]); 467 MDNode *IMD = IJ->getMetadata(Kind); 468 switch (Kind) { 469 case LLVMContext::MD_tbaa: 470 MD = MDNode::getMostGenericTBAA(MD, IMD); 471 break; 472 case LLVMContext::MD_alias_scope: 473 MD = MDNode::getMostGenericAliasScope(MD, IMD); 474 break; 475 case LLVMContext::MD_fpmath: 476 MD = MDNode::getMostGenericFPMath(MD, IMD); 477 break; 478 case LLVMContext::MD_noalias: 479 case LLVMContext::MD_nontemporal: 480 case LLVMContext::MD_invariant_load: 481 MD = MDNode::intersect(MD, IMD); 482 break; 483 default: 484 llvm_unreachable("unhandled metadata"); 485 } 486 } 487 488 Inst->setMetadata(Kind, MD); 489 } 490 491 return Inst; 492 } 493 494 Constant *llvm::createInterleaveMask(IRBuilder<> &Builder, unsigned VF, 495 unsigned NumVecs) { 496 SmallVector<Constant *, 16> Mask; 497 for (unsigned i = 0; i < VF; i++) 498 for (unsigned j = 0; j < NumVecs; j++) 499 Mask.push_back(Builder.getInt32(j * VF + i)); 500 501 return ConstantVector::get(Mask); 502 } 503 504 Constant *llvm::createStrideMask(IRBuilder<> &Builder, unsigned Start, 505 unsigned Stride, unsigned VF) { 506 SmallVector<Constant *, 16> Mask; 507 for (unsigned i = 0; i < VF; i++) 508 Mask.push_back(Builder.getInt32(Start + i * Stride)); 509 510 return ConstantVector::get(Mask); 511 } 512 513 Constant *llvm::createSequentialMask(IRBuilder<> &Builder, unsigned Start, 514 unsigned NumInts, unsigned NumUndefs) { 515 SmallVector<Constant *, 16> Mask; 516 for (unsigned i = 0; i < NumInts; i++) 517 Mask.push_back(Builder.getInt32(Start + i)); 518 519 Constant *Undef = UndefValue::get(Builder.getInt32Ty()); 520 for (unsigned i = 0; i < NumUndefs; i++) 521 Mask.push_back(Undef); 522 523 return ConstantVector::get(Mask); 524 } 525 526 /// A helper function for concatenating vectors. This function concatenates two 527 /// vectors having the same element type. If the second vector has fewer 528 /// elements than the first, it is padded with undefs. 529 static Value *concatenateTwoVectors(IRBuilder<> &Builder, Value *V1, 530 Value *V2) { 531 VectorType *VecTy1 = dyn_cast<VectorType>(V1->getType()); 532 VectorType *VecTy2 = dyn_cast<VectorType>(V2->getType()); 533 assert(VecTy1 && VecTy2 && 534 VecTy1->getScalarType() == VecTy2->getScalarType() && 535 "Expect two vectors with the same element type"); 536 537 unsigned NumElts1 = VecTy1->getNumElements(); 538 unsigned NumElts2 = VecTy2->getNumElements(); 539 assert(NumElts1 >= NumElts2 && "Unexpect the first vector has less elements"); 540 541 if (NumElts1 > NumElts2) { 542 // Extend with UNDEFs. 543 Constant *ExtMask = 544 createSequentialMask(Builder, 0, NumElts2, NumElts1 - NumElts2); 545 V2 = Builder.CreateShuffleVector(V2, UndefValue::get(VecTy2), ExtMask); 546 } 547 548 Constant *Mask = createSequentialMask(Builder, 0, NumElts1 + NumElts2, 0); 549 return Builder.CreateShuffleVector(V1, V2, Mask); 550 } 551 552 Value *llvm::concatenateVectors(IRBuilder<> &Builder, ArrayRef<Value *> Vecs) { 553 unsigned NumVecs = Vecs.size(); 554 assert(NumVecs > 1 && "Should be at least two vectors"); 555 556 SmallVector<Value *, 8> ResList; 557 ResList.append(Vecs.begin(), Vecs.end()); 558 do { 559 SmallVector<Value *, 8> TmpList; 560 for (unsigned i = 0; i < NumVecs - 1; i += 2) { 561 Value *V0 = ResList[i], *V1 = ResList[i + 1]; 562 assert((V0->getType() == V1->getType() || i == NumVecs - 2) && 563 "Only the last vector may have a different type"); 564 565 TmpList.push_back(concatenateTwoVectors(Builder, V0, V1)); 566 } 567 568 // Push the last vector if the total number of vectors is odd. 569 if (NumVecs % 2 != 0) 570 TmpList.push_back(ResList[NumVecs - 1]); 571 572 ResList = TmpList; 573 NumVecs = ResList.size(); 574 } while (NumVecs > 1); 575 576 return ResList[0]; 577 } 578