1 //===------- VectorCombine.cpp - Optimize partial vector operations -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This pass optimizes scalar/vector interactions using target cost models. The 10 // transforms implemented here may not fit in traditional loop-based or SLP 11 // vectorization passes. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "llvm/Transforms/Vectorize/VectorCombine.h" 16 #include "llvm/ADT/Statistic.h" 17 #include "llvm/Analysis/GlobalsModRef.h" 18 #include "llvm/Analysis/TargetTransformInfo.h" 19 #include "llvm/Analysis/ValueTracking.h" 20 #include "llvm/Analysis/VectorUtils.h" 21 #include "llvm/IR/Dominators.h" 22 #include "llvm/IR/Function.h" 23 #include "llvm/IR/IRBuilder.h" 24 #include "llvm/IR/PatternMatch.h" 25 #include "llvm/InitializePasses.h" 26 #include "llvm/Pass.h" 27 #include "llvm/Support/CommandLine.h" 28 #include "llvm/Transforms/Vectorize.h" 29 #include "llvm/Transforms/Utils/Local.h" 30 31 using namespace llvm; 32 using namespace llvm::PatternMatch; 33 34 #define DEBUG_TYPE "vector-combine" 35 STATISTIC(NumVecCmp, "Number of vector compares formed"); 36 STATISTIC(NumVecBO, "Number of vector binops formed"); 37 38 static cl::opt<bool> DisableVectorCombine( 39 "disable-vector-combine", cl::init(false), cl::Hidden, 40 cl::desc("Disable all vector combine transforms")); 41 42 static cl::opt<bool> DisableBinopExtractShuffle( 43 "disable-binop-extract-shuffle", cl::init(false), cl::Hidden, 44 cl::desc("Disable binop extract to shuffle transforms")); 45 46 47 /// Compare the relative costs of 2 extracts followed by scalar operation vs. 48 /// vector operation(s) followed by extract. Return true if the existing 49 /// instructions are cheaper than a vector alternative. Otherwise, return false 50 /// and if one of the extracts should be transformed to a shufflevector, set 51 /// \p ConvertToShuffle to that extract instruction. 52 static bool isExtractExtractCheap(Instruction *Ext0, Instruction *Ext1, 53 unsigned Opcode, 54 const TargetTransformInfo &TTI, 55 Instruction *&ConvertToShuffle, 56 unsigned PreferredExtractIndex) { 57 assert(isa<ConstantInt>(Ext0->getOperand(1)) && 58 isa<ConstantInt>(Ext1->getOperand(1)) && 59 "Expected constant extract indexes"); 60 Type *ScalarTy = Ext0->getType(); 61 auto *VecTy = cast<VectorType>(Ext0->getOperand(0)->getType()); 62 int ScalarOpCost, VectorOpCost; 63 64 // Get cost estimates for scalar and vector versions of the operation. 65 bool IsBinOp = Instruction::isBinaryOp(Opcode); 66 if (IsBinOp) { 67 ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, ScalarTy); 68 VectorOpCost = TTI.getArithmeticInstrCost(Opcode, VecTy); 69 } else { 70 assert((Opcode == Instruction::ICmp || Opcode == Instruction::FCmp) && 71 "Expected a compare"); 72 ScalarOpCost = TTI.getCmpSelInstrCost(Opcode, ScalarTy, 73 CmpInst::makeCmpResultType(ScalarTy)); 74 VectorOpCost = TTI.getCmpSelInstrCost(Opcode, VecTy, 75 CmpInst::makeCmpResultType(VecTy)); 76 } 77 78 // Get cost estimates for the extract elements. These costs will factor into 79 // both sequences. 80 unsigned Ext0Index = cast<ConstantInt>(Ext0->getOperand(1))->getZExtValue(); 81 unsigned Ext1Index = cast<ConstantInt>(Ext1->getOperand(1))->getZExtValue(); 82 83 int Extract0Cost = TTI.getVectorInstrCost(Instruction::ExtractElement, 84 VecTy, Ext0Index); 85 int Extract1Cost = TTI.getVectorInstrCost(Instruction::ExtractElement, 86 VecTy, Ext1Index); 87 88 // A more expensive extract will always be replaced by a splat shuffle. 89 // For example, if Ext0 is more expensive: 90 // opcode (extelt V0, Ext0), (ext V1, Ext1) --> 91 // extelt (opcode (splat V0, Ext0), V1), Ext1 92 // TODO: Evaluate whether that always results in lowest cost. Alternatively, 93 // check the cost of creating a broadcast shuffle and shuffling both 94 // operands to element 0. 95 int CheapExtractCost = std::min(Extract0Cost, Extract1Cost); 96 97 // Extra uses of the extracts mean that we include those costs in the 98 // vector total because those instructions will not be eliminated. 99 int OldCost, NewCost; 100 if (Ext0->getOperand(0) == Ext1->getOperand(0) && Ext0Index == Ext1Index) { 101 // Handle a special case. If the 2 extracts are identical, adjust the 102 // formulas to account for that. The extra use charge allows for either the 103 // CSE'd pattern or an unoptimized form with identical values: 104 // opcode (extelt V, C), (extelt V, C) --> extelt (opcode V, V), C 105 bool HasUseTax = Ext0 == Ext1 ? !Ext0->hasNUses(2) 106 : !Ext0->hasOneUse() || !Ext1->hasOneUse(); 107 OldCost = CheapExtractCost + ScalarOpCost; 108 NewCost = VectorOpCost + CheapExtractCost + HasUseTax * CheapExtractCost; 109 } else { 110 // Handle the general case. Each extract is actually a different value: 111 // opcode (extelt V0, C0), (extelt V1, C1) --> extelt (opcode V0, V1), C 112 OldCost = Extract0Cost + Extract1Cost + ScalarOpCost; 113 NewCost = VectorOpCost + CheapExtractCost + 114 !Ext0->hasOneUse() * Extract0Cost + 115 !Ext1->hasOneUse() * Extract1Cost; 116 } 117 118 if (Ext0Index == Ext1Index) { 119 // If the extract indexes are identical, no shuffle is needed. 120 ConvertToShuffle = nullptr; 121 } else { 122 if (IsBinOp && DisableBinopExtractShuffle) 123 return true; 124 125 // If we are extracting from 2 different indexes, then one operand must be 126 // shuffled before performing the vector operation. The shuffle mask is 127 // undefined except for 1 lane that is being translated to the remaining 128 // extraction lane. Therefore, it is a splat shuffle. Ex: 129 // ShufMask = { undef, undef, 0, undef } 130 // TODO: The cost model has an option for a "broadcast" shuffle 131 // (splat-from-element-0), but no option for a more general splat. 132 NewCost += 133 TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, VecTy); 134 135 // The more expensive extract will be replaced by a shuffle. If the costs 136 // are equal and there is a preferred extract index, shuffle the opposite 137 // operand. Otherwise, replace the extract with the higher index. 138 if (Extract0Cost > Extract1Cost) 139 ConvertToShuffle = Ext0; 140 else if (Extract1Cost > Extract0Cost) 141 ConvertToShuffle = Ext1; 142 else if (PreferredExtractIndex == Ext0Index) 143 ConvertToShuffle = Ext1; 144 else if (PreferredExtractIndex == Ext1Index) 145 ConvertToShuffle = Ext0; 146 else 147 ConvertToShuffle = Ext0Index > Ext1Index ? Ext0 : Ext1; 148 } 149 150 // Aggressively form a vector op if the cost is equal because the transform 151 // may enable further optimization. 152 // Codegen can reverse this transform (scalarize) if it was not profitable. 153 return OldCost < NewCost; 154 } 155 156 /// Try to reduce extract element costs by converting scalar compares to vector 157 /// compares followed by extract. 158 /// cmp (ext0 V0, C), (ext1 V1, C) 159 static void foldExtExtCmp(Instruction *Ext0, Instruction *Ext1, 160 Instruction &I, const TargetTransformInfo &TTI) { 161 assert(isa<CmpInst>(&I) && "Expected a compare"); 162 163 // cmp Pred (extelt V0, C), (extelt V1, C) --> extelt (cmp Pred V0, V1), C 164 ++NumVecCmp; 165 IRBuilder<> Builder(&I); 166 CmpInst::Predicate Pred = cast<CmpInst>(&I)->getPredicate(); 167 Value *V0 = Ext0->getOperand(0), *V1 = Ext1->getOperand(0); 168 Value *VecCmp = 169 Ext0->getType()->isFloatingPointTy() ? Builder.CreateFCmp(Pred, V0, V1) 170 : Builder.CreateICmp(Pred, V0, V1); 171 Value *Extract = Builder.CreateExtractElement(VecCmp, Ext0->getOperand(1)); 172 I.replaceAllUsesWith(Extract); 173 } 174 175 /// Try to reduce extract element costs by converting scalar binops to vector 176 /// binops followed by extract. 177 /// bo (ext0 V0, C), (ext1 V1, C) 178 static void foldExtExtBinop(Instruction *Ext0, Instruction *Ext1, 179 Instruction &I, const TargetTransformInfo &TTI) { 180 assert(isa<BinaryOperator>(&I) && "Expected a binary operator"); 181 182 // bo (extelt V0, C), (extelt V1, C) --> extelt (bo V0, V1), C 183 ++NumVecBO; 184 IRBuilder<> Builder(&I); 185 Value *V0 = Ext0->getOperand(0), *V1 = Ext1->getOperand(0); 186 Value *VecBO = 187 Builder.CreateBinOp(cast<BinaryOperator>(&I)->getOpcode(), V0, V1); 188 189 // All IR flags are safe to back-propagate because any potential poison 190 // created in unused vector elements is discarded by the extract. 191 if (auto *VecBOInst = dyn_cast<Instruction>(VecBO)) 192 VecBOInst->copyIRFlags(&I); 193 194 Value *Extract = Builder.CreateExtractElement(VecBO, Ext0->getOperand(1)); 195 I.replaceAllUsesWith(Extract); 196 } 197 198 /// Match an instruction with extracted vector operands. 199 static bool foldExtractExtract(Instruction &I, const TargetTransformInfo &TTI) { 200 // It is not safe to transform things like div, urem, etc. because we may 201 // create undefined behavior when executing those on unknown vector elements. 202 if (!isSafeToSpeculativelyExecute(&I)) 203 return false; 204 205 Instruction *Ext0, *Ext1; 206 CmpInst::Predicate Pred = CmpInst::BAD_ICMP_PREDICATE; 207 if (!match(&I, m_Cmp(Pred, m_Instruction(Ext0), m_Instruction(Ext1))) && 208 !match(&I, m_BinOp(m_Instruction(Ext0), m_Instruction(Ext1)))) 209 return false; 210 211 Value *V0, *V1; 212 uint64_t C0, C1; 213 if (!match(Ext0, m_ExtractElement(m_Value(V0), m_ConstantInt(C0))) || 214 !match(Ext1, m_ExtractElement(m_Value(V1), m_ConstantInt(C1))) || 215 V0->getType() != V1->getType()) 216 return false; 217 218 // If the scalar value 'I' is going to be re-inserted into a vector, then try 219 // to create an extract to that same element. The extract/insert can be 220 // reduced to a "select shuffle". 221 // TODO: If we add a larger pattern match that starts from an insert, this 222 // probably becomes unnecessary. 223 uint64_t InsertIndex = std::numeric_limits<uint64_t>::max(); 224 if (I.hasOneUse()) 225 match(I.user_back(), m_InsertElement(m_Value(), m_Value(), 226 m_ConstantInt(InsertIndex))); 227 228 Instruction *ConvertToShuffle; 229 if (isExtractExtractCheap(Ext0, Ext1, I.getOpcode(), TTI, ConvertToShuffle, 230 InsertIndex)) 231 return false; 232 233 if (ConvertToShuffle) { 234 // The shuffle mask is undefined except for 1 lane that is being translated 235 // to the cheap extraction lane. Example: 236 // ShufMask = { 2, undef, undef, undef } 237 uint64_t SplatIndex = ConvertToShuffle == Ext0 ? C0 : C1; 238 uint64_t CheapExtIndex = ConvertToShuffle == Ext0 ? C1 : C0; 239 auto *VecTy = cast<VectorType>(V0->getType()); 240 SmallVector<int, 32> ShufMask(VecTy->getNumElements(), -1); 241 ShufMask[CheapExtIndex] = SplatIndex; 242 IRBuilder<> Builder(ConvertToShuffle); 243 244 // extelt X, C --> extelt (splat X), C' 245 Value *Shuf = Builder.CreateShuffleVector(ConvertToShuffle->getOperand(0), 246 UndefValue::get(VecTy), ShufMask); 247 Value *NewExt = Builder.CreateExtractElement(Shuf, CheapExtIndex); 248 if (ConvertToShuffle == Ext0) 249 Ext0 = cast<Instruction>(NewExt); 250 else 251 Ext1 = cast<Instruction>(NewExt); 252 } 253 254 if (Pred != CmpInst::BAD_ICMP_PREDICATE) 255 foldExtExtCmp(Ext0, Ext1, I, TTI); 256 else 257 foldExtExtBinop(Ext0, Ext1, I, TTI); 258 259 return true; 260 } 261 262 /// If this is a bitcast of a shuffle, try to bitcast the source vector to the 263 /// destination type followed by shuffle. This can enable further transforms by 264 /// moving bitcasts or shuffles together. 265 static bool foldBitcastShuf(Instruction &I, const TargetTransformInfo &TTI) { 266 Value *V; 267 ArrayRef<int> Mask; 268 if (!match(&I, m_BitCast(m_OneUse(m_ShuffleVector(m_Value(V), m_Undef(), 269 m_Mask(Mask)))))) 270 return false; 271 272 // Disallow non-vector casts and length-changing shuffles. 273 // TODO: We could allow any shuffle. 274 auto *DestTy = dyn_cast<VectorType>(I.getType()); 275 auto *SrcTy = cast<VectorType>(V->getType()); 276 if (!DestTy || I.getOperand(0)->getType() != SrcTy) 277 return false; 278 279 // The new shuffle must not cost more than the old shuffle. The bitcast is 280 // moved ahead of the shuffle, so assume that it has the same cost as before. 281 if (TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, DestTy) > 282 TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, SrcTy)) 283 return false; 284 285 unsigned DestNumElts = DestTy->getNumElements(); 286 unsigned SrcNumElts = SrcTy->getNumElements(); 287 SmallVector<int, 16> NewMask; 288 if (SrcNumElts <= DestNumElts) { 289 // The bitcast is from wide to narrow/equal elements. The shuffle mask can 290 // always be expanded to the equivalent form choosing narrower elements. 291 assert(DestNumElts % SrcNumElts == 0 && "Unexpected shuffle mask"); 292 unsigned ScaleFactor = DestNumElts / SrcNumElts; 293 narrowShuffleMaskElts(ScaleFactor, Mask, NewMask); 294 } else { 295 // The bitcast is from narrow elements to wide elements. The shuffle mask 296 // must choose consecutive elements to allow casting first. 297 assert(SrcNumElts % DestNumElts == 0 && "Unexpected shuffle mask"); 298 unsigned ScaleFactor = SrcNumElts / DestNumElts; 299 if (!widenShuffleMaskElts(ScaleFactor, Mask, NewMask)) 300 return false; 301 } 302 // bitcast (shuf V, MaskC) --> shuf (bitcast V), MaskC' 303 IRBuilder<> Builder(&I); 304 Value *CastV = Builder.CreateBitCast(V, DestTy); 305 Value *Shuf = Builder.CreateShuffleVector(CastV, UndefValue::get(DestTy), 306 NewMask); 307 I.replaceAllUsesWith(Shuf); 308 return true; 309 } 310 311 /// This is the entry point for all transforms. Pass manager differences are 312 /// handled in the callers of this function. 313 static bool runImpl(Function &F, const TargetTransformInfo &TTI, 314 const DominatorTree &DT) { 315 if (DisableVectorCombine) 316 return false; 317 318 bool MadeChange = false; 319 for (BasicBlock &BB : F) { 320 // Ignore unreachable basic blocks. 321 if (!DT.isReachableFromEntry(&BB)) 322 continue; 323 // Do not delete instructions under here and invalidate the iterator. 324 // Walk the block backwards for efficiency. We're matching a chain of 325 // use->defs, so we're more likely to succeed by starting from the bottom. 326 // TODO: It could be more efficient to remove dead instructions 327 // iteratively in this loop rather than waiting until the end. 328 for (Instruction &I : make_range(BB.rbegin(), BB.rend())) { 329 if (isa<DbgInfoIntrinsic>(I)) 330 continue; 331 MadeChange |= foldExtractExtract(I, TTI); 332 MadeChange |= foldBitcastShuf(I, TTI); 333 } 334 } 335 336 // We're done with transforms, so remove dead instructions. 337 if (MadeChange) 338 for (BasicBlock &BB : F) 339 SimplifyInstructionsInBlock(&BB); 340 341 return MadeChange; 342 } 343 344 // Pass manager boilerplate below here. 345 346 namespace { 347 class VectorCombineLegacyPass : public FunctionPass { 348 public: 349 static char ID; 350 VectorCombineLegacyPass() : FunctionPass(ID) { 351 initializeVectorCombineLegacyPassPass(*PassRegistry::getPassRegistry()); 352 } 353 354 void getAnalysisUsage(AnalysisUsage &AU) const override { 355 AU.addRequired<DominatorTreeWrapperPass>(); 356 AU.addRequired<TargetTransformInfoWrapperPass>(); 357 AU.setPreservesCFG(); 358 AU.addPreserved<DominatorTreeWrapperPass>(); 359 AU.addPreserved<GlobalsAAWrapperPass>(); 360 FunctionPass::getAnalysisUsage(AU); 361 } 362 363 bool runOnFunction(Function &F) override { 364 if (skipFunction(F)) 365 return false; 366 auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 367 auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); 368 return runImpl(F, TTI, DT); 369 } 370 }; 371 } // namespace 372 373 char VectorCombineLegacyPass::ID = 0; 374 INITIALIZE_PASS_BEGIN(VectorCombineLegacyPass, "vector-combine", 375 "Optimize scalar/vector ops", false, 376 false) 377 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) 378 INITIALIZE_PASS_END(VectorCombineLegacyPass, "vector-combine", 379 "Optimize scalar/vector ops", false, false) 380 Pass *llvm::createVectorCombinePass() { 381 return new VectorCombineLegacyPass(); 382 } 383 384 PreservedAnalyses VectorCombinePass::run(Function &F, 385 FunctionAnalysisManager &FAM) { 386 TargetTransformInfo &TTI = FAM.getResult<TargetIRAnalysis>(F); 387 DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F); 388 if (!runImpl(F, TTI, DT)) 389 return PreservedAnalyses::all(); 390 PreservedAnalyses PA; 391 PA.preserveSet<CFGAnalyses>(); 392 PA.preserve<GlobalsAA>(); 393 return PA; 394 } 395