1 //===----------- VectorUtils.cpp - Vectorizer utility functions -----------===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // This file defines vectorizer utilities. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "llvm/ADT/EquivalenceClasses.h" 15 #include "llvm/Analysis/DemandedBits.h" 16 #include "llvm/Analysis/LoopInfo.h" 17 #include "llvm/Analysis/ScalarEvolutionExpressions.h" 18 #include "llvm/Analysis/ScalarEvolution.h" 19 #include "llvm/Analysis/TargetTransformInfo.h" 20 #include "llvm/Analysis/VectorUtils.h" 21 #include "llvm/IR/GetElementPtrTypeIterator.h" 22 #include "llvm/IR/PatternMatch.h" 23 #include "llvm/IR/Value.h" 24 #include "llvm/IR/Constants.h" 25 26 using namespace llvm; 27 using namespace llvm::PatternMatch; 28 29 /// \brief Identify if the intrinsic is trivially vectorizable. 30 /// This method returns true if the intrinsic's argument types are all 31 /// scalars for the scalar form of the intrinsic and all vectors for 32 /// the vector form of the intrinsic. 33 bool llvm::isTriviallyVectorizable(Intrinsic::ID ID) { 34 switch (ID) { 35 case Intrinsic::sqrt: 36 case Intrinsic::sin: 37 case Intrinsic::cos: 38 case Intrinsic::exp: 39 case Intrinsic::exp2: 40 case Intrinsic::log: 41 case Intrinsic::log10: 42 case Intrinsic::log2: 43 case Intrinsic::fabs: 44 case Intrinsic::minnum: 45 case Intrinsic::maxnum: 46 case Intrinsic::copysign: 47 case Intrinsic::floor: 48 case Intrinsic::ceil: 49 case Intrinsic::trunc: 50 case Intrinsic::rint: 51 case Intrinsic::nearbyint: 52 case Intrinsic::round: 53 case Intrinsic::bswap: 54 case Intrinsic::ctpop: 55 case Intrinsic::pow: 56 case Intrinsic::fma: 57 case Intrinsic::fmuladd: 58 case Intrinsic::ctlz: 59 case Intrinsic::cttz: 60 case Intrinsic::powi: 61 return true; 62 default: 63 return false; 64 } 65 } 66 67 /// \brief Identifies if the intrinsic has a scalar operand. It check for 68 /// ctlz,cttz and powi special intrinsics whose argument is scalar. 69 bool llvm::hasVectorInstrinsicScalarOpd(Intrinsic::ID ID, 70 unsigned ScalarOpdIdx) { 71 switch (ID) { 72 case Intrinsic::ctlz: 73 case Intrinsic::cttz: 74 case Intrinsic::powi: 75 return (ScalarOpdIdx == 1); 76 default: 77 return false; 78 } 79 } 80 81 /// \brief Check call has a unary float signature 82 /// It checks following: 83 /// a) call should have a single argument 84 /// b) argument type should be floating point type 85 /// c) call instruction type and argument type should be same 86 /// d) call should only reads memory. 87 /// If all these condition is met then return ValidIntrinsicID 88 /// else return not_intrinsic. 89 Intrinsic::ID 90 llvm::checkUnaryFloatSignature(const CallInst &I, 91 Intrinsic::ID ValidIntrinsicID) { 92 if (I.getNumArgOperands() != 1 || 93 !I.getArgOperand(0)->getType()->isFloatingPointTy() || 94 I.getType() != I.getArgOperand(0)->getType() || !I.onlyReadsMemory()) 95 return Intrinsic::not_intrinsic; 96 97 return ValidIntrinsicID; 98 } 99 100 /// \brief Check call has a binary float signature 101 /// It checks following: 102 /// a) call should have 2 arguments. 103 /// b) arguments type should be floating point type 104 /// c) call instruction type and arguments type should be same 105 /// d) call should only reads memory. 106 /// If all these condition is met then return ValidIntrinsicID 107 /// else return not_intrinsic. 108 Intrinsic::ID 109 llvm::checkBinaryFloatSignature(const CallInst &I, 110 Intrinsic::ID ValidIntrinsicID) { 111 if (I.getNumArgOperands() != 2 || 112 !I.getArgOperand(0)->getType()->isFloatingPointTy() || 113 !I.getArgOperand(1)->getType()->isFloatingPointTy() || 114 I.getType() != I.getArgOperand(0)->getType() || 115 I.getType() != I.getArgOperand(1)->getType() || !I.onlyReadsMemory()) 116 return Intrinsic::not_intrinsic; 117 118 return ValidIntrinsicID; 119 } 120 121 /// \brief Returns intrinsic ID for call. 122 /// For the input call instruction it finds mapping intrinsic and returns 123 /// its ID, in case it does not found it return not_intrinsic. 124 Intrinsic::ID llvm::getIntrinsicIDForCall(CallInst *CI, 125 const TargetLibraryInfo *TLI) { 126 // If we have an intrinsic call, check if it is trivially vectorizable. 127 if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI)) { 128 Intrinsic::ID ID = II->getIntrinsicID(); 129 if (isTriviallyVectorizable(ID) || ID == Intrinsic::lifetime_start || 130 ID == Intrinsic::lifetime_end || ID == Intrinsic::assume) 131 return ID; 132 return Intrinsic::not_intrinsic; 133 } 134 135 if (!TLI) 136 return Intrinsic::not_intrinsic; 137 138 LibFunc::Func Func; 139 Function *F = CI->getCalledFunction(); 140 // We're going to make assumptions on the semantics of the functions, check 141 // that the target knows that it's available in this environment and it does 142 // not have local linkage. 143 if (!F || F->hasLocalLinkage() || !TLI->getLibFunc(F->getName(), Func)) 144 return Intrinsic::not_intrinsic; 145 146 // Otherwise check if we have a call to a function that can be turned into a 147 // vector intrinsic. 148 switch (Func) { 149 default: 150 break; 151 case LibFunc::sin: 152 case LibFunc::sinf: 153 case LibFunc::sinl: 154 return checkUnaryFloatSignature(*CI, Intrinsic::sin); 155 case LibFunc::cos: 156 case LibFunc::cosf: 157 case LibFunc::cosl: 158 return checkUnaryFloatSignature(*CI, Intrinsic::cos); 159 case LibFunc::exp: 160 case LibFunc::expf: 161 case LibFunc::expl: 162 return checkUnaryFloatSignature(*CI, Intrinsic::exp); 163 case LibFunc::exp2: 164 case LibFunc::exp2f: 165 case LibFunc::exp2l: 166 return checkUnaryFloatSignature(*CI, Intrinsic::exp2); 167 case LibFunc::log: 168 case LibFunc::logf: 169 case LibFunc::logl: 170 return checkUnaryFloatSignature(*CI, Intrinsic::log); 171 case LibFunc::log10: 172 case LibFunc::log10f: 173 case LibFunc::log10l: 174 return checkUnaryFloatSignature(*CI, Intrinsic::log10); 175 case LibFunc::log2: 176 case LibFunc::log2f: 177 case LibFunc::log2l: 178 return checkUnaryFloatSignature(*CI, Intrinsic::log2); 179 case LibFunc::fabs: 180 case LibFunc::fabsf: 181 case LibFunc::fabsl: 182 return checkUnaryFloatSignature(*CI, Intrinsic::fabs); 183 case LibFunc::fmin: 184 case LibFunc::fminf: 185 case LibFunc::fminl: 186 return checkBinaryFloatSignature(*CI, Intrinsic::minnum); 187 case LibFunc::fmax: 188 case LibFunc::fmaxf: 189 case LibFunc::fmaxl: 190 return checkBinaryFloatSignature(*CI, Intrinsic::maxnum); 191 case LibFunc::copysign: 192 case LibFunc::copysignf: 193 case LibFunc::copysignl: 194 return checkBinaryFloatSignature(*CI, Intrinsic::copysign); 195 case LibFunc::floor: 196 case LibFunc::floorf: 197 case LibFunc::floorl: 198 return checkUnaryFloatSignature(*CI, Intrinsic::floor); 199 case LibFunc::ceil: 200 case LibFunc::ceilf: 201 case LibFunc::ceill: 202 return checkUnaryFloatSignature(*CI, Intrinsic::ceil); 203 case LibFunc::trunc: 204 case LibFunc::truncf: 205 case LibFunc::truncl: 206 return checkUnaryFloatSignature(*CI, Intrinsic::trunc); 207 case LibFunc::rint: 208 case LibFunc::rintf: 209 case LibFunc::rintl: 210 return checkUnaryFloatSignature(*CI, Intrinsic::rint); 211 case LibFunc::nearbyint: 212 case LibFunc::nearbyintf: 213 case LibFunc::nearbyintl: 214 return checkUnaryFloatSignature(*CI, Intrinsic::nearbyint); 215 case LibFunc::round: 216 case LibFunc::roundf: 217 case LibFunc::roundl: 218 return checkUnaryFloatSignature(*CI, Intrinsic::round); 219 case LibFunc::pow: 220 case LibFunc::powf: 221 case LibFunc::powl: 222 return checkBinaryFloatSignature(*CI, Intrinsic::pow); 223 } 224 225 return Intrinsic::not_intrinsic; 226 } 227 228 /// \brief Find the operand of the GEP that should be checked for consecutive 229 /// stores. This ignores trailing indices that have no effect on the final 230 /// pointer. 231 unsigned llvm::getGEPInductionOperand(const GetElementPtrInst *Gep) { 232 const DataLayout &DL = Gep->getModule()->getDataLayout(); 233 unsigned LastOperand = Gep->getNumOperands() - 1; 234 unsigned GEPAllocSize = DL.getTypeAllocSize( 235 cast<PointerType>(Gep->getType()->getScalarType())->getElementType()); 236 237 // Walk backwards and try to peel off zeros. 238 while (LastOperand > 1 && match(Gep->getOperand(LastOperand), m_Zero())) { 239 // Find the type we're currently indexing into. 240 gep_type_iterator GEPTI = gep_type_begin(Gep); 241 std::advance(GEPTI, LastOperand - 1); 242 243 // If it's a type with the same allocation size as the result of the GEP we 244 // can peel off the zero index. 245 if (DL.getTypeAllocSize(*GEPTI) != GEPAllocSize) 246 break; 247 --LastOperand; 248 } 249 250 return LastOperand; 251 } 252 253 /// \brief If the argument is a GEP, then returns the operand identified by 254 /// getGEPInductionOperand. However, if there is some other non-loop-invariant 255 /// operand, it returns that instead. 256 Value *llvm::stripGetElementPtr(Value *Ptr, ScalarEvolution *SE, Loop *Lp) { 257 GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr); 258 if (!GEP) 259 return Ptr; 260 261 unsigned InductionOperand = getGEPInductionOperand(GEP); 262 263 // Check that all of the gep indices are uniform except for our induction 264 // operand. 265 for (unsigned i = 0, e = GEP->getNumOperands(); i != e; ++i) 266 if (i != InductionOperand && 267 !SE->isLoopInvariant(SE->getSCEV(GEP->getOperand(i)), Lp)) 268 return Ptr; 269 return GEP->getOperand(InductionOperand); 270 } 271 272 /// \brief If a value has only one user that is a CastInst, return it. 273 Value *llvm::getUniqueCastUse(Value *Ptr, Loop *Lp, Type *Ty) { 274 Value *UniqueCast = nullptr; 275 for (User *U : Ptr->users()) { 276 CastInst *CI = dyn_cast<CastInst>(U); 277 if (CI && CI->getType() == Ty) { 278 if (!UniqueCast) 279 UniqueCast = CI; 280 else 281 return nullptr; 282 } 283 } 284 return UniqueCast; 285 } 286 287 /// \brief Get the stride of a pointer access in a loop. Looks for symbolic 288 /// strides "a[i*stride]". Returns the symbolic stride, or null otherwise. 289 Value *llvm::getStrideFromPointer(Value *Ptr, ScalarEvolution *SE, Loop *Lp) { 290 auto *PtrTy = dyn_cast<PointerType>(Ptr->getType()); 291 if (!PtrTy || PtrTy->isAggregateType()) 292 return nullptr; 293 294 // Try to remove a gep instruction to make the pointer (actually index at this 295 // point) easier analyzable. If OrigPtr is equal to Ptr we are analzying the 296 // pointer, otherwise, we are analyzing the index. 297 Value *OrigPtr = Ptr; 298 299 // The size of the pointer access. 300 int64_t PtrAccessSize = 1; 301 302 Ptr = stripGetElementPtr(Ptr, SE, Lp); 303 const SCEV *V = SE->getSCEV(Ptr); 304 305 if (Ptr != OrigPtr) 306 // Strip off casts. 307 while (const SCEVCastExpr *C = dyn_cast<SCEVCastExpr>(V)) 308 V = C->getOperand(); 309 310 const SCEVAddRecExpr *S = dyn_cast<SCEVAddRecExpr>(V); 311 if (!S) 312 return nullptr; 313 314 V = S->getStepRecurrence(*SE); 315 if (!V) 316 return nullptr; 317 318 // Strip off the size of access multiplication if we are still analyzing the 319 // pointer. 320 if (OrigPtr == Ptr) { 321 const DataLayout &DL = Lp->getHeader()->getModule()->getDataLayout(); 322 DL.getTypeAllocSize(PtrTy->getElementType()); 323 if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(V)) { 324 if (M->getOperand(0)->getSCEVType() != scConstant) 325 return nullptr; 326 327 const APInt &APStepVal = 328 cast<SCEVConstant>(M->getOperand(0))->getValue()->getValue(); 329 330 // Huge step value - give up. 331 if (APStepVal.getBitWidth() > 64) 332 return nullptr; 333 334 int64_t StepVal = APStepVal.getSExtValue(); 335 if (PtrAccessSize != StepVal) 336 return nullptr; 337 V = M->getOperand(1); 338 } 339 } 340 341 // Strip off casts. 342 Type *StripedOffRecurrenceCast = nullptr; 343 if (const SCEVCastExpr *C = dyn_cast<SCEVCastExpr>(V)) { 344 StripedOffRecurrenceCast = C->getType(); 345 V = C->getOperand(); 346 } 347 348 // Look for the loop invariant symbolic value. 349 const SCEVUnknown *U = dyn_cast<SCEVUnknown>(V); 350 if (!U) 351 return nullptr; 352 353 Value *Stride = U->getValue(); 354 if (!Lp->isLoopInvariant(Stride)) 355 return nullptr; 356 357 // If we have stripped off the recurrence cast we have to make sure that we 358 // return the value that is used in this loop so that we can replace it later. 359 if (StripedOffRecurrenceCast) 360 Stride = getUniqueCastUse(Stride, Lp, StripedOffRecurrenceCast); 361 362 return Stride; 363 } 364 365 /// \brief Given a vector and an element number, see if the scalar value is 366 /// already around as a register, for example if it were inserted then extracted 367 /// from the vector. 368 Value *llvm::findScalarElement(Value *V, unsigned EltNo) { 369 assert(V->getType()->isVectorTy() && "Not looking at a vector?"); 370 VectorType *VTy = cast<VectorType>(V->getType()); 371 unsigned Width = VTy->getNumElements(); 372 if (EltNo >= Width) // Out of range access. 373 return UndefValue::get(VTy->getElementType()); 374 375 if (Constant *C = dyn_cast<Constant>(V)) 376 return C->getAggregateElement(EltNo); 377 378 if (InsertElementInst *III = dyn_cast<InsertElementInst>(V)) { 379 // If this is an insert to a variable element, we don't know what it is. 380 if (!isa<ConstantInt>(III->getOperand(2))) 381 return nullptr; 382 unsigned IIElt = cast<ConstantInt>(III->getOperand(2))->getZExtValue(); 383 384 // If this is an insert to the element we are looking for, return the 385 // inserted value. 386 if (EltNo == IIElt) 387 return III->getOperand(1); 388 389 // Otherwise, the insertelement doesn't modify the value, recurse on its 390 // vector input. 391 return findScalarElement(III->getOperand(0), EltNo); 392 } 393 394 if (ShuffleVectorInst *SVI = dyn_cast<ShuffleVectorInst>(V)) { 395 unsigned LHSWidth = SVI->getOperand(0)->getType()->getVectorNumElements(); 396 int InEl = SVI->getMaskValue(EltNo); 397 if (InEl < 0) 398 return UndefValue::get(VTy->getElementType()); 399 if (InEl < (int)LHSWidth) 400 return findScalarElement(SVI->getOperand(0), InEl); 401 return findScalarElement(SVI->getOperand(1), InEl - LHSWidth); 402 } 403 404 // Extract a value from a vector add operation with a constant zero. 405 Value *Val = nullptr; Constant *Con = nullptr; 406 if (match(V, m_Add(m_Value(Val), m_Constant(Con)))) 407 if (Constant *Elt = Con->getAggregateElement(EltNo)) 408 if (Elt->isNullValue()) 409 return findScalarElement(Val, EltNo); 410 411 // Otherwise, we don't know. 412 return nullptr; 413 } 414 415 /// \brief Get splat value if the input is a splat vector or return nullptr. 416 /// This function is not fully general. It checks only 2 cases: 417 /// the input value is (1) a splat constants vector or (2) a sequence 418 /// of instructions that broadcast a single value into a vector. 419 /// 420 llvm::Value *llvm::getSplatValue(Value *V) { 421 if (auto *CV = dyn_cast<ConstantDataVector>(V)) 422 return CV->getSplatValue(); 423 424 auto *ShuffleInst = dyn_cast<ShuffleVectorInst>(V); 425 if (!ShuffleInst) 426 return nullptr; 427 // All-zero (or undef) shuffle mask elements. 428 for (int MaskElt : ShuffleInst->getShuffleMask()) 429 if (MaskElt != 0 && MaskElt != -1) 430 return nullptr; 431 // The first shuffle source is 'insertelement' with index 0. 432 auto *InsertEltInst = 433 dyn_cast<InsertElementInst>(ShuffleInst->getOperand(0)); 434 if (!InsertEltInst || !isa<ConstantInt>(InsertEltInst->getOperand(2)) || 435 !cast<ConstantInt>(InsertEltInst->getOperand(2))->isNullValue()) 436 return nullptr; 437 438 return InsertEltInst->getOperand(1); 439 } 440 441 DenseMap<Instruction *, uint64_t> 442 llvm::computeMinimumValueSizes(ArrayRef<BasicBlock *> Blocks, DemandedBits &DB, 443 const TargetTransformInfo *TTI) { 444 445 // DemandedBits will give us every value's live-out bits. But we want 446 // to ensure no extra casts would need to be inserted, so every DAG 447 // of connected values must have the same minimum bitwidth. 448 EquivalenceClasses<Value *> ECs; 449 SmallVector<Value *, 16> Worklist; 450 SmallPtrSet<Value *, 4> Roots; 451 SmallPtrSet<Value *, 16> Visited; 452 DenseMap<Value *, uint64_t> DBits; 453 SmallPtrSet<Instruction *, 4> InstructionSet; 454 DenseMap<Instruction *, uint64_t> MinBWs; 455 456 // Determine the roots. We work bottom-up, from truncs or icmps. 457 bool SeenExtFromIllegalType = false; 458 for (auto *BB : Blocks) 459 for (auto &I : *BB) { 460 InstructionSet.insert(&I); 461 462 if (TTI && (isa<ZExtInst>(&I) || isa<SExtInst>(&I)) && 463 !TTI->isTypeLegal(I.getOperand(0)->getType())) 464 SeenExtFromIllegalType = true; 465 466 // Only deal with non-vector integers up to 64-bits wide. 467 if ((isa<TruncInst>(&I) || isa<ICmpInst>(&I)) && 468 !I.getType()->isVectorTy() && 469 I.getOperand(0)->getType()->getScalarSizeInBits() <= 64) { 470 // Don't make work for ourselves. If we know the loaded type is legal, 471 // don't add it to the worklist. 472 if (TTI && isa<TruncInst>(&I) && TTI->isTypeLegal(I.getType())) 473 continue; 474 475 Worklist.push_back(&I); 476 Roots.insert(&I); 477 } 478 } 479 // Early exit. 480 if (Worklist.empty() || (TTI && !SeenExtFromIllegalType)) 481 return MinBWs; 482 483 // Now proceed breadth-first, unioning values together. 484 while (!Worklist.empty()) { 485 Value *Val = Worklist.pop_back_val(); 486 Value *Leader = ECs.getOrInsertLeaderValue(Val); 487 488 if (Visited.count(Val)) 489 continue; 490 Visited.insert(Val); 491 492 // Non-instructions terminate a chain successfully. 493 if (!isa<Instruction>(Val)) 494 continue; 495 Instruction *I = cast<Instruction>(Val); 496 497 // If we encounter a type that is larger than 64 bits, we can't represent 498 // it so bail out. 499 if (DB.getDemandedBits(I).getBitWidth() > 64) 500 return DenseMap<Instruction *, uint64_t>(); 501 502 uint64_t V = DB.getDemandedBits(I).getZExtValue(); 503 DBits[Leader] |= V; 504 505 // Casts, loads and instructions outside of our range terminate a chain 506 // successfully. 507 if (isa<SExtInst>(I) || isa<ZExtInst>(I) || isa<LoadInst>(I) || 508 !InstructionSet.count(I)) 509 continue; 510 511 // Unsafe casts terminate a chain unsuccessfully. We can't do anything 512 // useful with bitcasts, ptrtoints or inttoptrs and it'd be unsafe to 513 // transform anything that relies on them. 514 if (isa<BitCastInst>(I) || isa<PtrToIntInst>(I) || isa<IntToPtrInst>(I) || 515 !I->getType()->isIntegerTy()) { 516 DBits[Leader] |= ~0ULL; 517 continue; 518 } 519 520 // We don't modify the types of PHIs. Reductions will already have been 521 // truncated if possible, and inductions' sizes will have been chosen by 522 // indvars. 523 if (isa<PHINode>(I)) 524 continue; 525 526 if (DBits[Leader] == ~0ULL) 527 // All bits demanded, no point continuing. 528 continue; 529 530 for (Value *O : cast<User>(I)->operands()) { 531 ECs.unionSets(Leader, O); 532 Worklist.push_back(O); 533 } 534 } 535 536 // Now we've discovered all values, walk them to see if there are 537 // any users we didn't see. If there are, we can't optimize that 538 // chain. 539 for (auto &I : DBits) 540 for (auto *U : I.first->users()) 541 if (U->getType()->isIntegerTy() && DBits.count(U) == 0) 542 DBits[ECs.getOrInsertLeaderValue(I.first)] |= ~0ULL; 543 544 for (auto I = ECs.begin(), E = ECs.end(); I != E; ++I) { 545 uint64_t LeaderDemandedBits = 0; 546 for (auto MI = ECs.member_begin(I), ME = ECs.member_end(); MI != ME; ++MI) 547 LeaderDemandedBits |= DBits[*MI]; 548 549 uint64_t MinBW = (sizeof(LeaderDemandedBits) * 8) - 550 llvm::countLeadingZeros(LeaderDemandedBits); 551 // Round up to a power of 2 552 if (!isPowerOf2_64((uint64_t)MinBW)) 553 MinBW = NextPowerOf2(MinBW); 554 for (auto MI = ECs.member_begin(I), ME = ECs.member_end(); MI != ME; ++MI) { 555 if (!isa<Instruction>(*MI)) 556 continue; 557 Type *Ty = (*MI)->getType(); 558 if (Roots.count(*MI)) 559 Ty = cast<Instruction>(*MI)->getOperand(0)->getType(); 560 if (MinBW < Ty->getScalarSizeInBits()) 561 MinBWs[cast<Instruction>(*MI)] = MinBW; 562 } 563 } 564 565 return MinBWs; 566 } 567