1 //===------- VectorCombine.cpp - Optimize partial vector operations -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This pass optimizes scalar/vector interactions using target cost models. The 10 // transforms implemented here may not fit in traditional loop-based or SLP 11 // vectorization passes. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "llvm/Transforms/Vectorize/VectorCombine.h" 16 #include "llvm/ADT/Statistic.h" 17 #include "llvm/Analysis/GlobalsModRef.h" 18 #include "llvm/Analysis/TargetTransformInfo.h" 19 #include "llvm/Analysis/ValueTracking.h" 20 #include "llvm/Analysis/VectorUtils.h" 21 #include "llvm/IR/Dominators.h" 22 #include "llvm/IR/Function.h" 23 #include "llvm/IR/IRBuilder.h" 24 #include "llvm/IR/PatternMatch.h" 25 #include "llvm/InitializePasses.h" 26 #include "llvm/Pass.h" 27 #include "llvm/Support/CommandLine.h" 28 #include "llvm/Transforms/Vectorize.h" 29 #include "llvm/Transforms/Utils/Local.h" 30 31 using namespace llvm; 32 using namespace llvm::PatternMatch; 33 34 #define DEBUG_TYPE "vector-combine" 35 STATISTIC(NumVecCmp, "Number of vector compares formed"); 36 STATISTIC(NumVecBO, "Number of vector binops formed"); 37 38 static cl::opt<bool> DisableVectorCombine( 39 "disable-vector-combine", cl::init(false), cl::Hidden, 40 cl::desc("Disable all vector combine transforms")); 41 42 static cl::opt<bool> DisableBinopExtractShuffle( 43 "disable-binop-extract-shuffle", cl::init(false), cl::Hidden, 44 cl::desc("Disable binop extract to shuffle transforms")); 45 46 47 /// Compare the relative costs of 2 extracts followed by scalar operation vs. 48 /// vector operation(s) followed by extract. Return true if the existing 49 /// instructions are cheaper than a vector alternative. Otherwise, return false 50 /// and if one of the extracts should be transformed to a shufflevector, set 51 /// \p ConvertToShuffle to that extract instruction. 52 static bool isExtractExtractCheap(Instruction *Ext0, Instruction *Ext1, 53 unsigned Opcode, 54 const TargetTransformInfo &TTI, 55 Instruction *&ConvertToShuffle) { 56 assert(isa<ConstantInt>(Ext0->getOperand(1)) && 57 isa<ConstantInt>(Ext1->getOperand(1)) && 58 "Expected constant extract indexes"); 59 Type *ScalarTy = Ext0->getType(); 60 Type *VecTy = Ext0->getOperand(0)->getType(); 61 int ScalarOpCost, VectorOpCost; 62 63 // Get cost estimates for scalar and vector versions of the operation. 64 bool IsBinOp = Instruction::isBinaryOp(Opcode); 65 if (IsBinOp) { 66 ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, ScalarTy); 67 VectorOpCost = TTI.getArithmeticInstrCost(Opcode, VecTy); 68 } else { 69 assert((Opcode == Instruction::ICmp || Opcode == Instruction::FCmp) && 70 "Expected a compare"); 71 ScalarOpCost = TTI.getCmpSelInstrCost(Opcode, ScalarTy, 72 CmpInst::makeCmpResultType(ScalarTy)); 73 VectorOpCost = TTI.getCmpSelInstrCost(Opcode, VecTy, 74 CmpInst::makeCmpResultType(VecTy)); 75 } 76 77 // Get cost estimates for the extract elements. These costs will factor into 78 // both sequences. 79 unsigned Ext0Index = cast<ConstantInt>(Ext0->getOperand(1))->getZExtValue(); 80 unsigned Ext1Index = cast<ConstantInt>(Ext1->getOperand(1))->getZExtValue(); 81 82 int Extract0Cost = TTI.getVectorInstrCost(Instruction::ExtractElement, 83 VecTy, Ext0Index); 84 int Extract1Cost = TTI.getVectorInstrCost(Instruction::ExtractElement, 85 VecTy, Ext1Index); 86 87 // A more expensive extract will always be replaced by a splat shuffle. 88 // For example, if Ext0 is more expensive: 89 // opcode (extelt V0, Ext0), (ext V1, Ext1) --> 90 // extelt (opcode (splat V0, Ext0), V1), Ext1 91 // TODO: Evaluate whether that always results in lowest cost. Alternatively, 92 // check the cost of creating a broadcast shuffle and shuffling both 93 // operands to element 0. 94 int CheapExtractCost = std::min(Extract0Cost, Extract1Cost); 95 96 // Extra uses of the extracts mean that we include those costs in the 97 // vector total because those instructions will not be eliminated. 98 int OldCost, NewCost; 99 if (Ext0->getOperand(0) == Ext1->getOperand(0) && Ext0Index == Ext1Index) { 100 // Handle a special case. If the 2 extracts are identical, adjust the 101 // formulas to account for that. The extra use charge allows for either the 102 // CSE'd pattern or an unoptimized form with identical values: 103 // opcode (extelt V, C), (extelt V, C) --> extelt (opcode V, V), C 104 bool HasUseTax = Ext0 == Ext1 ? !Ext0->hasNUses(2) 105 : !Ext0->hasOneUse() || !Ext1->hasOneUse(); 106 OldCost = CheapExtractCost + ScalarOpCost; 107 NewCost = VectorOpCost + CheapExtractCost + HasUseTax * CheapExtractCost; 108 } else { 109 // Handle the general case. Each extract is actually a different value: 110 // opcode (extelt V0, C0), (extelt V1, C1) --> extelt (opcode V0, V1), C 111 OldCost = Extract0Cost + Extract1Cost + ScalarOpCost; 112 NewCost = VectorOpCost + CheapExtractCost + 113 !Ext0->hasOneUse() * Extract0Cost + 114 !Ext1->hasOneUse() * Extract1Cost; 115 } 116 117 if (Ext0Index == Ext1Index) { 118 // If the extract indexes are identical, no shuffle is needed. 119 ConvertToShuffle = nullptr; 120 } else { 121 if (IsBinOp && DisableBinopExtractShuffle) 122 return true; 123 124 // If we are extracting from 2 different indexes, then one operand must be 125 // shuffled before performing the vector operation. The shuffle mask is 126 // undefined except for 1 lane that is being translated to the remaining 127 // extraction lane. Therefore, it is a splat shuffle. Ex: 128 // ShufMask = { undef, undef, 0, undef } 129 // TODO: The cost model has an option for a "broadcast" shuffle 130 // (splat-from-element-0), but no option for a more general splat. 131 NewCost += 132 TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, VecTy); 133 134 // The more expensive extract will be replaced by a shuffle. If the extracts 135 // have the same cost, replace the extract with the higher index. 136 if (Extract0Cost > Extract1Cost) 137 ConvertToShuffle = Ext0; 138 else if (Extract1Cost > Extract0Cost) 139 ConvertToShuffle = Ext1; 140 else 141 ConvertToShuffle = Ext0Index > Ext1Index ? Ext0 : Ext1; 142 } 143 144 // Aggressively form a vector op if the cost is equal because the transform 145 // may enable further optimization. 146 // Codegen can reverse this transform (scalarize) if it was not profitable. 147 return OldCost < NewCost; 148 } 149 150 /// Try to reduce extract element costs by converting scalar compares to vector 151 /// compares followed by extract. 152 /// cmp (ext0 V0, C), (ext1 V1, C) 153 static void foldExtExtCmp(Instruction *Ext0, Instruction *Ext1, 154 Instruction &I, const TargetTransformInfo &TTI) { 155 assert(isa<CmpInst>(&I) && "Expected a compare"); 156 157 // cmp Pred (extelt V0, C), (extelt V1, C) --> extelt (cmp Pred V0, V1), C 158 ++NumVecCmp; 159 IRBuilder<> Builder(&I); 160 CmpInst::Predicate Pred = cast<CmpInst>(&I)->getPredicate(); 161 Value *V0 = Ext0->getOperand(0), *V1 = Ext1->getOperand(0); 162 Value *VecCmp = 163 Ext0->getType()->isFloatingPointTy() ? Builder.CreateFCmp(Pred, V0, V1) 164 : Builder.CreateICmp(Pred, V0, V1); 165 Value *Extract = Builder.CreateExtractElement(VecCmp, Ext0->getOperand(1)); 166 I.replaceAllUsesWith(Extract); 167 } 168 169 /// Try to reduce extract element costs by converting scalar binops to vector 170 /// binops followed by extract. 171 /// bo (ext0 V0, C), (ext1 V1, C) 172 static void foldExtExtBinop(Instruction *Ext0, Instruction *Ext1, 173 Instruction &I, const TargetTransformInfo &TTI) { 174 assert(isa<BinaryOperator>(&I) && "Expected a binary operator"); 175 176 // bo (extelt V0, C), (extelt V1, C) --> extelt (bo V0, V1), C 177 ++NumVecBO; 178 IRBuilder<> Builder(&I); 179 Value *V0 = Ext0->getOperand(0), *V1 = Ext1->getOperand(0); 180 Value *VecBO = 181 Builder.CreateBinOp(cast<BinaryOperator>(&I)->getOpcode(), V0, V1); 182 183 // All IR flags are safe to back-propagate because any potential poison 184 // created in unused vector elements is discarded by the extract. 185 if (auto *VecBOInst = dyn_cast<Instruction>(VecBO)) 186 VecBOInst->copyIRFlags(&I); 187 188 Value *Extract = Builder.CreateExtractElement(VecBO, Ext0->getOperand(1)); 189 I.replaceAllUsesWith(Extract); 190 } 191 192 /// Match an instruction with extracted vector operands. 193 static bool foldExtractExtract(Instruction &I, const TargetTransformInfo &TTI) { 194 // It is not safe to transform things like div, urem, etc. because we may 195 // create undefined behavior when executing those on unknown vector elements. 196 if (!isSafeToSpeculativelyExecute(&I)) 197 return false; 198 199 Instruction *Ext0, *Ext1; 200 CmpInst::Predicate Pred = CmpInst::BAD_ICMP_PREDICATE; 201 if (!match(&I, m_Cmp(Pred, m_Instruction(Ext0), m_Instruction(Ext1))) && 202 !match(&I, m_BinOp(m_Instruction(Ext0), m_Instruction(Ext1)))) 203 return false; 204 205 Value *V0, *V1; 206 uint64_t C0, C1; 207 if (!match(Ext0, m_ExtractElement(m_Value(V0), m_ConstantInt(C0))) || 208 !match(Ext1, m_ExtractElement(m_Value(V1), m_ConstantInt(C1))) || 209 V0->getType() != V1->getType()) 210 return false; 211 212 Instruction *ConvertToShuffle; 213 if (isExtractExtractCheap(Ext0, Ext1, I.getOpcode(), TTI, ConvertToShuffle)) 214 return false; 215 216 if (ConvertToShuffle) { 217 // The shuffle mask is undefined except for 1 lane that is being translated 218 // to the cheap extraction lane. Example: 219 // ShufMask = { 2, undef, undef, undef } 220 uint64_t SplatIndex = ConvertToShuffle == Ext0 ? C0 : C1; 221 uint64_t CheapExtIndex = ConvertToShuffle == Ext0 ? C1 : C0; 222 Type *VecTy = V0->getType(); 223 Type *I32Ty = IntegerType::getInt32Ty(I.getContext()); 224 UndefValue *Undef = UndefValue::get(I32Ty); 225 SmallVector<Constant *, 32> ShufMask(VecTy->getVectorNumElements(), Undef); 226 ShufMask[CheapExtIndex] = ConstantInt::get(I32Ty, SplatIndex); 227 IRBuilder<> Builder(ConvertToShuffle); 228 229 // extelt X, C --> extelt (splat X), C' 230 Value *Shuf = Builder.CreateShuffleVector(ConvertToShuffle->getOperand(0), 231 UndefValue::get(VecTy), 232 ConstantVector::get(ShufMask)); 233 Value *NewExt = Builder.CreateExtractElement(Shuf, CheapExtIndex); 234 if (ConvertToShuffle == Ext0) 235 Ext0 = cast<Instruction>(NewExt); 236 else 237 Ext1 = cast<Instruction>(NewExt); 238 } 239 240 if (Pred != CmpInst::BAD_ICMP_PREDICATE) 241 foldExtExtCmp(Ext0, Ext1, I, TTI); 242 else 243 foldExtExtBinop(Ext0, Ext1, I, TTI); 244 245 return true; 246 } 247 248 /// If this is a bitcast to narrow elements from a shuffle of wider elements, 249 /// try to bitcast the source vector to the narrow type followed by shuffle. 250 /// This can enable further transforms by moving bitcasts or shuffles together. 251 static bool foldBitcastShuf(Instruction &I, const TargetTransformInfo &TTI) { 252 Value *V; 253 ArrayRef<int> Mask; 254 if (!match(&I, m_BitCast(m_OneUse(m_ShuffleVector(m_Value(V), m_Undef(), 255 m_Mask(Mask)))))) 256 return false; 257 258 Type *DestTy = I.getType(); 259 Type *SrcTy = V->getType(); 260 if (!DestTy->isVectorTy() || I.getOperand(0)->getType() != SrcTy) 261 return false; 262 263 // TODO: Handle bitcast from narrow element type to wide element type. 264 assert(SrcTy->isVectorTy() && "Shuffle of non-vector type?"); 265 unsigned DestNumElts = DestTy->getVectorNumElements(); 266 unsigned SrcNumElts = SrcTy->getVectorNumElements(); 267 if (SrcNumElts > DestNumElts) 268 return false; 269 270 // The new shuffle must not cost more than the old shuffle. The bitcast is 271 // moved ahead of the shuffle, so assume that it has the same cost as before. 272 if (TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, DestTy) > 273 TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, SrcTy)) 274 return false; 275 276 // Bitcast the source vector and expand the shuffle mask to the equivalent for 277 // narrow elements. 278 // bitcast (shuf V, MaskC) --> shuf (bitcast V), MaskC' 279 IRBuilder<> Builder(&I); 280 Value *CastV = Builder.CreateBitCast(V, DestTy); 281 SmallVector<int, 16> NewMask; 282 assert(DestNumElts % SrcNumElts == 0 && "Unexpected shuffle mask"); 283 unsigned ScaleFactor = DestNumElts / SrcNumElts; 284 scaleShuffleMask(ScaleFactor, Mask, NewMask); 285 Value *Shuf = Builder.CreateShuffleVector(CastV, UndefValue::get(DestTy), 286 NewMask); 287 I.replaceAllUsesWith(Shuf); 288 return true; 289 } 290 291 /// This is the entry point for all transforms. Pass manager differences are 292 /// handled in the callers of this function. 293 static bool runImpl(Function &F, const TargetTransformInfo &TTI, 294 const DominatorTree &DT) { 295 if (DisableVectorCombine) 296 return false; 297 298 bool MadeChange = false; 299 for (BasicBlock &BB : F) { 300 // Ignore unreachable basic blocks. 301 if (!DT.isReachableFromEntry(&BB)) 302 continue; 303 // Do not delete instructions under here and invalidate the iterator. 304 // Walk the block backwards for efficiency. We're matching a chain of 305 // use->defs, so we're more likely to succeed by starting from the bottom. 306 // TODO: It could be more efficient to remove dead instructions 307 // iteratively in this loop rather than waiting until the end. 308 for (Instruction &I : make_range(BB.rbegin(), BB.rend())) { 309 if (isa<DbgInfoIntrinsic>(I)) 310 continue; 311 MadeChange |= foldExtractExtract(I, TTI); 312 MadeChange |= foldBitcastShuf(I, TTI); 313 } 314 } 315 316 // We're done with transforms, so remove dead instructions. 317 if (MadeChange) 318 for (BasicBlock &BB : F) 319 SimplifyInstructionsInBlock(&BB); 320 321 return MadeChange; 322 } 323 324 // Pass manager boilerplate below here. 325 326 namespace { 327 class VectorCombineLegacyPass : public FunctionPass { 328 public: 329 static char ID; 330 VectorCombineLegacyPass() : FunctionPass(ID) { 331 initializeVectorCombineLegacyPassPass(*PassRegistry::getPassRegistry()); 332 } 333 334 void getAnalysisUsage(AnalysisUsage &AU) const override { 335 AU.addRequired<DominatorTreeWrapperPass>(); 336 AU.addRequired<TargetTransformInfoWrapperPass>(); 337 AU.setPreservesCFG(); 338 AU.addPreserved<DominatorTreeWrapperPass>(); 339 AU.addPreserved<GlobalsAAWrapperPass>(); 340 FunctionPass::getAnalysisUsage(AU); 341 } 342 343 bool runOnFunction(Function &F) override { 344 if (skipFunction(F)) 345 return false; 346 auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 347 auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); 348 return runImpl(F, TTI, DT); 349 } 350 }; 351 } // namespace 352 353 char VectorCombineLegacyPass::ID = 0; 354 INITIALIZE_PASS_BEGIN(VectorCombineLegacyPass, "vector-combine", 355 "Optimize scalar/vector ops", false, 356 false) 357 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) 358 INITIALIZE_PASS_END(VectorCombineLegacyPass, "vector-combine", 359 "Optimize scalar/vector ops", false, false) 360 Pass *llvm::createVectorCombinePass() { 361 return new VectorCombineLegacyPass(); 362 } 363 364 PreservedAnalyses VectorCombinePass::run(Function &F, 365 FunctionAnalysisManager &FAM) { 366 TargetTransformInfo &TTI = FAM.getResult<TargetIRAnalysis>(F); 367 DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F); 368 if (!runImpl(F, TTI, DT)) 369 return PreservedAnalyses::all(); 370 PreservedAnalyses PA; 371 PA.preserveSet<CFGAnalyses>(); 372 PA.preserve<GlobalsAA>(); 373 return PA; 374 } 375