1 //===----------- VectorUtils.cpp - Vectorizer utility functions -----------===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // This file defines vectorizer utilities. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "llvm/Analysis/LoopInfo.h" 15 #include "llvm/Analysis/ScalarEvolutionExpressions.h" 16 #include "llvm/Analysis/ScalarEvolution.h" 17 #include "llvm/Analysis/VectorUtils.h" 18 #include "llvm/IR/GetElementPtrTypeIterator.h" 19 #include "llvm/IR/PatternMatch.h" 20 #include "llvm/IR/Value.h" 21 using namespace llvm; 22 using namespace llvm::PatternMatch; 23 24 /// \brief Identify if the intrinsic is trivially vectorizable. 25 /// This method returns true if the intrinsic's argument types are all 26 /// scalars for the scalar form of the intrinsic and all vectors for 27 /// the vector form of the intrinsic. 28 bool llvm::isTriviallyVectorizable(Intrinsic::ID ID) { 29 switch (ID) { 30 case Intrinsic::sqrt: 31 case Intrinsic::sin: 32 case Intrinsic::cos: 33 case Intrinsic::exp: 34 case Intrinsic::exp2: 35 case Intrinsic::log: 36 case Intrinsic::log10: 37 case Intrinsic::log2: 38 case Intrinsic::fabs: 39 case Intrinsic::minnum: 40 case Intrinsic::maxnum: 41 case Intrinsic::copysign: 42 case Intrinsic::floor: 43 case Intrinsic::ceil: 44 case Intrinsic::trunc: 45 case Intrinsic::rint: 46 case Intrinsic::nearbyint: 47 case Intrinsic::round: 48 case Intrinsic::bswap: 49 case Intrinsic::ctpop: 50 case Intrinsic::pow: 51 case Intrinsic::fma: 52 case Intrinsic::fmuladd: 53 case Intrinsic::ctlz: 54 case Intrinsic::cttz: 55 case Intrinsic::powi: 56 return true; 57 default: 58 return false; 59 } 60 } 61 62 /// \brief Identifies if the intrinsic has a scalar operand. It check for 63 /// ctlz,cttz and powi special intrinsics whose argument is scalar. 64 bool llvm::hasVectorInstrinsicScalarOpd(Intrinsic::ID ID, 65 unsigned ScalarOpdIdx) { 66 switch (ID) { 67 case Intrinsic::ctlz: 68 case Intrinsic::cttz: 69 case Intrinsic::powi: 70 return (ScalarOpdIdx == 1); 71 default: 72 return false; 73 } 74 } 75 76 /// \brief Check call has a unary float signature 77 /// It checks following: 78 /// a) call should have a single argument 79 /// b) argument type should be floating point type 80 /// c) call instruction type and argument type should be same 81 /// d) call should only reads memory. 82 /// If all these condition is met then return ValidIntrinsicID 83 /// else return not_intrinsic. 84 Intrinsic::ID 85 llvm::checkUnaryFloatSignature(const CallInst &I, 86 Intrinsic::ID ValidIntrinsicID) { 87 if (I.getNumArgOperands() != 1 || 88 !I.getArgOperand(0)->getType()->isFloatingPointTy() || 89 I.getType() != I.getArgOperand(0)->getType() || !I.onlyReadsMemory()) 90 return Intrinsic::not_intrinsic; 91 92 return ValidIntrinsicID; 93 } 94 95 /// \brief Check call has a binary float signature 96 /// It checks following: 97 /// a) call should have 2 arguments. 98 /// b) arguments type should be floating point type 99 /// c) call instruction type and arguments type should be same 100 /// d) call should only reads memory. 101 /// If all these condition is met then return ValidIntrinsicID 102 /// else return not_intrinsic. 103 Intrinsic::ID 104 llvm::checkBinaryFloatSignature(const CallInst &I, 105 Intrinsic::ID ValidIntrinsicID) { 106 if (I.getNumArgOperands() != 2 || 107 !I.getArgOperand(0)->getType()->isFloatingPointTy() || 108 !I.getArgOperand(1)->getType()->isFloatingPointTy() || 109 I.getType() != I.getArgOperand(0)->getType() || 110 I.getType() != I.getArgOperand(1)->getType() || !I.onlyReadsMemory()) 111 return Intrinsic::not_intrinsic; 112 113 return ValidIntrinsicID; 114 } 115 116 /// \brief Returns intrinsic ID for call. 117 /// For the input call instruction it finds mapping intrinsic and returns 118 /// its ID, in case it does not found it return not_intrinsic. 119 Intrinsic::ID llvm::getIntrinsicIDForCall(CallInst *CI, 120 const TargetLibraryInfo *TLI) { 121 // If we have an intrinsic call, check if it is trivially vectorizable. 122 if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI)) { 123 Intrinsic::ID ID = II->getIntrinsicID(); 124 if (isTriviallyVectorizable(ID) || ID == Intrinsic::lifetime_start || 125 ID == Intrinsic::lifetime_end || ID == Intrinsic::assume) 126 return ID; 127 return Intrinsic::not_intrinsic; 128 } 129 130 if (!TLI) 131 return Intrinsic::not_intrinsic; 132 133 LibFunc::Func Func; 134 Function *F = CI->getCalledFunction(); 135 // We're going to make assumptions on the semantics of the functions, check 136 // that the target knows that it's available in this environment and it does 137 // not have local linkage. 138 if (!F || F->hasLocalLinkage() || !TLI->getLibFunc(F->getName(), Func)) 139 return Intrinsic::not_intrinsic; 140 141 // Otherwise check if we have a call to a function that can be turned into a 142 // vector intrinsic. 143 switch (Func) { 144 default: 145 break; 146 case LibFunc::sin: 147 case LibFunc::sinf: 148 case LibFunc::sinl: 149 return checkUnaryFloatSignature(*CI, Intrinsic::sin); 150 case LibFunc::cos: 151 case LibFunc::cosf: 152 case LibFunc::cosl: 153 return checkUnaryFloatSignature(*CI, Intrinsic::cos); 154 case LibFunc::exp: 155 case LibFunc::expf: 156 case LibFunc::expl: 157 return checkUnaryFloatSignature(*CI, Intrinsic::exp); 158 case LibFunc::exp2: 159 case LibFunc::exp2f: 160 case LibFunc::exp2l: 161 return checkUnaryFloatSignature(*CI, Intrinsic::exp2); 162 case LibFunc::log: 163 case LibFunc::logf: 164 case LibFunc::logl: 165 return checkUnaryFloatSignature(*CI, Intrinsic::log); 166 case LibFunc::log10: 167 case LibFunc::log10f: 168 case LibFunc::log10l: 169 return checkUnaryFloatSignature(*CI, Intrinsic::log10); 170 case LibFunc::log2: 171 case LibFunc::log2f: 172 case LibFunc::log2l: 173 return checkUnaryFloatSignature(*CI, Intrinsic::log2); 174 case LibFunc::fabs: 175 case LibFunc::fabsf: 176 case LibFunc::fabsl: 177 return checkUnaryFloatSignature(*CI, Intrinsic::fabs); 178 case LibFunc::fmin: 179 case LibFunc::fminf: 180 case LibFunc::fminl: 181 return checkBinaryFloatSignature(*CI, Intrinsic::minnum); 182 case LibFunc::fmax: 183 case LibFunc::fmaxf: 184 case LibFunc::fmaxl: 185 return checkBinaryFloatSignature(*CI, Intrinsic::maxnum); 186 case LibFunc::copysign: 187 case LibFunc::copysignf: 188 case LibFunc::copysignl: 189 return checkBinaryFloatSignature(*CI, Intrinsic::copysign); 190 case LibFunc::floor: 191 case LibFunc::floorf: 192 case LibFunc::floorl: 193 return checkUnaryFloatSignature(*CI, Intrinsic::floor); 194 case LibFunc::ceil: 195 case LibFunc::ceilf: 196 case LibFunc::ceill: 197 return checkUnaryFloatSignature(*CI, Intrinsic::ceil); 198 case LibFunc::trunc: 199 case LibFunc::truncf: 200 case LibFunc::truncl: 201 return checkUnaryFloatSignature(*CI, Intrinsic::trunc); 202 case LibFunc::rint: 203 case LibFunc::rintf: 204 case LibFunc::rintl: 205 return checkUnaryFloatSignature(*CI, Intrinsic::rint); 206 case LibFunc::nearbyint: 207 case LibFunc::nearbyintf: 208 case LibFunc::nearbyintl: 209 return checkUnaryFloatSignature(*CI, Intrinsic::nearbyint); 210 case LibFunc::round: 211 case LibFunc::roundf: 212 case LibFunc::roundl: 213 return checkUnaryFloatSignature(*CI, Intrinsic::round); 214 case LibFunc::pow: 215 case LibFunc::powf: 216 case LibFunc::powl: 217 return checkBinaryFloatSignature(*CI, Intrinsic::pow); 218 } 219 220 return Intrinsic::not_intrinsic; 221 } 222 223 /// \brief Find the operand of the GEP that should be checked for consecutive 224 /// stores. This ignores trailing indices that have no effect on the final 225 /// pointer. 226 unsigned llvm::getGEPInductionOperand(const GetElementPtrInst *Gep) { 227 const DataLayout &DL = Gep->getModule()->getDataLayout(); 228 unsigned LastOperand = Gep->getNumOperands() - 1; 229 unsigned GEPAllocSize = DL.getTypeAllocSize( 230 cast<PointerType>(Gep->getType()->getScalarType())->getElementType()); 231 232 // Walk backwards and try to peel off zeros. 233 while (LastOperand > 1 && match(Gep->getOperand(LastOperand), m_Zero())) { 234 // Find the type we're currently indexing into. 235 gep_type_iterator GEPTI = gep_type_begin(Gep); 236 std::advance(GEPTI, LastOperand - 1); 237 238 // If it's a type with the same allocation size as the result of the GEP we 239 // can peel off the zero index. 240 if (DL.getTypeAllocSize(*GEPTI) != GEPAllocSize) 241 break; 242 --LastOperand; 243 } 244 245 return LastOperand; 246 } 247 248 /// \brief If the argument is a GEP, then returns the operand identified by 249 /// getGEPInductionOperand. However, if there is some other non-loop-invariant 250 /// operand, it returns that instead. 251 Value *llvm::stripGetElementPtr(Value *Ptr, ScalarEvolution *SE, Loop *Lp) { 252 GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr); 253 if (!GEP) 254 return Ptr; 255 256 unsigned InductionOperand = getGEPInductionOperand(GEP); 257 258 // Check that all of the gep indices are uniform except for our induction 259 // operand. 260 for (unsigned i = 0, e = GEP->getNumOperands(); i != e; ++i) 261 if (i != InductionOperand && 262 !SE->isLoopInvariant(SE->getSCEV(GEP->getOperand(i)), Lp)) 263 return Ptr; 264 return GEP->getOperand(InductionOperand); 265 } 266 267 /// \brief If a value has only one user that is a CastInst, return it. 268 Value *llvm::getUniqueCastUse(Value *Ptr, Loop *Lp, Type *Ty) { 269 Value *UniqueCast = nullptr; 270 for (User *U : Ptr->users()) { 271 CastInst *CI = dyn_cast<CastInst>(U); 272 if (CI && CI->getType() == Ty) { 273 if (!UniqueCast) 274 UniqueCast = CI; 275 else 276 return nullptr; 277 } 278 } 279 return UniqueCast; 280 } 281 282 /// \brief Get the stride of a pointer access in a loop. Looks for symbolic 283 /// strides "a[i*stride]". Returns the symbolic stride, or null otherwise. 284 Value *llvm::getStrideFromPointer(Value *Ptr, ScalarEvolution *SE, Loop *Lp) { 285 auto *PtrTy = dyn_cast<PointerType>(Ptr->getType()); 286 if (!PtrTy || PtrTy->isAggregateType()) 287 return nullptr; 288 289 // Try to remove a gep instruction to make the pointer (actually index at this 290 // point) easier analyzable. If OrigPtr is equal to Ptr we are analzying the 291 // pointer, otherwise, we are analyzing the index. 292 Value *OrigPtr = Ptr; 293 294 // The size of the pointer access. 295 int64_t PtrAccessSize = 1; 296 297 Ptr = stripGetElementPtr(Ptr, SE, Lp); 298 const SCEV *V = SE->getSCEV(Ptr); 299 300 if (Ptr != OrigPtr) 301 // Strip off casts. 302 while (const SCEVCastExpr *C = dyn_cast<SCEVCastExpr>(V)) 303 V = C->getOperand(); 304 305 const SCEVAddRecExpr *S = dyn_cast<SCEVAddRecExpr>(V); 306 if (!S) 307 return nullptr; 308 309 V = S->getStepRecurrence(*SE); 310 if (!V) 311 return nullptr; 312 313 // Strip off the size of access multiplication if we are still analyzing the 314 // pointer. 315 if (OrigPtr == Ptr) { 316 const DataLayout &DL = Lp->getHeader()->getModule()->getDataLayout(); 317 DL.getTypeAllocSize(PtrTy->getElementType()); 318 if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(V)) { 319 if (M->getOperand(0)->getSCEVType() != scConstant) 320 return nullptr; 321 322 const APInt &APStepVal = 323 cast<SCEVConstant>(M->getOperand(0))->getValue()->getValue(); 324 325 // Huge step value - give up. 326 if (APStepVal.getBitWidth() > 64) 327 return nullptr; 328 329 int64_t StepVal = APStepVal.getSExtValue(); 330 if (PtrAccessSize != StepVal) 331 return nullptr; 332 V = M->getOperand(1); 333 } 334 } 335 336 // Strip off casts. 337 Type *StripedOffRecurrenceCast = nullptr; 338 if (const SCEVCastExpr *C = dyn_cast<SCEVCastExpr>(V)) { 339 StripedOffRecurrenceCast = C->getType(); 340 V = C->getOperand(); 341 } 342 343 // Look for the loop invariant symbolic value. 344 const SCEVUnknown *U = dyn_cast<SCEVUnknown>(V); 345 if (!U) 346 return nullptr; 347 348 Value *Stride = U->getValue(); 349 if (!Lp->isLoopInvariant(Stride)) 350 return nullptr; 351 352 // If we have stripped off the recurrence cast we have to make sure that we 353 // return the value that is used in this loop so that we can replace it later. 354 if (StripedOffRecurrenceCast) 355 Stride = getUniqueCastUse(Stride, Lp, StripedOffRecurrenceCast); 356 357 return Stride; 358 } 359 360 /// \brief Given a vector and an element number, see if the scalar value is 361 /// already around as a register, for example if it were inserted then extracted 362 /// from the vector. 363 Value *llvm::findScalarElement(Value *V, unsigned EltNo) { 364 assert(V->getType()->isVectorTy() && "Not looking at a vector?"); 365 VectorType *VTy = cast<VectorType>(V->getType()); 366 unsigned Width = VTy->getNumElements(); 367 if (EltNo >= Width) // Out of range access. 368 return UndefValue::get(VTy->getElementType()); 369 370 if (Constant *C = dyn_cast<Constant>(V)) 371 return C->getAggregateElement(EltNo); 372 373 if (InsertElementInst *III = dyn_cast<InsertElementInst>(V)) { 374 // If this is an insert to a variable element, we don't know what it is. 375 if (!isa<ConstantInt>(III->getOperand(2))) 376 return nullptr; 377 unsigned IIElt = cast<ConstantInt>(III->getOperand(2))->getZExtValue(); 378 379 // If this is an insert to the element we are looking for, return the 380 // inserted value. 381 if (EltNo == IIElt) 382 return III->getOperand(1); 383 384 // Otherwise, the insertelement doesn't modify the value, recurse on its 385 // vector input. 386 return findScalarElement(III->getOperand(0), EltNo); 387 } 388 389 if (ShuffleVectorInst *SVI = dyn_cast<ShuffleVectorInst>(V)) { 390 unsigned LHSWidth = SVI->getOperand(0)->getType()->getVectorNumElements(); 391 int InEl = SVI->getMaskValue(EltNo); 392 if (InEl < 0) 393 return UndefValue::get(VTy->getElementType()); 394 if (InEl < (int)LHSWidth) 395 return findScalarElement(SVI->getOperand(0), InEl); 396 return findScalarElement(SVI->getOperand(1), InEl - LHSWidth); 397 } 398 399 // Extract a value from a vector add operation with a constant zero. 400 Value *Val = nullptr; Constant *Con = nullptr; 401 if (match(V, m_Add(m_Value(Val), m_Constant(Con)))) 402 if (Constant *Elt = Con->getAggregateElement(EltNo)) 403 if (Elt->isNullValue()) 404 return findScalarElement(Val, EltNo); 405 406 // Otherwise, we don't know. 407 return nullptr; 408 } 409