1 //===------- VectorCombine.cpp - Optimize partial vector operations -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This pass optimizes scalar/vector interactions using target cost models. The 10 // transforms implemented here may not fit in traditional loop-based or SLP 11 // vectorization passes. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "llvm/Transforms/Vectorize/VectorCombine.h" 16 #include "llvm/ADT/Statistic.h" 17 #include "llvm/Analysis/GlobalsModRef.h" 18 #include "llvm/Analysis/TargetTransformInfo.h" 19 #include "llvm/Analysis/ValueTracking.h" 20 #include "llvm/Analysis/VectorUtils.h" 21 #include "llvm/IR/Dominators.h" 22 #include "llvm/IR/Function.h" 23 #include "llvm/IR/IRBuilder.h" 24 #include "llvm/IR/PatternMatch.h" 25 #include "llvm/InitializePasses.h" 26 #include "llvm/Pass.h" 27 #include "llvm/Support/CommandLine.h" 28 #include "llvm/Transforms/Vectorize.h" 29 #include "llvm/Transforms/Utils/Local.h" 30 31 using namespace llvm; 32 using namespace llvm::PatternMatch; 33 34 #define DEBUG_TYPE "vector-combine" 35 STATISTIC(NumVecCmp, "Number of vector compares formed"); 36 STATISTIC(NumVecBO, "Number of vector binops formed"); 37 STATISTIC(NumScalarBO, "Number of scalar binops formed"); 38 39 static cl::opt<bool> DisableVectorCombine( 40 "disable-vector-combine", cl::init(false), cl::Hidden, 41 cl::desc("Disable all vector combine transforms")); 42 43 static cl::opt<bool> DisableBinopExtractShuffle( 44 "disable-binop-extract-shuffle", cl::init(false), cl::Hidden, 45 cl::desc("Disable binop extract to shuffle transforms")); 46 47 48 /// Compare the relative costs of 2 extracts followed by scalar operation vs. 49 /// vector operation(s) followed by extract. Return true if the existing 50 /// instructions are cheaper than a vector alternative. Otherwise, return false 51 /// and if one of the extracts should be transformed to a shufflevector, set 52 /// \p ConvertToShuffle to that extract instruction. 53 static bool isExtractExtractCheap(Instruction *Ext0, Instruction *Ext1, 54 unsigned Opcode, 55 const TargetTransformInfo &TTI, 56 Instruction *&ConvertToShuffle, 57 unsigned PreferredExtractIndex) { 58 assert(isa<ConstantInt>(Ext0->getOperand(1)) && 59 isa<ConstantInt>(Ext1->getOperand(1)) && 60 "Expected constant extract indexes"); 61 Type *ScalarTy = Ext0->getType(); 62 auto *VecTy = cast<VectorType>(Ext0->getOperand(0)->getType()); 63 int ScalarOpCost, VectorOpCost; 64 65 // Get cost estimates for scalar and vector versions of the operation. 66 bool IsBinOp = Instruction::isBinaryOp(Opcode); 67 if (IsBinOp) { 68 ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, ScalarTy); 69 VectorOpCost = TTI.getArithmeticInstrCost(Opcode, VecTy); 70 } else { 71 assert((Opcode == Instruction::ICmp || Opcode == Instruction::FCmp) && 72 "Expected a compare"); 73 ScalarOpCost = TTI.getCmpSelInstrCost(Opcode, ScalarTy, 74 CmpInst::makeCmpResultType(ScalarTy)); 75 VectorOpCost = TTI.getCmpSelInstrCost(Opcode, VecTy, 76 CmpInst::makeCmpResultType(VecTy)); 77 } 78 79 // Get cost estimates for the extract elements. These costs will factor into 80 // both sequences. 81 unsigned Ext0Index = cast<ConstantInt>(Ext0->getOperand(1))->getZExtValue(); 82 unsigned Ext1Index = cast<ConstantInt>(Ext1->getOperand(1))->getZExtValue(); 83 84 int Extract0Cost = TTI.getVectorInstrCost(Instruction::ExtractElement, 85 VecTy, Ext0Index); 86 int Extract1Cost = TTI.getVectorInstrCost(Instruction::ExtractElement, 87 VecTy, Ext1Index); 88 89 // A more expensive extract will always be replaced by a splat shuffle. 90 // For example, if Ext0 is more expensive: 91 // opcode (extelt V0, Ext0), (ext V1, Ext1) --> 92 // extelt (opcode (splat V0, Ext0), V1), Ext1 93 // TODO: Evaluate whether that always results in lowest cost. Alternatively, 94 // check the cost of creating a broadcast shuffle and shuffling both 95 // operands to element 0. 96 int CheapExtractCost = std::min(Extract0Cost, Extract1Cost); 97 98 // Extra uses of the extracts mean that we include those costs in the 99 // vector total because those instructions will not be eliminated. 100 int OldCost, NewCost; 101 if (Ext0->getOperand(0) == Ext1->getOperand(0) && Ext0Index == Ext1Index) { 102 // Handle a special case. If the 2 extracts are identical, adjust the 103 // formulas to account for that. The extra use charge allows for either the 104 // CSE'd pattern or an unoptimized form with identical values: 105 // opcode (extelt V, C), (extelt V, C) --> extelt (opcode V, V), C 106 bool HasUseTax = Ext0 == Ext1 ? !Ext0->hasNUses(2) 107 : !Ext0->hasOneUse() || !Ext1->hasOneUse(); 108 OldCost = CheapExtractCost + ScalarOpCost; 109 NewCost = VectorOpCost + CheapExtractCost + HasUseTax * CheapExtractCost; 110 } else { 111 // Handle the general case. Each extract is actually a different value: 112 // opcode (extelt V0, C0), (extelt V1, C1) --> extelt (opcode V0, V1), C 113 OldCost = Extract0Cost + Extract1Cost + ScalarOpCost; 114 NewCost = VectorOpCost + CheapExtractCost + 115 !Ext0->hasOneUse() * Extract0Cost + 116 !Ext1->hasOneUse() * Extract1Cost; 117 } 118 119 if (Ext0Index == Ext1Index) { 120 // If the extract indexes are identical, no shuffle is needed. 121 ConvertToShuffle = nullptr; 122 } else { 123 if (IsBinOp && DisableBinopExtractShuffle) 124 return true; 125 126 // If we are extracting from 2 different indexes, then one operand must be 127 // shuffled before performing the vector operation. The shuffle mask is 128 // undefined except for 1 lane that is being translated to the remaining 129 // extraction lane. Therefore, it is a splat shuffle. Ex: 130 // ShufMask = { undef, undef, 0, undef } 131 // TODO: The cost model has an option for a "broadcast" shuffle 132 // (splat-from-element-0), but no option for a more general splat. 133 NewCost += 134 TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, VecTy); 135 136 // The more expensive extract will be replaced by a shuffle. If the costs 137 // are equal and there is a preferred extract index, shuffle the opposite 138 // operand. Otherwise, replace the extract with the higher index. 139 if (Extract0Cost > Extract1Cost) 140 ConvertToShuffle = Ext0; 141 else if (Extract1Cost > Extract0Cost) 142 ConvertToShuffle = Ext1; 143 else if (PreferredExtractIndex == Ext0Index) 144 ConvertToShuffle = Ext1; 145 else if (PreferredExtractIndex == Ext1Index) 146 ConvertToShuffle = Ext0; 147 else 148 ConvertToShuffle = Ext0Index > Ext1Index ? Ext0 : Ext1; 149 } 150 151 // Aggressively form a vector op if the cost is equal because the transform 152 // may enable further optimization. 153 // Codegen can reverse this transform (scalarize) if it was not profitable. 154 return OldCost < NewCost; 155 } 156 157 /// Try to reduce extract element costs by converting scalar compares to vector 158 /// compares followed by extract. 159 /// cmp (ext0 V0, C), (ext1 V1, C) 160 static void foldExtExtCmp(Instruction *Ext0, Instruction *Ext1, 161 Instruction &I, const TargetTransformInfo &TTI) { 162 assert(isa<CmpInst>(&I) && "Expected a compare"); 163 164 // cmp Pred (extelt V0, C), (extelt V1, C) --> extelt (cmp Pred V0, V1), C 165 ++NumVecCmp; 166 IRBuilder<> Builder(&I); 167 CmpInst::Predicate Pred = cast<CmpInst>(&I)->getPredicate(); 168 Value *V0 = Ext0->getOperand(0), *V1 = Ext1->getOperand(0); 169 Value *VecCmp = 170 Ext0->getType()->isFloatingPointTy() ? Builder.CreateFCmp(Pred, V0, V1) 171 : Builder.CreateICmp(Pred, V0, V1); 172 Value *Extract = Builder.CreateExtractElement(VecCmp, Ext0->getOperand(1)); 173 I.replaceAllUsesWith(Extract); 174 } 175 176 /// Try to reduce extract element costs by converting scalar binops to vector 177 /// binops followed by extract. 178 /// bo (ext0 V0, C), (ext1 V1, C) 179 static void foldExtExtBinop(Instruction *Ext0, Instruction *Ext1, 180 Instruction &I, const TargetTransformInfo &TTI) { 181 assert(isa<BinaryOperator>(&I) && "Expected a binary operator"); 182 183 // bo (extelt V0, C), (extelt V1, C) --> extelt (bo V0, V1), C 184 ++NumVecBO; 185 IRBuilder<> Builder(&I); 186 Value *V0 = Ext0->getOperand(0), *V1 = Ext1->getOperand(0); 187 Value *VecBO = 188 Builder.CreateBinOp(cast<BinaryOperator>(&I)->getOpcode(), V0, V1); 189 190 // All IR flags are safe to back-propagate because any potential poison 191 // created in unused vector elements is discarded by the extract. 192 if (auto *VecBOInst = dyn_cast<Instruction>(VecBO)) 193 VecBOInst->copyIRFlags(&I); 194 195 Value *Extract = Builder.CreateExtractElement(VecBO, Ext0->getOperand(1)); 196 I.replaceAllUsesWith(Extract); 197 } 198 199 /// Match an instruction with extracted vector operands. 200 static bool foldExtractExtract(Instruction &I, const TargetTransformInfo &TTI) { 201 // It is not safe to transform things like div, urem, etc. because we may 202 // create undefined behavior when executing those on unknown vector elements. 203 if (!isSafeToSpeculativelyExecute(&I)) 204 return false; 205 206 Instruction *Ext0, *Ext1; 207 CmpInst::Predicate Pred = CmpInst::BAD_ICMP_PREDICATE; 208 if (!match(&I, m_Cmp(Pred, m_Instruction(Ext0), m_Instruction(Ext1))) && 209 !match(&I, m_BinOp(m_Instruction(Ext0), m_Instruction(Ext1)))) 210 return false; 211 212 Value *V0, *V1; 213 uint64_t C0, C1; 214 if (!match(Ext0, m_ExtractElement(m_Value(V0), m_ConstantInt(C0))) || 215 !match(Ext1, m_ExtractElement(m_Value(V1), m_ConstantInt(C1))) || 216 V0->getType() != V1->getType()) 217 return false; 218 219 // If the scalar value 'I' is going to be re-inserted into a vector, then try 220 // to create an extract to that same element. The extract/insert can be 221 // reduced to a "select shuffle". 222 // TODO: If we add a larger pattern match that starts from an insert, this 223 // probably becomes unnecessary. 224 uint64_t InsertIndex = std::numeric_limits<uint64_t>::max(); 225 if (I.hasOneUse()) 226 match(I.user_back(), m_InsertElement(m_Value(), m_Value(), 227 m_ConstantInt(InsertIndex))); 228 229 Instruction *ConvertToShuffle; 230 if (isExtractExtractCheap(Ext0, Ext1, I.getOpcode(), TTI, ConvertToShuffle, 231 InsertIndex)) 232 return false; 233 234 if (ConvertToShuffle) { 235 // The shuffle mask is undefined except for 1 lane that is being translated 236 // to the cheap extraction lane. Example: 237 // ShufMask = { 2, undef, undef, undef } 238 uint64_t SplatIndex = ConvertToShuffle == Ext0 ? C0 : C1; 239 uint64_t CheapExtIndex = ConvertToShuffle == Ext0 ? C1 : C0; 240 auto *VecTy = cast<VectorType>(V0->getType()); 241 SmallVector<int, 32> ShufMask(VecTy->getNumElements(), -1); 242 ShufMask[CheapExtIndex] = SplatIndex; 243 IRBuilder<> Builder(ConvertToShuffle); 244 245 // extelt X, C --> extelt (splat X), C' 246 Value *Shuf = Builder.CreateShuffleVector(ConvertToShuffle->getOperand(0), 247 UndefValue::get(VecTy), ShufMask); 248 Value *NewExt = Builder.CreateExtractElement(Shuf, CheapExtIndex); 249 if (ConvertToShuffle == Ext0) 250 Ext0 = cast<Instruction>(NewExt); 251 else 252 Ext1 = cast<Instruction>(NewExt); 253 } 254 255 if (Pred != CmpInst::BAD_ICMP_PREDICATE) 256 foldExtExtCmp(Ext0, Ext1, I, TTI); 257 else 258 foldExtExtBinop(Ext0, Ext1, I, TTI); 259 260 return true; 261 } 262 263 /// If this is a bitcast of a shuffle, try to bitcast the source vector to the 264 /// destination type followed by shuffle. This can enable further transforms by 265 /// moving bitcasts or shuffles together. 266 static bool foldBitcastShuf(Instruction &I, const TargetTransformInfo &TTI) { 267 Value *V; 268 ArrayRef<int> Mask; 269 if (!match(&I, m_BitCast(m_OneUse(m_ShuffleVector(m_Value(V), m_Undef(), 270 m_Mask(Mask)))))) 271 return false; 272 273 // Disallow non-vector casts and length-changing shuffles. 274 // TODO: We could allow any shuffle. 275 auto *DestTy = dyn_cast<VectorType>(I.getType()); 276 auto *SrcTy = cast<VectorType>(V->getType()); 277 if (!DestTy || I.getOperand(0)->getType() != SrcTy) 278 return false; 279 280 // The new shuffle must not cost more than the old shuffle. The bitcast is 281 // moved ahead of the shuffle, so assume that it has the same cost as before. 282 if (TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, DestTy) > 283 TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, SrcTy)) 284 return false; 285 286 unsigned DestNumElts = DestTy->getNumElements(); 287 unsigned SrcNumElts = SrcTy->getNumElements(); 288 SmallVector<int, 16> NewMask; 289 if (SrcNumElts <= DestNumElts) { 290 // The bitcast is from wide to narrow/equal elements. The shuffle mask can 291 // always be expanded to the equivalent form choosing narrower elements. 292 assert(DestNumElts % SrcNumElts == 0 && "Unexpected shuffle mask"); 293 unsigned ScaleFactor = DestNumElts / SrcNumElts; 294 narrowShuffleMaskElts(ScaleFactor, Mask, NewMask); 295 } else { 296 // The bitcast is from narrow elements to wide elements. The shuffle mask 297 // must choose consecutive elements to allow casting first. 298 assert(SrcNumElts % DestNumElts == 0 && "Unexpected shuffle mask"); 299 unsigned ScaleFactor = SrcNumElts / DestNumElts; 300 if (!widenShuffleMaskElts(ScaleFactor, Mask, NewMask)) 301 return false; 302 } 303 // bitcast (shuf V, MaskC) --> shuf (bitcast V), MaskC' 304 IRBuilder<> Builder(&I); 305 Value *CastV = Builder.CreateBitCast(V, DestTy); 306 Value *Shuf = Builder.CreateShuffleVector(CastV, UndefValue::get(DestTy), 307 NewMask); 308 I.replaceAllUsesWith(Shuf); 309 return true; 310 } 311 312 /// Match a vector binop instruction with inserted scalar operands and convert 313 /// to scalar binop followed by insertelement. 314 static bool scalarizeBinop(Instruction &I, const TargetTransformInfo &TTI) { 315 Instruction *Ins0, *Ins1; 316 if (!match(&I, m_BinOp(m_Instruction(Ins0), m_Instruction(Ins1)))) 317 return false; 318 319 // TODO: Deal with mismatched index constants and variable indexes? 320 Constant *VecC0, *VecC1; 321 Value *V0, *V1; 322 uint64_t Index; 323 if (!match(Ins0, m_InsertElement(m_Constant(VecC0), m_Value(V0), 324 m_ConstantInt(Index))) || 325 !match(Ins1, m_InsertElement(m_Constant(VecC1), m_Value(V1), 326 m_SpecificInt(Index)))) 327 return false; 328 329 Type *ScalarTy = V0->getType(); 330 Type *VecTy = I.getType(); 331 assert(VecTy->isVectorTy() && ScalarTy == V1->getType() && 332 (ScalarTy->isIntegerTy() || ScalarTy->isFloatingPointTy()) && 333 "Unexpected types for insert into binop"); 334 335 Instruction::BinaryOps Opcode = cast<BinaryOperator>(&I)->getOpcode(); 336 int ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, ScalarTy); 337 int VectorOpCost = TTI.getArithmeticInstrCost(Opcode, VecTy); 338 339 // Get cost estimate for the insert element. This cost will factor into 340 // both sequences. 341 int InsertCost = 342 TTI.getVectorInstrCost(Instruction::InsertElement, VecTy, Index); 343 int OldCost = InsertCost + InsertCost + VectorOpCost; 344 int NewCost = ScalarOpCost + InsertCost + 345 !Ins0->hasOneUse() * InsertCost + 346 !Ins1->hasOneUse() * InsertCost; 347 348 // We want to scalarize unless the vector variant actually has lower cost. 349 if (OldCost < NewCost) 350 return false; 351 352 // vec_bo (inselt VecC0, V0, Index), (inselt VecC1, V1, Index) --> 353 // inselt NewVecC, (scalar_bo V0, V1), Index 354 ++NumScalarBO; 355 IRBuilder<> Builder(&I); 356 Value *Scalar = Builder.CreateBinOp(Opcode, V0, V1, I.getName() + ".scalar"); 357 358 // All IR flags are safe to back-propagate. There is no potential for extra 359 // poison to be created by the scalar instruction. 360 if (auto *ScalarInst = dyn_cast<Instruction>(Scalar)) 361 ScalarInst->copyIRFlags(&I); 362 363 // Fold the vector constants in the original vectors into a new base vector. 364 Constant *NewVecC = ConstantExpr::get(Opcode, VecC0, VecC1); 365 Value *Insert = Builder.CreateInsertElement(NewVecC, Scalar, Index); 366 I.replaceAllUsesWith(Insert); 367 Insert->takeName(&I); 368 return true; 369 } 370 371 /// This is the entry point for all transforms. Pass manager differences are 372 /// handled in the callers of this function. 373 static bool runImpl(Function &F, const TargetTransformInfo &TTI, 374 const DominatorTree &DT) { 375 if (DisableVectorCombine) 376 return false; 377 378 bool MadeChange = false; 379 for (BasicBlock &BB : F) { 380 // Ignore unreachable basic blocks. 381 if (!DT.isReachableFromEntry(&BB)) 382 continue; 383 // Do not delete instructions under here and invalidate the iterator. 384 // Walk the block forwards to enable simple iterative chains of transforms. 385 // TODO: It could be more efficient to remove dead instructions 386 // iteratively in this loop rather than waiting until the end. 387 for (Instruction &I : BB) { 388 if (isa<DbgInfoIntrinsic>(I)) 389 continue; 390 MadeChange |= foldExtractExtract(I, TTI); 391 MadeChange |= foldBitcastShuf(I, TTI); 392 MadeChange |= scalarizeBinop(I, TTI); 393 } 394 } 395 396 // We're done with transforms, so remove dead instructions. 397 if (MadeChange) 398 for (BasicBlock &BB : F) 399 SimplifyInstructionsInBlock(&BB); 400 401 return MadeChange; 402 } 403 404 // Pass manager boilerplate below here. 405 406 namespace { 407 class VectorCombineLegacyPass : public FunctionPass { 408 public: 409 static char ID; 410 VectorCombineLegacyPass() : FunctionPass(ID) { 411 initializeVectorCombineLegacyPassPass(*PassRegistry::getPassRegistry()); 412 } 413 414 void getAnalysisUsage(AnalysisUsage &AU) const override { 415 AU.addRequired<DominatorTreeWrapperPass>(); 416 AU.addRequired<TargetTransformInfoWrapperPass>(); 417 AU.setPreservesCFG(); 418 AU.addPreserved<DominatorTreeWrapperPass>(); 419 AU.addPreserved<GlobalsAAWrapperPass>(); 420 FunctionPass::getAnalysisUsage(AU); 421 } 422 423 bool runOnFunction(Function &F) override { 424 if (skipFunction(F)) 425 return false; 426 auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 427 auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); 428 return runImpl(F, TTI, DT); 429 } 430 }; 431 } // namespace 432 433 char VectorCombineLegacyPass::ID = 0; 434 INITIALIZE_PASS_BEGIN(VectorCombineLegacyPass, "vector-combine", 435 "Optimize scalar/vector ops", false, 436 false) 437 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) 438 INITIALIZE_PASS_END(VectorCombineLegacyPass, "vector-combine", 439 "Optimize scalar/vector ops", false, false) 440 Pass *llvm::createVectorCombinePass() { 441 return new VectorCombineLegacyPass(); 442 } 443 444 PreservedAnalyses VectorCombinePass::run(Function &F, 445 FunctionAnalysisManager &FAM) { 446 TargetTransformInfo &TTI = FAM.getResult<TargetIRAnalysis>(F); 447 DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F); 448 if (!runImpl(F, TTI, DT)) 449 return PreservedAnalyses::all(); 450 PreservedAnalyses PA; 451 PA.preserveSet<CFGAnalyses>(); 452 PA.preserve<GlobalsAA>(); 453 return PA; 454 } 455