1 //===----------- VectorUtils.cpp - Vectorizer utility functions -----------===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // This file defines vectorizer utilities. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "llvm/ADT/EquivalenceClasses.h" 15 #include "llvm/Analysis/DemandedBits.h" 16 #include "llvm/Analysis/LoopInfo.h" 17 #include "llvm/Analysis/ScalarEvolutionExpressions.h" 18 #include "llvm/Analysis/ScalarEvolution.h" 19 #include "llvm/Analysis/TargetTransformInfo.h" 20 #include "llvm/Analysis/VectorUtils.h" 21 #include "llvm/IR/GetElementPtrTypeIterator.h" 22 #include "llvm/IR/PatternMatch.h" 23 #include "llvm/IR/Value.h" 24 #include "llvm/IR/Constants.h" 25 26 using namespace llvm; 27 using namespace llvm::PatternMatch; 28 29 /// \brief Identify if the intrinsic is trivially vectorizable. 30 /// This method returns true if the intrinsic's argument types are all 31 /// scalars for the scalar form of the intrinsic and all vectors for 32 /// the vector form of the intrinsic. 33 bool llvm::isTriviallyVectorizable(Intrinsic::ID ID) { 34 switch (ID) { 35 case Intrinsic::sqrt: 36 case Intrinsic::sin: 37 case Intrinsic::cos: 38 case Intrinsic::exp: 39 case Intrinsic::exp2: 40 case Intrinsic::log: 41 case Intrinsic::log10: 42 case Intrinsic::log2: 43 case Intrinsic::fabs: 44 case Intrinsic::minnum: 45 case Intrinsic::maxnum: 46 case Intrinsic::copysign: 47 case Intrinsic::floor: 48 case Intrinsic::ceil: 49 case Intrinsic::trunc: 50 case Intrinsic::rint: 51 case Intrinsic::nearbyint: 52 case Intrinsic::round: 53 case Intrinsic::bswap: 54 case Intrinsic::ctpop: 55 case Intrinsic::pow: 56 case Intrinsic::fma: 57 case Intrinsic::fmuladd: 58 case Intrinsic::ctlz: 59 case Intrinsic::cttz: 60 case Intrinsic::powi: 61 return true; 62 default: 63 return false; 64 } 65 } 66 67 /// \brief Identifies if the intrinsic has a scalar operand. It check for 68 /// ctlz,cttz and powi special intrinsics whose argument is scalar. 69 bool llvm::hasVectorInstrinsicScalarOpd(Intrinsic::ID ID, 70 unsigned ScalarOpdIdx) { 71 switch (ID) { 72 case Intrinsic::ctlz: 73 case Intrinsic::cttz: 74 case Intrinsic::powi: 75 return (ScalarOpdIdx == 1); 76 default: 77 return false; 78 } 79 } 80 81 /// \brief Check call has a unary float signature 82 /// It checks following: 83 /// a) call should have a single argument 84 /// b) argument type should be floating point type 85 /// c) call instruction type and argument type should be same 86 /// d) call should only reads memory. 87 /// If all these condition is met then return ValidIntrinsicID 88 /// else return not_intrinsic. 89 Intrinsic::ID 90 llvm::checkUnaryFloatSignature(const CallInst &I, 91 Intrinsic::ID ValidIntrinsicID) { 92 if (I.getNumArgOperands() != 1 || 93 !I.getArgOperand(0)->getType()->isFloatingPointTy() || 94 I.getType() != I.getArgOperand(0)->getType() || !I.onlyReadsMemory()) 95 return Intrinsic::not_intrinsic; 96 97 return ValidIntrinsicID; 98 } 99 100 /// \brief Check call has a binary float signature 101 /// It checks following: 102 /// a) call should have 2 arguments. 103 /// b) arguments type should be floating point type 104 /// c) call instruction type and arguments type should be same 105 /// d) call should only reads memory. 106 /// If all these condition is met then return ValidIntrinsicID 107 /// else return not_intrinsic. 108 Intrinsic::ID 109 llvm::checkBinaryFloatSignature(const CallInst &I, 110 Intrinsic::ID ValidIntrinsicID) { 111 if (I.getNumArgOperands() != 2 || 112 !I.getArgOperand(0)->getType()->isFloatingPointTy() || 113 !I.getArgOperand(1)->getType()->isFloatingPointTy() || 114 I.getType() != I.getArgOperand(0)->getType() || 115 I.getType() != I.getArgOperand(1)->getType() || !I.onlyReadsMemory()) 116 return Intrinsic::not_intrinsic; 117 118 return ValidIntrinsicID; 119 } 120 121 /// \brief Returns intrinsic ID for call. 122 /// For the input call instruction it finds mapping intrinsic and returns 123 /// its ID, in case it does not found it return not_intrinsic. 124 Intrinsic::ID llvm::getIntrinsicIDForCall(const CallInst *CI, 125 const TargetLibraryInfo *TLI) { 126 // If we have an intrinsic call, check if it is trivially vectorizable. 127 if (const auto *II = dyn_cast<IntrinsicInst>(CI)) { 128 Intrinsic::ID ID = II->getIntrinsicID(); 129 if (isTriviallyVectorizable(ID) || ID == Intrinsic::lifetime_start || 130 ID == Intrinsic::lifetime_end || ID == Intrinsic::assume) 131 return ID; 132 return Intrinsic::not_intrinsic; 133 } 134 135 if (!TLI) 136 return Intrinsic::not_intrinsic; 137 138 LibFunc::Func Func; 139 Function *F = CI->getCalledFunction(); 140 // We're going to make assumptions on the semantics of the functions, check 141 // that the target knows that it's available in this environment and it does 142 // not have local linkage. 143 if (!F || F->hasLocalLinkage() || !TLI->getLibFunc(F->getName(), Func)) 144 return Intrinsic::not_intrinsic; 145 146 // Otherwise check if we have a call to a function that can be turned into a 147 // vector intrinsic. 148 switch (Func) { 149 default: 150 break; 151 case LibFunc::sin: 152 case LibFunc::sinf: 153 case LibFunc::sinl: 154 return checkUnaryFloatSignature(*CI, Intrinsic::sin); 155 case LibFunc::cos: 156 case LibFunc::cosf: 157 case LibFunc::cosl: 158 return checkUnaryFloatSignature(*CI, Intrinsic::cos); 159 case LibFunc::exp: 160 case LibFunc::expf: 161 case LibFunc::expl: 162 return checkUnaryFloatSignature(*CI, Intrinsic::exp); 163 case LibFunc::exp2: 164 case LibFunc::exp2f: 165 case LibFunc::exp2l: 166 return checkUnaryFloatSignature(*CI, Intrinsic::exp2); 167 case LibFunc::log: 168 case LibFunc::logf: 169 case LibFunc::logl: 170 return checkUnaryFloatSignature(*CI, Intrinsic::log); 171 case LibFunc::log10: 172 case LibFunc::log10f: 173 case LibFunc::log10l: 174 return checkUnaryFloatSignature(*CI, Intrinsic::log10); 175 case LibFunc::log2: 176 case LibFunc::log2f: 177 case LibFunc::log2l: 178 return checkUnaryFloatSignature(*CI, Intrinsic::log2); 179 case LibFunc::fabs: 180 case LibFunc::fabsf: 181 case LibFunc::fabsl: 182 return checkUnaryFloatSignature(*CI, Intrinsic::fabs); 183 case LibFunc::fmin: 184 case LibFunc::fminf: 185 case LibFunc::fminl: 186 return checkBinaryFloatSignature(*CI, Intrinsic::minnum); 187 case LibFunc::fmax: 188 case LibFunc::fmaxf: 189 case LibFunc::fmaxl: 190 return checkBinaryFloatSignature(*CI, Intrinsic::maxnum); 191 case LibFunc::copysign: 192 case LibFunc::copysignf: 193 case LibFunc::copysignl: 194 return checkBinaryFloatSignature(*CI, Intrinsic::copysign); 195 case LibFunc::floor: 196 case LibFunc::floorf: 197 case LibFunc::floorl: 198 return checkUnaryFloatSignature(*CI, Intrinsic::floor); 199 case LibFunc::ceil: 200 case LibFunc::ceilf: 201 case LibFunc::ceill: 202 return checkUnaryFloatSignature(*CI, Intrinsic::ceil); 203 case LibFunc::trunc: 204 case LibFunc::truncf: 205 case LibFunc::truncl: 206 return checkUnaryFloatSignature(*CI, Intrinsic::trunc); 207 case LibFunc::rint: 208 case LibFunc::rintf: 209 case LibFunc::rintl: 210 return checkUnaryFloatSignature(*CI, Intrinsic::rint); 211 case LibFunc::nearbyint: 212 case LibFunc::nearbyintf: 213 case LibFunc::nearbyintl: 214 return checkUnaryFloatSignature(*CI, Intrinsic::nearbyint); 215 case LibFunc::round: 216 case LibFunc::roundf: 217 case LibFunc::roundl: 218 return checkUnaryFloatSignature(*CI, Intrinsic::round); 219 case LibFunc::pow: 220 case LibFunc::powf: 221 case LibFunc::powl: 222 return checkBinaryFloatSignature(*CI, Intrinsic::pow); 223 case LibFunc::sqrt: 224 case LibFunc::sqrtf: 225 case LibFunc::sqrtl: 226 if (CI->hasNoNaNs()) 227 return checkUnaryFloatSignature(*CI, Intrinsic::sqrt); 228 return Intrinsic::not_intrinsic; 229 } 230 231 return Intrinsic::not_intrinsic; 232 } 233 234 /// \brief Find the operand of the GEP that should be checked for consecutive 235 /// stores. This ignores trailing indices that have no effect on the final 236 /// pointer. 237 unsigned llvm::getGEPInductionOperand(const GetElementPtrInst *Gep) { 238 const DataLayout &DL = Gep->getModule()->getDataLayout(); 239 unsigned LastOperand = Gep->getNumOperands() - 1; 240 unsigned GEPAllocSize = DL.getTypeAllocSize(Gep->getResultElementType()); 241 242 // Walk backwards and try to peel off zeros. 243 while (LastOperand > 1 && match(Gep->getOperand(LastOperand), m_Zero())) { 244 // Find the type we're currently indexing into. 245 gep_type_iterator GEPTI = gep_type_begin(Gep); 246 std::advance(GEPTI, LastOperand - 1); 247 248 // If it's a type with the same allocation size as the result of the GEP we 249 // can peel off the zero index. 250 if (DL.getTypeAllocSize(*GEPTI) != GEPAllocSize) 251 break; 252 --LastOperand; 253 } 254 255 return LastOperand; 256 } 257 258 /// \brief If the argument is a GEP, then returns the operand identified by 259 /// getGEPInductionOperand. However, if there is some other non-loop-invariant 260 /// operand, it returns that instead. 261 Value *llvm::stripGetElementPtr(Value *Ptr, ScalarEvolution *SE, Loop *Lp) { 262 GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr); 263 if (!GEP) 264 return Ptr; 265 266 unsigned InductionOperand = getGEPInductionOperand(GEP); 267 268 // Check that all of the gep indices are uniform except for our induction 269 // operand. 270 for (unsigned i = 0, e = GEP->getNumOperands(); i != e; ++i) 271 if (i != InductionOperand && 272 !SE->isLoopInvariant(SE->getSCEV(GEP->getOperand(i)), Lp)) 273 return Ptr; 274 return GEP->getOperand(InductionOperand); 275 } 276 277 /// \brief If a value has only one user that is a CastInst, return it. 278 Value *llvm::getUniqueCastUse(Value *Ptr, Loop *Lp, Type *Ty) { 279 Value *UniqueCast = nullptr; 280 for (User *U : Ptr->users()) { 281 CastInst *CI = dyn_cast<CastInst>(U); 282 if (CI && CI->getType() == Ty) { 283 if (!UniqueCast) 284 UniqueCast = CI; 285 else 286 return nullptr; 287 } 288 } 289 return UniqueCast; 290 } 291 292 /// \brief Get the stride of a pointer access in a loop. Looks for symbolic 293 /// strides "a[i*stride]". Returns the symbolic stride, or null otherwise. 294 Value *llvm::getStrideFromPointer(Value *Ptr, ScalarEvolution *SE, Loop *Lp) { 295 auto *PtrTy = dyn_cast<PointerType>(Ptr->getType()); 296 if (!PtrTy || PtrTy->isAggregateType()) 297 return nullptr; 298 299 // Try to remove a gep instruction to make the pointer (actually index at this 300 // point) easier analyzable. If OrigPtr is equal to Ptr we are analzying the 301 // pointer, otherwise, we are analyzing the index. 302 Value *OrigPtr = Ptr; 303 304 // The size of the pointer access. 305 int64_t PtrAccessSize = 1; 306 307 Ptr = stripGetElementPtr(Ptr, SE, Lp); 308 const SCEV *V = SE->getSCEV(Ptr); 309 310 if (Ptr != OrigPtr) 311 // Strip off casts. 312 while (const SCEVCastExpr *C = dyn_cast<SCEVCastExpr>(V)) 313 V = C->getOperand(); 314 315 const SCEVAddRecExpr *S = dyn_cast<SCEVAddRecExpr>(V); 316 if (!S) 317 return nullptr; 318 319 V = S->getStepRecurrence(*SE); 320 if (!V) 321 return nullptr; 322 323 // Strip off the size of access multiplication if we are still analyzing the 324 // pointer. 325 if (OrigPtr == Ptr) { 326 if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(V)) { 327 if (M->getOperand(0)->getSCEVType() != scConstant) 328 return nullptr; 329 330 const APInt &APStepVal = cast<SCEVConstant>(M->getOperand(0))->getAPInt(); 331 332 // Huge step value - give up. 333 if (APStepVal.getBitWidth() > 64) 334 return nullptr; 335 336 int64_t StepVal = APStepVal.getSExtValue(); 337 if (PtrAccessSize != StepVal) 338 return nullptr; 339 V = M->getOperand(1); 340 } 341 } 342 343 // Strip off casts. 344 Type *StripedOffRecurrenceCast = nullptr; 345 if (const SCEVCastExpr *C = dyn_cast<SCEVCastExpr>(V)) { 346 StripedOffRecurrenceCast = C->getType(); 347 V = C->getOperand(); 348 } 349 350 // Look for the loop invariant symbolic value. 351 const SCEVUnknown *U = dyn_cast<SCEVUnknown>(V); 352 if (!U) 353 return nullptr; 354 355 Value *Stride = U->getValue(); 356 if (!Lp->isLoopInvariant(Stride)) 357 return nullptr; 358 359 // If we have stripped off the recurrence cast we have to make sure that we 360 // return the value that is used in this loop so that we can replace it later. 361 if (StripedOffRecurrenceCast) 362 Stride = getUniqueCastUse(Stride, Lp, StripedOffRecurrenceCast); 363 364 return Stride; 365 } 366 367 /// \brief Given a vector and an element number, see if the scalar value is 368 /// already around as a register, for example if it were inserted then extracted 369 /// from the vector. 370 Value *llvm::findScalarElement(Value *V, unsigned EltNo) { 371 assert(V->getType()->isVectorTy() && "Not looking at a vector?"); 372 VectorType *VTy = cast<VectorType>(V->getType()); 373 unsigned Width = VTy->getNumElements(); 374 if (EltNo >= Width) // Out of range access. 375 return UndefValue::get(VTy->getElementType()); 376 377 if (Constant *C = dyn_cast<Constant>(V)) 378 return C->getAggregateElement(EltNo); 379 380 if (InsertElementInst *III = dyn_cast<InsertElementInst>(V)) { 381 // If this is an insert to a variable element, we don't know what it is. 382 if (!isa<ConstantInt>(III->getOperand(2))) 383 return nullptr; 384 unsigned IIElt = cast<ConstantInt>(III->getOperand(2))->getZExtValue(); 385 386 // If this is an insert to the element we are looking for, return the 387 // inserted value. 388 if (EltNo == IIElt) 389 return III->getOperand(1); 390 391 // Otherwise, the insertelement doesn't modify the value, recurse on its 392 // vector input. 393 return findScalarElement(III->getOperand(0), EltNo); 394 } 395 396 if (ShuffleVectorInst *SVI = dyn_cast<ShuffleVectorInst>(V)) { 397 unsigned LHSWidth = SVI->getOperand(0)->getType()->getVectorNumElements(); 398 int InEl = SVI->getMaskValue(EltNo); 399 if (InEl < 0) 400 return UndefValue::get(VTy->getElementType()); 401 if (InEl < (int)LHSWidth) 402 return findScalarElement(SVI->getOperand(0), InEl); 403 return findScalarElement(SVI->getOperand(1), InEl - LHSWidth); 404 } 405 406 // Extract a value from a vector add operation with a constant zero. 407 Value *Val = nullptr; Constant *Con = nullptr; 408 if (match(V, m_Add(m_Value(Val), m_Constant(Con)))) 409 if (Constant *Elt = Con->getAggregateElement(EltNo)) 410 if (Elt->isNullValue()) 411 return findScalarElement(Val, EltNo); 412 413 // Otherwise, we don't know. 414 return nullptr; 415 } 416 417 /// \brief Get splat value if the input is a splat vector or return nullptr. 418 /// This function is not fully general. It checks only 2 cases: 419 /// the input value is (1) a splat constants vector or (2) a sequence 420 /// of instructions that broadcast a single value into a vector. 421 /// 422 const llvm::Value *llvm::getSplatValue(const Value *V) { 423 424 if (auto *C = dyn_cast<Constant>(V)) 425 if (isa<VectorType>(V->getType())) 426 return C->getSplatValue(); 427 428 auto *ShuffleInst = dyn_cast<ShuffleVectorInst>(V); 429 if (!ShuffleInst) 430 return nullptr; 431 // All-zero (or undef) shuffle mask elements. 432 for (int MaskElt : ShuffleInst->getShuffleMask()) 433 if (MaskElt != 0 && MaskElt != -1) 434 return nullptr; 435 // The first shuffle source is 'insertelement' with index 0. 436 auto *InsertEltInst = 437 dyn_cast<InsertElementInst>(ShuffleInst->getOperand(0)); 438 if (!InsertEltInst || !isa<ConstantInt>(InsertEltInst->getOperand(2)) || 439 !cast<ConstantInt>(InsertEltInst->getOperand(2))->isNullValue()) 440 return nullptr; 441 442 return InsertEltInst->getOperand(1); 443 } 444 445 MapVector<Instruction *, uint64_t> 446 llvm::computeMinimumValueSizes(ArrayRef<BasicBlock *> Blocks, DemandedBits &DB, 447 const TargetTransformInfo *TTI) { 448 449 // DemandedBits will give us every value's live-out bits. But we want 450 // to ensure no extra casts would need to be inserted, so every DAG 451 // of connected values must have the same minimum bitwidth. 452 EquivalenceClasses<Value *> ECs; 453 SmallVector<Value *, 16> Worklist; 454 SmallPtrSet<Value *, 4> Roots; 455 SmallPtrSet<Value *, 16> Visited; 456 DenseMap<Value *, uint64_t> DBits; 457 SmallPtrSet<Instruction *, 4> InstructionSet; 458 MapVector<Instruction *, uint64_t> MinBWs; 459 460 // Determine the roots. We work bottom-up, from truncs or icmps. 461 bool SeenExtFromIllegalType = false; 462 for (auto *BB : Blocks) 463 for (auto &I : *BB) { 464 InstructionSet.insert(&I); 465 466 if (TTI && (isa<ZExtInst>(&I) || isa<SExtInst>(&I)) && 467 !TTI->isTypeLegal(I.getOperand(0)->getType())) 468 SeenExtFromIllegalType = true; 469 470 // Only deal with non-vector integers up to 64-bits wide. 471 if ((isa<TruncInst>(&I) || isa<ICmpInst>(&I)) && 472 !I.getType()->isVectorTy() && 473 I.getOperand(0)->getType()->getScalarSizeInBits() <= 64) { 474 // Don't make work for ourselves. If we know the loaded type is legal, 475 // don't add it to the worklist. 476 if (TTI && isa<TruncInst>(&I) && TTI->isTypeLegal(I.getType())) 477 continue; 478 479 Worklist.push_back(&I); 480 Roots.insert(&I); 481 } 482 } 483 // Early exit. 484 if (Worklist.empty() || (TTI && !SeenExtFromIllegalType)) 485 return MinBWs; 486 487 // Now proceed breadth-first, unioning values together. 488 while (!Worklist.empty()) { 489 Value *Val = Worklist.pop_back_val(); 490 Value *Leader = ECs.getOrInsertLeaderValue(Val); 491 492 if (Visited.count(Val)) 493 continue; 494 Visited.insert(Val); 495 496 // Non-instructions terminate a chain successfully. 497 if (!isa<Instruction>(Val)) 498 continue; 499 Instruction *I = cast<Instruction>(Val); 500 501 // If we encounter a type that is larger than 64 bits, we can't represent 502 // it so bail out. 503 if (DB.getDemandedBits(I).getBitWidth() > 64) 504 return MapVector<Instruction *, uint64_t>(); 505 506 uint64_t V = DB.getDemandedBits(I).getZExtValue(); 507 DBits[Leader] |= V; 508 DBits[I] = V; 509 510 // Casts, loads and instructions outside of our range terminate a chain 511 // successfully. 512 if (isa<SExtInst>(I) || isa<ZExtInst>(I) || isa<LoadInst>(I) || 513 !InstructionSet.count(I)) 514 continue; 515 516 // Unsafe casts terminate a chain unsuccessfully. We can't do anything 517 // useful with bitcasts, ptrtoints or inttoptrs and it'd be unsafe to 518 // transform anything that relies on them. 519 if (isa<BitCastInst>(I) || isa<PtrToIntInst>(I) || isa<IntToPtrInst>(I) || 520 !I->getType()->isIntegerTy()) { 521 DBits[Leader] |= ~0ULL; 522 continue; 523 } 524 525 // We don't modify the types of PHIs. Reductions will already have been 526 // truncated if possible, and inductions' sizes will have been chosen by 527 // indvars. 528 if (isa<PHINode>(I)) 529 continue; 530 531 if (DBits[Leader] == ~0ULL) 532 // All bits demanded, no point continuing. 533 continue; 534 535 for (Value *O : cast<User>(I)->operands()) { 536 ECs.unionSets(Leader, O); 537 Worklist.push_back(O); 538 } 539 } 540 541 // Now we've discovered all values, walk them to see if there are 542 // any users we didn't see. If there are, we can't optimize that 543 // chain. 544 for (auto &I : DBits) 545 for (auto *U : I.first->users()) 546 if (U->getType()->isIntegerTy() && DBits.count(U) == 0) 547 DBits[ECs.getOrInsertLeaderValue(I.first)] |= ~0ULL; 548 549 for (auto I = ECs.begin(), E = ECs.end(); I != E; ++I) { 550 uint64_t LeaderDemandedBits = 0; 551 for (auto MI = ECs.member_begin(I), ME = ECs.member_end(); MI != ME; ++MI) 552 LeaderDemandedBits |= DBits[*MI]; 553 554 uint64_t MinBW = (sizeof(LeaderDemandedBits) * 8) - 555 llvm::countLeadingZeros(LeaderDemandedBits); 556 // Round up to a power of 2 557 if (!isPowerOf2_64((uint64_t)MinBW)) 558 MinBW = NextPowerOf2(MinBW); 559 560 // We don't modify the types of PHIs. Reductions will already have been 561 // truncated if possible, and inductions' sizes will have been chosen by 562 // indvars. 563 // If we are required to shrink a PHI, abandon this entire equivalence class. 564 bool Abort = false; 565 for (auto MI = ECs.member_begin(I), ME = ECs.member_end(); MI != ME; ++MI) 566 if (isa<PHINode>(*MI) && MinBW < (*MI)->getType()->getScalarSizeInBits()) { 567 Abort = true; 568 break; 569 } 570 if (Abort) 571 continue; 572 573 for (auto MI = ECs.member_begin(I), ME = ECs.member_end(); MI != ME; ++MI) { 574 if (!isa<Instruction>(*MI)) 575 continue; 576 Type *Ty = (*MI)->getType(); 577 if (Roots.count(*MI)) 578 Ty = cast<Instruction>(*MI)->getOperand(0)->getType(); 579 if (MinBW < Ty->getScalarSizeInBits()) 580 MinBWs[cast<Instruction>(*MI)] = MinBW; 581 } 582 } 583 584 return MinBWs; 585 } 586