1 //===----------- VectorUtils.cpp - Vectorizer utility functions -----------===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // This file defines vectorizer utilities. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "llvm/Analysis/LoopInfo.h" 15 #include "llvm/Analysis/ScalarEvolutionExpressions.h" 16 #include "llvm/Analysis/ScalarEvolution.h" 17 #include "llvm/Analysis/VectorUtils.h" 18 #include "llvm/IR/GetElementPtrTypeIterator.h" 19 #include "llvm/IR/PatternMatch.h" 20 #include "llvm/IR/Value.h" 21 22 /// \brief Identify if the intrinsic is trivially vectorizable. 23 /// This method returns true if the intrinsic's argument types are all 24 /// scalars for the scalar form of the intrinsic and all vectors for 25 /// the vector form of the intrinsic. 26 bool llvm::isTriviallyVectorizable(Intrinsic::ID ID) { 27 switch (ID) { 28 case Intrinsic::sqrt: 29 case Intrinsic::sin: 30 case Intrinsic::cos: 31 case Intrinsic::exp: 32 case Intrinsic::exp2: 33 case Intrinsic::log: 34 case Intrinsic::log10: 35 case Intrinsic::log2: 36 case Intrinsic::fabs: 37 case Intrinsic::minnum: 38 case Intrinsic::maxnum: 39 case Intrinsic::copysign: 40 case Intrinsic::floor: 41 case Intrinsic::ceil: 42 case Intrinsic::trunc: 43 case Intrinsic::rint: 44 case Intrinsic::nearbyint: 45 case Intrinsic::round: 46 case Intrinsic::bswap: 47 case Intrinsic::ctpop: 48 case Intrinsic::pow: 49 case Intrinsic::fma: 50 case Intrinsic::fmuladd: 51 case Intrinsic::ctlz: 52 case Intrinsic::cttz: 53 case Intrinsic::powi: 54 return true; 55 default: 56 return false; 57 } 58 } 59 60 /// \brief Identifies if the intrinsic has a scalar operand. It check for 61 /// ctlz,cttz and powi special intrinsics whose argument is scalar. 62 bool llvm::hasVectorInstrinsicScalarOpd(Intrinsic::ID ID, 63 unsigned ScalarOpdIdx) { 64 switch (ID) { 65 case Intrinsic::ctlz: 66 case Intrinsic::cttz: 67 case Intrinsic::powi: 68 return (ScalarOpdIdx == 1); 69 default: 70 return false; 71 } 72 } 73 74 /// \brief Check call has a unary float signature 75 /// It checks following: 76 /// a) call should have a single argument 77 /// b) argument type should be floating point type 78 /// c) call instruction type and argument type should be same 79 /// d) call should only reads memory. 80 /// If all these condition is met then return ValidIntrinsicID 81 /// else return not_intrinsic. 82 llvm::Intrinsic::ID 83 llvm::checkUnaryFloatSignature(const CallInst &I, 84 Intrinsic::ID ValidIntrinsicID) { 85 if (I.getNumArgOperands() != 1 || 86 !I.getArgOperand(0)->getType()->isFloatingPointTy() || 87 I.getType() != I.getArgOperand(0)->getType() || !I.onlyReadsMemory()) 88 return Intrinsic::not_intrinsic; 89 90 return ValidIntrinsicID; 91 } 92 93 /// \brief Check call has a binary float signature 94 /// It checks following: 95 /// a) call should have 2 arguments. 96 /// b) arguments type should be floating point type 97 /// c) call instruction type and arguments type should be same 98 /// d) call should only reads memory. 99 /// If all these condition is met then return ValidIntrinsicID 100 /// else return not_intrinsic. 101 llvm::Intrinsic::ID 102 llvm::checkBinaryFloatSignature(const CallInst &I, 103 Intrinsic::ID ValidIntrinsicID) { 104 if (I.getNumArgOperands() != 2 || 105 !I.getArgOperand(0)->getType()->isFloatingPointTy() || 106 !I.getArgOperand(1)->getType()->isFloatingPointTy() || 107 I.getType() != I.getArgOperand(0)->getType() || 108 I.getType() != I.getArgOperand(1)->getType() || !I.onlyReadsMemory()) 109 return Intrinsic::not_intrinsic; 110 111 return ValidIntrinsicID; 112 } 113 114 /// \brief Returns intrinsic ID for call. 115 /// For the input call instruction it finds mapping intrinsic and returns 116 /// its ID, in case it does not found it return not_intrinsic. 117 llvm::Intrinsic::ID llvm::getIntrinsicIDForCall(CallInst *CI, 118 const TargetLibraryInfo *TLI) { 119 // If we have an intrinsic call, check if it is trivially vectorizable. 120 if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI)) { 121 Intrinsic::ID ID = II->getIntrinsicID(); 122 if (isTriviallyVectorizable(ID) || ID == Intrinsic::lifetime_start || 123 ID == Intrinsic::lifetime_end || ID == Intrinsic::assume) 124 return ID; 125 return Intrinsic::not_intrinsic; 126 } 127 128 if (!TLI) 129 return Intrinsic::not_intrinsic; 130 131 LibFunc::Func Func; 132 Function *F = CI->getCalledFunction(); 133 // We're going to make assumptions on the semantics of the functions, check 134 // that the target knows that it's available in this environment and it does 135 // not have local linkage. 136 if (!F || F->hasLocalLinkage() || !TLI->getLibFunc(F->getName(), Func)) 137 return Intrinsic::not_intrinsic; 138 139 // Otherwise check if we have a call to a function that can be turned into a 140 // vector intrinsic. 141 switch (Func) { 142 default: 143 break; 144 case LibFunc::sin: 145 case LibFunc::sinf: 146 case LibFunc::sinl: 147 return checkUnaryFloatSignature(*CI, Intrinsic::sin); 148 case LibFunc::cos: 149 case LibFunc::cosf: 150 case LibFunc::cosl: 151 return checkUnaryFloatSignature(*CI, Intrinsic::cos); 152 case LibFunc::exp: 153 case LibFunc::expf: 154 case LibFunc::expl: 155 return checkUnaryFloatSignature(*CI, Intrinsic::exp); 156 case LibFunc::exp2: 157 case LibFunc::exp2f: 158 case LibFunc::exp2l: 159 return checkUnaryFloatSignature(*CI, Intrinsic::exp2); 160 case LibFunc::log: 161 case LibFunc::logf: 162 case LibFunc::logl: 163 return checkUnaryFloatSignature(*CI, Intrinsic::log); 164 case LibFunc::log10: 165 case LibFunc::log10f: 166 case LibFunc::log10l: 167 return checkUnaryFloatSignature(*CI, Intrinsic::log10); 168 case LibFunc::log2: 169 case LibFunc::log2f: 170 case LibFunc::log2l: 171 return checkUnaryFloatSignature(*CI, Intrinsic::log2); 172 case LibFunc::fabs: 173 case LibFunc::fabsf: 174 case LibFunc::fabsl: 175 return checkUnaryFloatSignature(*CI, Intrinsic::fabs); 176 case LibFunc::fmin: 177 case LibFunc::fminf: 178 case LibFunc::fminl: 179 return checkBinaryFloatSignature(*CI, Intrinsic::minnum); 180 case LibFunc::fmax: 181 case LibFunc::fmaxf: 182 case LibFunc::fmaxl: 183 return checkBinaryFloatSignature(*CI, Intrinsic::maxnum); 184 case LibFunc::copysign: 185 case LibFunc::copysignf: 186 case LibFunc::copysignl: 187 return checkBinaryFloatSignature(*CI, Intrinsic::copysign); 188 case LibFunc::floor: 189 case LibFunc::floorf: 190 case LibFunc::floorl: 191 return checkUnaryFloatSignature(*CI, Intrinsic::floor); 192 case LibFunc::ceil: 193 case LibFunc::ceilf: 194 case LibFunc::ceill: 195 return checkUnaryFloatSignature(*CI, Intrinsic::ceil); 196 case LibFunc::trunc: 197 case LibFunc::truncf: 198 case LibFunc::truncl: 199 return checkUnaryFloatSignature(*CI, Intrinsic::trunc); 200 case LibFunc::rint: 201 case LibFunc::rintf: 202 case LibFunc::rintl: 203 return checkUnaryFloatSignature(*CI, Intrinsic::rint); 204 case LibFunc::nearbyint: 205 case LibFunc::nearbyintf: 206 case LibFunc::nearbyintl: 207 return checkUnaryFloatSignature(*CI, Intrinsic::nearbyint); 208 case LibFunc::round: 209 case LibFunc::roundf: 210 case LibFunc::roundl: 211 return checkUnaryFloatSignature(*CI, Intrinsic::round); 212 case LibFunc::pow: 213 case LibFunc::powf: 214 case LibFunc::powl: 215 return checkBinaryFloatSignature(*CI, Intrinsic::pow); 216 } 217 218 return Intrinsic::not_intrinsic; 219 } 220 221 /// \brief Find the operand of the GEP that should be checked for consecutive 222 /// stores. This ignores trailing indices that have no effect on the final 223 /// pointer. 224 unsigned llvm::getGEPInductionOperand(const GetElementPtrInst *Gep) { 225 const DataLayout &DL = Gep->getModule()->getDataLayout(); 226 unsigned LastOperand = Gep->getNumOperands() - 1; 227 unsigned GEPAllocSize = DL.getTypeAllocSize( 228 cast<PointerType>(Gep->getType()->getScalarType())->getElementType()); 229 230 // Walk backwards and try to peel off zeros. 231 while (LastOperand > 1 && 232 match(Gep->getOperand(LastOperand), llvm::PatternMatch::m_Zero())) { 233 // Find the type we're currently indexing into. 234 gep_type_iterator GEPTI = gep_type_begin(Gep); 235 std::advance(GEPTI, LastOperand - 1); 236 237 // If it's a type with the same allocation size as the result of the GEP we 238 // can peel off the zero index. 239 if (DL.getTypeAllocSize(*GEPTI) != GEPAllocSize) 240 break; 241 --LastOperand; 242 } 243 244 return LastOperand; 245 } 246 247 /// \brief If the argument is a GEP, then returns the operand identified by 248 /// getGEPInductionOperand. However, if there is some other non-loop-invariant 249 /// operand, it returns that instead. 250 llvm::Value *llvm::stripGetElementPtr(llvm::Value *Ptr, ScalarEvolution *SE, 251 Loop *Lp) { 252 GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr); 253 if (!GEP) 254 return Ptr; 255 256 unsigned InductionOperand = getGEPInductionOperand(GEP); 257 258 // Check that all of the gep indices are uniform except for our induction 259 // operand. 260 for (unsigned i = 0, e = GEP->getNumOperands(); i != e; ++i) 261 if (i != InductionOperand && 262 !SE->isLoopInvariant(SE->getSCEV(GEP->getOperand(i)), Lp)) 263 return Ptr; 264 return GEP->getOperand(InductionOperand); 265 } 266 267 /// \brief If a value has only one user that is a CastInst, return it. 268 llvm::Value *llvm::getUniqueCastUse(llvm::Value *Ptr, Loop *Lp, Type *Ty) { 269 llvm::Value *UniqueCast = nullptr; 270 for (User *U : Ptr->users()) { 271 CastInst *CI = dyn_cast<CastInst>(U); 272 if (CI && CI->getType() == Ty) { 273 if (!UniqueCast) 274 UniqueCast = CI; 275 else 276 return nullptr; 277 } 278 } 279 return UniqueCast; 280 } 281 282 /// \brief Get the stride of a pointer access in a loop. Looks for symbolic 283 /// strides "a[i*stride]". Returns the symbolic stride, or null otherwise. 284 llvm::Value *llvm::getStrideFromPointer(llvm::Value *Ptr, ScalarEvolution *SE, 285 Loop *Lp) { 286 const PointerType *PtrTy = dyn_cast<PointerType>(Ptr->getType()); 287 if (!PtrTy || PtrTy->isAggregateType()) 288 return nullptr; 289 290 // Try to remove a gep instruction to make the pointer (actually index at this 291 // point) easier analyzable. If OrigPtr is equal to Ptr we are analzying the 292 // pointer, otherwise, we are analyzing the index. 293 llvm::Value *OrigPtr = Ptr; 294 295 // The size of the pointer access. 296 int64_t PtrAccessSize = 1; 297 298 Ptr = stripGetElementPtr(Ptr, SE, Lp); 299 const SCEV *V = SE->getSCEV(Ptr); 300 301 if (Ptr != OrigPtr) 302 // Strip off casts. 303 while (const SCEVCastExpr *C = dyn_cast<SCEVCastExpr>(V)) 304 V = C->getOperand(); 305 306 const SCEVAddRecExpr *S = dyn_cast<SCEVAddRecExpr>(V); 307 if (!S) 308 return nullptr; 309 310 V = S->getStepRecurrence(*SE); 311 if (!V) 312 return nullptr; 313 314 // Strip off the size of access multiplication if we are still analyzing the 315 // pointer. 316 if (OrigPtr == Ptr) { 317 const DataLayout &DL = Lp->getHeader()->getModule()->getDataLayout(); 318 DL.getTypeAllocSize(PtrTy->getElementType()); 319 if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(V)) { 320 if (M->getOperand(0)->getSCEVType() != scConstant) 321 return nullptr; 322 323 const APInt &APStepVal = 324 cast<SCEVConstant>(M->getOperand(0))->getValue()->getValue(); 325 326 // Huge step value - give up. 327 if (APStepVal.getBitWidth() > 64) 328 return nullptr; 329 330 int64_t StepVal = APStepVal.getSExtValue(); 331 if (PtrAccessSize != StepVal) 332 return nullptr; 333 V = M->getOperand(1); 334 } 335 } 336 337 // Strip off casts. 338 Type *StripedOffRecurrenceCast = nullptr; 339 if (const SCEVCastExpr *C = dyn_cast<SCEVCastExpr>(V)) { 340 StripedOffRecurrenceCast = C->getType(); 341 V = C->getOperand(); 342 } 343 344 // Look for the loop invariant symbolic value. 345 const SCEVUnknown *U = dyn_cast<SCEVUnknown>(V); 346 if (!U) 347 return nullptr; 348 349 llvm::Value *Stride = U->getValue(); 350 if (!Lp->isLoopInvariant(Stride)) 351 return nullptr; 352 353 // If we have stripped off the recurrence cast we have to make sure that we 354 // return the value that is used in this loop so that we can replace it later. 355 if (StripedOffRecurrenceCast) 356 Stride = getUniqueCastUse(Stride, Lp, StripedOffRecurrenceCast); 357 358 return Stride; 359 } 360 361 /// \brief Given a vector and an element number, see if the scalar value is 362 /// already around as a register, for example if it were inserted then extracted 363 /// from the vector. 364 llvm::Value *llvm::findScalarElement(llvm::Value *V, unsigned EltNo) { 365 assert(V->getType()->isVectorTy() && "Not looking at a vector?"); 366 VectorType *VTy = cast<VectorType>(V->getType()); 367 unsigned Width = VTy->getNumElements(); 368 if (EltNo >= Width) // Out of range access. 369 return UndefValue::get(VTy->getElementType()); 370 371 if (Constant *C = dyn_cast<Constant>(V)) 372 return C->getAggregateElement(EltNo); 373 374 if (InsertElementInst *III = dyn_cast<InsertElementInst>(V)) { 375 // If this is an insert to a variable element, we don't know what it is. 376 if (!isa<ConstantInt>(III->getOperand(2))) 377 return nullptr; 378 unsigned IIElt = cast<ConstantInt>(III->getOperand(2))->getZExtValue(); 379 380 // If this is an insert to the element we are looking for, return the 381 // inserted value. 382 if (EltNo == IIElt) 383 return III->getOperand(1); 384 385 // Otherwise, the insertelement doesn't modify the value, recurse on its 386 // vector input. 387 return findScalarElement(III->getOperand(0), EltNo); 388 } 389 390 if (ShuffleVectorInst *SVI = dyn_cast<ShuffleVectorInst>(V)) { 391 unsigned LHSWidth = SVI->getOperand(0)->getType()->getVectorNumElements(); 392 int InEl = SVI->getMaskValue(EltNo); 393 if (InEl < 0) 394 return UndefValue::get(VTy->getElementType()); 395 if (InEl < (int)LHSWidth) 396 return findScalarElement(SVI->getOperand(0), InEl); 397 return findScalarElement(SVI->getOperand(1), InEl - LHSWidth); 398 } 399 400 // Extract a value from a vector add operation with a constant zero. 401 Value *Val = nullptr; Constant *Con = nullptr; 402 if (match(V, 403 llvm::PatternMatch::m_Add(llvm::PatternMatch::m_Value(Val), 404 llvm::PatternMatch::m_Constant(Con)))) { 405 if (Con->getAggregateElement(EltNo)->isNullValue()) 406 return findScalarElement(Val, EltNo); 407 } 408 409 // Otherwise, we don't know. 410 return nullptr; 411 } 412