1 //===------- VectorCombine.cpp - Optimize partial vector operations -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This pass optimizes scalar/vector interactions using target cost models. The 10 // transforms implemented here may not fit in traditional loop-based or SLP 11 // vectorization passes. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "llvm/Transforms/Vectorize/VectorCombine.h" 16 #include "llvm/ADT/Statistic.h" 17 #include "llvm/Analysis/GlobalsModRef.h" 18 #include "llvm/Analysis/TargetTransformInfo.h" 19 #include "llvm/Analysis/ValueTracking.h" 20 #include "llvm/IR/Dominators.h" 21 #include "llvm/IR/Function.h" 22 #include "llvm/IR/IRBuilder.h" 23 #include "llvm/IR/PatternMatch.h" 24 #include "llvm/InitializePasses.h" 25 #include "llvm/Pass.h" 26 #include "llvm/Transforms/Vectorize.h" 27 #include "llvm/Transforms/Utils/Local.h" 28 29 using namespace llvm; 30 using namespace llvm::PatternMatch; 31 32 #define DEBUG_TYPE "vector-combine" 33 STATISTIC(NumVecCmp, "Number of vector compares formed"); 34 STATISTIC(NumVecBO, "Number of vector binops formed"); 35 36 /// Compare the relative costs of extracts followed by scalar operation vs. 37 /// vector operation followed by extract: 38 /// opcode (extelt V0, C), (extelt V1, C) --> extelt (opcode V0, V1), C 39 /// Unless the vector op is much more expensive than the scalar op, this 40 /// eliminates an extract. 41 static bool isExtractExtractCheap(Instruction *Ext0, Instruction *Ext1, 42 unsigned Opcode, 43 const TargetTransformInfo &TTI) { 44 assert(Ext0->getOperand(1) == Ext1->getOperand(1) && 45 isa<ConstantInt>(Ext0->getOperand(1)) && 46 "Expected same constant extract index"); 47 48 Type *ScalarTy = Ext0->getType(); 49 Type *VecTy = Ext0->getOperand(0)->getType(); 50 int ScalarOpCost, VectorOpCost; 51 52 // Get cost estimates for scalar and vector versions of the operation. 53 bool IsBinOp = Instruction::isBinaryOp(Opcode); 54 if (IsBinOp) { 55 ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, ScalarTy); 56 VectorOpCost = TTI.getArithmeticInstrCost(Opcode, VecTy); 57 } else { 58 assert((Opcode == Instruction::ICmp || Opcode == Instruction::FCmp) && 59 "Expected a compare"); 60 ScalarOpCost = TTI.getCmpSelInstrCost(Opcode, ScalarTy, 61 CmpInst::makeCmpResultType(ScalarTy)); 62 VectorOpCost = TTI.getCmpSelInstrCost(Opcode, VecTy, 63 CmpInst::makeCmpResultType(VecTy)); 64 } 65 66 // Get cost estimate for the extract element. This cost will factor into 67 // both sequences. 68 unsigned ExtIndex = cast<ConstantInt>(Ext0->getOperand(1))->getZExtValue(); 69 int ExtractCost = TTI.getVectorInstrCost(Instruction::ExtractElement, 70 VecTy, ExtIndex); 71 72 // Extra uses of the extracts mean that we include those costs in the 73 // vector total because those instructions will not be eliminated. 74 int OldCost, NewCost; 75 if (Ext0->getOperand(0) == Ext1->getOperand(0)) { 76 // Handle a special case. If the 2 operands are identical, adjust the 77 // formulas to account for that. The extra use charge allows for either the 78 // CSE'd pattern or an unoptimized form with identical values: 79 // opcode (extelt V, C), (extelt V, C) --> extelt (opcode V, V), C 80 bool HasUseTax = Ext0 == Ext1 ? !Ext0->hasNUses(2) 81 : !Ext0->hasOneUse() || !Ext1->hasOneUse(); 82 OldCost = ExtractCost + ScalarOpCost; 83 NewCost = VectorOpCost + ExtractCost + HasUseTax * ExtractCost; 84 } else { 85 // Handle the general case. Each extract is actually a different value: 86 // opcode (extelt V0, C), (extelt V1, C) --> extelt (opcode V0, V1), C 87 OldCost = 2 * ExtractCost + ScalarOpCost; 88 NewCost = VectorOpCost + ExtractCost + !Ext0->hasOneUse() * ExtractCost + 89 !Ext1->hasOneUse() * ExtractCost; 90 } 91 // TODO: The cost comparison should not differ based on opcode. Either we 92 // want to be uniformly more or less aggressive in deciding if a vector 93 // operation should replace the scalar operation. 94 return IsBinOp ? OldCost <= NewCost : OldCost < NewCost; 95 } 96 97 /// Try to reduce extract element costs by converting scalar compares to vector 98 /// compares followed by extract. 99 /// cmp (ext0 V0, C), (ext1 V1, C) 100 static void foldExtExtCmp(Instruction *Ext0, Instruction *Ext1, 101 Instruction &I, const TargetTransformInfo &TTI) { 102 assert(isa<CmpInst>(&I) && "Expected a compare"); 103 104 // cmp Pred (extelt V0, C), (extelt V1, C) --> extelt (cmp Pred V0, V1), C 105 ++NumVecCmp; 106 IRBuilder<> Builder(&I); 107 CmpInst::Predicate Pred = cast<CmpInst>(&I)->getPredicate(); 108 Value *V0 = Ext0->getOperand(0), *V1 = Ext1->getOperand(0); 109 Value *VecCmp = 110 Ext0->getType()->isFloatingPointTy() ? Builder.CreateFCmp(Pred, V0, V1) 111 : Builder.CreateICmp(Pred, V0, V1); 112 Value *Extract = Builder.CreateExtractElement(VecCmp, Ext0->getOperand(1)); 113 I.replaceAllUsesWith(Extract); 114 } 115 116 /// Try to reduce extract element costs by converting scalar binops to vector 117 /// binops followed by extract. 118 /// bo (ext0 V0, C), (ext1 V1, C) 119 static void foldExtExtBinop(Instruction *Ext0, Instruction *Ext1, 120 Instruction &I, const TargetTransformInfo &TTI) { 121 assert(isa<BinaryOperator>(&I) && "Expected a binary operator"); 122 123 // bo (extelt V0, C), (extelt V1, C) --> extelt (bo V0, V1), C 124 ++NumVecBO; 125 IRBuilder<> Builder(&I); 126 Value *V0 = Ext0->getOperand(0), *V1 = Ext1->getOperand(0); 127 Value *VecBO = 128 Builder.CreateBinOp(cast<BinaryOperator>(&I)->getOpcode(), V0, V1); 129 130 // All IR flags are safe to back-propagate because any potential poison 131 // created in unused vector elements is discarded by the extract. 132 if (auto *VecBOInst = dyn_cast<Instruction>(VecBO)) 133 VecBOInst->copyIRFlags(&I); 134 135 Value *Extract = Builder.CreateExtractElement(VecBO, Ext0->getOperand(1)); 136 I.replaceAllUsesWith(Extract); 137 } 138 139 /// Match an instruction with extracted vector operands. 140 static bool foldExtractExtract(Instruction &I, const TargetTransformInfo &TTI) { 141 // It is not safe to transform things like div, urem, etc. because we may 142 // create undefined behavior when executing those on unknown vector elements. 143 if (!isSafeToSpeculativelyExecute(&I)) 144 return false; 145 146 Instruction *Ext0, *Ext1; 147 CmpInst::Predicate Pred = CmpInst::BAD_ICMP_PREDICATE; 148 if (!match(&I, m_Cmp(Pred, m_Instruction(Ext0), m_Instruction(Ext1))) && 149 !match(&I, m_BinOp(m_Instruction(Ext0), m_Instruction(Ext1)))) 150 return false; 151 152 Value *V0, *V1; 153 uint64_t C0, C1; 154 if (!match(Ext0, m_ExtractElement(m_Value(V0), m_ConstantInt(C0))) || 155 !match(Ext1, m_ExtractElement(m_Value(V1), m_ConstantInt(C1))) || 156 V0->getType() != V1->getType()) 157 return false; 158 159 // TODO: Handle C0 != C1 by shuffling 1 of the operands. 160 if (C0 != C1) 161 return false; 162 163 if (isExtractExtractCheap(Ext0, Ext1, I.getOpcode(), TTI)) 164 return false; 165 166 if (Pred != CmpInst::BAD_ICMP_PREDICATE) 167 foldExtExtCmp(Ext0, Ext1, I, TTI); 168 else 169 foldExtExtBinop(Ext0, Ext1, I, TTI); 170 171 return true; 172 } 173 174 /// This is the entry point for all transforms. Pass manager differences are 175 /// handled in the callers of this function. 176 static bool runImpl(Function &F, const TargetTransformInfo &TTI, 177 const DominatorTree &DT) { 178 bool MadeChange = false; 179 for (BasicBlock &BB : F) { 180 // Ignore unreachable basic blocks. 181 if (!DT.isReachableFromEntry(&BB)) 182 continue; 183 // Do not delete instructions under here and invalidate the iterator. 184 // Walk the block backwards for efficiency. We're matching a chain of 185 // use->defs, so we're more likely to succeed by starting from the bottom. 186 // TODO: It could be more efficient to remove dead instructions 187 // iteratively in this loop rather than waiting until the end. 188 for (Instruction &I : make_range(BB.rbegin(), BB.rend())) 189 MadeChange |= foldExtractExtract(I, TTI); 190 } 191 192 // We're done with transforms, so remove dead instructions. 193 if (MadeChange) 194 for (BasicBlock &BB : F) 195 SimplifyInstructionsInBlock(&BB); 196 197 return MadeChange; 198 } 199 200 // Pass manager boilerplate below here. 201 202 namespace { 203 class VectorCombineLegacyPass : public FunctionPass { 204 public: 205 static char ID; 206 VectorCombineLegacyPass() : FunctionPass(ID) { 207 initializeVectorCombineLegacyPassPass(*PassRegistry::getPassRegistry()); 208 } 209 210 void getAnalysisUsage(AnalysisUsage &AU) const override { 211 AU.addRequired<DominatorTreeWrapperPass>(); 212 AU.addRequired<TargetTransformInfoWrapperPass>(); 213 AU.setPreservesCFG(); 214 AU.addPreserved<DominatorTreeWrapperPass>(); 215 AU.addPreserved<GlobalsAAWrapperPass>(); 216 FunctionPass::getAnalysisUsage(AU); 217 } 218 219 bool runOnFunction(Function &F) override { 220 if (skipFunction(F)) 221 return false; 222 auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 223 auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); 224 return runImpl(F, TTI, DT); 225 } 226 }; 227 } // namespace 228 229 char VectorCombineLegacyPass::ID = 0; 230 INITIALIZE_PASS_BEGIN(VectorCombineLegacyPass, "vector-combine", 231 "Optimize scalar/vector ops", false, 232 false) 233 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) 234 INITIALIZE_PASS_END(VectorCombineLegacyPass, "vector-combine", 235 "Optimize scalar/vector ops", false, false) 236 Pass *llvm::createVectorCombinePass() { 237 return new VectorCombineLegacyPass(); 238 } 239 240 PreservedAnalyses VectorCombinePass::run(Function &F, 241 FunctionAnalysisManager &FAM) { 242 TargetTransformInfo &TTI = FAM.getResult<TargetIRAnalysis>(F); 243 DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F); 244 if (!runImpl(F, TTI, DT)) 245 return PreservedAnalyses::all(); 246 PreservedAnalyses PA; 247 PA.preserveSet<CFGAnalyses>(); 248 PA.preserve<GlobalsAA>(); 249 return PA; 250 } 251