1 //===------- VectorCombine.cpp - Optimize partial vector operations -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This pass optimizes scalar/vector interactions using target cost models. The 10 // transforms implemented here may not fit in traditional loop-based or SLP 11 // vectorization passes. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "llvm/Transforms/Vectorize/VectorCombine.h" 16 #include "llvm/ADT/Statistic.h" 17 #include "llvm/Analysis/AssumptionCache.h" 18 #include "llvm/Analysis/BasicAliasAnalysis.h" 19 #include "llvm/Analysis/GlobalsModRef.h" 20 #include "llvm/Analysis/Loads.h" 21 #include "llvm/Analysis/TargetTransformInfo.h" 22 #include "llvm/Analysis/ValueTracking.h" 23 #include "llvm/Analysis/VectorUtils.h" 24 #include "llvm/IR/Dominators.h" 25 #include "llvm/IR/Function.h" 26 #include "llvm/IR/IRBuilder.h" 27 #include "llvm/IR/PatternMatch.h" 28 #include "llvm/InitializePasses.h" 29 #include "llvm/Pass.h" 30 #include "llvm/Support/CommandLine.h" 31 #include "llvm/Transforms/Utils/Local.h" 32 #include "llvm/Transforms/Vectorize.h" 33 34 #define DEBUG_TYPE "vector-combine" 35 #include "llvm/Transforms/Utils/InstructionWorklist.h" 36 37 using namespace llvm; 38 using namespace llvm::PatternMatch; 39 40 STATISTIC(NumVecLoad, "Number of vector loads formed"); 41 STATISTIC(NumVecCmp, "Number of vector compares formed"); 42 STATISTIC(NumVecBO, "Number of vector binops formed"); 43 STATISTIC(NumVecCmpBO, "Number of vector compare + binop formed"); 44 STATISTIC(NumShufOfBitcast, "Number of shuffles moved after bitcast"); 45 STATISTIC(NumScalarBO, "Number of scalar binops formed"); 46 STATISTIC(NumScalarCmp, "Number of scalar compares formed"); 47 48 static cl::opt<bool> DisableVectorCombine( 49 "disable-vector-combine", cl::init(false), cl::Hidden, 50 cl::desc("Disable all vector combine transforms")); 51 52 static cl::opt<bool> DisableBinopExtractShuffle( 53 "disable-binop-extract-shuffle", cl::init(false), cl::Hidden, 54 cl::desc("Disable binop extract to shuffle transforms")); 55 56 static cl::opt<unsigned> MaxInstrsToScan( 57 "vector-combine-max-scan-instrs", cl::init(30), cl::Hidden, 58 cl::desc("Max number of instructions to scan for vector combining.")); 59 60 static const unsigned InvalidIndex = std::numeric_limits<unsigned>::max(); 61 62 namespace { 63 class VectorCombine { 64 public: 65 VectorCombine(Function &F, const TargetTransformInfo &TTI, 66 const DominatorTree &DT, AAResults &AA, AssumptionCache &AC) 67 : F(F), Builder(F.getContext()), TTI(TTI), DT(DT), AA(AA), AC(AC) {} 68 69 bool run(); 70 71 private: 72 Function &F; 73 IRBuilder<> Builder; 74 const TargetTransformInfo &TTI; 75 const DominatorTree &DT; 76 AAResults &AA; 77 AssumptionCache &AC; 78 InstructionWorklist Worklist; 79 80 bool vectorizeLoadInsert(Instruction &I); 81 ExtractElementInst *getShuffleExtract(ExtractElementInst *Ext0, 82 ExtractElementInst *Ext1, 83 unsigned PreferredExtractIndex) const; 84 bool isExtractExtractCheap(ExtractElementInst *Ext0, ExtractElementInst *Ext1, 85 unsigned Opcode, 86 ExtractElementInst *&ConvertToShuffle, 87 unsigned PreferredExtractIndex); 88 void foldExtExtCmp(ExtractElementInst *Ext0, ExtractElementInst *Ext1, 89 Instruction &I); 90 void foldExtExtBinop(ExtractElementInst *Ext0, ExtractElementInst *Ext1, 91 Instruction &I); 92 bool foldExtractExtract(Instruction &I); 93 bool foldBitcastShuf(Instruction &I); 94 bool scalarizeBinopOrCmp(Instruction &I); 95 bool foldExtractedCmps(Instruction &I); 96 bool foldSingleElementStore(Instruction &I); 97 bool scalarizeLoadExtract(Instruction &I); 98 99 void replaceValue(Value &Old, Value &New) { 100 Old.replaceAllUsesWith(&New); 101 New.takeName(&Old); 102 if (auto *NewI = dyn_cast<Instruction>(&New)) { 103 Worklist.pushUsersToWorkList(*NewI); 104 Worklist.pushValue(NewI); 105 } 106 Worklist.pushValue(&Old); 107 } 108 109 void eraseInstruction(Instruction &I) { 110 for (Value *Op : I.operands()) 111 Worklist.pushValue(Op); 112 Worklist.remove(&I); 113 I.eraseFromParent(); 114 } 115 }; 116 } // namespace 117 118 bool VectorCombine::vectorizeLoadInsert(Instruction &I) { 119 // Match insert into fixed vector of scalar value. 120 // TODO: Handle non-zero insert index. 121 auto *Ty = dyn_cast<FixedVectorType>(I.getType()); 122 Value *Scalar; 123 if (!Ty || !match(&I, m_InsertElt(m_Undef(), m_Value(Scalar), m_ZeroInt())) || 124 !Scalar->hasOneUse()) 125 return false; 126 127 // Optionally match an extract from another vector. 128 Value *X; 129 bool HasExtract = match(Scalar, m_ExtractElt(m_Value(X), m_ZeroInt())); 130 if (!HasExtract) 131 X = Scalar; 132 133 // Match source value as load of scalar or vector. 134 // Do not vectorize scalar load (widening) if atomic/volatile or under 135 // asan/hwasan/memtag/tsan. The widened load may load data from dirty regions 136 // or create data races non-existent in the source. 137 auto *Load = dyn_cast<LoadInst>(X); 138 if (!Load || !Load->isSimple() || !Load->hasOneUse() || 139 Load->getFunction()->hasFnAttribute(Attribute::SanitizeMemTag) || 140 mustSuppressSpeculation(*Load)) 141 return false; 142 143 const DataLayout &DL = I.getModule()->getDataLayout(); 144 Value *SrcPtr = Load->getPointerOperand()->stripPointerCasts(); 145 assert(isa<PointerType>(SrcPtr->getType()) && "Expected a pointer type"); 146 147 // If original AS != Load's AS, we can't bitcast the original pointer and have 148 // to use Load's operand instead. Ideally we would want to strip pointer casts 149 // without changing AS, but there's no API to do that ATM. 150 unsigned AS = Load->getPointerAddressSpace(); 151 if (AS != SrcPtr->getType()->getPointerAddressSpace()) 152 SrcPtr = Load->getPointerOperand(); 153 154 // We are potentially transforming byte-sized (8-bit) memory accesses, so make 155 // sure we have all of our type-based constraints in place for this target. 156 Type *ScalarTy = Scalar->getType(); 157 uint64_t ScalarSize = ScalarTy->getPrimitiveSizeInBits(); 158 unsigned MinVectorSize = TTI.getMinVectorRegisterBitWidth(); 159 if (!ScalarSize || !MinVectorSize || MinVectorSize % ScalarSize != 0 || 160 ScalarSize % 8 != 0) 161 return false; 162 163 // Check safety of replacing the scalar load with a larger vector load. 164 // We use minimal alignment (maximum flexibility) because we only care about 165 // the dereferenceable region. When calculating cost and creating a new op, 166 // we may use a larger value based on alignment attributes. 167 unsigned MinVecNumElts = MinVectorSize / ScalarSize; 168 auto *MinVecTy = VectorType::get(ScalarTy, MinVecNumElts, false); 169 unsigned OffsetEltIndex = 0; 170 Align Alignment = Load->getAlign(); 171 if (!isSafeToLoadUnconditionally(SrcPtr, MinVecTy, Align(1), DL, Load, &DT)) { 172 // It is not safe to load directly from the pointer, but we can still peek 173 // through gep offsets and check if it safe to load from a base address with 174 // updated alignment. If it is, we can shuffle the element(s) into place 175 // after loading. 176 unsigned OffsetBitWidth = DL.getIndexTypeSizeInBits(SrcPtr->getType()); 177 APInt Offset(OffsetBitWidth, 0); 178 SrcPtr = SrcPtr->stripAndAccumulateInBoundsConstantOffsets(DL, Offset); 179 180 // We want to shuffle the result down from a high element of a vector, so 181 // the offset must be positive. 182 if (Offset.isNegative()) 183 return false; 184 185 // The offset must be a multiple of the scalar element to shuffle cleanly 186 // in the element's size. 187 uint64_t ScalarSizeInBytes = ScalarSize / 8; 188 if (Offset.urem(ScalarSizeInBytes) != 0) 189 return false; 190 191 // If we load MinVecNumElts, will our target element still be loaded? 192 OffsetEltIndex = Offset.udiv(ScalarSizeInBytes).getZExtValue(); 193 if (OffsetEltIndex >= MinVecNumElts) 194 return false; 195 196 if (!isSafeToLoadUnconditionally(SrcPtr, MinVecTy, Align(1), DL, Load, &DT)) 197 return false; 198 199 // Update alignment with offset value. Note that the offset could be negated 200 // to more accurately represent "(new) SrcPtr - Offset = (old) SrcPtr", but 201 // negation does not change the result of the alignment calculation. 202 Alignment = commonAlignment(Alignment, Offset.getZExtValue()); 203 } 204 205 // Original pattern: insertelt undef, load [free casts of] PtrOp, 0 206 // Use the greater of the alignment on the load or its source pointer. 207 Alignment = std::max(SrcPtr->getPointerAlignment(DL), Alignment); 208 Type *LoadTy = Load->getType(); 209 InstructionCost OldCost = 210 TTI.getMemoryOpCost(Instruction::Load, LoadTy, Alignment, AS); 211 APInt DemandedElts = APInt::getOneBitSet(MinVecNumElts, 0); 212 OldCost += TTI.getScalarizationOverhead(MinVecTy, DemandedElts, 213 /* Insert */ true, HasExtract); 214 215 // New pattern: load VecPtr 216 InstructionCost NewCost = 217 TTI.getMemoryOpCost(Instruction::Load, MinVecTy, Alignment, AS); 218 // Optionally, we are shuffling the loaded vector element(s) into place. 219 // For the mask set everything but element 0 to undef to prevent poison from 220 // propagating from the extra loaded memory. This will also optionally 221 // shrink/grow the vector from the loaded size to the output size. 222 // We assume this operation has no cost in codegen if there was no offset. 223 // Note that we could use freeze to avoid poison problems, but then we might 224 // still need a shuffle to change the vector size. 225 unsigned OutputNumElts = Ty->getNumElements(); 226 SmallVector<int, 16> Mask(OutputNumElts, UndefMaskElem); 227 assert(OffsetEltIndex < MinVecNumElts && "Address offset too big"); 228 Mask[0] = OffsetEltIndex; 229 if (OffsetEltIndex) 230 NewCost += TTI.getShuffleCost(TTI::SK_PermuteSingleSrc, MinVecTy, Mask); 231 232 // We can aggressively convert to the vector form because the backend can 233 // invert this transform if it does not result in a performance win. 234 if (OldCost < NewCost || !NewCost.isValid()) 235 return false; 236 237 // It is safe and potentially profitable to load a vector directly: 238 // inselt undef, load Scalar, 0 --> load VecPtr 239 IRBuilder<> Builder(Load); 240 Value *CastedPtr = Builder.CreateBitCast(SrcPtr, MinVecTy->getPointerTo(AS)); 241 Value *VecLd = Builder.CreateAlignedLoad(MinVecTy, CastedPtr, Alignment); 242 VecLd = Builder.CreateShuffleVector(VecLd, Mask); 243 244 replaceValue(I, *VecLd); 245 ++NumVecLoad; 246 return true; 247 } 248 249 /// Determine which, if any, of the inputs should be replaced by a shuffle 250 /// followed by extract from a different index. 251 ExtractElementInst *VectorCombine::getShuffleExtract( 252 ExtractElementInst *Ext0, ExtractElementInst *Ext1, 253 unsigned PreferredExtractIndex = InvalidIndex) const { 254 assert(isa<ConstantInt>(Ext0->getIndexOperand()) && 255 isa<ConstantInt>(Ext1->getIndexOperand()) && 256 "Expected constant extract indexes"); 257 258 unsigned Index0 = cast<ConstantInt>(Ext0->getIndexOperand())->getZExtValue(); 259 unsigned Index1 = cast<ConstantInt>(Ext1->getIndexOperand())->getZExtValue(); 260 261 // If the extract indexes are identical, no shuffle is needed. 262 if (Index0 == Index1) 263 return nullptr; 264 265 Type *VecTy = Ext0->getVectorOperand()->getType(); 266 assert(VecTy == Ext1->getVectorOperand()->getType() && "Need matching types"); 267 InstructionCost Cost0 = 268 TTI.getVectorInstrCost(Ext0->getOpcode(), VecTy, Index0); 269 InstructionCost Cost1 = 270 TTI.getVectorInstrCost(Ext1->getOpcode(), VecTy, Index1); 271 272 // If both costs are invalid no shuffle is needed 273 if (!Cost0.isValid() && !Cost1.isValid()) 274 return nullptr; 275 276 // We are extracting from 2 different indexes, so one operand must be shuffled 277 // before performing a vector operation and/or extract. The more expensive 278 // extract will be replaced by a shuffle. 279 if (Cost0 > Cost1) 280 return Ext0; 281 if (Cost1 > Cost0) 282 return Ext1; 283 284 // If the costs are equal and there is a preferred extract index, shuffle the 285 // opposite operand. 286 if (PreferredExtractIndex == Index0) 287 return Ext1; 288 if (PreferredExtractIndex == Index1) 289 return Ext0; 290 291 // Otherwise, replace the extract with the higher index. 292 return Index0 > Index1 ? Ext0 : Ext1; 293 } 294 295 /// Compare the relative costs of 2 extracts followed by scalar operation vs. 296 /// vector operation(s) followed by extract. Return true if the existing 297 /// instructions are cheaper than a vector alternative. Otherwise, return false 298 /// and if one of the extracts should be transformed to a shufflevector, set 299 /// \p ConvertToShuffle to that extract instruction. 300 bool VectorCombine::isExtractExtractCheap(ExtractElementInst *Ext0, 301 ExtractElementInst *Ext1, 302 unsigned Opcode, 303 ExtractElementInst *&ConvertToShuffle, 304 unsigned PreferredExtractIndex) { 305 assert(isa<ConstantInt>(Ext0->getOperand(1)) && 306 isa<ConstantInt>(Ext1->getOperand(1)) && 307 "Expected constant extract indexes"); 308 Type *ScalarTy = Ext0->getType(); 309 auto *VecTy = cast<VectorType>(Ext0->getOperand(0)->getType()); 310 InstructionCost ScalarOpCost, VectorOpCost; 311 312 // Get cost estimates for scalar and vector versions of the operation. 313 bool IsBinOp = Instruction::isBinaryOp(Opcode); 314 if (IsBinOp) { 315 ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, ScalarTy); 316 VectorOpCost = TTI.getArithmeticInstrCost(Opcode, VecTy); 317 } else { 318 assert((Opcode == Instruction::ICmp || Opcode == Instruction::FCmp) && 319 "Expected a compare"); 320 ScalarOpCost = TTI.getCmpSelInstrCost(Opcode, ScalarTy, 321 CmpInst::makeCmpResultType(ScalarTy)); 322 VectorOpCost = TTI.getCmpSelInstrCost(Opcode, VecTy, 323 CmpInst::makeCmpResultType(VecTy)); 324 } 325 326 // Get cost estimates for the extract elements. These costs will factor into 327 // both sequences. 328 unsigned Ext0Index = cast<ConstantInt>(Ext0->getOperand(1))->getZExtValue(); 329 unsigned Ext1Index = cast<ConstantInt>(Ext1->getOperand(1))->getZExtValue(); 330 331 InstructionCost Extract0Cost = 332 TTI.getVectorInstrCost(Instruction::ExtractElement, VecTy, Ext0Index); 333 InstructionCost Extract1Cost = 334 TTI.getVectorInstrCost(Instruction::ExtractElement, VecTy, Ext1Index); 335 336 // A more expensive extract will always be replaced by a splat shuffle. 337 // For example, if Ext0 is more expensive: 338 // opcode (extelt V0, Ext0), (ext V1, Ext1) --> 339 // extelt (opcode (splat V0, Ext0), V1), Ext1 340 // TODO: Evaluate whether that always results in lowest cost. Alternatively, 341 // check the cost of creating a broadcast shuffle and shuffling both 342 // operands to element 0. 343 InstructionCost CheapExtractCost = std::min(Extract0Cost, Extract1Cost); 344 345 // Extra uses of the extracts mean that we include those costs in the 346 // vector total because those instructions will not be eliminated. 347 InstructionCost OldCost, NewCost; 348 if (Ext0->getOperand(0) == Ext1->getOperand(0) && Ext0Index == Ext1Index) { 349 // Handle a special case. If the 2 extracts are identical, adjust the 350 // formulas to account for that. The extra use charge allows for either the 351 // CSE'd pattern or an unoptimized form with identical values: 352 // opcode (extelt V, C), (extelt V, C) --> extelt (opcode V, V), C 353 bool HasUseTax = Ext0 == Ext1 ? !Ext0->hasNUses(2) 354 : !Ext0->hasOneUse() || !Ext1->hasOneUse(); 355 OldCost = CheapExtractCost + ScalarOpCost; 356 NewCost = VectorOpCost + CheapExtractCost + HasUseTax * CheapExtractCost; 357 } else { 358 // Handle the general case. Each extract is actually a different value: 359 // opcode (extelt V0, C0), (extelt V1, C1) --> extelt (opcode V0, V1), C 360 OldCost = Extract0Cost + Extract1Cost + ScalarOpCost; 361 NewCost = VectorOpCost + CheapExtractCost + 362 !Ext0->hasOneUse() * Extract0Cost + 363 !Ext1->hasOneUse() * Extract1Cost; 364 } 365 366 ConvertToShuffle = getShuffleExtract(Ext0, Ext1, PreferredExtractIndex); 367 if (ConvertToShuffle) { 368 if (IsBinOp && DisableBinopExtractShuffle) 369 return true; 370 371 // If we are extracting from 2 different indexes, then one operand must be 372 // shuffled before performing the vector operation. The shuffle mask is 373 // undefined except for 1 lane that is being translated to the remaining 374 // extraction lane. Therefore, it is a splat shuffle. Ex: 375 // ShufMask = { undef, undef, 0, undef } 376 // TODO: The cost model has an option for a "broadcast" shuffle 377 // (splat-from-element-0), but no option for a more general splat. 378 NewCost += 379 TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, VecTy); 380 } 381 382 // Aggressively form a vector op if the cost is equal because the transform 383 // may enable further optimization. 384 // Codegen can reverse this transform (scalarize) if it was not profitable. 385 return OldCost < NewCost; 386 } 387 388 /// Create a shuffle that translates (shifts) 1 element from the input vector 389 /// to a new element location. 390 static Value *createShiftShuffle(Value *Vec, unsigned OldIndex, 391 unsigned NewIndex, IRBuilder<> &Builder) { 392 // The shuffle mask is undefined except for 1 lane that is being translated 393 // to the new element index. Example for OldIndex == 2 and NewIndex == 0: 394 // ShufMask = { 2, undef, undef, undef } 395 auto *VecTy = cast<FixedVectorType>(Vec->getType()); 396 SmallVector<int, 32> ShufMask(VecTy->getNumElements(), UndefMaskElem); 397 ShufMask[NewIndex] = OldIndex; 398 return Builder.CreateShuffleVector(Vec, ShufMask, "shift"); 399 } 400 401 /// Given an extract element instruction with constant index operand, shuffle 402 /// the source vector (shift the scalar element) to a NewIndex for extraction. 403 /// Return null if the input can be constant folded, so that we are not creating 404 /// unnecessary instructions. 405 static ExtractElementInst *translateExtract(ExtractElementInst *ExtElt, 406 unsigned NewIndex, 407 IRBuilder<> &Builder) { 408 // If the extract can be constant-folded, this code is unsimplified. Defer 409 // to other passes to handle that. 410 Value *X = ExtElt->getVectorOperand(); 411 Value *C = ExtElt->getIndexOperand(); 412 assert(isa<ConstantInt>(C) && "Expected a constant index operand"); 413 if (isa<Constant>(X)) 414 return nullptr; 415 416 Value *Shuf = createShiftShuffle(X, cast<ConstantInt>(C)->getZExtValue(), 417 NewIndex, Builder); 418 return cast<ExtractElementInst>(Builder.CreateExtractElement(Shuf, NewIndex)); 419 } 420 421 /// Try to reduce extract element costs by converting scalar compares to vector 422 /// compares followed by extract. 423 /// cmp (ext0 V0, C), (ext1 V1, C) 424 void VectorCombine::foldExtExtCmp(ExtractElementInst *Ext0, 425 ExtractElementInst *Ext1, Instruction &I) { 426 assert(isa<CmpInst>(&I) && "Expected a compare"); 427 assert(cast<ConstantInt>(Ext0->getIndexOperand())->getZExtValue() == 428 cast<ConstantInt>(Ext1->getIndexOperand())->getZExtValue() && 429 "Expected matching constant extract indexes"); 430 431 // cmp Pred (extelt V0, C), (extelt V1, C) --> extelt (cmp Pred V0, V1), C 432 ++NumVecCmp; 433 CmpInst::Predicate Pred = cast<CmpInst>(&I)->getPredicate(); 434 Value *V0 = Ext0->getVectorOperand(), *V1 = Ext1->getVectorOperand(); 435 Value *VecCmp = Builder.CreateCmp(Pred, V0, V1); 436 Value *NewExt = Builder.CreateExtractElement(VecCmp, Ext0->getIndexOperand()); 437 replaceValue(I, *NewExt); 438 } 439 440 /// Try to reduce extract element costs by converting scalar binops to vector 441 /// binops followed by extract. 442 /// bo (ext0 V0, C), (ext1 V1, C) 443 void VectorCombine::foldExtExtBinop(ExtractElementInst *Ext0, 444 ExtractElementInst *Ext1, Instruction &I) { 445 assert(isa<BinaryOperator>(&I) && "Expected a binary operator"); 446 assert(cast<ConstantInt>(Ext0->getIndexOperand())->getZExtValue() == 447 cast<ConstantInt>(Ext1->getIndexOperand())->getZExtValue() && 448 "Expected matching constant extract indexes"); 449 450 // bo (extelt V0, C), (extelt V1, C) --> extelt (bo V0, V1), C 451 ++NumVecBO; 452 Value *V0 = Ext0->getVectorOperand(), *V1 = Ext1->getVectorOperand(); 453 Value *VecBO = 454 Builder.CreateBinOp(cast<BinaryOperator>(&I)->getOpcode(), V0, V1); 455 456 // All IR flags are safe to back-propagate because any potential poison 457 // created in unused vector elements is discarded by the extract. 458 if (auto *VecBOInst = dyn_cast<Instruction>(VecBO)) 459 VecBOInst->copyIRFlags(&I); 460 461 Value *NewExt = Builder.CreateExtractElement(VecBO, Ext0->getIndexOperand()); 462 replaceValue(I, *NewExt); 463 } 464 465 /// Match an instruction with extracted vector operands. 466 bool VectorCombine::foldExtractExtract(Instruction &I) { 467 // It is not safe to transform things like div, urem, etc. because we may 468 // create undefined behavior when executing those on unknown vector elements. 469 if (!isSafeToSpeculativelyExecute(&I)) 470 return false; 471 472 Instruction *I0, *I1; 473 CmpInst::Predicate Pred = CmpInst::BAD_ICMP_PREDICATE; 474 if (!match(&I, m_Cmp(Pred, m_Instruction(I0), m_Instruction(I1))) && 475 !match(&I, m_BinOp(m_Instruction(I0), m_Instruction(I1)))) 476 return false; 477 478 Value *V0, *V1; 479 uint64_t C0, C1; 480 if (!match(I0, m_ExtractElt(m_Value(V0), m_ConstantInt(C0))) || 481 !match(I1, m_ExtractElt(m_Value(V1), m_ConstantInt(C1))) || 482 V0->getType() != V1->getType()) 483 return false; 484 485 // If the scalar value 'I' is going to be re-inserted into a vector, then try 486 // to create an extract to that same element. The extract/insert can be 487 // reduced to a "select shuffle". 488 // TODO: If we add a larger pattern match that starts from an insert, this 489 // probably becomes unnecessary. 490 auto *Ext0 = cast<ExtractElementInst>(I0); 491 auto *Ext1 = cast<ExtractElementInst>(I1); 492 uint64_t InsertIndex = InvalidIndex; 493 if (I.hasOneUse()) 494 match(I.user_back(), 495 m_InsertElt(m_Value(), m_Value(), m_ConstantInt(InsertIndex))); 496 497 ExtractElementInst *ExtractToChange; 498 if (isExtractExtractCheap(Ext0, Ext1, I.getOpcode(), ExtractToChange, 499 InsertIndex)) 500 return false; 501 502 if (ExtractToChange) { 503 unsigned CheapExtractIdx = ExtractToChange == Ext0 ? C1 : C0; 504 ExtractElementInst *NewExtract = 505 translateExtract(ExtractToChange, CheapExtractIdx, Builder); 506 if (!NewExtract) 507 return false; 508 if (ExtractToChange == Ext0) 509 Ext0 = NewExtract; 510 else 511 Ext1 = NewExtract; 512 } 513 514 if (Pred != CmpInst::BAD_ICMP_PREDICATE) 515 foldExtExtCmp(Ext0, Ext1, I); 516 else 517 foldExtExtBinop(Ext0, Ext1, I); 518 519 Worklist.push(Ext0); 520 Worklist.push(Ext1); 521 return true; 522 } 523 524 /// If this is a bitcast of a shuffle, try to bitcast the source vector to the 525 /// destination type followed by shuffle. This can enable further transforms by 526 /// moving bitcasts or shuffles together. 527 bool VectorCombine::foldBitcastShuf(Instruction &I) { 528 Value *V; 529 ArrayRef<int> Mask; 530 if (!match(&I, m_BitCast( 531 m_OneUse(m_Shuffle(m_Value(V), m_Undef(), m_Mask(Mask)))))) 532 return false; 533 534 // 1) Do not fold bitcast shuffle for scalable type. First, shuffle cost for 535 // scalable type is unknown; Second, we cannot reason if the narrowed shuffle 536 // mask for scalable type is a splat or not. 537 // 2) Disallow non-vector casts and length-changing shuffles. 538 // TODO: We could allow any shuffle. 539 auto *DestTy = dyn_cast<FixedVectorType>(I.getType()); 540 auto *SrcTy = dyn_cast<FixedVectorType>(V->getType()); 541 if (!SrcTy || !DestTy || I.getOperand(0)->getType() != SrcTy) 542 return false; 543 544 unsigned DestNumElts = DestTy->getNumElements(); 545 unsigned SrcNumElts = SrcTy->getNumElements(); 546 SmallVector<int, 16> NewMask; 547 if (SrcNumElts <= DestNumElts) { 548 // The bitcast is from wide to narrow/equal elements. The shuffle mask can 549 // always be expanded to the equivalent form choosing narrower elements. 550 assert(DestNumElts % SrcNumElts == 0 && "Unexpected shuffle mask"); 551 unsigned ScaleFactor = DestNumElts / SrcNumElts; 552 narrowShuffleMaskElts(ScaleFactor, Mask, NewMask); 553 } else { 554 // The bitcast is from narrow elements to wide elements. The shuffle mask 555 // must choose consecutive elements to allow casting first. 556 assert(SrcNumElts % DestNumElts == 0 && "Unexpected shuffle mask"); 557 unsigned ScaleFactor = SrcNumElts / DestNumElts; 558 if (!widenShuffleMaskElts(ScaleFactor, Mask, NewMask)) 559 return false; 560 } 561 562 // The new shuffle must not cost more than the old shuffle. The bitcast is 563 // moved ahead of the shuffle, so assume that it has the same cost as before. 564 InstructionCost DestCost = TTI.getShuffleCost( 565 TargetTransformInfo::SK_PermuteSingleSrc, DestTy, NewMask); 566 InstructionCost SrcCost = 567 TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, SrcTy, Mask); 568 if (DestCost > SrcCost || !DestCost.isValid()) 569 return false; 570 571 // bitcast (shuf V, MaskC) --> shuf (bitcast V), MaskC' 572 ++NumShufOfBitcast; 573 Value *CastV = Builder.CreateBitCast(V, DestTy); 574 Value *Shuf = Builder.CreateShuffleVector(CastV, NewMask); 575 replaceValue(I, *Shuf); 576 return true; 577 } 578 579 /// Match a vector binop or compare instruction with at least one inserted 580 /// scalar operand and convert to scalar binop/cmp followed by insertelement. 581 bool VectorCombine::scalarizeBinopOrCmp(Instruction &I) { 582 CmpInst::Predicate Pred = CmpInst::BAD_ICMP_PREDICATE; 583 Value *Ins0, *Ins1; 584 if (!match(&I, m_BinOp(m_Value(Ins0), m_Value(Ins1))) && 585 !match(&I, m_Cmp(Pred, m_Value(Ins0), m_Value(Ins1)))) 586 return false; 587 588 // Do not convert the vector condition of a vector select into a scalar 589 // condition. That may cause problems for codegen because of differences in 590 // boolean formats and register-file transfers. 591 // TODO: Can we account for that in the cost model? 592 bool IsCmp = Pred != CmpInst::Predicate::BAD_ICMP_PREDICATE; 593 if (IsCmp) 594 for (User *U : I.users()) 595 if (match(U, m_Select(m_Specific(&I), m_Value(), m_Value()))) 596 return false; 597 598 // Match against one or both scalar values being inserted into constant 599 // vectors: 600 // vec_op VecC0, (inselt VecC1, V1, Index) 601 // vec_op (inselt VecC0, V0, Index), VecC1 602 // vec_op (inselt VecC0, V0, Index), (inselt VecC1, V1, Index) 603 // TODO: Deal with mismatched index constants and variable indexes? 604 Constant *VecC0 = nullptr, *VecC1 = nullptr; 605 Value *V0 = nullptr, *V1 = nullptr; 606 uint64_t Index0 = 0, Index1 = 0; 607 if (!match(Ins0, m_InsertElt(m_Constant(VecC0), m_Value(V0), 608 m_ConstantInt(Index0))) && 609 !match(Ins0, m_Constant(VecC0))) 610 return false; 611 if (!match(Ins1, m_InsertElt(m_Constant(VecC1), m_Value(V1), 612 m_ConstantInt(Index1))) && 613 !match(Ins1, m_Constant(VecC1))) 614 return false; 615 616 bool IsConst0 = !V0; 617 bool IsConst1 = !V1; 618 if (IsConst0 && IsConst1) 619 return false; 620 if (!IsConst0 && !IsConst1 && Index0 != Index1) 621 return false; 622 623 // Bail for single insertion if it is a load. 624 // TODO: Handle this once getVectorInstrCost can cost for load/stores. 625 auto *I0 = dyn_cast_or_null<Instruction>(V0); 626 auto *I1 = dyn_cast_or_null<Instruction>(V1); 627 if ((IsConst0 && I1 && I1->mayReadFromMemory()) || 628 (IsConst1 && I0 && I0->mayReadFromMemory())) 629 return false; 630 631 uint64_t Index = IsConst0 ? Index1 : Index0; 632 Type *ScalarTy = IsConst0 ? V1->getType() : V0->getType(); 633 Type *VecTy = I.getType(); 634 assert(VecTy->isVectorTy() && 635 (IsConst0 || IsConst1 || V0->getType() == V1->getType()) && 636 (ScalarTy->isIntegerTy() || ScalarTy->isFloatingPointTy() || 637 ScalarTy->isPointerTy()) && 638 "Unexpected types for insert element into binop or cmp"); 639 640 unsigned Opcode = I.getOpcode(); 641 InstructionCost ScalarOpCost, VectorOpCost; 642 if (IsCmp) { 643 ScalarOpCost = TTI.getCmpSelInstrCost(Opcode, ScalarTy); 644 VectorOpCost = TTI.getCmpSelInstrCost(Opcode, VecTy); 645 } else { 646 ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, ScalarTy); 647 VectorOpCost = TTI.getArithmeticInstrCost(Opcode, VecTy); 648 } 649 650 // Get cost estimate for the insert element. This cost will factor into 651 // both sequences. 652 InstructionCost InsertCost = 653 TTI.getVectorInstrCost(Instruction::InsertElement, VecTy, Index); 654 InstructionCost OldCost = 655 (IsConst0 ? 0 : InsertCost) + (IsConst1 ? 0 : InsertCost) + VectorOpCost; 656 InstructionCost NewCost = ScalarOpCost + InsertCost + 657 (IsConst0 ? 0 : !Ins0->hasOneUse() * InsertCost) + 658 (IsConst1 ? 0 : !Ins1->hasOneUse() * InsertCost); 659 660 // We want to scalarize unless the vector variant actually has lower cost. 661 if (OldCost < NewCost || !NewCost.isValid()) 662 return false; 663 664 // vec_op (inselt VecC0, V0, Index), (inselt VecC1, V1, Index) --> 665 // inselt NewVecC, (scalar_op V0, V1), Index 666 if (IsCmp) 667 ++NumScalarCmp; 668 else 669 ++NumScalarBO; 670 671 // For constant cases, extract the scalar element, this should constant fold. 672 if (IsConst0) 673 V0 = ConstantExpr::getExtractElement(VecC0, Builder.getInt64(Index)); 674 if (IsConst1) 675 V1 = ConstantExpr::getExtractElement(VecC1, Builder.getInt64(Index)); 676 677 Value *Scalar = 678 IsCmp ? Builder.CreateCmp(Pred, V0, V1) 679 : Builder.CreateBinOp((Instruction::BinaryOps)Opcode, V0, V1); 680 681 Scalar->setName(I.getName() + ".scalar"); 682 683 // All IR flags are safe to back-propagate. There is no potential for extra 684 // poison to be created by the scalar instruction. 685 if (auto *ScalarInst = dyn_cast<Instruction>(Scalar)) 686 ScalarInst->copyIRFlags(&I); 687 688 // Fold the vector constants in the original vectors into a new base vector. 689 Constant *NewVecC = IsCmp ? ConstantExpr::getCompare(Pred, VecC0, VecC1) 690 : ConstantExpr::get(Opcode, VecC0, VecC1); 691 Value *Insert = Builder.CreateInsertElement(NewVecC, Scalar, Index); 692 replaceValue(I, *Insert); 693 return true; 694 } 695 696 /// Try to combine a scalar binop + 2 scalar compares of extracted elements of 697 /// a vector into vector operations followed by extract. Note: The SLP pass 698 /// may miss this pattern because of implementation problems. 699 bool VectorCombine::foldExtractedCmps(Instruction &I) { 700 // We are looking for a scalar binop of booleans. 701 // binop i1 (cmp Pred I0, C0), (cmp Pred I1, C1) 702 if (!I.isBinaryOp() || !I.getType()->isIntegerTy(1)) 703 return false; 704 705 // The compare predicates should match, and each compare should have a 706 // constant operand. 707 // TODO: Relax the one-use constraints. 708 Value *B0 = I.getOperand(0), *B1 = I.getOperand(1); 709 Instruction *I0, *I1; 710 Constant *C0, *C1; 711 CmpInst::Predicate P0, P1; 712 if (!match(B0, m_OneUse(m_Cmp(P0, m_Instruction(I0), m_Constant(C0)))) || 713 !match(B1, m_OneUse(m_Cmp(P1, m_Instruction(I1), m_Constant(C1)))) || 714 P0 != P1) 715 return false; 716 717 // The compare operands must be extracts of the same vector with constant 718 // extract indexes. 719 // TODO: Relax the one-use constraints. 720 Value *X; 721 uint64_t Index0, Index1; 722 if (!match(I0, m_OneUse(m_ExtractElt(m_Value(X), m_ConstantInt(Index0)))) || 723 !match(I1, m_OneUse(m_ExtractElt(m_Specific(X), m_ConstantInt(Index1))))) 724 return false; 725 726 auto *Ext0 = cast<ExtractElementInst>(I0); 727 auto *Ext1 = cast<ExtractElementInst>(I1); 728 ExtractElementInst *ConvertToShuf = getShuffleExtract(Ext0, Ext1); 729 if (!ConvertToShuf) 730 return false; 731 732 // The original scalar pattern is: 733 // binop i1 (cmp Pred (ext X, Index0), C0), (cmp Pred (ext X, Index1), C1) 734 CmpInst::Predicate Pred = P0; 735 unsigned CmpOpcode = CmpInst::isFPPredicate(Pred) ? Instruction::FCmp 736 : Instruction::ICmp; 737 auto *VecTy = dyn_cast<FixedVectorType>(X->getType()); 738 if (!VecTy) 739 return false; 740 741 InstructionCost OldCost = 742 TTI.getVectorInstrCost(Ext0->getOpcode(), VecTy, Index0); 743 OldCost += TTI.getVectorInstrCost(Ext1->getOpcode(), VecTy, Index1); 744 OldCost += TTI.getCmpSelInstrCost(CmpOpcode, I0->getType()) * 2; 745 OldCost += TTI.getArithmeticInstrCost(I.getOpcode(), I.getType()); 746 747 // The proposed vector pattern is: 748 // vcmp = cmp Pred X, VecC 749 // ext (binop vNi1 vcmp, (shuffle vcmp, Index1)), Index0 750 int CheapIndex = ConvertToShuf == Ext0 ? Index1 : Index0; 751 int ExpensiveIndex = ConvertToShuf == Ext0 ? Index0 : Index1; 752 auto *CmpTy = cast<FixedVectorType>(CmpInst::makeCmpResultType(X->getType())); 753 InstructionCost NewCost = TTI.getCmpSelInstrCost(CmpOpcode, X->getType()); 754 SmallVector<int, 32> ShufMask(VecTy->getNumElements(), UndefMaskElem); 755 ShufMask[CheapIndex] = ExpensiveIndex; 756 NewCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, CmpTy, 757 ShufMask); 758 NewCost += TTI.getArithmeticInstrCost(I.getOpcode(), CmpTy); 759 NewCost += TTI.getVectorInstrCost(Ext0->getOpcode(), CmpTy, CheapIndex); 760 761 // Aggressively form vector ops if the cost is equal because the transform 762 // may enable further optimization. 763 // Codegen can reverse this transform (scalarize) if it was not profitable. 764 if (OldCost < NewCost || !NewCost.isValid()) 765 return false; 766 767 // Create a vector constant from the 2 scalar constants. 768 SmallVector<Constant *, 32> CmpC(VecTy->getNumElements(), 769 UndefValue::get(VecTy->getElementType())); 770 CmpC[Index0] = C0; 771 CmpC[Index1] = C1; 772 Value *VCmp = Builder.CreateCmp(Pred, X, ConstantVector::get(CmpC)); 773 774 Value *Shuf = createShiftShuffle(VCmp, ExpensiveIndex, CheapIndex, Builder); 775 Value *VecLogic = Builder.CreateBinOp(cast<BinaryOperator>(I).getOpcode(), 776 VCmp, Shuf); 777 Value *NewExt = Builder.CreateExtractElement(VecLogic, CheapIndex); 778 replaceValue(I, *NewExt); 779 ++NumVecCmpBO; 780 return true; 781 } 782 783 // Check if memory loc modified between two instrs in the same BB 784 static bool isMemModifiedBetween(BasicBlock::iterator Begin, 785 BasicBlock::iterator End, 786 const MemoryLocation &Loc, AAResults &AA) { 787 unsigned NumScanned = 0; 788 return std::any_of(Begin, End, [&](const Instruction &Instr) { 789 return isModSet(AA.getModRefInfo(&Instr, Loc)) || 790 ++NumScanned > MaxInstrsToScan; 791 }); 792 } 793 794 /// Helper class to indicate whether a vector index can be safely scalarized and 795 /// if a freeze needs to be inserted. 796 class ScalarizationResult { 797 enum class StatusTy { Unsafe, Safe, SafeWithFreeze }; 798 799 StatusTy Status; 800 Value *ToFreeze; 801 802 ScalarizationResult(StatusTy Status, Value *ToFreeze = nullptr) 803 : Status(Status), ToFreeze(ToFreeze) {} 804 805 public: 806 ScalarizationResult(const ScalarizationResult &Other) = default; 807 ~ScalarizationResult() { 808 assert(!ToFreeze && "freeze() not called with ToFreeze being set"); 809 } 810 811 static ScalarizationResult unsafe() { return {StatusTy::Unsafe}; } 812 static ScalarizationResult safe() { return {StatusTy::Safe}; } 813 static ScalarizationResult safeWithFreeze(Value *ToFreeze) { 814 return {StatusTy::SafeWithFreeze, ToFreeze}; 815 } 816 817 /// Returns true if the index can be scalarize without requiring a freeze. 818 bool isSafe() const { return Status == StatusTy::Safe; } 819 /// Returns true if the index cannot be scalarized. 820 bool isUnsafe() const { return Status == StatusTy::Unsafe; } 821 /// Returns true if the index can be scalarize, but requires inserting a 822 /// freeze. 823 bool isSafeWithFreeze() const { return Status == StatusTy::SafeWithFreeze; } 824 825 /// Reset the state of Unsafe and clear ToFreze if set. 826 void discard() { 827 ToFreeze = nullptr; 828 Status = StatusTy::Unsafe; 829 } 830 831 /// Freeze the ToFreeze and update the use in \p User to use it. 832 void freeze(IRBuilder<> &Builder, Instruction &UserI) { 833 assert(isSafeWithFreeze() && 834 "should only be used when freezing is required"); 835 assert(is_contained(ToFreeze->users(), &UserI) && 836 "UserI must be a user of ToFreeze"); 837 IRBuilder<>::InsertPointGuard Guard(Builder); 838 Builder.SetInsertPoint(cast<Instruction>(&UserI)); 839 Value *Frozen = 840 Builder.CreateFreeze(ToFreeze, ToFreeze->getName() + ".frozen"); 841 for (Use &U : make_early_inc_range((UserI.operands()))) 842 if (U.get() == ToFreeze) 843 U.set(Frozen); 844 845 ToFreeze = nullptr; 846 } 847 }; 848 849 /// Check if it is legal to scalarize a memory access to \p VecTy at index \p 850 /// Idx. \p Idx must access a valid vector element. 851 static ScalarizationResult canScalarizeAccess(FixedVectorType *VecTy, 852 Value *Idx, Instruction *CtxI, 853 AssumptionCache &AC, 854 const DominatorTree &DT) { 855 if (auto *C = dyn_cast<ConstantInt>(Idx)) { 856 if (C->getValue().ult(VecTy->getNumElements())) 857 return ScalarizationResult::safe(); 858 return ScalarizationResult::unsafe(); 859 } 860 861 unsigned IntWidth = Idx->getType()->getScalarSizeInBits(); 862 APInt Zero(IntWidth, 0); 863 APInt MaxElts(IntWidth, VecTy->getNumElements()); 864 ConstantRange ValidIndices(Zero, MaxElts); 865 ConstantRange IdxRange(IntWidth, true); 866 867 if (isGuaranteedNotToBePoison(Idx, &AC)) { 868 if (ValidIndices.contains(computeConstantRange(Idx, true, &AC, CtxI, &DT))) 869 return ScalarizationResult::safe(); 870 return ScalarizationResult::unsafe(); 871 } 872 873 // If the index may be poison, check if we can insert a freeze before the 874 // range of the index is restricted. 875 Value *IdxBase; 876 ConstantInt *CI; 877 if (match(Idx, m_And(m_Value(IdxBase), m_ConstantInt(CI)))) { 878 IdxRange = IdxRange.binaryAnd(CI->getValue()); 879 } else if (match(Idx, m_URem(m_Value(IdxBase), m_ConstantInt(CI)))) { 880 IdxRange = IdxRange.urem(CI->getValue()); 881 } 882 883 if (ValidIndices.contains(IdxRange)) 884 return ScalarizationResult::safeWithFreeze(IdxBase); 885 return ScalarizationResult::unsafe(); 886 } 887 888 /// The memory operation on a vector of \p ScalarType had alignment of 889 /// \p VectorAlignment. Compute the maximal, but conservatively correct, 890 /// alignment that will be valid for the memory operation on a single scalar 891 /// element of the same type with index \p Idx. 892 static Align computeAlignmentAfterScalarization(Align VectorAlignment, 893 Type *ScalarType, Value *Idx, 894 const DataLayout &DL) { 895 if (auto *C = dyn_cast<ConstantInt>(Idx)) 896 return commonAlignment(VectorAlignment, 897 C->getZExtValue() * DL.getTypeStoreSize(ScalarType)); 898 return commonAlignment(VectorAlignment, DL.getTypeStoreSize(ScalarType)); 899 } 900 901 // Combine patterns like: 902 // %0 = load <4 x i32>, <4 x i32>* %a 903 // %1 = insertelement <4 x i32> %0, i32 %b, i32 1 904 // store <4 x i32> %1, <4 x i32>* %a 905 // to: 906 // %0 = bitcast <4 x i32>* %a to i32* 907 // %1 = getelementptr inbounds i32, i32* %0, i64 0, i64 1 908 // store i32 %b, i32* %1 909 bool VectorCombine::foldSingleElementStore(Instruction &I) { 910 StoreInst *SI = dyn_cast<StoreInst>(&I); 911 if (!SI || !SI->isSimple() || 912 !isa<FixedVectorType>(SI->getValueOperand()->getType())) 913 return false; 914 915 // TODO: Combine more complicated patterns (multiple insert) by referencing 916 // TargetTransformInfo. 917 Instruction *Source; 918 Value *NewElement; 919 Value *Idx; 920 if (!match(SI->getValueOperand(), 921 m_InsertElt(m_Instruction(Source), m_Value(NewElement), 922 m_Value(Idx)))) 923 return false; 924 925 if (auto *Load = dyn_cast<LoadInst>(Source)) { 926 auto VecTy = cast<FixedVectorType>(SI->getValueOperand()->getType()); 927 const DataLayout &DL = I.getModule()->getDataLayout(); 928 Value *SrcAddr = Load->getPointerOperand()->stripPointerCasts(); 929 // Don't optimize for atomic/volatile load or store. Ensure memory is not 930 // modified between, vector type matches store size, and index is inbounds. 931 if (!Load->isSimple() || Load->getParent() != SI->getParent() || 932 !DL.typeSizeEqualsStoreSize(Load->getType()) || 933 SrcAddr != SI->getPointerOperand()->stripPointerCasts()) 934 return false; 935 936 auto ScalarizableIdx = canScalarizeAccess(VecTy, Idx, Load, AC, DT); 937 if (ScalarizableIdx.isUnsafe() || 938 isMemModifiedBetween(Load->getIterator(), SI->getIterator(), 939 MemoryLocation::get(SI), AA)) 940 return false; 941 942 if (ScalarizableIdx.isSafeWithFreeze()) 943 ScalarizableIdx.freeze(Builder, *cast<Instruction>(Idx)); 944 Value *GEP = Builder.CreateInBoundsGEP( 945 SI->getValueOperand()->getType(), SI->getPointerOperand(), 946 {ConstantInt::get(Idx->getType(), 0), Idx}); 947 StoreInst *NSI = Builder.CreateStore(NewElement, GEP); 948 NSI->copyMetadata(*SI); 949 Align ScalarOpAlignment = computeAlignmentAfterScalarization( 950 std::max(SI->getAlign(), Load->getAlign()), NewElement->getType(), Idx, 951 DL); 952 NSI->setAlignment(ScalarOpAlignment); 953 replaceValue(I, *NSI); 954 eraseInstruction(I); 955 return true; 956 } 957 958 return false; 959 } 960 961 /// Try to scalarize vector loads feeding extractelement instructions. 962 bool VectorCombine::scalarizeLoadExtract(Instruction &I) { 963 Value *Ptr; 964 if (!match(&I, m_Load(m_Value(Ptr)))) 965 return false; 966 967 auto *LI = cast<LoadInst>(&I); 968 const DataLayout &DL = I.getModule()->getDataLayout(); 969 if (LI->isVolatile() || !DL.typeSizeEqualsStoreSize(LI->getType())) 970 return false; 971 972 auto *FixedVT = dyn_cast<FixedVectorType>(LI->getType()); 973 if (!FixedVT) 974 return false; 975 976 InstructionCost OriginalCost = TTI.getMemoryOpCost( 977 Instruction::Load, LI->getType(), Align(LI->getAlignment()), 978 LI->getPointerAddressSpace()); 979 InstructionCost ScalarizedCost = 0; 980 981 Instruction *LastCheckedInst = LI; 982 unsigned NumInstChecked = 0; 983 // Check if all users of the load are extracts with no memory modifications 984 // between the load and the extract. Compute the cost of both the original 985 // code and the scalarized version. 986 for (User *U : LI->users()) { 987 auto *UI = dyn_cast<ExtractElementInst>(U); 988 if (!UI || UI->getParent() != LI->getParent()) 989 return false; 990 991 if (!isGuaranteedNotToBePoison(UI->getOperand(1), &AC, LI, &DT)) 992 return false; 993 994 // Check if any instruction between the load and the extract may modify 995 // memory. 996 if (LastCheckedInst->comesBefore(UI)) { 997 for (Instruction &I : 998 make_range(std::next(LI->getIterator()), UI->getIterator())) { 999 // Bail out if we reached the check limit or the instruction may write 1000 // to memory. 1001 if (NumInstChecked == MaxInstrsToScan || I.mayWriteToMemory()) 1002 return false; 1003 NumInstChecked++; 1004 } 1005 } 1006 1007 if (!LastCheckedInst) 1008 LastCheckedInst = UI; 1009 else if (LastCheckedInst->comesBefore(UI)) 1010 LastCheckedInst = UI; 1011 1012 auto ScalarIdx = canScalarizeAccess(FixedVT, UI->getOperand(1), &I, AC, DT); 1013 if (!ScalarIdx.isSafe()) { 1014 // TODO: Freeze index if it is safe to do so. 1015 ScalarIdx.discard(); 1016 return false; 1017 } 1018 1019 auto *Index = dyn_cast<ConstantInt>(UI->getOperand(1)); 1020 OriginalCost += 1021 TTI.getVectorInstrCost(Instruction::ExtractElement, LI->getType(), 1022 Index ? Index->getZExtValue() : -1); 1023 ScalarizedCost += 1024 TTI.getMemoryOpCost(Instruction::Load, FixedVT->getElementType(), 1025 Align(1), LI->getPointerAddressSpace()); 1026 ScalarizedCost += TTI.getAddressComputationCost(FixedVT->getElementType()); 1027 } 1028 1029 if (ScalarizedCost >= OriginalCost) 1030 return false; 1031 1032 // Replace extracts with narrow scalar loads. 1033 for (User *U : LI->users()) { 1034 auto *EI = cast<ExtractElementInst>(U); 1035 Builder.SetInsertPoint(EI); 1036 1037 Value *Idx = EI->getOperand(1); 1038 Value *GEP = 1039 Builder.CreateInBoundsGEP(FixedVT, Ptr, {Builder.getInt32(0), Idx}); 1040 auto *NewLoad = cast<LoadInst>(Builder.CreateLoad( 1041 FixedVT->getElementType(), GEP, EI->getName() + ".scalar")); 1042 1043 Align ScalarOpAlignment = computeAlignmentAfterScalarization( 1044 LI->getAlign(), FixedVT->getElementType(), Idx, DL); 1045 NewLoad->setAlignment(ScalarOpAlignment); 1046 1047 replaceValue(*EI, *NewLoad); 1048 } 1049 1050 return true; 1051 } 1052 1053 /// This is the entry point for all transforms. Pass manager differences are 1054 /// handled in the callers of this function. 1055 bool VectorCombine::run() { 1056 if (DisableVectorCombine) 1057 return false; 1058 1059 // Don't attempt vectorization if the target does not support vectors. 1060 if (!TTI.getNumberOfRegisters(TTI.getRegisterClassForType(/*Vector*/ true))) 1061 return false; 1062 1063 bool MadeChange = false; 1064 auto FoldInst = [this, &MadeChange](Instruction &I) { 1065 Builder.SetInsertPoint(&I); 1066 MadeChange |= vectorizeLoadInsert(I); 1067 MadeChange |= foldExtractExtract(I); 1068 MadeChange |= foldBitcastShuf(I); 1069 MadeChange |= scalarizeBinopOrCmp(I); 1070 MadeChange |= foldExtractedCmps(I); 1071 MadeChange |= scalarizeLoadExtract(I); 1072 MadeChange |= foldSingleElementStore(I); 1073 }; 1074 for (BasicBlock &BB : F) { 1075 // Ignore unreachable basic blocks. 1076 if (!DT.isReachableFromEntry(&BB)) 1077 continue; 1078 // Use early increment range so that we can erase instructions in loop. 1079 for (Instruction &I : make_early_inc_range(BB)) { 1080 if (isa<DbgInfoIntrinsic>(I)) 1081 continue; 1082 FoldInst(I); 1083 } 1084 } 1085 1086 while (!Worklist.isEmpty()) { 1087 Instruction *I = Worklist.removeOne(); 1088 if (!I) 1089 continue; 1090 1091 if (isInstructionTriviallyDead(I)) { 1092 eraseInstruction(*I); 1093 continue; 1094 } 1095 1096 FoldInst(*I); 1097 } 1098 1099 return MadeChange; 1100 } 1101 1102 // Pass manager boilerplate below here. 1103 1104 namespace { 1105 class VectorCombineLegacyPass : public FunctionPass { 1106 public: 1107 static char ID; 1108 VectorCombineLegacyPass() : FunctionPass(ID) { 1109 initializeVectorCombineLegacyPassPass(*PassRegistry::getPassRegistry()); 1110 } 1111 1112 void getAnalysisUsage(AnalysisUsage &AU) const override { 1113 AU.addRequired<AssumptionCacheTracker>(); 1114 AU.addRequired<DominatorTreeWrapperPass>(); 1115 AU.addRequired<TargetTransformInfoWrapperPass>(); 1116 AU.addRequired<AAResultsWrapperPass>(); 1117 AU.setPreservesCFG(); 1118 AU.addPreserved<DominatorTreeWrapperPass>(); 1119 AU.addPreserved<GlobalsAAWrapperPass>(); 1120 AU.addPreserved<AAResultsWrapperPass>(); 1121 AU.addPreserved<BasicAAWrapperPass>(); 1122 FunctionPass::getAnalysisUsage(AU); 1123 } 1124 1125 bool runOnFunction(Function &F) override { 1126 if (skipFunction(F)) 1127 return false; 1128 auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); 1129 auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 1130 auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); 1131 auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults(); 1132 VectorCombine Combiner(F, TTI, DT, AA, AC); 1133 return Combiner.run(); 1134 } 1135 }; 1136 } // namespace 1137 1138 char VectorCombineLegacyPass::ID = 0; 1139 INITIALIZE_PASS_BEGIN(VectorCombineLegacyPass, "vector-combine", 1140 "Optimize scalar/vector ops", false, 1141 false) 1142 INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) 1143 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) 1144 INITIALIZE_PASS_END(VectorCombineLegacyPass, "vector-combine", 1145 "Optimize scalar/vector ops", false, false) 1146 Pass *llvm::createVectorCombinePass() { 1147 return new VectorCombineLegacyPass(); 1148 } 1149 1150 PreservedAnalyses VectorCombinePass::run(Function &F, 1151 FunctionAnalysisManager &FAM) { 1152 auto &AC = FAM.getResult<AssumptionAnalysis>(F); 1153 TargetTransformInfo &TTI = FAM.getResult<TargetIRAnalysis>(F); 1154 DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F); 1155 AAResults &AA = FAM.getResult<AAManager>(F); 1156 VectorCombine Combiner(F, TTI, DT, AA, AC); 1157 if (!Combiner.run()) 1158 return PreservedAnalyses::all(); 1159 PreservedAnalyses PA; 1160 PA.preserveSet<CFGAnalyses>(); 1161 return PA; 1162 } 1163