1 //===------- VectorCombine.cpp - Optimize partial vector operations -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This pass optimizes scalar/vector interactions using target cost models. The 10 // transforms implemented here may not fit in traditional loop-based or SLP 11 // vectorization passes. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "llvm/Transforms/Vectorize/VectorCombine.h" 16 #include "llvm/ADT/Statistic.h" 17 #include "llvm/Analysis/GlobalsModRef.h" 18 #include "llvm/Analysis/TargetTransformInfo.h" 19 #include "llvm/Analysis/ValueTracking.h" 20 #include "llvm/IR/Dominators.h" 21 #include "llvm/IR/Function.h" 22 #include "llvm/IR/IRBuilder.h" 23 #include "llvm/IR/PatternMatch.h" 24 #include "llvm/InitializePasses.h" 25 #include "llvm/Pass.h" 26 #include "llvm/Support/CommandLine.h" 27 #include "llvm/Transforms/Vectorize.h" 28 #include "llvm/Transforms/Utils/Local.h" 29 30 using namespace llvm; 31 using namespace llvm::PatternMatch; 32 33 #define DEBUG_TYPE "vector-combine" 34 STATISTIC(NumVecCmp, "Number of vector compares formed"); 35 STATISTIC(NumVecBO, "Number of vector binops formed"); 36 37 static cl::opt<bool> DisableVectorCombine( 38 "disable-vector-combine", cl::init(false), cl::Hidden, 39 cl::desc("Disable all vector combine transforms")); 40 41 /// Compare the relative costs of extracts followed by scalar operation vs. 42 /// vector operation followed by extract: 43 /// opcode (extelt V0, C), (extelt V1, C) --> extelt (opcode V0, V1), C 44 /// Unless the vector op is much more expensive than the scalar op, this 45 /// eliminates an extract. 46 static bool isExtractExtractCheap(Instruction *Ext0, Instruction *Ext1, 47 unsigned Opcode, 48 const TargetTransformInfo &TTI) { 49 assert(Ext0->getOperand(1) == Ext1->getOperand(1) && 50 isa<ConstantInt>(Ext0->getOperand(1)) && 51 "Expected same constant extract index"); 52 53 Type *ScalarTy = Ext0->getType(); 54 Type *VecTy = Ext0->getOperand(0)->getType(); 55 int ScalarOpCost, VectorOpCost; 56 57 // Get cost estimates for scalar and vector versions of the operation. 58 bool IsBinOp = Instruction::isBinaryOp(Opcode); 59 if (IsBinOp) { 60 ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, ScalarTy); 61 VectorOpCost = TTI.getArithmeticInstrCost(Opcode, VecTy); 62 } else { 63 assert((Opcode == Instruction::ICmp || Opcode == Instruction::FCmp) && 64 "Expected a compare"); 65 ScalarOpCost = TTI.getCmpSelInstrCost(Opcode, ScalarTy, 66 CmpInst::makeCmpResultType(ScalarTy)); 67 VectorOpCost = TTI.getCmpSelInstrCost(Opcode, VecTy, 68 CmpInst::makeCmpResultType(VecTy)); 69 } 70 71 // Get cost estimate for the extract element. This cost will factor into 72 // both sequences. 73 unsigned ExtIndex = cast<ConstantInt>(Ext0->getOperand(1))->getZExtValue(); 74 int ExtractCost = TTI.getVectorInstrCost(Instruction::ExtractElement, 75 VecTy, ExtIndex); 76 77 // Extra uses of the extracts mean that we include those costs in the 78 // vector total because those instructions will not be eliminated. 79 int OldCost, NewCost; 80 if (Ext0->getOperand(0) == Ext1->getOperand(0)) { 81 // Handle a special case. If the 2 operands are identical, adjust the 82 // formulas to account for that. The extra use charge allows for either the 83 // CSE'd pattern or an unoptimized form with identical values: 84 // opcode (extelt V, C), (extelt V, C) --> extelt (opcode V, V), C 85 bool HasUseTax = Ext0 == Ext1 ? !Ext0->hasNUses(2) 86 : !Ext0->hasOneUse() || !Ext1->hasOneUse(); 87 OldCost = ExtractCost + ScalarOpCost; 88 NewCost = VectorOpCost + ExtractCost + HasUseTax * ExtractCost; 89 } else { 90 // Handle the general case. Each extract is actually a different value: 91 // opcode (extelt V0, C), (extelt V1, C) --> extelt (opcode V0, V1), C 92 OldCost = 2 * ExtractCost + ScalarOpCost; 93 NewCost = VectorOpCost + ExtractCost + !Ext0->hasOneUse() * ExtractCost + 94 !Ext1->hasOneUse() * ExtractCost; 95 } 96 // Aggressively form a vector op if the cost is equal because the transform 97 // may enable further optimization. 98 // Codegen can reverse this transform (scalarize) if it was not profitable. 99 return OldCost < NewCost; 100 } 101 102 /// Try to reduce extract element costs by converting scalar compares to vector 103 /// compares followed by extract. 104 /// cmp (ext0 V0, C), (ext1 V1, C) 105 static void foldExtExtCmp(Instruction *Ext0, Instruction *Ext1, 106 Instruction &I, const TargetTransformInfo &TTI) { 107 assert(isa<CmpInst>(&I) && "Expected a compare"); 108 109 // cmp Pred (extelt V0, C), (extelt V1, C) --> extelt (cmp Pred V0, V1), C 110 ++NumVecCmp; 111 IRBuilder<> Builder(&I); 112 CmpInst::Predicate Pred = cast<CmpInst>(&I)->getPredicate(); 113 Value *V0 = Ext0->getOperand(0), *V1 = Ext1->getOperand(0); 114 Value *VecCmp = 115 Ext0->getType()->isFloatingPointTy() ? Builder.CreateFCmp(Pred, V0, V1) 116 : Builder.CreateICmp(Pred, V0, V1); 117 Value *Extract = Builder.CreateExtractElement(VecCmp, Ext0->getOperand(1)); 118 I.replaceAllUsesWith(Extract); 119 } 120 121 /// Try to reduce extract element costs by converting scalar binops to vector 122 /// binops followed by extract. 123 /// bo (ext0 V0, C), (ext1 V1, C) 124 static void foldExtExtBinop(Instruction *Ext0, Instruction *Ext1, 125 Instruction &I, const TargetTransformInfo &TTI) { 126 assert(isa<BinaryOperator>(&I) && "Expected a binary operator"); 127 128 // bo (extelt V0, C), (extelt V1, C) --> extelt (bo V0, V1), C 129 ++NumVecBO; 130 IRBuilder<> Builder(&I); 131 Value *V0 = Ext0->getOperand(0), *V1 = Ext1->getOperand(0); 132 Value *VecBO = 133 Builder.CreateBinOp(cast<BinaryOperator>(&I)->getOpcode(), V0, V1); 134 135 // All IR flags are safe to back-propagate because any potential poison 136 // created in unused vector elements is discarded by the extract. 137 if (auto *VecBOInst = dyn_cast<Instruction>(VecBO)) 138 VecBOInst->copyIRFlags(&I); 139 140 Value *Extract = Builder.CreateExtractElement(VecBO, Ext0->getOperand(1)); 141 I.replaceAllUsesWith(Extract); 142 } 143 144 /// Match an instruction with extracted vector operands. 145 static bool foldExtractExtract(Instruction &I, const TargetTransformInfo &TTI) { 146 // It is not safe to transform things like div, urem, etc. because we may 147 // create undefined behavior when executing those on unknown vector elements. 148 if (!isSafeToSpeculativelyExecute(&I)) 149 return false; 150 151 Instruction *Ext0, *Ext1; 152 CmpInst::Predicate Pred = CmpInst::BAD_ICMP_PREDICATE; 153 if (!match(&I, m_Cmp(Pred, m_Instruction(Ext0), m_Instruction(Ext1))) && 154 !match(&I, m_BinOp(m_Instruction(Ext0), m_Instruction(Ext1)))) 155 return false; 156 157 Value *V0, *V1; 158 uint64_t C0, C1; 159 if (!match(Ext0, m_ExtractElement(m_Value(V0), m_ConstantInt(C0))) || 160 !match(Ext1, m_ExtractElement(m_Value(V1), m_ConstantInt(C1))) || 161 V0->getType() != V1->getType()) 162 return false; 163 164 // TODO: Handle C0 != C1 by shuffling 1 of the operands. 165 if (C0 != C1) 166 return false; 167 168 if (isExtractExtractCheap(Ext0, Ext1, I.getOpcode(), TTI)) 169 return false; 170 171 if (Pred != CmpInst::BAD_ICMP_PREDICATE) 172 foldExtExtCmp(Ext0, Ext1, I, TTI); 173 else 174 foldExtExtBinop(Ext0, Ext1, I, TTI); 175 176 return true; 177 } 178 179 /// This is the entry point for all transforms. Pass manager differences are 180 /// handled in the callers of this function. 181 static bool runImpl(Function &F, const TargetTransformInfo &TTI, 182 const DominatorTree &DT) { 183 if (DisableVectorCombine) 184 return false; 185 186 bool MadeChange = false; 187 for (BasicBlock &BB : F) { 188 // Ignore unreachable basic blocks. 189 if (!DT.isReachableFromEntry(&BB)) 190 continue; 191 // Do not delete instructions under here and invalidate the iterator. 192 // Walk the block backwards for efficiency. We're matching a chain of 193 // use->defs, so we're more likely to succeed by starting from the bottom. 194 // TODO: It could be more efficient to remove dead instructions 195 // iteratively in this loop rather than waiting until the end. 196 for (Instruction &I : make_range(BB.rbegin(), BB.rend())) 197 MadeChange |= foldExtractExtract(I, TTI); 198 } 199 200 // We're done with transforms, so remove dead instructions. 201 if (MadeChange) 202 for (BasicBlock &BB : F) 203 SimplifyInstructionsInBlock(&BB); 204 205 return MadeChange; 206 } 207 208 // Pass manager boilerplate below here. 209 210 namespace { 211 class VectorCombineLegacyPass : public FunctionPass { 212 public: 213 static char ID; 214 VectorCombineLegacyPass() : FunctionPass(ID) { 215 initializeVectorCombineLegacyPassPass(*PassRegistry::getPassRegistry()); 216 } 217 218 void getAnalysisUsage(AnalysisUsage &AU) const override { 219 AU.addRequired<DominatorTreeWrapperPass>(); 220 AU.addRequired<TargetTransformInfoWrapperPass>(); 221 AU.setPreservesCFG(); 222 AU.addPreserved<DominatorTreeWrapperPass>(); 223 AU.addPreserved<GlobalsAAWrapperPass>(); 224 FunctionPass::getAnalysisUsage(AU); 225 } 226 227 bool runOnFunction(Function &F) override { 228 if (skipFunction(F)) 229 return false; 230 auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 231 auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); 232 return runImpl(F, TTI, DT); 233 } 234 }; 235 } // namespace 236 237 char VectorCombineLegacyPass::ID = 0; 238 INITIALIZE_PASS_BEGIN(VectorCombineLegacyPass, "vector-combine", 239 "Optimize scalar/vector ops", false, 240 false) 241 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) 242 INITIALIZE_PASS_END(VectorCombineLegacyPass, "vector-combine", 243 "Optimize scalar/vector ops", false, false) 244 Pass *llvm::createVectorCombinePass() { 245 return new VectorCombineLegacyPass(); 246 } 247 248 PreservedAnalyses VectorCombinePass::run(Function &F, 249 FunctionAnalysisManager &FAM) { 250 TargetTransformInfo &TTI = FAM.getResult<TargetIRAnalysis>(F); 251 DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F); 252 if (!runImpl(F, TTI, DT)) 253 return PreservedAnalyses::all(); 254 PreservedAnalyses PA; 255 PA.preserveSet<CFGAnalyses>(); 256 PA.preserve<GlobalsAA>(); 257 return PA; 258 } 259