1 //===------- VectorCombine.cpp - Optimize partial vector operations -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This pass optimizes scalar/vector interactions using target cost models. The 10 // transforms implemented here may not fit in traditional loop-based or SLP 11 // vectorization passes. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "llvm/Transforms/Vectorize/VectorCombine.h" 16 #include "llvm/ADT/Statistic.h" 17 #include "llvm/Analysis/GlobalsModRef.h" 18 #include "llvm/Analysis/TargetTransformInfo.h" 19 #include "llvm/Analysis/ValueTracking.h" 20 #include "llvm/IR/Dominators.h" 21 #include "llvm/IR/Function.h" 22 #include "llvm/IR/IRBuilder.h" 23 #include "llvm/IR/PatternMatch.h" 24 #include "llvm/InitializePasses.h" 25 #include "llvm/Pass.h" 26 #include "llvm/Support/CommandLine.h" 27 #include "llvm/Transforms/Vectorize.h" 28 #include "llvm/Transforms/Utils/Local.h" 29 30 using namespace llvm; 31 using namespace llvm::PatternMatch; 32 33 #define DEBUG_TYPE "vector-combine" 34 STATISTIC(NumVecCmp, "Number of vector compares formed"); 35 STATISTIC(NumVecBO, "Number of vector binops formed"); 36 37 static cl::opt<bool> DisableVectorCombine( 38 "disable-vector-combine", cl::init(false), cl::Hidden, 39 cl::desc("Disable all vector combine transforms")); 40 41 /// Compare the relative costs of extracts followed by scalar operation vs. 42 /// vector operation followed by extract: 43 /// opcode (extelt V0, C), (extelt V1, C) --> extelt (opcode V0, V1), C 44 /// Unless the vector op is much more expensive than the scalar op, this 45 /// eliminates an extract. 46 static bool isExtractExtractCheap(Instruction *Ext0, Instruction *Ext1, 47 unsigned Opcode, 48 const TargetTransformInfo &TTI) { 49 assert(isa<ConstantInt>(Ext0->getOperand(1)) && 50 (cast<ConstantInt>(Ext0->getOperand(1))->getZExtValue() == 51 cast<ConstantInt>(Ext1->getOperand(1))->getZExtValue()) && 52 "Expected same constant extract index"); 53 54 Type *ScalarTy = Ext0->getType(); 55 Type *VecTy = Ext0->getOperand(0)->getType(); 56 int ScalarOpCost, VectorOpCost; 57 58 // Get cost estimates for scalar and vector versions of the operation. 59 bool IsBinOp = Instruction::isBinaryOp(Opcode); 60 if (IsBinOp) { 61 ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, ScalarTy); 62 VectorOpCost = TTI.getArithmeticInstrCost(Opcode, VecTy); 63 } else { 64 assert((Opcode == Instruction::ICmp || Opcode == Instruction::FCmp) && 65 "Expected a compare"); 66 ScalarOpCost = TTI.getCmpSelInstrCost(Opcode, ScalarTy, 67 CmpInst::makeCmpResultType(ScalarTy)); 68 VectorOpCost = TTI.getCmpSelInstrCost(Opcode, VecTy, 69 CmpInst::makeCmpResultType(VecTy)); 70 } 71 72 // Get cost estimate for the extract element. This cost will factor into 73 // both sequences. 74 unsigned ExtIndex = cast<ConstantInt>(Ext0->getOperand(1))->getZExtValue(); 75 int ExtractCost = TTI.getVectorInstrCost(Instruction::ExtractElement, 76 VecTy, ExtIndex); 77 78 // Extra uses of the extracts mean that we include those costs in the 79 // vector total because those instructions will not be eliminated. 80 int OldCost, NewCost; 81 if (Ext0->getOperand(0) == Ext1->getOperand(0)) { 82 // Handle a special case. If the 2 operands are identical, adjust the 83 // formulas to account for that. The extra use charge allows for either the 84 // CSE'd pattern or an unoptimized form with identical values: 85 // opcode (extelt V, C), (extelt V, C) --> extelt (opcode V, V), C 86 bool HasUseTax = Ext0 == Ext1 ? !Ext0->hasNUses(2) 87 : !Ext0->hasOneUse() || !Ext1->hasOneUse(); 88 OldCost = ExtractCost + ScalarOpCost; 89 NewCost = VectorOpCost + ExtractCost + HasUseTax * ExtractCost; 90 } else { 91 // Handle the general case. Each extract is actually a different value: 92 // opcode (extelt V0, C), (extelt V1, C) --> extelt (opcode V0, V1), C 93 OldCost = 2 * ExtractCost + ScalarOpCost; 94 NewCost = VectorOpCost + ExtractCost + !Ext0->hasOneUse() * ExtractCost + 95 !Ext1->hasOneUse() * ExtractCost; 96 } 97 // Aggressively form a vector op if the cost is equal because the transform 98 // may enable further optimization. 99 // Codegen can reverse this transform (scalarize) if it was not profitable. 100 return OldCost < NewCost; 101 } 102 103 /// Try to reduce extract element costs by converting scalar compares to vector 104 /// compares followed by extract. 105 /// cmp (ext0 V0, C), (ext1 V1, C) 106 static void foldExtExtCmp(Instruction *Ext0, Instruction *Ext1, 107 Instruction &I, const TargetTransformInfo &TTI) { 108 assert(isa<CmpInst>(&I) && "Expected a compare"); 109 110 // cmp Pred (extelt V0, C), (extelt V1, C) --> extelt (cmp Pred V0, V1), C 111 ++NumVecCmp; 112 IRBuilder<> Builder(&I); 113 CmpInst::Predicate Pred = cast<CmpInst>(&I)->getPredicate(); 114 Value *V0 = Ext0->getOperand(0), *V1 = Ext1->getOperand(0); 115 Value *VecCmp = 116 Ext0->getType()->isFloatingPointTy() ? Builder.CreateFCmp(Pred, V0, V1) 117 : Builder.CreateICmp(Pred, V0, V1); 118 Value *Extract = Builder.CreateExtractElement(VecCmp, Ext0->getOperand(1)); 119 I.replaceAllUsesWith(Extract); 120 } 121 122 /// Try to reduce extract element costs by converting scalar binops to vector 123 /// binops followed by extract. 124 /// bo (ext0 V0, C), (ext1 V1, C) 125 static void foldExtExtBinop(Instruction *Ext0, Instruction *Ext1, 126 Instruction &I, const TargetTransformInfo &TTI) { 127 assert(isa<BinaryOperator>(&I) && "Expected a binary operator"); 128 129 // bo (extelt V0, C), (extelt V1, C) --> extelt (bo V0, V1), C 130 ++NumVecBO; 131 IRBuilder<> Builder(&I); 132 Value *V0 = Ext0->getOperand(0), *V1 = Ext1->getOperand(0); 133 Value *VecBO = 134 Builder.CreateBinOp(cast<BinaryOperator>(&I)->getOpcode(), V0, V1); 135 136 // All IR flags are safe to back-propagate because any potential poison 137 // created in unused vector elements is discarded by the extract. 138 if (auto *VecBOInst = dyn_cast<Instruction>(VecBO)) 139 VecBOInst->copyIRFlags(&I); 140 141 Value *Extract = Builder.CreateExtractElement(VecBO, Ext0->getOperand(1)); 142 I.replaceAllUsesWith(Extract); 143 } 144 145 /// Match an instruction with extracted vector operands. 146 static bool foldExtractExtract(Instruction &I, const TargetTransformInfo &TTI) { 147 // It is not safe to transform things like div, urem, etc. because we may 148 // create undefined behavior when executing those on unknown vector elements. 149 if (!isSafeToSpeculativelyExecute(&I)) 150 return false; 151 152 Instruction *Ext0, *Ext1; 153 CmpInst::Predicate Pred = CmpInst::BAD_ICMP_PREDICATE; 154 if (!match(&I, m_Cmp(Pred, m_Instruction(Ext0), m_Instruction(Ext1))) && 155 !match(&I, m_BinOp(m_Instruction(Ext0), m_Instruction(Ext1)))) 156 return false; 157 158 Value *V0, *V1; 159 uint64_t C0, C1; 160 if (!match(Ext0, m_ExtractElement(m_Value(V0), m_ConstantInt(C0))) || 161 !match(Ext1, m_ExtractElement(m_Value(V1), m_ConstantInt(C1))) || 162 V0->getType() != V1->getType()) 163 return false; 164 165 // TODO: Handle C0 != C1 by shuffling 1 of the operands. 166 if (C0 != C1) 167 return false; 168 169 if (isExtractExtractCheap(Ext0, Ext1, I.getOpcode(), TTI)) 170 return false; 171 172 if (Pred != CmpInst::BAD_ICMP_PREDICATE) 173 foldExtExtCmp(Ext0, Ext1, I, TTI); 174 else 175 foldExtExtBinop(Ext0, Ext1, I, TTI); 176 177 return true; 178 } 179 180 /// This is the entry point for all transforms. Pass manager differences are 181 /// handled in the callers of this function. 182 static bool runImpl(Function &F, const TargetTransformInfo &TTI, 183 const DominatorTree &DT) { 184 if (DisableVectorCombine) 185 return false; 186 187 bool MadeChange = false; 188 for (BasicBlock &BB : F) { 189 // Ignore unreachable basic blocks. 190 if (!DT.isReachableFromEntry(&BB)) 191 continue; 192 // Do not delete instructions under here and invalidate the iterator. 193 // Walk the block backwards for efficiency. We're matching a chain of 194 // use->defs, so we're more likely to succeed by starting from the bottom. 195 // TODO: It could be more efficient to remove dead instructions 196 // iteratively in this loop rather than waiting until the end. 197 for (Instruction &I : make_range(BB.rbegin(), BB.rend())) 198 MadeChange |= foldExtractExtract(I, TTI); 199 } 200 201 // We're done with transforms, so remove dead instructions. 202 if (MadeChange) 203 for (BasicBlock &BB : F) 204 SimplifyInstructionsInBlock(&BB); 205 206 return MadeChange; 207 } 208 209 // Pass manager boilerplate below here. 210 211 namespace { 212 class VectorCombineLegacyPass : public FunctionPass { 213 public: 214 static char ID; 215 VectorCombineLegacyPass() : FunctionPass(ID) { 216 initializeVectorCombineLegacyPassPass(*PassRegistry::getPassRegistry()); 217 } 218 219 void getAnalysisUsage(AnalysisUsage &AU) const override { 220 AU.addRequired<DominatorTreeWrapperPass>(); 221 AU.addRequired<TargetTransformInfoWrapperPass>(); 222 AU.setPreservesCFG(); 223 AU.addPreserved<DominatorTreeWrapperPass>(); 224 AU.addPreserved<GlobalsAAWrapperPass>(); 225 FunctionPass::getAnalysisUsage(AU); 226 } 227 228 bool runOnFunction(Function &F) override { 229 if (skipFunction(F)) 230 return false; 231 auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 232 auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); 233 return runImpl(F, TTI, DT); 234 } 235 }; 236 } // namespace 237 238 char VectorCombineLegacyPass::ID = 0; 239 INITIALIZE_PASS_BEGIN(VectorCombineLegacyPass, "vector-combine", 240 "Optimize scalar/vector ops", false, 241 false) 242 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) 243 INITIALIZE_PASS_END(VectorCombineLegacyPass, "vector-combine", 244 "Optimize scalar/vector ops", false, false) 245 Pass *llvm::createVectorCombinePass() { 246 return new VectorCombineLegacyPass(); 247 } 248 249 PreservedAnalyses VectorCombinePass::run(Function &F, 250 FunctionAnalysisManager &FAM) { 251 TargetTransformInfo &TTI = FAM.getResult<TargetIRAnalysis>(F); 252 DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F); 253 if (!runImpl(F, TTI, DT)) 254 return PreservedAnalyses::all(); 255 PreservedAnalyses PA; 256 PA.preserveSet<CFGAnalyses>(); 257 PA.preserve<GlobalsAA>(); 258 return PA; 259 } 260