1 //===------- VectorCombine.cpp - Optimize partial vector operations -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This pass optimizes scalar/vector interactions using target cost models. The 10 // transforms implemented here may not fit in traditional loop-based or SLP 11 // vectorization passes. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "llvm/Transforms/Vectorize/VectorCombine.h" 16 #include "llvm/ADT/Statistic.h" 17 #include "llvm/Analysis/GlobalsModRef.h" 18 #include "llvm/Analysis/TargetTransformInfo.h" 19 #include "llvm/Analysis/ValueTracking.h" 20 #include "llvm/IR/Dominators.h" 21 #include "llvm/IR/Function.h" 22 #include "llvm/IR/IRBuilder.h" 23 #include "llvm/IR/PatternMatch.h" 24 #include "llvm/InitializePasses.h" 25 #include "llvm/Pass.h" 26 #include "llvm/Transforms/Vectorize.h" 27 #include "llvm/Transforms/Utils/Local.h" 28 29 using namespace llvm; 30 using namespace llvm::PatternMatch; 31 32 #define DEBUG_TYPE "vector-combine" 33 STATISTIC(NumVecCmp, "Number of vector compares formed"); 34 STATISTIC(NumVecBO, "Number of vector binops formed"); 35 36 static bool foldExtractCmp(Instruction &I, const TargetTransformInfo &TTI) { 37 // Match a cmp with extracted vector operands. 38 CmpInst::Predicate Pred; 39 Instruction *Ext0, *Ext1; 40 if (!match(&I, m_Cmp(Pred, m_Instruction(Ext0), m_Instruction(Ext1)))) 41 return false; 42 43 Value *V0, *V1; 44 ConstantInt *C; 45 if (!match(Ext0, m_ExtractElement(m_Value(V0), m_ConstantInt(C))) || 46 !match(Ext1, m_ExtractElement(m_Value(V1), m_Specific(C))) || 47 V0->getType() != V1->getType()) 48 return false; 49 50 Type *ScalarTy = Ext0->getType(); 51 Type *VecTy = V0->getType(); 52 bool IsFP = ScalarTy->isFloatingPointTy(); 53 unsigned CmpOpcode = IsFP ? Instruction::FCmp : Instruction::ICmp; 54 55 // Check if the existing scalar code or the vector alternative is cheaper. 56 // Extra uses of the extracts mean that we include those costs in the 57 // vector total because those instructions will not be eliminated. 58 // ((2 * extract) + scalar cmp) < (vector cmp + extract) ? 59 int ExtractCost = TTI.getVectorInstrCost(Instruction::ExtractElement, 60 VecTy, C->getZExtValue()); 61 int ScalarCmpCost = TTI.getOperationCost(CmpOpcode, ScalarTy); 62 int VecCmpCost = TTI.getOperationCost(CmpOpcode, VecTy); 63 64 int ScalarCost = 2 * ExtractCost + ScalarCmpCost; 65 int VecCost = VecCmpCost + ExtractCost + 66 !Ext0->hasOneUse() * ExtractCost + 67 !Ext1->hasOneUse() * ExtractCost; 68 if (ScalarCost < VecCost) 69 return false; 70 71 // cmp Pred (extelt V0, C), (extelt V1, C) --> extelt (cmp Pred V0, V1), C 72 ++NumVecCmp; 73 IRBuilder<> Builder(&I); 74 Value *VecCmp = IsFP ? Builder.CreateFCmp(Pred, V0, V1) 75 : Builder.CreateICmp(Pred, V0, V1); 76 Value *Ext = Builder.CreateExtractElement(VecCmp, C); 77 I.replaceAllUsesWith(Ext); 78 return true; 79 } 80 81 /// Try to reduce extract element costs by converting scalar binops to vector 82 /// binops followed by extract. 83 static bool foldExtractBinop(Instruction &I, const TargetTransformInfo &TTI) { 84 // It is not safe to transform things like div, urem, etc. because we may 85 // create undefined behavior when executing those on unknown vector elements. 86 if (!isSafeToSpeculativelyExecute(&I)) 87 return false; 88 89 // Match a scalar binop with extracted vector operands: 90 // bo (extelt X, C0), (extelt Y, C1) 91 Instruction *Ext0, *Ext1; 92 if (!match(&I, m_BinOp(m_Instruction(Ext0), m_Instruction(Ext1)))) 93 return false; 94 95 Value *X, *Y; 96 uint64_t C0, C1; 97 if (!match(Ext0, m_ExtractElement(m_Value(X), m_ConstantInt(C0))) || 98 !match(Ext1, m_ExtractElement(m_Value(Y), m_ConstantInt(C1))) || 99 X->getType() != Y->getType()) 100 return false; 101 102 // Check if using a vector binop would be cheaper. 103 Instruction::BinaryOps BOpcode = cast<BinaryOperator>(I).getOpcode(); 104 Type *ScalarTy = I.getType(); 105 Type *VecTy = X->getType(); 106 int ScalarBOCost = TTI.getArithmeticInstrCost(BOpcode, ScalarTy); 107 int VecBOCost = TTI.getArithmeticInstrCost(BOpcode, VecTy); 108 int Extract0Cost = TTI.getVectorInstrCost(Instruction::ExtractElement, 109 VecTy, C0); 110 111 // Handle a special case - if the extract indexes are the same, the 112 // replacement sequence does not require a shuffle. Unless the vector binop is 113 // much more expensive than the scalar binop, this eliminates an extract. 114 // Extra uses of the extracts mean that we include those costs in the 115 // vector total because those instructions will not be eliminated. 116 if (C0 == C1) { 117 assert(Extract0Cost == 118 TTI.getVectorInstrCost(Instruction::ExtractElement, VecTy, C1) && 119 "Different costs for same extract?"); 120 int ExtractCost = Extract0Cost; 121 if (X != Y) { 122 int ScalarCost = ExtractCost + ExtractCost + ScalarBOCost; 123 int VecCost = VecBOCost + ExtractCost + 124 !Ext0->hasOneUse() * ExtractCost + 125 !Ext1->hasOneUse() * ExtractCost; 126 if (ScalarCost <= VecCost) 127 return false; 128 } else { 129 // Handle an extra-special case. If the 2 binop operands are identical, 130 // adjust the formulas to account for that: 131 // bo (extelt X, C), (extelt X, C) --> extelt (bo X, X), C 132 // The extra use charge allows for either the CSE'd pattern or an 133 // unoptimized form with identical values. 134 bool HasUseTax = Ext0 == Ext1 ? !Ext0->hasNUses(2) 135 : !Ext0->hasOneUse() || !Ext1->hasOneUse(); 136 int ScalarCost = ExtractCost + ScalarBOCost; 137 int VecCost = VecBOCost + ExtractCost + HasUseTax * ExtractCost; 138 if (ScalarCost <= VecCost) 139 return false; 140 } 141 142 // bo (extelt X, C), (extelt Y, C) --> extelt (bo X, Y), C 143 ++NumVecBO; 144 IRBuilder<> Builder(&I); 145 Value *NewBO = Builder.CreateBinOp(BOpcode, X, Y); 146 if (auto *VecBOInst = dyn_cast<Instruction>(NewBO)) { 147 // All IR flags are safe to back-propagate because any potential poison 148 // created in unused vector elements is discarded by the extract. 149 VecBOInst->copyIRFlags(&I); 150 } 151 Value *Extract = Builder.CreateExtractElement(NewBO, Ext0->getOperand(1)); 152 I.replaceAllUsesWith(Extract); 153 return true; 154 } 155 156 // TODO: Handle C0 != C1 by shuffling 1 of the operands. 157 return false; 158 } 159 160 /// This is the entry point for all transforms. Pass manager differences are 161 /// handled in the callers of this function. 162 static bool runImpl(Function &F, const TargetTransformInfo &TTI, 163 const DominatorTree &DT) { 164 bool MadeChange = false; 165 for (BasicBlock &BB : F) { 166 // Ignore unreachable basic blocks. 167 if (!DT.isReachableFromEntry(&BB)) 168 continue; 169 // Do not delete instructions under here and invalidate the iterator. 170 // Walk the block backwards for efficiency. We're matching a chain of 171 // use->defs, so we're more likely to succeed by starting from the bottom. 172 // TODO: It could be more efficient to remove dead instructions 173 // iteratively in this loop rather than waiting until the end. 174 for (Instruction &I : make_range(BB.rbegin(), BB.rend())) { 175 MadeChange |= foldExtractCmp(I, TTI); 176 MadeChange |= foldExtractBinop(I, TTI); 177 } 178 } 179 180 // We're done with transforms, so remove dead instructions. 181 if (MadeChange) 182 for (BasicBlock &BB : F) 183 SimplifyInstructionsInBlock(&BB); 184 185 return MadeChange; 186 } 187 188 // Pass manager boilerplate below here. 189 190 namespace { 191 class VectorCombineLegacyPass : public FunctionPass { 192 public: 193 static char ID; 194 VectorCombineLegacyPass() : FunctionPass(ID) { 195 initializeVectorCombineLegacyPassPass(*PassRegistry::getPassRegistry()); 196 } 197 198 void getAnalysisUsage(AnalysisUsage &AU) const override { 199 AU.addRequired<DominatorTreeWrapperPass>(); 200 AU.addRequired<TargetTransformInfoWrapperPass>(); 201 AU.setPreservesCFG(); 202 AU.addPreserved<DominatorTreeWrapperPass>(); 203 AU.addPreserved<GlobalsAAWrapperPass>(); 204 FunctionPass::getAnalysisUsage(AU); 205 } 206 207 bool runOnFunction(Function &F) override { 208 if (skipFunction(F)) 209 return false; 210 auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 211 auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); 212 return runImpl(F, TTI, DT); 213 } 214 }; 215 } // namespace 216 217 char VectorCombineLegacyPass::ID = 0; 218 INITIALIZE_PASS_BEGIN(VectorCombineLegacyPass, "vector-combine", 219 "Optimize scalar/vector ops", false, 220 false) 221 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) 222 INITIALIZE_PASS_END(VectorCombineLegacyPass, "vector-combine", 223 "Optimize scalar/vector ops", false, false) 224 Pass *llvm::createVectorCombinePass() { 225 return new VectorCombineLegacyPass(); 226 } 227 228 PreservedAnalyses VectorCombinePass::run(Function &F, 229 FunctionAnalysisManager &FAM) { 230 TargetTransformInfo &TTI = FAM.getResult<TargetIRAnalysis>(F); 231 DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F); 232 if (!runImpl(F, TTI, DT)) 233 return PreservedAnalyses::all(); 234 PreservedAnalyses PA; 235 PA.preserveSet<CFGAnalyses>(); 236 PA.preserve<GlobalsAA>(); 237 return PA; 238 } 239