1 //===------- VectorCombine.cpp - Optimize partial vector operations -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This pass optimizes scalar/vector interactions using target cost models. The 10 // transforms implemented here may not fit in traditional loop-based or SLP 11 // vectorization passes. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "llvm/Transforms/Vectorize/VectorCombine.h" 16 #include "llvm/ADT/Statistic.h" 17 #include "llvm/Analysis/GlobalsModRef.h" 18 #include "llvm/Analysis/TargetTransformInfo.h" 19 #include "llvm/Analysis/ValueTracking.h" 20 #include "llvm/Analysis/VectorUtils.h" 21 #include "llvm/IR/Dominators.h" 22 #include "llvm/IR/Function.h" 23 #include "llvm/IR/IRBuilder.h" 24 #include "llvm/IR/PatternMatch.h" 25 #include "llvm/InitializePasses.h" 26 #include "llvm/Pass.h" 27 #include "llvm/Support/CommandLine.h" 28 #include "llvm/Transforms/Vectorize.h" 29 #include "llvm/Transforms/Utils/Local.h" 30 31 using namespace llvm; 32 using namespace llvm::PatternMatch; 33 34 #define DEBUG_TYPE "vector-combine" 35 STATISTIC(NumVecCmp, "Number of vector compares formed"); 36 STATISTIC(NumVecBO, "Number of vector binops formed"); 37 38 static cl::opt<bool> DisableVectorCombine( 39 "disable-vector-combine", cl::init(false), cl::Hidden, 40 cl::desc("Disable all vector combine transforms")); 41 42 static cl::opt<bool> DisableBinopExtractShuffle( 43 "disable-binop-extract-shuffle", cl::init(false), cl::Hidden, 44 cl::desc("Disable binop extract to shuffle transforms")); 45 46 47 /// Compare the relative costs of 2 extracts followed by scalar operation vs. 48 /// vector operation(s) followed by extract. Return true if the existing 49 /// instructions are cheaper than a vector alternative. Otherwise, return false 50 /// and if one of the extracts should be transformed to a shufflevector, set 51 /// \p ConvertToShuffle to that extract instruction. 52 static bool isExtractExtractCheap(Instruction *Ext0, Instruction *Ext1, 53 unsigned Opcode, 54 const TargetTransformInfo &TTI, 55 Instruction *&ConvertToShuffle, 56 unsigned PreferredExtractIndex) { 57 assert(isa<ConstantInt>(Ext0->getOperand(1)) && 58 isa<ConstantInt>(Ext1->getOperand(1)) && 59 "Expected constant extract indexes"); 60 Type *ScalarTy = Ext0->getType(); 61 Type *VecTy = Ext0->getOperand(0)->getType(); 62 int ScalarOpCost, VectorOpCost; 63 64 // Get cost estimates for scalar and vector versions of the operation. 65 bool IsBinOp = Instruction::isBinaryOp(Opcode); 66 if (IsBinOp) { 67 ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, ScalarTy); 68 VectorOpCost = TTI.getArithmeticInstrCost(Opcode, VecTy); 69 } else { 70 assert((Opcode == Instruction::ICmp || Opcode == Instruction::FCmp) && 71 "Expected a compare"); 72 ScalarOpCost = TTI.getCmpSelInstrCost(Opcode, ScalarTy, 73 CmpInst::makeCmpResultType(ScalarTy)); 74 VectorOpCost = TTI.getCmpSelInstrCost(Opcode, VecTy, 75 CmpInst::makeCmpResultType(VecTy)); 76 } 77 78 // Get cost estimates for the extract elements. These costs will factor into 79 // both sequences. 80 unsigned Ext0Index = cast<ConstantInt>(Ext0->getOperand(1))->getZExtValue(); 81 unsigned Ext1Index = cast<ConstantInt>(Ext1->getOperand(1))->getZExtValue(); 82 83 int Extract0Cost = TTI.getVectorInstrCost(Instruction::ExtractElement, 84 VecTy, Ext0Index); 85 int Extract1Cost = TTI.getVectorInstrCost(Instruction::ExtractElement, 86 VecTy, Ext1Index); 87 88 // A more expensive extract will always be replaced by a splat shuffle. 89 // For example, if Ext0 is more expensive: 90 // opcode (extelt V0, Ext0), (ext V1, Ext1) --> 91 // extelt (opcode (splat V0, Ext0), V1), Ext1 92 // TODO: Evaluate whether that always results in lowest cost. Alternatively, 93 // check the cost of creating a broadcast shuffle and shuffling both 94 // operands to element 0. 95 int CheapExtractCost = std::min(Extract0Cost, Extract1Cost); 96 97 // Extra uses of the extracts mean that we include those costs in the 98 // vector total because those instructions will not be eliminated. 99 int OldCost, NewCost; 100 if (Ext0->getOperand(0) == Ext1->getOperand(0) && Ext0Index == Ext1Index) { 101 // Handle a special case. If the 2 extracts are identical, adjust the 102 // formulas to account for that. The extra use charge allows for either the 103 // CSE'd pattern or an unoptimized form with identical values: 104 // opcode (extelt V, C), (extelt V, C) --> extelt (opcode V, V), C 105 bool HasUseTax = Ext0 == Ext1 ? !Ext0->hasNUses(2) 106 : !Ext0->hasOneUse() || !Ext1->hasOneUse(); 107 OldCost = CheapExtractCost + ScalarOpCost; 108 NewCost = VectorOpCost + CheapExtractCost + HasUseTax * CheapExtractCost; 109 } else { 110 // Handle the general case. Each extract is actually a different value: 111 // opcode (extelt V0, C0), (extelt V1, C1) --> extelt (opcode V0, V1), C 112 OldCost = Extract0Cost + Extract1Cost + ScalarOpCost; 113 NewCost = VectorOpCost + CheapExtractCost + 114 !Ext0->hasOneUse() * Extract0Cost + 115 !Ext1->hasOneUse() * Extract1Cost; 116 } 117 118 if (Ext0Index == Ext1Index) { 119 // If the extract indexes are identical, no shuffle is needed. 120 ConvertToShuffle = nullptr; 121 } else { 122 if (IsBinOp && DisableBinopExtractShuffle) 123 return true; 124 125 // If we are extracting from 2 different indexes, then one operand must be 126 // shuffled before performing the vector operation. The shuffle mask is 127 // undefined except for 1 lane that is being translated to the remaining 128 // extraction lane. Therefore, it is a splat shuffle. Ex: 129 // ShufMask = { undef, undef, 0, undef } 130 // TODO: The cost model has an option for a "broadcast" shuffle 131 // (splat-from-element-0), but no option for a more general splat. 132 NewCost += 133 TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, VecTy); 134 135 // The more expensive extract will be replaced by a shuffle. If the costs 136 // are equal and there is a preferred extract index, shuffle the opposite 137 // operand. Otherwise, replace the extract with the higher index. 138 if (Extract0Cost > Extract1Cost) 139 ConvertToShuffle = Ext0; 140 else if (Extract1Cost > Extract0Cost) 141 ConvertToShuffle = Ext1; 142 else if (PreferredExtractIndex == Ext0Index) 143 ConvertToShuffle = Ext1; 144 else if (PreferredExtractIndex == Ext1Index) 145 ConvertToShuffle = Ext0; 146 else 147 ConvertToShuffle = Ext0Index > Ext1Index ? Ext0 : Ext1; 148 } 149 150 // Aggressively form a vector op if the cost is equal because the transform 151 // may enable further optimization. 152 // Codegen can reverse this transform (scalarize) if it was not profitable. 153 return OldCost < NewCost; 154 } 155 156 /// Try to reduce extract element costs by converting scalar compares to vector 157 /// compares followed by extract. 158 /// cmp (ext0 V0, C), (ext1 V1, C) 159 static void foldExtExtCmp(Instruction *Ext0, Instruction *Ext1, 160 Instruction &I, const TargetTransformInfo &TTI) { 161 assert(isa<CmpInst>(&I) && "Expected a compare"); 162 163 // cmp Pred (extelt V0, C), (extelt V1, C) --> extelt (cmp Pred V0, V1), C 164 ++NumVecCmp; 165 IRBuilder<> Builder(&I); 166 CmpInst::Predicate Pred = cast<CmpInst>(&I)->getPredicate(); 167 Value *V0 = Ext0->getOperand(0), *V1 = Ext1->getOperand(0); 168 Value *VecCmp = 169 Ext0->getType()->isFloatingPointTy() ? Builder.CreateFCmp(Pred, V0, V1) 170 : Builder.CreateICmp(Pred, V0, V1); 171 Value *Extract = Builder.CreateExtractElement(VecCmp, Ext0->getOperand(1)); 172 I.replaceAllUsesWith(Extract); 173 } 174 175 /// Try to reduce extract element costs by converting scalar binops to vector 176 /// binops followed by extract. 177 /// bo (ext0 V0, C), (ext1 V1, C) 178 static void foldExtExtBinop(Instruction *Ext0, Instruction *Ext1, 179 Instruction &I, const TargetTransformInfo &TTI) { 180 assert(isa<BinaryOperator>(&I) && "Expected a binary operator"); 181 182 // bo (extelt V0, C), (extelt V1, C) --> extelt (bo V0, V1), C 183 ++NumVecBO; 184 IRBuilder<> Builder(&I); 185 Value *V0 = Ext0->getOperand(0), *V1 = Ext1->getOperand(0); 186 Value *VecBO = 187 Builder.CreateBinOp(cast<BinaryOperator>(&I)->getOpcode(), V0, V1); 188 189 // All IR flags are safe to back-propagate because any potential poison 190 // created in unused vector elements is discarded by the extract. 191 if (auto *VecBOInst = dyn_cast<Instruction>(VecBO)) 192 VecBOInst->copyIRFlags(&I); 193 194 Value *Extract = Builder.CreateExtractElement(VecBO, Ext0->getOperand(1)); 195 I.replaceAllUsesWith(Extract); 196 } 197 198 /// Match an instruction with extracted vector operands. 199 static bool foldExtractExtract(Instruction &I, const TargetTransformInfo &TTI) { 200 // It is not safe to transform things like div, urem, etc. because we may 201 // create undefined behavior when executing those on unknown vector elements. 202 if (!isSafeToSpeculativelyExecute(&I)) 203 return false; 204 205 Instruction *Ext0, *Ext1; 206 CmpInst::Predicate Pred = CmpInst::BAD_ICMP_PREDICATE; 207 if (!match(&I, m_Cmp(Pred, m_Instruction(Ext0), m_Instruction(Ext1))) && 208 !match(&I, m_BinOp(m_Instruction(Ext0), m_Instruction(Ext1)))) 209 return false; 210 211 Value *V0, *V1; 212 uint64_t C0, C1; 213 if (!match(Ext0, m_ExtractElement(m_Value(V0), m_ConstantInt(C0))) || 214 !match(Ext1, m_ExtractElement(m_Value(V1), m_ConstantInt(C1))) || 215 V0->getType() != V1->getType()) 216 return false; 217 218 // If the scalar value 'I' is going to be re-inserted into a vector, then try 219 // to create an extract to that same element. The extract/insert can be 220 // reduced to a "select shuffle". 221 // TODO: If we add a larger pattern match that starts from an insert, this 222 // probably becomes unnecessary. 223 uint64_t InsertIndex = std::numeric_limits<uint64_t>::max(); 224 if (I.hasOneUse()) 225 match(I.user_back(), m_InsertElement(m_Value(), m_Value(), 226 m_ConstantInt(InsertIndex))); 227 228 Instruction *ConvertToShuffle; 229 if (isExtractExtractCheap(Ext0, Ext1, I.getOpcode(), TTI, ConvertToShuffle, 230 InsertIndex)) 231 return false; 232 233 if (ConvertToShuffle) { 234 // The shuffle mask is undefined except for 1 lane that is being translated 235 // to the cheap extraction lane. Example: 236 // ShufMask = { 2, undef, undef, undef } 237 uint64_t SplatIndex = ConvertToShuffle == Ext0 ? C0 : C1; 238 uint64_t CheapExtIndex = ConvertToShuffle == Ext0 ? C1 : C0; 239 Type *VecTy = V0->getType(); 240 Type *I32Ty = IntegerType::getInt32Ty(I.getContext()); 241 UndefValue *Undef = UndefValue::get(I32Ty); 242 SmallVector<Constant *, 32> ShufMask(VecTy->getVectorNumElements(), Undef); 243 ShufMask[CheapExtIndex] = ConstantInt::get(I32Ty, SplatIndex); 244 IRBuilder<> Builder(ConvertToShuffle); 245 246 // extelt X, C --> extelt (splat X), C' 247 Value *Shuf = Builder.CreateShuffleVector(ConvertToShuffle->getOperand(0), 248 UndefValue::get(VecTy), 249 ConstantVector::get(ShufMask)); 250 Value *NewExt = Builder.CreateExtractElement(Shuf, CheapExtIndex); 251 if (ConvertToShuffle == Ext0) 252 Ext0 = cast<Instruction>(NewExt); 253 else 254 Ext1 = cast<Instruction>(NewExt); 255 } 256 257 if (Pred != CmpInst::BAD_ICMP_PREDICATE) 258 foldExtExtCmp(Ext0, Ext1, I, TTI); 259 else 260 foldExtExtBinop(Ext0, Ext1, I, TTI); 261 262 return true; 263 } 264 265 /// If this is a bitcast to narrow elements from a shuffle of wider elements, 266 /// try to bitcast the source vector to the narrow type followed by shuffle. 267 /// This can enable further transforms by moving bitcasts or shuffles together. 268 static bool foldBitcastShuf(Instruction &I, const TargetTransformInfo &TTI) { 269 Value *V; 270 ArrayRef<int> Mask; 271 if (!match(&I, m_BitCast(m_OneUse(m_ShuffleVector(m_Value(V), m_Undef(), 272 m_Mask(Mask)))))) 273 return false; 274 275 Type *DestTy = I.getType(); 276 Type *SrcTy = V->getType(); 277 if (!DestTy->isVectorTy() || I.getOperand(0)->getType() != SrcTy) 278 return false; 279 280 // TODO: Handle bitcast from narrow element type to wide element type. 281 assert(SrcTy->isVectorTy() && "Shuffle of non-vector type?"); 282 unsigned DestNumElts = DestTy->getVectorNumElements(); 283 unsigned SrcNumElts = SrcTy->getVectorNumElements(); 284 if (SrcNumElts > DestNumElts) 285 return false; 286 287 // The new shuffle must not cost more than the old shuffle. The bitcast is 288 // moved ahead of the shuffle, so assume that it has the same cost as before. 289 if (TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, DestTy) > 290 TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, SrcTy)) 291 return false; 292 293 // Bitcast the source vector and expand the shuffle mask to the equivalent for 294 // narrow elements. 295 // bitcast (shuf V, MaskC) --> shuf (bitcast V), MaskC' 296 IRBuilder<> Builder(&I); 297 Value *CastV = Builder.CreateBitCast(V, DestTy); 298 SmallVector<int, 16> NewMask; 299 assert(DestNumElts % SrcNumElts == 0 && "Unexpected shuffle mask"); 300 unsigned ScaleFactor = DestNumElts / SrcNumElts; 301 scaleShuffleMask(ScaleFactor, Mask, NewMask); 302 Value *Shuf = Builder.CreateShuffleVector(CastV, UndefValue::get(DestTy), 303 NewMask); 304 I.replaceAllUsesWith(Shuf); 305 return true; 306 } 307 308 /// This is the entry point for all transforms. Pass manager differences are 309 /// handled in the callers of this function. 310 static bool runImpl(Function &F, const TargetTransformInfo &TTI, 311 const DominatorTree &DT) { 312 if (DisableVectorCombine) 313 return false; 314 315 bool MadeChange = false; 316 for (BasicBlock &BB : F) { 317 // Ignore unreachable basic blocks. 318 if (!DT.isReachableFromEntry(&BB)) 319 continue; 320 // Do not delete instructions under here and invalidate the iterator. 321 // Walk the block backwards for efficiency. We're matching a chain of 322 // use->defs, so we're more likely to succeed by starting from the bottom. 323 // TODO: It could be more efficient to remove dead instructions 324 // iteratively in this loop rather than waiting until the end. 325 for (Instruction &I : make_range(BB.rbegin(), BB.rend())) { 326 if (isa<DbgInfoIntrinsic>(I)) 327 continue; 328 MadeChange |= foldExtractExtract(I, TTI); 329 MadeChange |= foldBitcastShuf(I, TTI); 330 } 331 } 332 333 // We're done with transforms, so remove dead instructions. 334 if (MadeChange) 335 for (BasicBlock &BB : F) 336 SimplifyInstructionsInBlock(&BB); 337 338 return MadeChange; 339 } 340 341 // Pass manager boilerplate below here. 342 343 namespace { 344 class VectorCombineLegacyPass : public FunctionPass { 345 public: 346 static char ID; 347 VectorCombineLegacyPass() : FunctionPass(ID) { 348 initializeVectorCombineLegacyPassPass(*PassRegistry::getPassRegistry()); 349 } 350 351 void getAnalysisUsage(AnalysisUsage &AU) const override { 352 AU.addRequired<DominatorTreeWrapperPass>(); 353 AU.addRequired<TargetTransformInfoWrapperPass>(); 354 AU.setPreservesCFG(); 355 AU.addPreserved<DominatorTreeWrapperPass>(); 356 AU.addPreserved<GlobalsAAWrapperPass>(); 357 FunctionPass::getAnalysisUsage(AU); 358 } 359 360 bool runOnFunction(Function &F) override { 361 if (skipFunction(F)) 362 return false; 363 auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 364 auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); 365 return runImpl(F, TTI, DT); 366 } 367 }; 368 } // namespace 369 370 char VectorCombineLegacyPass::ID = 0; 371 INITIALIZE_PASS_BEGIN(VectorCombineLegacyPass, "vector-combine", 372 "Optimize scalar/vector ops", false, 373 false) 374 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) 375 INITIALIZE_PASS_END(VectorCombineLegacyPass, "vector-combine", 376 "Optimize scalar/vector ops", false, false) 377 Pass *llvm::createVectorCombinePass() { 378 return new VectorCombineLegacyPass(); 379 } 380 381 PreservedAnalyses VectorCombinePass::run(Function &F, 382 FunctionAnalysisManager &FAM) { 383 TargetTransformInfo &TTI = FAM.getResult<TargetIRAnalysis>(F); 384 DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F); 385 if (!runImpl(F, TTI, DT)) 386 return PreservedAnalyses::all(); 387 PreservedAnalyses PA; 388 PA.preserveSet<CFGAnalyses>(); 389 PA.preserve<GlobalsAA>(); 390 return PA; 391 } 392