1 //===------- VectorCombine.cpp - Optimize partial vector operations -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This pass optimizes scalar/vector interactions using target cost models. The 10 // transforms implemented here may not fit in traditional loop-based or SLP 11 // vectorization passes. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "llvm/Transforms/Vectorize/VectorCombine.h" 16 #include "llvm/ADT/Statistic.h" 17 #include "llvm/Analysis/BasicAliasAnalysis.h" 18 #include "llvm/Analysis/GlobalsModRef.h" 19 #include "llvm/Analysis/Loads.h" 20 #include "llvm/Analysis/TargetTransformInfo.h" 21 #include "llvm/Analysis/ValueTracking.h" 22 #include "llvm/Analysis/VectorUtils.h" 23 #include "llvm/IR/Dominators.h" 24 #include "llvm/IR/Function.h" 25 #include "llvm/IR/IRBuilder.h" 26 #include "llvm/IR/PatternMatch.h" 27 #include "llvm/InitializePasses.h" 28 #include "llvm/Pass.h" 29 #include "llvm/Support/CommandLine.h" 30 #include "llvm/Transforms/Utils/Local.h" 31 #include "llvm/Transforms/Vectorize.h" 32 33 using namespace llvm; 34 using namespace llvm::PatternMatch; 35 36 #define DEBUG_TYPE "vector-combine" 37 STATISTIC(NumVecLoad, "Number of vector loads formed"); 38 STATISTIC(NumVecCmp, "Number of vector compares formed"); 39 STATISTIC(NumVecBO, "Number of vector binops formed"); 40 STATISTIC(NumVecCmpBO, "Number of vector compare + binop formed"); 41 STATISTIC(NumShufOfBitcast, "Number of shuffles moved after bitcast"); 42 STATISTIC(NumScalarBO, "Number of scalar binops formed"); 43 STATISTIC(NumScalarCmp, "Number of scalar compares formed"); 44 45 static cl::opt<bool> DisableVectorCombine( 46 "disable-vector-combine", cl::init(false), cl::Hidden, 47 cl::desc("Disable all vector combine transforms")); 48 49 static cl::opt<bool> DisableBinopExtractShuffle( 50 "disable-binop-extract-shuffle", cl::init(false), cl::Hidden, 51 cl::desc("Disable binop extract to shuffle transforms")); 52 53 static const unsigned InvalidIndex = std::numeric_limits<unsigned>::max(); 54 55 namespace { 56 class VectorCombine { 57 public: 58 VectorCombine(Function &F, const TargetTransformInfo &TTI, 59 const DominatorTree &DT) 60 : F(F), Builder(F.getContext()), TTI(TTI), DT(DT) {} 61 62 bool run(); 63 64 private: 65 Function &F; 66 IRBuilder<> Builder; 67 const TargetTransformInfo &TTI; 68 const DominatorTree &DT; 69 70 bool vectorizeLoadInsert(Instruction &I); 71 ExtractElementInst *getShuffleExtract(ExtractElementInst *Ext0, 72 ExtractElementInst *Ext1, 73 unsigned PreferredExtractIndex) const; 74 bool isExtractExtractCheap(ExtractElementInst *Ext0, ExtractElementInst *Ext1, 75 unsigned Opcode, 76 ExtractElementInst *&ConvertToShuffle, 77 unsigned PreferredExtractIndex); 78 void foldExtExtCmp(ExtractElementInst *Ext0, ExtractElementInst *Ext1, 79 Instruction &I); 80 void foldExtExtBinop(ExtractElementInst *Ext0, ExtractElementInst *Ext1, 81 Instruction &I); 82 bool foldExtractExtract(Instruction &I); 83 bool foldBitcastShuf(Instruction &I); 84 bool scalarizeBinopOrCmp(Instruction &I); 85 bool foldExtractedCmps(Instruction &I); 86 }; 87 } // namespace 88 89 static void replaceValue(Value &Old, Value &New) { 90 Old.replaceAllUsesWith(&New); 91 New.takeName(&Old); 92 } 93 94 bool VectorCombine::vectorizeLoadInsert(Instruction &I) { 95 // Match insert into fixed vector of scalar value. 96 auto *Ty = dyn_cast<FixedVectorType>(I.getType()); 97 Value *Scalar; 98 if (!Ty || !match(&I, m_InsertElt(m_Undef(), m_Value(Scalar), m_ZeroInt())) || 99 !Scalar->hasOneUse()) 100 return false; 101 102 // Optionally match an extract from another vector. 103 Value *X; 104 bool HasExtract = match(Scalar, m_ExtractElt(m_Value(X), m_ZeroInt())); 105 if (!HasExtract) 106 X = Scalar; 107 108 // Match source value as load of scalar or vector. 109 // Do not vectorize scalar load (widening) if atomic/volatile or under 110 // asan/hwasan/memtag/tsan. The widened load may load data from dirty regions 111 // or create data races non-existent in the source. 112 auto *Load = dyn_cast<LoadInst>(X); 113 if (!Load || !Load->isSimple() || !Load->hasOneUse() || 114 Load->getFunction()->hasFnAttribute(Attribute::SanitizeMemTag) || 115 mustSuppressSpeculation(*Load)) 116 return false; 117 118 // TODO: Extend this to match GEP with constant offsets. 119 const DataLayout &DL = I.getModule()->getDataLayout(); 120 Value *SrcPtr = Load->getPointerOperand()->stripPointerCasts(); 121 assert(isa<PointerType>(SrcPtr->getType()) && "Expected a pointer type"); 122 123 // If original AS != Load's AS, we can't bitcast the original pointer and have 124 // to use Load's operand instead. Ideally we would want to strip pointer casts 125 // without changing AS, but there's no API to do that ATM. 126 unsigned AS = Load->getPointerAddressSpace(); 127 if (AS != SrcPtr->getType()->getPointerAddressSpace()) 128 SrcPtr = Load->getPointerOperand(); 129 130 Type *ScalarTy = Scalar->getType(); 131 uint64_t ScalarSize = ScalarTy->getPrimitiveSizeInBits(); 132 unsigned MinVectorSize = TTI.getMinVectorRegisterBitWidth(); 133 if (!ScalarSize || !MinVectorSize || MinVectorSize % ScalarSize != 0) 134 return false; 135 136 // Check safety of replacing the scalar load with a larger vector load. 137 // We use minimal alignment (maximum flexibility) because we only care about 138 // the dereferenceable region. When calculating cost and creating a new op, 139 // we may use a larger value based on alignment attributes. 140 unsigned MinVecNumElts = MinVectorSize / ScalarSize; 141 auto *MinVecTy = VectorType::get(ScalarTy, MinVecNumElts, false); 142 if (!isSafeToLoadUnconditionally(SrcPtr, MinVecTy, Align(1), DL, Load, &DT)) 143 return false; 144 145 // Original pattern: insertelt undef, load [free casts of] PtrOp, 0 146 // Use the greater of the alignment on the load or its source pointer. 147 Align Alignment = std::max(SrcPtr->getPointerAlignment(DL), Load->getAlign()); 148 Type *LoadTy = Load->getType(); 149 int OldCost = TTI.getMemoryOpCost(Instruction::Load, LoadTy, Alignment, AS); 150 APInt DemandedElts = APInt::getOneBitSet(MinVecNumElts, 0); 151 OldCost += TTI.getScalarizationOverhead(MinVecTy, DemandedElts, 152 /* Insert */ true, HasExtract); 153 154 // New pattern: load VecPtr 155 int NewCost = TTI.getMemoryOpCost(Instruction::Load, MinVecTy, Alignment, AS); 156 157 // We can aggressively convert to the vector form because the backend can 158 // invert this transform if it does not result in a performance win. 159 if (OldCost < NewCost) 160 return false; 161 162 // It is safe and potentially profitable to load a vector directly: 163 // inselt undef, load Scalar, 0 --> load VecPtr 164 IRBuilder<> Builder(Load); 165 Value *CastedPtr = Builder.CreateBitCast(SrcPtr, MinVecTy->getPointerTo(AS)); 166 Value *VecLd = Builder.CreateAlignedLoad(MinVecTy, CastedPtr, Alignment); 167 168 // Set everything but element 0 to undef to prevent poison from propagating 169 // from the extra loaded memory. This will also optionally shrink/grow the 170 // vector from the loaded size to the output size. 171 // We assume this operation has no cost in codegen. 172 // Note that we could use freeze to avoid poison problems, but then we might 173 // still need a shuffle to change the vector size. 174 unsigned OutputNumElts = Ty->getNumElements(); 175 SmallVector<int, 16> Mask(OutputNumElts, UndefMaskElem); 176 Mask[0] = 0; 177 VecLd = Builder.CreateShuffleVector(VecLd, Mask); 178 179 replaceValue(I, *VecLd); 180 ++NumVecLoad; 181 return true; 182 } 183 184 /// Determine which, if any, of the inputs should be replaced by a shuffle 185 /// followed by extract from a different index. 186 ExtractElementInst *VectorCombine::getShuffleExtract( 187 ExtractElementInst *Ext0, ExtractElementInst *Ext1, 188 unsigned PreferredExtractIndex = InvalidIndex) const { 189 assert(isa<ConstantInt>(Ext0->getIndexOperand()) && 190 isa<ConstantInt>(Ext1->getIndexOperand()) && 191 "Expected constant extract indexes"); 192 193 unsigned Index0 = cast<ConstantInt>(Ext0->getIndexOperand())->getZExtValue(); 194 unsigned Index1 = cast<ConstantInt>(Ext1->getIndexOperand())->getZExtValue(); 195 196 // If the extract indexes are identical, no shuffle is needed. 197 if (Index0 == Index1) 198 return nullptr; 199 200 Type *VecTy = Ext0->getVectorOperand()->getType(); 201 assert(VecTy == Ext1->getVectorOperand()->getType() && "Need matching types"); 202 int Cost0 = TTI.getVectorInstrCost(Ext0->getOpcode(), VecTy, Index0); 203 int Cost1 = TTI.getVectorInstrCost(Ext1->getOpcode(), VecTy, Index1); 204 205 // We are extracting from 2 different indexes, so one operand must be shuffled 206 // before performing a vector operation and/or extract. The more expensive 207 // extract will be replaced by a shuffle. 208 if (Cost0 > Cost1) 209 return Ext0; 210 if (Cost1 > Cost0) 211 return Ext1; 212 213 // If the costs are equal and there is a preferred extract index, shuffle the 214 // opposite operand. 215 if (PreferredExtractIndex == Index0) 216 return Ext1; 217 if (PreferredExtractIndex == Index1) 218 return Ext0; 219 220 // Otherwise, replace the extract with the higher index. 221 return Index0 > Index1 ? Ext0 : Ext1; 222 } 223 224 /// Compare the relative costs of 2 extracts followed by scalar operation vs. 225 /// vector operation(s) followed by extract. Return true if the existing 226 /// instructions are cheaper than a vector alternative. Otherwise, return false 227 /// and if one of the extracts should be transformed to a shufflevector, set 228 /// \p ConvertToShuffle to that extract instruction. 229 bool VectorCombine::isExtractExtractCheap(ExtractElementInst *Ext0, 230 ExtractElementInst *Ext1, 231 unsigned Opcode, 232 ExtractElementInst *&ConvertToShuffle, 233 unsigned PreferredExtractIndex) { 234 assert(isa<ConstantInt>(Ext0->getOperand(1)) && 235 isa<ConstantInt>(Ext1->getOperand(1)) && 236 "Expected constant extract indexes"); 237 Type *ScalarTy = Ext0->getType(); 238 auto *VecTy = cast<VectorType>(Ext0->getOperand(0)->getType()); 239 int ScalarOpCost, VectorOpCost; 240 241 // Get cost estimates for scalar and vector versions of the operation. 242 bool IsBinOp = Instruction::isBinaryOp(Opcode); 243 if (IsBinOp) { 244 ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, ScalarTy); 245 VectorOpCost = TTI.getArithmeticInstrCost(Opcode, VecTy); 246 } else { 247 assert((Opcode == Instruction::ICmp || Opcode == Instruction::FCmp) && 248 "Expected a compare"); 249 ScalarOpCost = TTI.getCmpSelInstrCost(Opcode, ScalarTy, 250 CmpInst::makeCmpResultType(ScalarTy)); 251 VectorOpCost = TTI.getCmpSelInstrCost(Opcode, VecTy, 252 CmpInst::makeCmpResultType(VecTy)); 253 } 254 255 // Get cost estimates for the extract elements. These costs will factor into 256 // both sequences. 257 unsigned Ext0Index = cast<ConstantInt>(Ext0->getOperand(1))->getZExtValue(); 258 unsigned Ext1Index = cast<ConstantInt>(Ext1->getOperand(1))->getZExtValue(); 259 260 int Extract0Cost = 261 TTI.getVectorInstrCost(Instruction::ExtractElement, VecTy, Ext0Index); 262 int Extract1Cost = 263 TTI.getVectorInstrCost(Instruction::ExtractElement, VecTy, Ext1Index); 264 265 // A more expensive extract will always be replaced by a splat shuffle. 266 // For example, if Ext0 is more expensive: 267 // opcode (extelt V0, Ext0), (ext V1, Ext1) --> 268 // extelt (opcode (splat V0, Ext0), V1), Ext1 269 // TODO: Evaluate whether that always results in lowest cost. Alternatively, 270 // check the cost of creating a broadcast shuffle and shuffling both 271 // operands to element 0. 272 int CheapExtractCost = std::min(Extract0Cost, Extract1Cost); 273 274 // Extra uses of the extracts mean that we include those costs in the 275 // vector total because those instructions will not be eliminated. 276 int OldCost, NewCost; 277 if (Ext0->getOperand(0) == Ext1->getOperand(0) && Ext0Index == Ext1Index) { 278 // Handle a special case. If the 2 extracts are identical, adjust the 279 // formulas to account for that. The extra use charge allows for either the 280 // CSE'd pattern or an unoptimized form with identical values: 281 // opcode (extelt V, C), (extelt V, C) --> extelt (opcode V, V), C 282 bool HasUseTax = Ext0 == Ext1 ? !Ext0->hasNUses(2) 283 : !Ext0->hasOneUse() || !Ext1->hasOneUse(); 284 OldCost = CheapExtractCost + ScalarOpCost; 285 NewCost = VectorOpCost + CheapExtractCost + HasUseTax * CheapExtractCost; 286 } else { 287 // Handle the general case. Each extract is actually a different value: 288 // opcode (extelt V0, C0), (extelt V1, C1) --> extelt (opcode V0, V1), C 289 OldCost = Extract0Cost + Extract1Cost + ScalarOpCost; 290 NewCost = VectorOpCost + CheapExtractCost + 291 !Ext0->hasOneUse() * Extract0Cost + 292 !Ext1->hasOneUse() * Extract1Cost; 293 } 294 295 ConvertToShuffle = getShuffleExtract(Ext0, Ext1, PreferredExtractIndex); 296 if (ConvertToShuffle) { 297 if (IsBinOp && DisableBinopExtractShuffle) 298 return true; 299 300 // If we are extracting from 2 different indexes, then one operand must be 301 // shuffled before performing the vector operation. The shuffle mask is 302 // undefined except for 1 lane that is being translated to the remaining 303 // extraction lane. Therefore, it is a splat shuffle. Ex: 304 // ShufMask = { undef, undef, 0, undef } 305 // TODO: The cost model has an option for a "broadcast" shuffle 306 // (splat-from-element-0), but no option for a more general splat. 307 NewCost += 308 TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, VecTy); 309 } 310 311 // Aggressively form a vector op if the cost is equal because the transform 312 // may enable further optimization. 313 // Codegen can reverse this transform (scalarize) if it was not profitable. 314 return OldCost < NewCost; 315 } 316 317 /// Create a shuffle that translates (shifts) 1 element from the input vector 318 /// to a new element location. 319 static Value *createShiftShuffle(Value *Vec, unsigned OldIndex, 320 unsigned NewIndex, IRBuilder<> &Builder) { 321 // The shuffle mask is undefined except for 1 lane that is being translated 322 // to the new element index. Example for OldIndex == 2 and NewIndex == 0: 323 // ShufMask = { 2, undef, undef, undef } 324 auto *VecTy = cast<FixedVectorType>(Vec->getType()); 325 SmallVector<int, 32> ShufMask(VecTy->getNumElements(), UndefMaskElem); 326 ShufMask[NewIndex] = OldIndex; 327 return Builder.CreateShuffleVector(Vec, ShufMask, "shift"); 328 } 329 330 /// Given an extract element instruction with constant index operand, shuffle 331 /// the source vector (shift the scalar element) to a NewIndex for extraction. 332 /// Return null if the input can be constant folded, so that we are not creating 333 /// unnecessary instructions. 334 static ExtractElementInst *translateExtract(ExtractElementInst *ExtElt, 335 unsigned NewIndex, 336 IRBuilder<> &Builder) { 337 // If the extract can be constant-folded, this code is unsimplified. Defer 338 // to other passes to handle that. 339 Value *X = ExtElt->getVectorOperand(); 340 Value *C = ExtElt->getIndexOperand(); 341 assert(isa<ConstantInt>(C) && "Expected a constant index operand"); 342 if (isa<Constant>(X)) 343 return nullptr; 344 345 Value *Shuf = createShiftShuffle(X, cast<ConstantInt>(C)->getZExtValue(), 346 NewIndex, Builder); 347 return cast<ExtractElementInst>(Builder.CreateExtractElement(Shuf, NewIndex)); 348 } 349 350 /// Try to reduce extract element costs by converting scalar compares to vector 351 /// compares followed by extract. 352 /// cmp (ext0 V0, C), (ext1 V1, C) 353 void VectorCombine::foldExtExtCmp(ExtractElementInst *Ext0, 354 ExtractElementInst *Ext1, Instruction &I) { 355 assert(isa<CmpInst>(&I) && "Expected a compare"); 356 assert(cast<ConstantInt>(Ext0->getIndexOperand())->getZExtValue() == 357 cast<ConstantInt>(Ext1->getIndexOperand())->getZExtValue() && 358 "Expected matching constant extract indexes"); 359 360 // cmp Pred (extelt V0, C), (extelt V1, C) --> extelt (cmp Pred V0, V1), C 361 ++NumVecCmp; 362 CmpInst::Predicate Pred = cast<CmpInst>(&I)->getPredicate(); 363 Value *V0 = Ext0->getVectorOperand(), *V1 = Ext1->getVectorOperand(); 364 Value *VecCmp = Builder.CreateCmp(Pred, V0, V1); 365 Value *NewExt = Builder.CreateExtractElement(VecCmp, Ext0->getIndexOperand()); 366 replaceValue(I, *NewExt); 367 } 368 369 /// Try to reduce extract element costs by converting scalar binops to vector 370 /// binops followed by extract. 371 /// bo (ext0 V0, C), (ext1 V1, C) 372 void VectorCombine::foldExtExtBinop(ExtractElementInst *Ext0, 373 ExtractElementInst *Ext1, Instruction &I) { 374 assert(isa<BinaryOperator>(&I) && "Expected a binary operator"); 375 assert(cast<ConstantInt>(Ext0->getIndexOperand())->getZExtValue() == 376 cast<ConstantInt>(Ext1->getIndexOperand())->getZExtValue() && 377 "Expected matching constant extract indexes"); 378 379 // bo (extelt V0, C), (extelt V1, C) --> extelt (bo V0, V1), C 380 ++NumVecBO; 381 Value *V0 = Ext0->getVectorOperand(), *V1 = Ext1->getVectorOperand(); 382 Value *VecBO = 383 Builder.CreateBinOp(cast<BinaryOperator>(&I)->getOpcode(), V0, V1); 384 385 // All IR flags are safe to back-propagate because any potential poison 386 // created in unused vector elements is discarded by the extract. 387 if (auto *VecBOInst = dyn_cast<Instruction>(VecBO)) 388 VecBOInst->copyIRFlags(&I); 389 390 Value *NewExt = Builder.CreateExtractElement(VecBO, Ext0->getIndexOperand()); 391 replaceValue(I, *NewExt); 392 } 393 394 /// Match an instruction with extracted vector operands. 395 bool VectorCombine::foldExtractExtract(Instruction &I) { 396 // It is not safe to transform things like div, urem, etc. because we may 397 // create undefined behavior when executing those on unknown vector elements. 398 if (!isSafeToSpeculativelyExecute(&I)) 399 return false; 400 401 Instruction *I0, *I1; 402 CmpInst::Predicate Pred = CmpInst::BAD_ICMP_PREDICATE; 403 if (!match(&I, m_Cmp(Pred, m_Instruction(I0), m_Instruction(I1))) && 404 !match(&I, m_BinOp(m_Instruction(I0), m_Instruction(I1)))) 405 return false; 406 407 Value *V0, *V1; 408 uint64_t C0, C1; 409 if (!match(I0, m_ExtractElt(m_Value(V0), m_ConstantInt(C0))) || 410 !match(I1, m_ExtractElt(m_Value(V1), m_ConstantInt(C1))) || 411 V0->getType() != V1->getType()) 412 return false; 413 414 // If the scalar value 'I' is going to be re-inserted into a vector, then try 415 // to create an extract to that same element. The extract/insert can be 416 // reduced to a "select shuffle". 417 // TODO: If we add a larger pattern match that starts from an insert, this 418 // probably becomes unnecessary. 419 auto *Ext0 = cast<ExtractElementInst>(I0); 420 auto *Ext1 = cast<ExtractElementInst>(I1); 421 uint64_t InsertIndex = InvalidIndex; 422 if (I.hasOneUse()) 423 match(I.user_back(), 424 m_InsertElt(m_Value(), m_Value(), m_ConstantInt(InsertIndex))); 425 426 ExtractElementInst *ExtractToChange; 427 if (isExtractExtractCheap(Ext0, Ext1, I.getOpcode(), ExtractToChange, 428 InsertIndex)) 429 return false; 430 431 if (ExtractToChange) { 432 unsigned CheapExtractIdx = ExtractToChange == Ext0 ? C1 : C0; 433 ExtractElementInst *NewExtract = 434 translateExtract(ExtractToChange, CheapExtractIdx, Builder); 435 if (!NewExtract) 436 return false; 437 if (ExtractToChange == Ext0) 438 Ext0 = NewExtract; 439 else 440 Ext1 = NewExtract; 441 } 442 443 if (Pred != CmpInst::BAD_ICMP_PREDICATE) 444 foldExtExtCmp(Ext0, Ext1, I); 445 else 446 foldExtExtBinop(Ext0, Ext1, I); 447 448 return true; 449 } 450 451 /// If this is a bitcast of a shuffle, try to bitcast the source vector to the 452 /// destination type followed by shuffle. This can enable further transforms by 453 /// moving bitcasts or shuffles together. 454 bool VectorCombine::foldBitcastShuf(Instruction &I) { 455 Value *V; 456 ArrayRef<int> Mask; 457 if (!match(&I, m_BitCast( 458 m_OneUse(m_Shuffle(m_Value(V), m_Undef(), m_Mask(Mask)))))) 459 return false; 460 461 // 1) Do not fold bitcast shuffle for scalable type. First, shuffle cost for 462 // scalable type is unknown; Second, we cannot reason if the narrowed shuffle 463 // mask for scalable type is a splat or not. 464 // 2) Disallow non-vector casts and length-changing shuffles. 465 // TODO: We could allow any shuffle. 466 auto *DestTy = dyn_cast<FixedVectorType>(I.getType()); 467 auto *SrcTy = dyn_cast<FixedVectorType>(V->getType()); 468 if (!SrcTy || !DestTy || I.getOperand(0)->getType() != SrcTy) 469 return false; 470 471 // The new shuffle must not cost more than the old shuffle. The bitcast is 472 // moved ahead of the shuffle, so assume that it has the same cost as before. 473 if (TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, DestTy) > 474 TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, SrcTy)) 475 return false; 476 477 unsigned DestNumElts = DestTy->getNumElements(); 478 unsigned SrcNumElts = SrcTy->getNumElements(); 479 SmallVector<int, 16> NewMask; 480 if (SrcNumElts <= DestNumElts) { 481 // The bitcast is from wide to narrow/equal elements. The shuffle mask can 482 // always be expanded to the equivalent form choosing narrower elements. 483 assert(DestNumElts % SrcNumElts == 0 && "Unexpected shuffle mask"); 484 unsigned ScaleFactor = DestNumElts / SrcNumElts; 485 narrowShuffleMaskElts(ScaleFactor, Mask, NewMask); 486 } else { 487 // The bitcast is from narrow elements to wide elements. The shuffle mask 488 // must choose consecutive elements to allow casting first. 489 assert(SrcNumElts % DestNumElts == 0 && "Unexpected shuffle mask"); 490 unsigned ScaleFactor = SrcNumElts / DestNumElts; 491 if (!widenShuffleMaskElts(ScaleFactor, Mask, NewMask)) 492 return false; 493 } 494 // bitcast (shuf V, MaskC) --> shuf (bitcast V), MaskC' 495 ++NumShufOfBitcast; 496 Value *CastV = Builder.CreateBitCast(V, DestTy); 497 Value *Shuf = Builder.CreateShuffleVector(CastV, NewMask); 498 replaceValue(I, *Shuf); 499 return true; 500 } 501 502 /// Match a vector binop or compare instruction with at least one inserted 503 /// scalar operand and convert to scalar binop/cmp followed by insertelement. 504 bool VectorCombine::scalarizeBinopOrCmp(Instruction &I) { 505 CmpInst::Predicate Pred = CmpInst::BAD_ICMP_PREDICATE; 506 Value *Ins0, *Ins1; 507 if (!match(&I, m_BinOp(m_Value(Ins0), m_Value(Ins1))) && 508 !match(&I, m_Cmp(Pred, m_Value(Ins0), m_Value(Ins1)))) 509 return false; 510 511 // Do not convert the vector condition of a vector select into a scalar 512 // condition. That may cause problems for codegen because of differences in 513 // boolean formats and register-file transfers. 514 // TODO: Can we account for that in the cost model? 515 bool IsCmp = Pred != CmpInst::Predicate::BAD_ICMP_PREDICATE; 516 if (IsCmp) 517 for (User *U : I.users()) 518 if (match(U, m_Select(m_Specific(&I), m_Value(), m_Value()))) 519 return false; 520 521 // Match against one or both scalar values being inserted into constant 522 // vectors: 523 // vec_op VecC0, (inselt VecC1, V1, Index) 524 // vec_op (inselt VecC0, V0, Index), VecC1 525 // vec_op (inselt VecC0, V0, Index), (inselt VecC1, V1, Index) 526 // TODO: Deal with mismatched index constants and variable indexes? 527 Constant *VecC0 = nullptr, *VecC1 = nullptr; 528 Value *V0 = nullptr, *V1 = nullptr; 529 uint64_t Index0 = 0, Index1 = 0; 530 if (!match(Ins0, m_InsertElt(m_Constant(VecC0), m_Value(V0), 531 m_ConstantInt(Index0))) && 532 !match(Ins0, m_Constant(VecC0))) 533 return false; 534 if (!match(Ins1, m_InsertElt(m_Constant(VecC1), m_Value(V1), 535 m_ConstantInt(Index1))) && 536 !match(Ins1, m_Constant(VecC1))) 537 return false; 538 539 bool IsConst0 = !V0; 540 bool IsConst1 = !V1; 541 if (IsConst0 && IsConst1) 542 return false; 543 if (!IsConst0 && !IsConst1 && Index0 != Index1) 544 return false; 545 546 // Bail for single insertion if it is a load. 547 // TODO: Handle this once getVectorInstrCost can cost for load/stores. 548 auto *I0 = dyn_cast_or_null<Instruction>(V0); 549 auto *I1 = dyn_cast_or_null<Instruction>(V1); 550 if ((IsConst0 && I1 && I1->mayReadFromMemory()) || 551 (IsConst1 && I0 && I0->mayReadFromMemory())) 552 return false; 553 554 uint64_t Index = IsConst0 ? Index1 : Index0; 555 Type *ScalarTy = IsConst0 ? V1->getType() : V0->getType(); 556 Type *VecTy = I.getType(); 557 assert(VecTy->isVectorTy() && 558 (IsConst0 || IsConst1 || V0->getType() == V1->getType()) && 559 (ScalarTy->isIntegerTy() || ScalarTy->isFloatingPointTy() || 560 ScalarTy->isPointerTy()) && 561 "Unexpected types for insert element into binop or cmp"); 562 563 unsigned Opcode = I.getOpcode(); 564 int ScalarOpCost, VectorOpCost; 565 if (IsCmp) { 566 ScalarOpCost = TTI.getCmpSelInstrCost(Opcode, ScalarTy); 567 VectorOpCost = TTI.getCmpSelInstrCost(Opcode, VecTy); 568 } else { 569 ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, ScalarTy); 570 VectorOpCost = TTI.getArithmeticInstrCost(Opcode, VecTy); 571 } 572 573 // Get cost estimate for the insert element. This cost will factor into 574 // both sequences. 575 int InsertCost = 576 TTI.getVectorInstrCost(Instruction::InsertElement, VecTy, Index); 577 int OldCost = (IsConst0 ? 0 : InsertCost) + (IsConst1 ? 0 : InsertCost) + 578 VectorOpCost; 579 int NewCost = ScalarOpCost + InsertCost + 580 (IsConst0 ? 0 : !Ins0->hasOneUse() * InsertCost) + 581 (IsConst1 ? 0 : !Ins1->hasOneUse() * InsertCost); 582 583 // We want to scalarize unless the vector variant actually has lower cost. 584 if (OldCost < NewCost) 585 return false; 586 587 // vec_op (inselt VecC0, V0, Index), (inselt VecC1, V1, Index) --> 588 // inselt NewVecC, (scalar_op V0, V1), Index 589 if (IsCmp) 590 ++NumScalarCmp; 591 else 592 ++NumScalarBO; 593 594 // For constant cases, extract the scalar element, this should constant fold. 595 if (IsConst0) 596 V0 = ConstantExpr::getExtractElement(VecC0, Builder.getInt64(Index)); 597 if (IsConst1) 598 V1 = ConstantExpr::getExtractElement(VecC1, Builder.getInt64(Index)); 599 600 Value *Scalar = 601 IsCmp ? Builder.CreateCmp(Pred, V0, V1) 602 : Builder.CreateBinOp((Instruction::BinaryOps)Opcode, V0, V1); 603 604 Scalar->setName(I.getName() + ".scalar"); 605 606 // All IR flags are safe to back-propagate. There is no potential for extra 607 // poison to be created by the scalar instruction. 608 if (auto *ScalarInst = dyn_cast<Instruction>(Scalar)) 609 ScalarInst->copyIRFlags(&I); 610 611 // Fold the vector constants in the original vectors into a new base vector. 612 Constant *NewVecC = IsCmp ? ConstantExpr::getCompare(Pred, VecC0, VecC1) 613 : ConstantExpr::get(Opcode, VecC0, VecC1); 614 Value *Insert = Builder.CreateInsertElement(NewVecC, Scalar, Index); 615 replaceValue(I, *Insert); 616 return true; 617 } 618 619 /// Try to combine a scalar binop + 2 scalar compares of extracted elements of 620 /// a vector into vector operations followed by extract. Note: The SLP pass 621 /// may miss this pattern because of implementation problems. 622 bool VectorCombine::foldExtractedCmps(Instruction &I) { 623 // We are looking for a scalar binop of booleans. 624 // binop i1 (cmp Pred I0, C0), (cmp Pred I1, C1) 625 if (!I.isBinaryOp() || !I.getType()->isIntegerTy(1)) 626 return false; 627 628 // The compare predicates should match, and each compare should have a 629 // constant operand. 630 // TODO: Relax the one-use constraints. 631 Value *B0 = I.getOperand(0), *B1 = I.getOperand(1); 632 Instruction *I0, *I1; 633 Constant *C0, *C1; 634 CmpInst::Predicate P0, P1; 635 if (!match(B0, m_OneUse(m_Cmp(P0, m_Instruction(I0), m_Constant(C0)))) || 636 !match(B1, m_OneUse(m_Cmp(P1, m_Instruction(I1), m_Constant(C1)))) || 637 P0 != P1) 638 return false; 639 640 // The compare operands must be extracts of the same vector with constant 641 // extract indexes. 642 // TODO: Relax the one-use constraints. 643 Value *X; 644 uint64_t Index0, Index1; 645 if (!match(I0, m_OneUse(m_ExtractElt(m_Value(X), m_ConstantInt(Index0)))) || 646 !match(I1, m_OneUse(m_ExtractElt(m_Specific(X), m_ConstantInt(Index1))))) 647 return false; 648 649 auto *Ext0 = cast<ExtractElementInst>(I0); 650 auto *Ext1 = cast<ExtractElementInst>(I1); 651 ExtractElementInst *ConvertToShuf = getShuffleExtract(Ext0, Ext1); 652 if (!ConvertToShuf) 653 return false; 654 655 // The original scalar pattern is: 656 // binop i1 (cmp Pred (ext X, Index0), C0), (cmp Pred (ext X, Index1), C1) 657 CmpInst::Predicate Pred = P0; 658 unsigned CmpOpcode = CmpInst::isFPPredicate(Pred) ? Instruction::FCmp 659 : Instruction::ICmp; 660 auto *VecTy = dyn_cast<FixedVectorType>(X->getType()); 661 if (!VecTy) 662 return false; 663 664 int OldCost = TTI.getVectorInstrCost(Ext0->getOpcode(), VecTy, Index0); 665 OldCost += TTI.getVectorInstrCost(Ext1->getOpcode(), VecTy, Index1); 666 OldCost += TTI.getCmpSelInstrCost(CmpOpcode, I0->getType()) * 2; 667 OldCost += TTI.getArithmeticInstrCost(I.getOpcode(), I.getType()); 668 669 // The proposed vector pattern is: 670 // vcmp = cmp Pred X, VecC 671 // ext (binop vNi1 vcmp, (shuffle vcmp, Index1)), Index0 672 int CheapIndex = ConvertToShuf == Ext0 ? Index1 : Index0; 673 int ExpensiveIndex = ConvertToShuf == Ext0 ? Index0 : Index1; 674 auto *CmpTy = cast<FixedVectorType>(CmpInst::makeCmpResultType(X->getType())); 675 int NewCost = TTI.getCmpSelInstrCost(CmpOpcode, X->getType()); 676 NewCost += 677 TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, CmpTy); 678 NewCost += TTI.getArithmeticInstrCost(I.getOpcode(), CmpTy); 679 NewCost += TTI.getVectorInstrCost(Ext0->getOpcode(), CmpTy, CheapIndex); 680 681 // Aggressively form vector ops if the cost is equal because the transform 682 // may enable further optimization. 683 // Codegen can reverse this transform (scalarize) if it was not profitable. 684 if (OldCost < NewCost) 685 return false; 686 687 // Create a vector constant from the 2 scalar constants. 688 SmallVector<Constant *, 32> CmpC(VecTy->getNumElements(), 689 UndefValue::get(VecTy->getElementType())); 690 CmpC[Index0] = C0; 691 CmpC[Index1] = C1; 692 Value *VCmp = Builder.CreateCmp(Pred, X, ConstantVector::get(CmpC)); 693 694 Value *Shuf = createShiftShuffle(VCmp, ExpensiveIndex, CheapIndex, Builder); 695 Value *VecLogic = Builder.CreateBinOp(cast<BinaryOperator>(I).getOpcode(), 696 VCmp, Shuf); 697 Value *NewExt = Builder.CreateExtractElement(VecLogic, CheapIndex); 698 replaceValue(I, *NewExt); 699 ++NumVecCmpBO; 700 return true; 701 } 702 703 /// This is the entry point for all transforms. Pass manager differences are 704 /// handled in the callers of this function. 705 bool VectorCombine::run() { 706 if (DisableVectorCombine) 707 return false; 708 709 // Don't attempt vectorization if the target does not support vectors. 710 if (!TTI.getNumberOfRegisters(TTI.getRegisterClassForType(/*Vector*/ true))) 711 return false; 712 713 bool MadeChange = false; 714 for (BasicBlock &BB : F) { 715 // Ignore unreachable basic blocks. 716 if (!DT.isReachableFromEntry(&BB)) 717 continue; 718 // Do not delete instructions under here and invalidate the iterator. 719 // Walk the block forwards to enable simple iterative chains of transforms. 720 // TODO: It could be more efficient to remove dead instructions 721 // iteratively in this loop rather than waiting until the end. 722 for (Instruction &I : BB) { 723 if (isa<DbgInfoIntrinsic>(I)) 724 continue; 725 Builder.SetInsertPoint(&I); 726 MadeChange |= vectorizeLoadInsert(I); 727 MadeChange |= foldExtractExtract(I); 728 MadeChange |= foldBitcastShuf(I); 729 MadeChange |= scalarizeBinopOrCmp(I); 730 MadeChange |= foldExtractedCmps(I); 731 } 732 } 733 734 // We're done with transforms, so remove dead instructions. 735 if (MadeChange) 736 for (BasicBlock &BB : F) 737 SimplifyInstructionsInBlock(&BB); 738 739 return MadeChange; 740 } 741 742 // Pass manager boilerplate below here. 743 744 namespace { 745 class VectorCombineLegacyPass : public FunctionPass { 746 public: 747 static char ID; 748 VectorCombineLegacyPass() : FunctionPass(ID) { 749 initializeVectorCombineLegacyPassPass(*PassRegistry::getPassRegistry()); 750 } 751 752 void getAnalysisUsage(AnalysisUsage &AU) const override { 753 AU.addRequired<DominatorTreeWrapperPass>(); 754 AU.addRequired<TargetTransformInfoWrapperPass>(); 755 AU.setPreservesCFG(); 756 AU.addPreserved<DominatorTreeWrapperPass>(); 757 AU.addPreserved<GlobalsAAWrapperPass>(); 758 AU.addPreserved<AAResultsWrapperPass>(); 759 AU.addPreserved<BasicAAWrapperPass>(); 760 FunctionPass::getAnalysisUsage(AU); 761 } 762 763 bool runOnFunction(Function &F) override { 764 if (skipFunction(F)) 765 return false; 766 auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 767 auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); 768 VectorCombine Combiner(F, TTI, DT); 769 return Combiner.run(); 770 } 771 }; 772 } // namespace 773 774 char VectorCombineLegacyPass::ID = 0; 775 INITIALIZE_PASS_BEGIN(VectorCombineLegacyPass, "vector-combine", 776 "Optimize scalar/vector ops", false, 777 false) 778 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) 779 INITIALIZE_PASS_END(VectorCombineLegacyPass, "vector-combine", 780 "Optimize scalar/vector ops", false, false) 781 Pass *llvm::createVectorCombinePass() { 782 return new VectorCombineLegacyPass(); 783 } 784 785 PreservedAnalyses VectorCombinePass::run(Function &F, 786 FunctionAnalysisManager &FAM) { 787 TargetTransformInfo &TTI = FAM.getResult<TargetIRAnalysis>(F); 788 DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F); 789 VectorCombine Combiner(F, TTI, DT); 790 if (!Combiner.run()) 791 return PreservedAnalyses::all(); 792 PreservedAnalyses PA; 793 PA.preserveSet<CFGAnalyses>(); 794 PA.preserve<GlobalsAA>(); 795 PA.preserve<AAManager>(); 796 PA.preserve<BasicAA>(); 797 return PA; 798 } 799