1 //===----------- VectorUtils.cpp - Vectorizer utility functions -----------===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // This file defines vectorizer utilities. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "llvm/ADT/EquivalenceClasses.h" 15 #include "llvm/Analysis/DemandedBits.h" 16 #include "llvm/Analysis/LoopInfo.h" 17 #include "llvm/Analysis/ScalarEvolutionExpressions.h" 18 #include "llvm/Analysis/ScalarEvolution.h" 19 #include "llvm/Analysis/TargetTransformInfo.h" 20 #include "llvm/Analysis/VectorUtils.h" 21 #include "llvm/IR/GetElementPtrTypeIterator.h" 22 #include "llvm/IR/PatternMatch.h" 23 #include "llvm/IR/Value.h" 24 #include "llvm/IR/Constants.h" 25 26 using namespace llvm; 27 using namespace llvm::PatternMatch; 28 29 /// \brief Identify if the intrinsic is trivially vectorizable. 30 /// This method returns true if the intrinsic's argument types are all 31 /// scalars for the scalar form of the intrinsic and all vectors for 32 /// the vector form of the intrinsic. 33 bool llvm::isTriviallyVectorizable(Intrinsic::ID ID) { 34 switch (ID) { 35 case Intrinsic::sqrt: 36 case Intrinsic::sin: 37 case Intrinsic::cos: 38 case Intrinsic::exp: 39 case Intrinsic::exp2: 40 case Intrinsic::log: 41 case Intrinsic::log10: 42 case Intrinsic::log2: 43 case Intrinsic::fabs: 44 case Intrinsic::minnum: 45 case Intrinsic::maxnum: 46 case Intrinsic::copysign: 47 case Intrinsic::floor: 48 case Intrinsic::ceil: 49 case Intrinsic::trunc: 50 case Intrinsic::rint: 51 case Intrinsic::nearbyint: 52 case Intrinsic::round: 53 case Intrinsic::bswap: 54 case Intrinsic::ctpop: 55 case Intrinsic::pow: 56 case Intrinsic::fma: 57 case Intrinsic::fmuladd: 58 case Intrinsic::ctlz: 59 case Intrinsic::cttz: 60 case Intrinsic::powi: 61 return true; 62 default: 63 return false; 64 } 65 } 66 67 /// \brief Identifies if the intrinsic has a scalar operand. It check for 68 /// ctlz,cttz and powi special intrinsics whose argument is scalar. 69 bool llvm::hasVectorInstrinsicScalarOpd(Intrinsic::ID ID, 70 unsigned ScalarOpdIdx) { 71 switch (ID) { 72 case Intrinsic::ctlz: 73 case Intrinsic::cttz: 74 case Intrinsic::powi: 75 return (ScalarOpdIdx == 1); 76 default: 77 return false; 78 } 79 } 80 81 /// \brief Check call has a unary float signature 82 /// It checks following: 83 /// a) call should have a single argument 84 /// b) argument type should be floating point type 85 /// c) call instruction type and argument type should be same 86 /// d) call should only reads memory. 87 /// If all these condition is met then return ValidIntrinsicID 88 /// else return not_intrinsic. 89 Intrinsic::ID 90 llvm::checkUnaryFloatSignature(const CallInst &I, 91 Intrinsic::ID ValidIntrinsicID) { 92 if (I.getNumArgOperands() != 1 || 93 !I.getArgOperand(0)->getType()->isFloatingPointTy() || 94 I.getType() != I.getArgOperand(0)->getType() || !I.onlyReadsMemory()) 95 return Intrinsic::not_intrinsic; 96 97 return ValidIntrinsicID; 98 } 99 100 /// \brief Check call has a binary float signature 101 /// It checks following: 102 /// a) call should have 2 arguments. 103 /// b) arguments type should be floating point type 104 /// c) call instruction type and arguments type should be same 105 /// d) call should only reads memory. 106 /// If all these condition is met then return ValidIntrinsicID 107 /// else return not_intrinsic. 108 Intrinsic::ID 109 llvm::checkBinaryFloatSignature(const CallInst &I, 110 Intrinsic::ID ValidIntrinsicID) { 111 if (I.getNumArgOperands() != 2 || 112 !I.getArgOperand(0)->getType()->isFloatingPointTy() || 113 !I.getArgOperand(1)->getType()->isFloatingPointTy() || 114 I.getType() != I.getArgOperand(0)->getType() || 115 I.getType() != I.getArgOperand(1)->getType() || !I.onlyReadsMemory()) 116 return Intrinsic::not_intrinsic; 117 118 return ValidIntrinsicID; 119 } 120 121 /// \brief Returns intrinsic ID for call. 122 /// For the input call instruction it finds mapping intrinsic and returns 123 /// its ID, in case it does not found it return not_intrinsic. 124 Intrinsic::ID llvm::getIntrinsicIDForCall(CallInst *CI, 125 const TargetLibraryInfo *TLI) { 126 // If we have an intrinsic call, check if it is trivially vectorizable. 127 if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI)) { 128 Intrinsic::ID ID = II->getIntrinsicID(); 129 if (isTriviallyVectorizable(ID) || ID == Intrinsic::lifetime_start || 130 ID == Intrinsic::lifetime_end || ID == Intrinsic::assume) 131 return ID; 132 return Intrinsic::not_intrinsic; 133 } 134 135 if (!TLI) 136 return Intrinsic::not_intrinsic; 137 138 LibFunc::Func Func; 139 Function *F = CI->getCalledFunction(); 140 // We're going to make assumptions on the semantics of the functions, check 141 // that the target knows that it's available in this environment and it does 142 // not have local linkage. 143 if (!F || F->hasLocalLinkage() || !TLI->getLibFunc(F->getName(), Func)) 144 return Intrinsic::not_intrinsic; 145 146 // Otherwise check if we have a call to a function that can be turned into a 147 // vector intrinsic. 148 switch (Func) { 149 default: 150 break; 151 case LibFunc::sin: 152 case LibFunc::sinf: 153 case LibFunc::sinl: 154 return checkUnaryFloatSignature(*CI, Intrinsic::sin); 155 case LibFunc::cos: 156 case LibFunc::cosf: 157 case LibFunc::cosl: 158 return checkUnaryFloatSignature(*CI, Intrinsic::cos); 159 case LibFunc::exp: 160 case LibFunc::expf: 161 case LibFunc::expl: 162 return checkUnaryFloatSignature(*CI, Intrinsic::exp); 163 case LibFunc::exp2: 164 case LibFunc::exp2f: 165 case LibFunc::exp2l: 166 return checkUnaryFloatSignature(*CI, Intrinsic::exp2); 167 case LibFunc::log: 168 case LibFunc::logf: 169 case LibFunc::logl: 170 return checkUnaryFloatSignature(*CI, Intrinsic::log); 171 case LibFunc::log10: 172 case LibFunc::log10f: 173 case LibFunc::log10l: 174 return checkUnaryFloatSignature(*CI, Intrinsic::log10); 175 case LibFunc::log2: 176 case LibFunc::log2f: 177 case LibFunc::log2l: 178 return checkUnaryFloatSignature(*CI, Intrinsic::log2); 179 case LibFunc::fabs: 180 case LibFunc::fabsf: 181 case LibFunc::fabsl: 182 return checkUnaryFloatSignature(*CI, Intrinsic::fabs); 183 case LibFunc::fmin: 184 case LibFunc::fminf: 185 case LibFunc::fminl: 186 return checkBinaryFloatSignature(*CI, Intrinsic::minnum); 187 case LibFunc::fmax: 188 case LibFunc::fmaxf: 189 case LibFunc::fmaxl: 190 return checkBinaryFloatSignature(*CI, Intrinsic::maxnum); 191 case LibFunc::copysign: 192 case LibFunc::copysignf: 193 case LibFunc::copysignl: 194 return checkBinaryFloatSignature(*CI, Intrinsic::copysign); 195 case LibFunc::floor: 196 case LibFunc::floorf: 197 case LibFunc::floorl: 198 return checkUnaryFloatSignature(*CI, Intrinsic::floor); 199 case LibFunc::ceil: 200 case LibFunc::ceilf: 201 case LibFunc::ceill: 202 return checkUnaryFloatSignature(*CI, Intrinsic::ceil); 203 case LibFunc::trunc: 204 case LibFunc::truncf: 205 case LibFunc::truncl: 206 return checkUnaryFloatSignature(*CI, Intrinsic::trunc); 207 case LibFunc::rint: 208 case LibFunc::rintf: 209 case LibFunc::rintl: 210 return checkUnaryFloatSignature(*CI, Intrinsic::rint); 211 case LibFunc::nearbyint: 212 case LibFunc::nearbyintf: 213 case LibFunc::nearbyintl: 214 return checkUnaryFloatSignature(*CI, Intrinsic::nearbyint); 215 case LibFunc::round: 216 case LibFunc::roundf: 217 case LibFunc::roundl: 218 return checkUnaryFloatSignature(*CI, Intrinsic::round); 219 case LibFunc::pow: 220 case LibFunc::powf: 221 case LibFunc::powl: 222 return checkBinaryFloatSignature(*CI, Intrinsic::pow); 223 } 224 225 return Intrinsic::not_intrinsic; 226 } 227 228 /// \brief Find the operand of the GEP that should be checked for consecutive 229 /// stores. This ignores trailing indices that have no effect on the final 230 /// pointer. 231 unsigned llvm::getGEPInductionOperand(const GetElementPtrInst *Gep) { 232 const DataLayout &DL = Gep->getModule()->getDataLayout(); 233 unsigned LastOperand = Gep->getNumOperands() - 1; 234 unsigned GEPAllocSize = DL.getTypeAllocSize(Gep->getResultElementType()); 235 236 // Walk backwards and try to peel off zeros. 237 while (LastOperand > 1 && match(Gep->getOperand(LastOperand), m_Zero())) { 238 // Find the type we're currently indexing into. 239 gep_type_iterator GEPTI = gep_type_begin(Gep); 240 std::advance(GEPTI, LastOperand - 1); 241 242 // If it's a type with the same allocation size as the result of the GEP we 243 // can peel off the zero index. 244 if (DL.getTypeAllocSize(*GEPTI) != GEPAllocSize) 245 break; 246 --LastOperand; 247 } 248 249 return LastOperand; 250 } 251 252 /// \brief If the argument is a GEP, then returns the operand identified by 253 /// getGEPInductionOperand. However, if there is some other non-loop-invariant 254 /// operand, it returns that instead. 255 Value *llvm::stripGetElementPtr(Value *Ptr, ScalarEvolution *SE, Loop *Lp) { 256 GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr); 257 if (!GEP) 258 return Ptr; 259 260 unsigned InductionOperand = getGEPInductionOperand(GEP); 261 262 // Check that all of the gep indices are uniform except for our induction 263 // operand. 264 for (unsigned i = 0, e = GEP->getNumOperands(); i != e; ++i) 265 if (i != InductionOperand && 266 !SE->isLoopInvariant(SE->getSCEV(GEP->getOperand(i)), Lp)) 267 return Ptr; 268 return GEP->getOperand(InductionOperand); 269 } 270 271 /// \brief If a value has only one user that is a CastInst, return it. 272 Value *llvm::getUniqueCastUse(Value *Ptr, Loop *Lp, Type *Ty) { 273 Value *UniqueCast = nullptr; 274 for (User *U : Ptr->users()) { 275 CastInst *CI = dyn_cast<CastInst>(U); 276 if (CI && CI->getType() == Ty) { 277 if (!UniqueCast) 278 UniqueCast = CI; 279 else 280 return nullptr; 281 } 282 } 283 return UniqueCast; 284 } 285 286 /// \brief Get the stride of a pointer access in a loop. Looks for symbolic 287 /// strides "a[i*stride]". Returns the symbolic stride, or null otherwise. 288 Value *llvm::getStrideFromPointer(Value *Ptr, ScalarEvolution *SE, Loop *Lp) { 289 auto *PtrTy = dyn_cast<PointerType>(Ptr->getType()); 290 if (!PtrTy || PtrTy->isAggregateType()) 291 return nullptr; 292 293 // Try to remove a gep instruction to make the pointer (actually index at this 294 // point) easier analyzable. If OrigPtr is equal to Ptr we are analzying the 295 // pointer, otherwise, we are analyzing the index. 296 Value *OrigPtr = Ptr; 297 298 // The size of the pointer access. 299 int64_t PtrAccessSize = 1; 300 301 Ptr = stripGetElementPtr(Ptr, SE, Lp); 302 const SCEV *V = SE->getSCEV(Ptr); 303 304 if (Ptr != OrigPtr) 305 // Strip off casts. 306 while (const SCEVCastExpr *C = dyn_cast<SCEVCastExpr>(V)) 307 V = C->getOperand(); 308 309 const SCEVAddRecExpr *S = dyn_cast<SCEVAddRecExpr>(V); 310 if (!S) 311 return nullptr; 312 313 V = S->getStepRecurrence(*SE); 314 if (!V) 315 return nullptr; 316 317 // Strip off the size of access multiplication if we are still analyzing the 318 // pointer. 319 if (OrigPtr == Ptr) { 320 if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(V)) { 321 if (M->getOperand(0)->getSCEVType() != scConstant) 322 return nullptr; 323 324 const APInt &APStepVal = cast<SCEVConstant>(M->getOperand(0))->getAPInt(); 325 326 // Huge step value - give up. 327 if (APStepVal.getBitWidth() > 64) 328 return nullptr; 329 330 int64_t StepVal = APStepVal.getSExtValue(); 331 if (PtrAccessSize != StepVal) 332 return nullptr; 333 V = M->getOperand(1); 334 } 335 } 336 337 // Strip off casts. 338 Type *StripedOffRecurrenceCast = nullptr; 339 if (const SCEVCastExpr *C = dyn_cast<SCEVCastExpr>(V)) { 340 StripedOffRecurrenceCast = C->getType(); 341 V = C->getOperand(); 342 } 343 344 // Look for the loop invariant symbolic value. 345 const SCEVUnknown *U = dyn_cast<SCEVUnknown>(V); 346 if (!U) 347 return nullptr; 348 349 Value *Stride = U->getValue(); 350 if (!Lp->isLoopInvariant(Stride)) 351 return nullptr; 352 353 // If we have stripped off the recurrence cast we have to make sure that we 354 // return the value that is used in this loop so that we can replace it later. 355 if (StripedOffRecurrenceCast) 356 Stride = getUniqueCastUse(Stride, Lp, StripedOffRecurrenceCast); 357 358 return Stride; 359 } 360 361 /// \brief Given a vector and an element number, see if the scalar value is 362 /// already around as a register, for example if it were inserted then extracted 363 /// from the vector. 364 Value *llvm::findScalarElement(Value *V, unsigned EltNo) { 365 assert(V->getType()->isVectorTy() && "Not looking at a vector?"); 366 VectorType *VTy = cast<VectorType>(V->getType()); 367 unsigned Width = VTy->getNumElements(); 368 if (EltNo >= Width) // Out of range access. 369 return UndefValue::get(VTy->getElementType()); 370 371 if (Constant *C = dyn_cast<Constant>(V)) 372 return C->getAggregateElement(EltNo); 373 374 if (InsertElementInst *III = dyn_cast<InsertElementInst>(V)) { 375 // If this is an insert to a variable element, we don't know what it is. 376 if (!isa<ConstantInt>(III->getOperand(2))) 377 return nullptr; 378 unsigned IIElt = cast<ConstantInt>(III->getOperand(2))->getZExtValue(); 379 380 // If this is an insert to the element we are looking for, return the 381 // inserted value. 382 if (EltNo == IIElt) 383 return III->getOperand(1); 384 385 // Otherwise, the insertelement doesn't modify the value, recurse on its 386 // vector input. 387 return findScalarElement(III->getOperand(0), EltNo); 388 } 389 390 if (ShuffleVectorInst *SVI = dyn_cast<ShuffleVectorInst>(V)) { 391 unsigned LHSWidth = SVI->getOperand(0)->getType()->getVectorNumElements(); 392 int InEl = SVI->getMaskValue(EltNo); 393 if (InEl < 0) 394 return UndefValue::get(VTy->getElementType()); 395 if (InEl < (int)LHSWidth) 396 return findScalarElement(SVI->getOperand(0), InEl); 397 return findScalarElement(SVI->getOperand(1), InEl - LHSWidth); 398 } 399 400 // Extract a value from a vector add operation with a constant zero. 401 Value *Val = nullptr; Constant *Con = nullptr; 402 if (match(V, m_Add(m_Value(Val), m_Constant(Con)))) 403 if (Constant *Elt = Con->getAggregateElement(EltNo)) 404 if (Elt->isNullValue()) 405 return findScalarElement(Val, EltNo); 406 407 // Otherwise, we don't know. 408 return nullptr; 409 } 410 411 /// \brief Get splat value if the input is a splat vector or return nullptr. 412 /// This function is not fully general. It checks only 2 cases: 413 /// the input value is (1) a splat constants vector or (2) a sequence 414 /// of instructions that broadcast a single value into a vector. 415 /// 416 const llvm::Value *llvm::getSplatValue(const Value *V) { 417 418 if (auto *C = dyn_cast<Constant>(V)) 419 if (isa<VectorType>(V->getType())) 420 return C->getSplatValue(); 421 422 auto *ShuffleInst = dyn_cast<ShuffleVectorInst>(V); 423 if (!ShuffleInst) 424 return nullptr; 425 // All-zero (or undef) shuffle mask elements. 426 for (int MaskElt : ShuffleInst->getShuffleMask()) 427 if (MaskElt != 0 && MaskElt != -1) 428 return nullptr; 429 // The first shuffle source is 'insertelement' with index 0. 430 auto *InsertEltInst = 431 dyn_cast<InsertElementInst>(ShuffleInst->getOperand(0)); 432 if (!InsertEltInst || !isa<ConstantInt>(InsertEltInst->getOperand(2)) || 433 !cast<ConstantInt>(InsertEltInst->getOperand(2))->isNullValue()) 434 return nullptr; 435 436 return InsertEltInst->getOperand(1); 437 } 438 439 MapVector<Instruction *, uint64_t> 440 llvm::computeMinimumValueSizes(ArrayRef<BasicBlock *> Blocks, DemandedBits &DB, 441 const TargetTransformInfo *TTI) { 442 443 // DemandedBits will give us every value's live-out bits. But we want 444 // to ensure no extra casts would need to be inserted, so every DAG 445 // of connected values must have the same minimum bitwidth. 446 EquivalenceClasses<Value *> ECs; 447 SmallVector<Value *, 16> Worklist; 448 SmallPtrSet<Value *, 4> Roots; 449 SmallPtrSet<Value *, 16> Visited; 450 DenseMap<Value *, uint64_t> DBits; 451 SmallPtrSet<Instruction *, 4> InstructionSet; 452 MapVector<Instruction *, uint64_t> MinBWs; 453 454 // Determine the roots. We work bottom-up, from truncs or icmps. 455 bool SeenExtFromIllegalType = false; 456 for (auto *BB : Blocks) 457 for (auto &I : *BB) { 458 InstructionSet.insert(&I); 459 460 if (TTI && (isa<ZExtInst>(&I) || isa<SExtInst>(&I)) && 461 !TTI->isTypeLegal(I.getOperand(0)->getType())) 462 SeenExtFromIllegalType = true; 463 464 // Only deal with non-vector integers up to 64-bits wide. 465 if ((isa<TruncInst>(&I) || isa<ICmpInst>(&I)) && 466 !I.getType()->isVectorTy() && 467 I.getOperand(0)->getType()->getScalarSizeInBits() <= 64) { 468 // Don't make work for ourselves. If we know the loaded type is legal, 469 // don't add it to the worklist. 470 if (TTI && isa<TruncInst>(&I) && TTI->isTypeLegal(I.getType())) 471 continue; 472 473 Worklist.push_back(&I); 474 Roots.insert(&I); 475 } 476 } 477 // Early exit. 478 if (Worklist.empty() || (TTI && !SeenExtFromIllegalType)) 479 return MinBWs; 480 481 // Now proceed breadth-first, unioning values together. 482 while (!Worklist.empty()) { 483 Value *Val = Worklist.pop_back_val(); 484 Value *Leader = ECs.getOrInsertLeaderValue(Val); 485 486 if (Visited.count(Val)) 487 continue; 488 Visited.insert(Val); 489 490 // Non-instructions terminate a chain successfully. 491 if (!isa<Instruction>(Val)) 492 continue; 493 Instruction *I = cast<Instruction>(Val); 494 495 // If we encounter a type that is larger than 64 bits, we can't represent 496 // it so bail out. 497 if (DB.getDemandedBits(I).getBitWidth() > 64) 498 return MapVector<Instruction *, uint64_t>(); 499 500 uint64_t V = DB.getDemandedBits(I).getZExtValue(); 501 DBits[Leader] |= V; 502 503 // Casts, loads and instructions outside of our range terminate a chain 504 // successfully. 505 if (isa<SExtInst>(I) || isa<ZExtInst>(I) || isa<LoadInst>(I) || 506 !InstructionSet.count(I)) 507 continue; 508 509 // Unsafe casts terminate a chain unsuccessfully. We can't do anything 510 // useful with bitcasts, ptrtoints or inttoptrs and it'd be unsafe to 511 // transform anything that relies on them. 512 if (isa<BitCastInst>(I) || isa<PtrToIntInst>(I) || isa<IntToPtrInst>(I) || 513 !I->getType()->isIntegerTy()) { 514 DBits[Leader] |= ~0ULL; 515 continue; 516 } 517 518 // We don't modify the types of PHIs. Reductions will already have been 519 // truncated if possible, and inductions' sizes will have been chosen by 520 // indvars. 521 if (isa<PHINode>(I)) 522 continue; 523 524 if (DBits[Leader] == ~0ULL) 525 // All bits demanded, no point continuing. 526 continue; 527 528 for (Value *O : cast<User>(I)->operands()) { 529 ECs.unionSets(Leader, O); 530 Worklist.push_back(O); 531 } 532 } 533 534 // Now we've discovered all values, walk them to see if there are 535 // any users we didn't see. If there are, we can't optimize that 536 // chain. 537 for (auto &I : DBits) 538 for (auto *U : I.first->users()) 539 if (U->getType()->isIntegerTy() && DBits.count(U) == 0) 540 DBits[ECs.getOrInsertLeaderValue(I.first)] |= ~0ULL; 541 542 for (auto I = ECs.begin(), E = ECs.end(); I != E; ++I) { 543 uint64_t LeaderDemandedBits = 0; 544 for (auto MI = ECs.member_begin(I), ME = ECs.member_end(); MI != ME; ++MI) 545 LeaderDemandedBits |= DBits[*MI]; 546 547 uint64_t MinBW = (sizeof(LeaderDemandedBits) * 8) - 548 llvm::countLeadingZeros(LeaderDemandedBits); 549 // Round up to a power of 2 550 if (!isPowerOf2_64((uint64_t)MinBW)) 551 MinBW = NextPowerOf2(MinBW); 552 for (auto MI = ECs.member_begin(I), ME = ECs.member_end(); MI != ME; ++MI) { 553 if (!isa<Instruction>(*MI)) 554 continue; 555 Type *Ty = (*MI)->getType(); 556 if (Roots.count(*MI)) 557 Ty = cast<Instruction>(*MI)->getOperand(0)->getType(); 558 if (MinBW < Ty->getScalarSizeInBits()) 559 MinBWs[cast<Instruction>(*MI)] = MinBW; 560 } 561 } 562 563 return MinBWs; 564 } 565