1 //===----------- VectorUtils.cpp - Vectorizer utility functions -----------===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // This file defines vectorizer utilities. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "llvm/Analysis/VectorUtils.h" 15 #include "llvm/ADT/EquivalenceClasses.h" 16 #include "llvm/Analysis/DemandedBits.h" 17 #include "llvm/Analysis/LoopInfo.h" 18 #include "llvm/Analysis/ScalarEvolution.h" 19 #include "llvm/Analysis/ScalarEvolutionExpressions.h" 20 #include "llvm/Analysis/TargetTransformInfo.h" 21 #include "llvm/Analysis/ValueTracking.h" 22 #include "llvm/IR/Constants.h" 23 #include "llvm/IR/GetElementPtrTypeIterator.h" 24 #include "llvm/IR/IRBuilder.h" 25 #include "llvm/IR/PatternMatch.h" 26 #include "llvm/IR/Value.h" 27 28 using namespace llvm; 29 using namespace llvm::PatternMatch; 30 31 /// \brief Identify if the intrinsic is trivially vectorizable. 32 /// This method returns true if the intrinsic's argument types are all 33 /// scalars for the scalar form of the intrinsic and all vectors for 34 /// the vector form of the intrinsic. 35 bool llvm::isTriviallyVectorizable(Intrinsic::ID ID) { 36 switch (ID) { 37 case Intrinsic::sqrt: 38 case Intrinsic::sin: 39 case Intrinsic::cos: 40 case Intrinsic::exp: 41 case Intrinsic::exp2: 42 case Intrinsic::log: 43 case Intrinsic::log10: 44 case Intrinsic::log2: 45 case Intrinsic::fabs: 46 case Intrinsic::minnum: 47 case Intrinsic::maxnum: 48 case Intrinsic::copysign: 49 case Intrinsic::floor: 50 case Intrinsic::ceil: 51 case Intrinsic::trunc: 52 case Intrinsic::rint: 53 case Intrinsic::nearbyint: 54 case Intrinsic::round: 55 case Intrinsic::bswap: 56 case Intrinsic::bitreverse: 57 case Intrinsic::ctpop: 58 case Intrinsic::pow: 59 case Intrinsic::fma: 60 case Intrinsic::fmuladd: 61 case Intrinsic::ctlz: 62 case Intrinsic::cttz: 63 case Intrinsic::powi: 64 return true; 65 default: 66 return false; 67 } 68 } 69 70 /// \brief Identifies if the intrinsic has a scalar operand. It check for 71 /// ctlz,cttz and powi special intrinsics whose argument is scalar. 72 bool llvm::hasVectorInstrinsicScalarOpd(Intrinsic::ID ID, 73 unsigned ScalarOpdIdx) { 74 switch (ID) { 75 case Intrinsic::ctlz: 76 case Intrinsic::cttz: 77 case Intrinsic::powi: 78 return (ScalarOpdIdx == 1); 79 default: 80 return false; 81 } 82 } 83 84 /// \brief Returns intrinsic ID for call. 85 /// For the input call instruction it finds mapping intrinsic and returns 86 /// its ID, in case it does not found it return not_intrinsic. 87 Intrinsic::ID llvm::getVectorIntrinsicIDForCall(const CallInst *CI, 88 const TargetLibraryInfo *TLI) { 89 Intrinsic::ID ID = getIntrinsicForCallSite(CI, TLI); 90 if (ID == Intrinsic::not_intrinsic) 91 return Intrinsic::not_intrinsic; 92 93 if (isTriviallyVectorizable(ID) || ID == Intrinsic::lifetime_start || 94 ID == Intrinsic::lifetime_end || ID == Intrinsic::assume) 95 return ID; 96 return Intrinsic::not_intrinsic; 97 } 98 99 /// \brief Find the operand of the GEP that should be checked for consecutive 100 /// stores. This ignores trailing indices that have no effect on the final 101 /// pointer. 102 unsigned llvm::getGEPInductionOperand(const GetElementPtrInst *Gep) { 103 const DataLayout &DL = Gep->getModule()->getDataLayout(); 104 unsigned LastOperand = Gep->getNumOperands() - 1; 105 unsigned GEPAllocSize = DL.getTypeAllocSize(Gep->getResultElementType()); 106 107 // Walk backwards and try to peel off zeros. 108 while (LastOperand > 1 && match(Gep->getOperand(LastOperand), m_Zero())) { 109 // Find the type we're currently indexing into. 110 gep_type_iterator GEPTI = gep_type_begin(Gep); 111 std::advance(GEPTI, LastOperand - 2); 112 113 // If it's a type with the same allocation size as the result of the GEP we 114 // can peel off the zero index. 115 if (DL.getTypeAllocSize(GEPTI.getIndexedType()) != GEPAllocSize) 116 break; 117 --LastOperand; 118 } 119 120 return LastOperand; 121 } 122 123 /// \brief If the argument is a GEP, then returns the operand identified by 124 /// getGEPInductionOperand. However, if there is some other non-loop-invariant 125 /// operand, it returns that instead. 126 Value *llvm::stripGetElementPtr(Value *Ptr, ScalarEvolution *SE, Loop *Lp) { 127 GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr); 128 if (!GEP) 129 return Ptr; 130 131 unsigned InductionOperand = getGEPInductionOperand(GEP); 132 133 // Check that all of the gep indices are uniform except for our induction 134 // operand. 135 for (unsigned i = 0, e = GEP->getNumOperands(); i != e; ++i) 136 if (i != InductionOperand && 137 !SE->isLoopInvariant(SE->getSCEV(GEP->getOperand(i)), Lp)) 138 return Ptr; 139 return GEP->getOperand(InductionOperand); 140 } 141 142 /// \brief If a value has only one user that is a CastInst, return it. 143 Value *llvm::getUniqueCastUse(Value *Ptr, Loop *Lp, Type *Ty) { 144 Value *UniqueCast = nullptr; 145 for (User *U : Ptr->users()) { 146 CastInst *CI = dyn_cast<CastInst>(U); 147 if (CI && CI->getType() == Ty) { 148 if (!UniqueCast) 149 UniqueCast = CI; 150 else 151 return nullptr; 152 } 153 } 154 return UniqueCast; 155 } 156 157 /// \brief Get the stride of a pointer access in a loop. Looks for symbolic 158 /// strides "a[i*stride]". Returns the symbolic stride, or null otherwise. 159 Value *llvm::getStrideFromPointer(Value *Ptr, ScalarEvolution *SE, Loop *Lp) { 160 auto *PtrTy = dyn_cast<PointerType>(Ptr->getType()); 161 if (!PtrTy || PtrTy->isAggregateType()) 162 return nullptr; 163 164 // Try to remove a gep instruction to make the pointer (actually index at this 165 // point) easier analyzable. If OrigPtr is equal to Ptr we are analzying the 166 // pointer, otherwise, we are analyzing the index. 167 Value *OrigPtr = Ptr; 168 169 // The size of the pointer access. 170 int64_t PtrAccessSize = 1; 171 172 Ptr = stripGetElementPtr(Ptr, SE, Lp); 173 const SCEV *V = SE->getSCEV(Ptr); 174 175 if (Ptr != OrigPtr) 176 // Strip off casts. 177 while (const SCEVCastExpr *C = dyn_cast<SCEVCastExpr>(V)) 178 V = C->getOperand(); 179 180 const SCEVAddRecExpr *S = dyn_cast<SCEVAddRecExpr>(V); 181 if (!S) 182 return nullptr; 183 184 V = S->getStepRecurrence(*SE); 185 if (!V) 186 return nullptr; 187 188 // Strip off the size of access multiplication if we are still analyzing the 189 // pointer. 190 if (OrigPtr == Ptr) { 191 if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(V)) { 192 if (M->getOperand(0)->getSCEVType() != scConstant) 193 return nullptr; 194 195 const APInt &APStepVal = cast<SCEVConstant>(M->getOperand(0))->getAPInt(); 196 197 // Huge step value - give up. 198 if (APStepVal.getBitWidth() > 64) 199 return nullptr; 200 201 int64_t StepVal = APStepVal.getSExtValue(); 202 if (PtrAccessSize != StepVal) 203 return nullptr; 204 V = M->getOperand(1); 205 } 206 } 207 208 // Strip off casts. 209 Type *StripedOffRecurrenceCast = nullptr; 210 if (const SCEVCastExpr *C = dyn_cast<SCEVCastExpr>(V)) { 211 StripedOffRecurrenceCast = C->getType(); 212 V = C->getOperand(); 213 } 214 215 // Look for the loop invariant symbolic value. 216 const SCEVUnknown *U = dyn_cast<SCEVUnknown>(V); 217 if (!U) 218 return nullptr; 219 220 Value *Stride = U->getValue(); 221 if (!Lp->isLoopInvariant(Stride)) 222 return nullptr; 223 224 // If we have stripped off the recurrence cast we have to make sure that we 225 // return the value that is used in this loop so that we can replace it later. 226 if (StripedOffRecurrenceCast) 227 Stride = getUniqueCastUse(Stride, Lp, StripedOffRecurrenceCast); 228 229 return Stride; 230 } 231 232 /// \brief Given a vector and an element number, see if the scalar value is 233 /// already around as a register, for example if it were inserted then extracted 234 /// from the vector. 235 Value *llvm::findScalarElement(Value *V, unsigned EltNo) { 236 assert(V->getType()->isVectorTy() && "Not looking at a vector?"); 237 VectorType *VTy = cast<VectorType>(V->getType()); 238 unsigned Width = VTy->getNumElements(); 239 if (EltNo >= Width) // Out of range access. 240 return UndefValue::get(VTy->getElementType()); 241 242 if (Constant *C = dyn_cast<Constant>(V)) 243 return C->getAggregateElement(EltNo); 244 245 if (InsertElementInst *III = dyn_cast<InsertElementInst>(V)) { 246 // If this is an insert to a variable element, we don't know what it is. 247 if (!isa<ConstantInt>(III->getOperand(2))) 248 return nullptr; 249 unsigned IIElt = cast<ConstantInt>(III->getOperand(2))->getZExtValue(); 250 251 // If this is an insert to the element we are looking for, return the 252 // inserted value. 253 if (EltNo == IIElt) 254 return III->getOperand(1); 255 256 // Otherwise, the insertelement doesn't modify the value, recurse on its 257 // vector input. 258 return findScalarElement(III->getOperand(0), EltNo); 259 } 260 261 if (ShuffleVectorInst *SVI = dyn_cast<ShuffleVectorInst>(V)) { 262 unsigned LHSWidth = SVI->getOperand(0)->getType()->getVectorNumElements(); 263 int InEl = SVI->getMaskValue(EltNo); 264 if (InEl < 0) 265 return UndefValue::get(VTy->getElementType()); 266 if (InEl < (int)LHSWidth) 267 return findScalarElement(SVI->getOperand(0), InEl); 268 return findScalarElement(SVI->getOperand(1), InEl - LHSWidth); 269 } 270 271 // Extract a value from a vector add operation with a constant zero. 272 Value *Val = nullptr; Constant *Con = nullptr; 273 if (match(V, m_Add(m_Value(Val), m_Constant(Con)))) 274 if (Constant *Elt = Con->getAggregateElement(EltNo)) 275 if (Elt->isNullValue()) 276 return findScalarElement(Val, EltNo); 277 278 // Otherwise, we don't know. 279 return nullptr; 280 } 281 282 /// \brief Get splat value if the input is a splat vector or return nullptr. 283 /// This function is not fully general. It checks only 2 cases: 284 /// the input value is (1) a splat constants vector or (2) a sequence 285 /// of instructions that broadcast a single value into a vector. 286 /// 287 const llvm::Value *llvm::getSplatValue(const Value *V) { 288 289 if (auto *C = dyn_cast<Constant>(V)) 290 if (isa<VectorType>(V->getType())) 291 return C->getSplatValue(); 292 293 auto *ShuffleInst = dyn_cast<ShuffleVectorInst>(V); 294 if (!ShuffleInst) 295 return nullptr; 296 // All-zero (or undef) shuffle mask elements. 297 for (int MaskElt : ShuffleInst->getShuffleMask()) 298 if (MaskElt != 0 && MaskElt != -1) 299 return nullptr; 300 // The first shuffle source is 'insertelement' with index 0. 301 auto *InsertEltInst = 302 dyn_cast<InsertElementInst>(ShuffleInst->getOperand(0)); 303 if (!InsertEltInst || !isa<ConstantInt>(InsertEltInst->getOperand(2)) || 304 !cast<ConstantInt>(InsertEltInst->getOperand(2))->isZero()) 305 return nullptr; 306 307 return InsertEltInst->getOperand(1); 308 } 309 310 MapVector<Instruction *, uint64_t> 311 llvm::computeMinimumValueSizes(ArrayRef<BasicBlock *> Blocks, DemandedBits &DB, 312 const TargetTransformInfo *TTI) { 313 314 // DemandedBits will give us every value's live-out bits. But we want 315 // to ensure no extra casts would need to be inserted, so every DAG 316 // of connected values must have the same minimum bitwidth. 317 EquivalenceClasses<Value *> ECs; 318 SmallVector<Value *, 16> Worklist; 319 SmallPtrSet<Value *, 4> Roots; 320 SmallPtrSet<Value *, 16> Visited; 321 DenseMap<Value *, uint64_t> DBits; 322 SmallPtrSet<Instruction *, 4> InstructionSet; 323 MapVector<Instruction *, uint64_t> MinBWs; 324 325 // Determine the roots. We work bottom-up, from truncs or icmps. 326 bool SeenExtFromIllegalType = false; 327 for (auto *BB : Blocks) 328 for (auto &I : *BB) { 329 InstructionSet.insert(&I); 330 331 if (TTI && (isa<ZExtInst>(&I) || isa<SExtInst>(&I)) && 332 !TTI->isTypeLegal(I.getOperand(0)->getType())) 333 SeenExtFromIllegalType = true; 334 335 // Only deal with non-vector integers up to 64-bits wide. 336 if ((isa<TruncInst>(&I) || isa<ICmpInst>(&I)) && 337 !I.getType()->isVectorTy() && 338 I.getOperand(0)->getType()->getScalarSizeInBits() <= 64) { 339 // Don't make work for ourselves. If we know the loaded type is legal, 340 // don't add it to the worklist. 341 if (TTI && isa<TruncInst>(&I) && TTI->isTypeLegal(I.getType())) 342 continue; 343 344 Worklist.push_back(&I); 345 Roots.insert(&I); 346 } 347 } 348 // Early exit. 349 if (Worklist.empty() || (TTI && !SeenExtFromIllegalType)) 350 return MinBWs; 351 352 // Now proceed breadth-first, unioning values together. 353 while (!Worklist.empty()) { 354 Value *Val = Worklist.pop_back_val(); 355 Value *Leader = ECs.getOrInsertLeaderValue(Val); 356 357 if (Visited.count(Val)) 358 continue; 359 Visited.insert(Val); 360 361 // Non-instructions terminate a chain successfully. 362 if (!isa<Instruction>(Val)) 363 continue; 364 Instruction *I = cast<Instruction>(Val); 365 366 // If we encounter a type that is larger than 64 bits, we can't represent 367 // it so bail out. 368 if (DB.getDemandedBits(I).getBitWidth() > 64) 369 return MapVector<Instruction *, uint64_t>(); 370 371 uint64_t V = DB.getDemandedBits(I).getZExtValue(); 372 DBits[Leader] |= V; 373 DBits[I] = V; 374 375 // Casts, loads and instructions outside of our range terminate a chain 376 // successfully. 377 if (isa<SExtInst>(I) || isa<ZExtInst>(I) || isa<LoadInst>(I) || 378 !InstructionSet.count(I)) 379 continue; 380 381 // Unsafe casts terminate a chain unsuccessfully. We can't do anything 382 // useful with bitcasts, ptrtoints or inttoptrs and it'd be unsafe to 383 // transform anything that relies on them. 384 if (isa<BitCastInst>(I) || isa<PtrToIntInst>(I) || isa<IntToPtrInst>(I) || 385 !I->getType()->isIntegerTy()) { 386 DBits[Leader] |= ~0ULL; 387 continue; 388 } 389 390 // We don't modify the types of PHIs. Reductions will already have been 391 // truncated if possible, and inductions' sizes will have been chosen by 392 // indvars. 393 if (isa<PHINode>(I)) 394 continue; 395 396 if (DBits[Leader] == ~0ULL) 397 // All bits demanded, no point continuing. 398 continue; 399 400 for (Value *O : cast<User>(I)->operands()) { 401 ECs.unionSets(Leader, O); 402 Worklist.push_back(O); 403 } 404 } 405 406 // Now we've discovered all values, walk them to see if there are 407 // any users we didn't see. If there are, we can't optimize that 408 // chain. 409 for (auto &I : DBits) 410 for (auto *U : I.first->users()) 411 if (U->getType()->isIntegerTy() && DBits.count(U) == 0) 412 DBits[ECs.getOrInsertLeaderValue(I.first)] |= ~0ULL; 413 414 for (auto I = ECs.begin(), E = ECs.end(); I != E; ++I) { 415 uint64_t LeaderDemandedBits = 0; 416 for (auto MI = ECs.member_begin(I), ME = ECs.member_end(); MI != ME; ++MI) 417 LeaderDemandedBits |= DBits[*MI]; 418 419 uint64_t MinBW = (sizeof(LeaderDemandedBits) * 8) - 420 llvm::countLeadingZeros(LeaderDemandedBits); 421 // Round up to a power of 2 422 if (!isPowerOf2_64((uint64_t)MinBW)) 423 MinBW = NextPowerOf2(MinBW); 424 425 // We don't modify the types of PHIs. Reductions will already have been 426 // truncated if possible, and inductions' sizes will have been chosen by 427 // indvars. 428 // If we are required to shrink a PHI, abandon this entire equivalence class. 429 bool Abort = false; 430 for (auto MI = ECs.member_begin(I), ME = ECs.member_end(); MI != ME; ++MI) 431 if (isa<PHINode>(*MI) && MinBW < (*MI)->getType()->getScalarSizeInBits()) { 432 Abort = true; 433 break; 434 } 435 if (Abort) 436 continue; 437 438 for (auto MI = ECs.member_begin(I), ME = ECs.member_end(); MI != ME; ++MI) { 439 if (!isa<Instruction>(*MI)) 440 continue; 441 Type *Ty = (*MI)->getType(); 442 if (Roots.count(*MI)) 443 Ty = cast<Instruction>(*MI)->getOperand(0)->getType(); 444 if (MinBW < Ty->getScalarSizeInBits()) 445 MinBWs[cast<Instruction>(*MI)] = MinBW; 446 } 447 } 448 449 return MinBWs; 450 } 451 452 /// \returns \p I after propagating metadata from \p VL. 453 Instruction *llvm::propagateMetadata(Instruction *Inst, ArrayRef<Value *> VL) { 454 Instruction *I0 = cast<Instruction>(VL[0]); 455 SmallVector<std::pair<unsigned, MDNode *>, 4> Metadata; 456 I0->getAllMetadataOtherThanDebugLoc(Metadata); 457 458 for (auto Kind : 459 {LLVMContext::MD_tbaa, LLVMContext::MD_alias_scope, 460 LLVMContext::MD_noalias, LLVMContext::MD_fpmath, 461 LLVMContext::MD_nontemporal, LLVMContext::MD_invariant_load}) { 462 MDNode *MD = I0->getMetadata(Kind); 463 464 for (int J = 1, E = VL.size(); MD && J != E; ++J) { 465 const Instruction *IJ = cast<Instruction>(VL[J]); 466 MDNode *IMD = IJ->getMetadata(Kind); 467 switch (Kind) { 468 case LLVMContext::MD_tbaa: 469 MD = MDNode::getMostGenericTBAA(MD, IMD); 470 break; 471 case LLVMContext::MD_alias_scope: 472 MD = MDNode::getMostGenericAliasScope(MD, IMD); 473 break; 474 case LLVMContext::MD_fpmath: 475 MD = MDNode::getMostGenericFPMath(MD, IMD); 476 break; 477 case LLVMContext::MD_noalias: 478 case LLVMContext::MD_nontemporal: 479 case LLVMContext::MD_invariant_load: 480 MD = MDNode::intersect(MD, IMD); 481 break; 482 default: 483 llvm_unreachable("unhandled metadata"); 484 } 485 } 486 487 Inst->setMetadata(Kind, MD); 488 } 489 490 return Inst; 491 } 492 493 Constant *llvm::createInterleaveMask(IRBuilder<> &Builder, unsigned VF, 494 unsigned NumVecs) { 495 SmallVector<Constant *, 16> Mask; 496 for (unsigned i = 0; i < VF; i++) 497 for (unsigned j = 0; j < NumVecs; j++) 498 Mask.push_back(Builder.getInt32(j * VF + i)); 499 500 return ConstantVector::get(Mask); 501 } 502 503 Constant *llvm::createStrideMask(IRBuilder<> &Builder, unsigned Start, 504 unsigned Stride, unsigned VF) { 505 SmallVector<Constant *, 16> Mask; 506 for (unsigned i = 0; i < VF; i++) 507 Mask.push_back(Builder.getInt32(Start + i * Stride)); 508 509 return ConstantVector::get(Mask); 510 } 511 512 Constant *llvm::createSequentialMask(IRBuilder<> &Builder, unsigned Start, 513 unsigned NumInts, unsigned NumUndefs) { 514 SmallVector<Constant *, 16> Mask; 515 for (unsigned i = 0; i < NumInts; i++) 516 Mask.push_back(Builder.getInt32(Start + i)); 517 518 Constant *Undef = UndefValue::get(Builder.getInt32Ty()); 519 for (unsigned i = 0; i < NumUndefs; i++) 520 Mask.push_back(Undef); 521 522 return ConstantVector::get(Mask); 523 } 524 525 /// A helper function for concatenating vectors. This function concatenates two 526 /// vectors having the same element type. If the second vector has fewer 527 /// elements than the first, it is padded with undefs. 528 static Value *concatenateTwoVectors(IRBuilder<> &Builder, Value *V1, 529 Value *V2) { 530 VectorType *VecTy1 = dyn_cast<VectorType>(V1->getType()); 531 VectorType *VecTy2 = dyn_cast<VectorType>(V2->getType()); 532 assert(VecTy1 && VecTy2 && 533 VecTy1->getScalarType() == VecTy2->getScalarType() && 534 "Expect two vectors with the same element type"); 535 536 unsigned NumElts1 = VecTy1->getNumElements(); 537 unsigned NumElts2 = VecTy2->getNumElements(); 538 assert(NumElts1 >= NumElts2 && "Unexpect the first vector has less elements"); 539 540 if (NumElts1 > NumElts2) { 541 // Extend with UNDEFs. 542 Constant *ExtMask = 543 createSequentialMask(Builder, 0, NumElts2, NumElts1 - NumElts2); 544 V2 = Builder.CreateShuffleVector(V2, UndefValue::get(VecTy2), ExtMask); 545 } 546 547 Constant *Mask = createSequentialMask(Builder, 0, NumElts1 + NumElts2, 0); 548 return Builder.CreateShuffleVector(V1, V2, Mask); 549 } 550 551 Value *llvm::concatenateVectors(IRBuilder<> &Builder, ArrayRef<Value *> Vecs) { 552 unsigned NumVecs = Vecs.size(); 553 assert(NumVecs > 1 && "Should be at least two vectors"); 554 555 SmallVector<Value *, 8> ResList; 556 ResList.append(Vecs.begin(), Vecs.end()); 557 do { 558 SmallVector<Value *, 8> TmpList; 559 for (unsigned i = 0; i < NumVecs - 1; i += 2) { 560 Value *V0 = ResList[i], *V1 = ResList[i + 1]; 561 assert((V0->getType() == V1->getType() || i == NumVecs - 2) && 562 "Only the last vector may have a different type"); 563 564 TmpList.push_back(concatenateTwoVectors(Builder, V0, V1)); 565 } 566 567 // Push the last vector if the total number of vectors is odd. 568 if (NumVecs % 2 != 0) 569 TmpList.push_back(ResList[NumVecs - 1]); 570 571 ResList = TmpList; 572 NumVecs = ResList.size(); 573 } while (NumVecs > 1); 574 575 return ResList[0]; 576 } 577