1 //===------- VectorCombine.cpp - Optimize partial vector operations -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This pass optimizes scalar/vector interactions using target cost models. The 10 // transforms implemented here may not fit in traditional loop-based or SLP 11 // vectorization passes. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "llvm/Transforms/Vectorize/VectorCombine.h" 16 #include "llvm/ADT/Statistic.h" 17 #include "llvm/Analysis/GlobalsModRef.h" 18 #include "llvm/Analysis/TargetTransformInfo.h" 19 #include "llvm/Analysis/ValueTracking.h" 20 #include "llvm/IR/Dominators.h" 21 #include "llvm/IR/Function.h" 22 #include "llvm/IR/IRBuilder.h" 23 #include "llvm/IR/PatternMatch.h" 24 #include "llvm/InitializePasses.h" 25 #include "llvm/Pass.h" 26 #include "llvm/Transforms/Vectorize.h" 27 #include "llvm/Transforms/Utils/Local.h" 28 29 using namespace llvm; 30 using namespace llvm::PatternMatch; 31 32 #define DEBUG_TYPE "vector-combine" 33 STATISTIC(NumVecCmp, "Number of vector compares formed"); 34 STATISTIC(NumVecBO, "Number of vector binops formed"); 35 36 static bool foldExtractCmp(Instruction &I, const TargetTransformInfo &TTI) { 37 // Match a cmp with extracted vector operands. 38 CmpInst::Predicate Pred; 39 Instruction *Ext0, *Ext1; 40 if (!match(&I, m_Cmp(Pred, m_Instruction(Ext0), m_Instruction(Ext1)))) 41 return false; 42 43 Value *V0, *V1; 44 ConstantInt *C; 45 if (!match(Ext0, m_ExtractElement(m_Value(V0), m_ConstantInt(C))) || 46 !match(Ext1, m_ExtractElement(m_Value(V1), m_Specific(C))) || 47 V0->getType() != V1->getType()) 48 return false; 49 50 Type *ScalarTy = Ext0->getType(); 51 Type *VecTy = V0->getType(); 52 bool IsFP = ScalarTy->isFloatingPointTy(); 53 unsigned CmpOpcode = IsFP ? Instruction::FCmp : Instruction::ICmp; 54 55 // Check if the existing scalar code or the vector alternative is cheaper. 56 // Extra uses of the extracts mean that we include those costs in the 57 // vector total because those instructions will not be eliminated. 58 // ((2 * extract) + scalar cmp) < (vector cmp + extract) ? 59 int ExtractCost = TTI.getVectorInstrCost(Instruction::ExtractElement, 60 VecTy, C->getZExtValue()); 61 int ScalarCmpCost = TTI.getCmpSelInstrCost(CmpOpcode, ScalarTy, I.getType()); 62 int VecCmpCost = TTI.getCmpSelInstrCost(CmpOpcode, VecTy, 63 CmpInst::makeCmpResultType(VecTy)); 64 65 int ScalarCost = 2 * ExtractCost + ScalarCmpCost; 66 int VecCost = VecCmpCost + ExtractCost + 67 !Ext0->hasOneUse() * ExtractCost + 68 !Ext1->hasOneUse() * ExtractCost; 69 if (ScalarCost < VecCost) 70 return false; 71 72 // cmp Pred (extelt V0, C), (extelt V1, C) --> extelt (cmp Pred V0, V1), C 73 ++NumVecCmp; 74 IRBuilder<> Builder(&I); 75 Value *VecCmp = IsFP ? Builder.CreateFCmp(Pred, V0, V1) 76 : Builder.CreateICmp(Pred, V0, V1); 77 Value *Ext = Builder.CreateExtractElement(VecCmp, C); 78 I.replaceAllUsesWith(Ext); 79 return true; 80 } 81 82 /// Try to reduce extract element costs by converting scalar binops to vector 83 /// binops followed by extract. 84 static bool foldExtractBinop(Instruction &I, const TargetTransformInfo &TTI) { 85 // It is not safe to transform things like div, urem, etc. because we may 86 // create undefined behavior when executing those on unknown vector elements. 87 if (!isSafeToSpeculativelyExecute(&I)) 88 return false; 89 90 // Match a scalar binop with extracted vector operands: 91 // bo (extelt X, C0), (extelt Y, C1) 92 Instruction *Ext0, *Ext1; 93 if (!match(&I, m_BinOp(m_Instruction(Ext0), m_Instruction(Ext1)))) 94 return false; 95 96 Value *X, *Y; 97 uint64_t C0, C1; 98 if (!match(Ext0, m_ExtractElement(m_Value(X), m_ConstantInt(C0))) || 99 !match(Ext1, m_ExtractElement(m_Value(Y), m_ConstantInt(C1))) || 100 X->getType() != Y->getType()) 101 return false; 102 103 // Check if using a vector binop would be cheaper. 104 Instruction::BinaryOps BOpcode = cast<BinaryOperator>(I).getOpcode(); 105 Type *ScalarTy = I.getType(); 106 Type *VecTy = X->getType(); 107 int ScalarBOCost = TTI.getArithmeticInstrCost(BOpcode, ScalarTy); 108 int VecBOCost = TTI.getArithmeticInstrCost(BOpcode, VecTy); 109 int Extract0Cost = TTI.getVectorInstrCost(Instruction::ExtractElement, 110 VecTy, C0); 111 112 // Handle a special case - if the extract indexes are the same, the 113 // replacement sequence does not require a shuffle. Unless the vector binop is 114 // much more expensive than the scalar binop, this eliminates an extract. 115 // Extra uses of the extracts mean that we include those costs in the 116 // vector total because those instructions will not be eliminated. 117 if (C0 == C1) { 118 assert(Extract0Cost == 119 TTI.getVectorInstrCost(Instruction::ExtractElement, VecTy, C1) && 120 "Different costs for same extract?"); 121 int ExtractCost = Extract0Cost; 122 if (X != Y) { 123 int ScalarCost = ExtractCost + ExtractCost + ScalarBOCost; 124 int VecCost = VecBOCost + ExtractCost + 125 !Ext0->hasOneUse() * ExtractCost + 126 !Ext1->hasOneUse() * ExtractCost; 127 if (ScalarCost <= VecCost) 128 return false; 129 } else { 130 // Handle an extra-special case. If the 2 binop operands are identical, 131 // adjust the formulas to account for that: 132 // bo (extelt X, C), (extelt X, C) --> extelt (bo X, X), C 133 // The extra use charge allows for either the CSE'd pattern or an 134 // unoptimized form with identical values. 135 bool HasUseTax = Ext0 == Ext1 ? !Ext0->hasNUses(2) 136 : !Ext0->hasOneUse() || !Ext1->hasOneUse(); 137 int ScalarCost = ExtractCost + ScalarBOCost; 138 int VecCost = VecBOCost + ExtractCost + HasUseTax * ExtractCost; 139 if (ScalarCost <= VecCost) 140 return false; 141 } 142 143 // bo (extelt X, C), (extelt Y, C) --> extelt (bo X, Y), C 144 ++NumVecBO; 145 IRBuilder<> Builder(&I); 146 Value *NewBO = Builder.CreateBinOp(BOpcode, X, Y); 147 if (auto *VecBOInst = dyn_cast<Instruction>(NewBO)) { 148 // All IR flags are safe to back-propagate because any potential poison 149 // created in unused vector elements is discarded by the extract. 150 VecBOInst->copyIRFlags(&I); 151 } 152 Value *Extract = Builder.CreateExtractElement(NewBO, Ext0->getOperand(1)); 153 I.replaceAllUsesWith(Extract); 154 return true; 155 } 156 157 // TODO: Handle C0 != C1 by shuffling 1 of the operands. 158 return false; 159 } 160 161 /// This is the entry point for all transforms. Pass manager differences are 162 /// handled in the callers of this function. 163 static bool runImpl(Function &F, const TargetTransformInfo &TTI, 164 const DominatorTree &DT) { 165 bool MadeChange = false; 166 for (BasicBlock &BB : F) { 167 // Ignore unreachable basic blocks. 168 if (!DT.isReachableFromEntry(&BB)) 169 continue; 170 // Do not delete instructions under here and invalidate the iterator. 171 // Walk the block backwards for efficiency. We're matching a chain of 172 // use->defs, so we're more likely to succeed by starting from the bottom. 173 // TODO: It could be more efficient to remove dead instructions 174 // iteratively in this loop rather than waiting until the end. 175 for (Instruction &I : make_range(BB.rbegin(), BB.rend())) { 176 MadeChange |= foldExtractCmp(I, TTI); 177 MadeChange |= foldExtractBinop(I, TTI); 178 } 179 } 180 181 // We're done with transforms, so remove dead instructions. 182 if (MadeChange) 183 for (BasicBlock &BB : F) 184 SimplifyInstructionsInBlock(&BB); 185 186 return MadeChange; 187 } 188 189 // Pass manager boilerplate below here. 190 191 namespace { 192 class VectorCombineLegacyPass : public FunctionPass { 193 public: 194 static char ID; 195 VectorCombineLegacyPass() : FunctionPass(ID) { 196 initializeVectorCombineLegacyPassPass(*PassRegistry::getPassRegistry()); 197 } 198 199 void getAnalysisUsage(AnalysisUsage &AU) const override { 200 AU.addRequired<DominatorTreeWrapperPass>(); 201 AU.addRequired<TargetTransformInfoWrapperPass>(); 202 AU.setPreservesCFG(); 203 AU.addPreserved<DominatorTreeWrapperPass>(); 204 AU.addPreserved<GlobalsAAWrapperPass>(); 205 FunctionPass::getAnalysisUsage(AU); 206 } 207 208 bool runOnFunction(Function &F) override { 209 if (skipFunction(F)) 210 return false; 211 auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 212 auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); 213 return runImpl(F, TTI, DT); 214 } 215 }; 216 } // namespace 217 218 char VectorCombineLegacyPass::ID = 0; 219 INITIALIZE_PASS_BEGIN(VectorCombineLegacyPass, "vector-combine", 220 "Optimize scalar/vector ops", false, 221 false) 222 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) 223 INITIALIZE_PASS_END(VectorCombineLegacyPass, "vector-combine", 224 "Optimize scalar/vector ops", false, false) 225 Pass *llvm::createVectorCombinePass() { 226 return new VectorCombineLegacyPass(); 227 } 228 229 PreservedAnalyses VectorCombinePass::run(Function &F, 230 FunctionAnalysisManager &FAM) { 231 TargetTransformInfo &TTI = FAM.getResult<TargetIRAnalysis>(F); 232 DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F); 233 if (!runImpl(F, TTI, DT)) 234 return PreservedAnalyses::all(); 235 PreservedAnalyses PA; 236 PA.preserveSet<CFGAnalyses>(); 237 PA.preserve<GlobalsAA>(); 238 return PA; 239 } 240