1 //===------- VectorCombine.cpp - Optimize partial vector operations -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass optimizes scalar/vector interactions using target cost models. The
10 // transforms implemented here may not fit in traditional loop-based or SLP
11 // vectorization passes.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "llvm/Transforms/Vectorize/VectorCombine.h"
16 #include "llvm/ADT/Statistic.h"
17 #include "llvm/Analysis/GlobalsModRef.h"
18 #include "llvm/Analysis/TargetTransformInfo.h"
19 #include "llvm/Analysis/ValueTracking.h"
20 #include "llvm/IR/Dominators.h"
21 #include "llvm/IR/Function.h"
22 #include "llvm/IR/IRBuilder.h"
23 #include "llvm/IR/PatternMatch.h"
24 #include "llvm/InitializePasses.h"
25 #include "llvm/Pass.h"
26 #include "llvm/Support/CommandLine.h"
27 #include "llvm/Transforms/Vectorize.h"
28 #include "llvm/Transforms/Utils/Local.h"
29 
30 using namespace llvm;
31 using namespace llvm::PatternMatch;
32 
33 #define DEBUG_TYPE "vector-combine"
34 STATISTIC(NumVecCmp, "Number of vector compares formed");
35 STATISTIC(NumVecBO, "Number of vector binops formed");
36 
37 static cl::opt<bool> DisableVectorCombine(
38     "disable-vector-combine", cl::init(false), cl::Hidden,
39     cl::desc("Disable all vector combine transforms"));
40 
41 /// Compare the relative costs of extracts followed by scalar operation vs.
42 /// vector operation followed by extract:
43 /// opcode (extelt V0, C), (extelt V1, C) --> extelt (opcode V0, V1), C
44 /// Unless the vector op is much more expensive than the scalar op, this
45 /// eliminates an extract.
46 static bool isExtractExtractCheap(Instruction *Ext0, Instruction *Ext1,
47                                   unsigned Opcode,
48                                   const TargetTransformInfo &TTI) {
49   assert(Ext0->getOperand(1) == Ext1->getOperand(1) &&
50          isa<ConstantInt>(Ext0->getOperand(1)) &&
51          "Expected same constant extract index");
52 
53   Type *ScalarTy = Ext0->getType();
54   Type *VecTy = Ext0->getOperand(0)->getType();
55   int ScalarOpCost, VectorOpCost;
56 
57   // Get cost estimates for scalar and vector versions of the operation.
58   bool IsBinOp = Instruction::isBinaryOp(Opcode);
59   if (IsBinOp) {
60     ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, ScalarTy);
61     VectorOpCost = TTI.getArithmeticInstrCost(Opcode, VecTy);
62   } else {
63     assert((Opcode == Instruction::ICmp || Opcode == Instruction::FCmp) &&
64            "Expected a compare");
65     ScalarOpCost = TTI.getCmpSelInstrCost(Opcode, ScalarTy,
66                                           CmpInst::makeCmpResultType(ScalarTy));
67     VectorOpCost = TTI.getCmpSelInstrCost(Opcode, VecTy,
68                                           CmpInst::makeCmpResultType(VecTy));
69   }
70 
71   // Get cost estimate for the extract element. This cost will factor into
72   // both sequences.
73   unsigned ExtIndex = cast<ConstantInt>(Ext0->getOperand(1))->getZExtValue();
74   int ExtractCost = TTI.getVectorInstrCost(Instruction::ExtractElement,
75                                            VecTy, ExtIndex);
76 
77   // Extra uses of the extracts mean that we include those costs in the
78   // vector total because those instructions will not be eliminated.
79   int OldCost, NewCost;
80   if (Ext0->getOperand(0) == Ext1->getOperand(0)) {
81     // Handle a special case. If the 2 operands are identical, adjust the
82     // formulas to account for that. The extra use charge allows for either the
83     // CSE'd pattern or an unoptimized form with identical values:
84     // opcode (extelt V, C), (extelt V, C) --> extelt (opcode V, V), C
85     bool HasUseTax = Ext0 == Ext1 ? !Ext0->hasNUses(2)
86                                   : !Ext0->hasOneUse() || !Ext1->hasOneUse();
87     OldCost = ExtractCost + ScalarOpCost;
88     NewCost = VectorOpCost + ExtractCost + HasUseTax * ExtractCost;
89   } else {
90     // Handle the general case. Each extract is actually a different value:
91     // opcode (extelt V0, C), (extelt V1, C) --> extelt (opcode V0, V1), C
92     OldCost = 2 * ExtractCost + ScalarOpCost;
93     NewCost = VectorOpCost + ExtractCost + !Ext0->hasOneUse() * ExtractCost +
94               !Ext1->hasOneUse() * ExtractCost;
95   }
96   // Aggressively form a vector op if the cost is equal because the transform
97   // may enable further optimization.
98   // Codegen can reverse this transform (scalarize) if it was not profitable.
99   return OldCost < NewCost;
100 }
101 
102 /// Try to reduce extract element costs by converting scalar compares to vector
103 /// compares followed by extract.
104 /// cmp (ext0 V0, C), (ext1 V1, C)
105 static void foldExtExtCmp(Instruction *Ext0, Instruction *Ext1,
106                           Instruction &I, const TargetTransformInfo &TTI) {
107   assert(isa<CmpInst>(&I) && "Expected a compare");
108 
109   // cmp Pred (extelt V0, C), (extelt V1, C) --> extelt (cmp Pred V0, V1), C
110   ++NumVecCmp;
111   IRBuilder<> Builder(&I);
112   CmpInst::Predicate Pred = cast<CmpInst>(&I)->getPredicate();
113   Value *V0 = Ext0->getOperand(0), *V1 = Ext1->getOperand(0);
114   Value *VecCmp =
115       Ext0->getType()->isFloatingPointTy() ? Builder.CreateFCmp(Pred, V0, V1)
116                                            : Builder.CreateICmp(Pred, V0, V1);
117   Value *Extract = Builder.CreateExtractElement(VecCmp, Ext0->getOperand(1));
118   I.replaceAllUsesWith(Extract);
119 }
120 
121 /// Try to reduce extract element costs by converting scalar binops to vector
122 /// binops followed by extract.
123 /// bo (ext0 V0, C), (ext1 V1, C)
124 static void foldExtExtBinop(Instruction *Ext0, Instruction *Ext1,
125                             Instruction &I, const TargetTransformInfo &TTI) {
126   assert(isa<BinaryOperator>(&I) && "Expected a binary operator");
127 
128   // bo (extelt V0, C), (extelt V1, C) --> extelt (bo V0, V1), C
129   ++NumVecBO;
130   IRBuilder<> Builder(&I);
131   Value *V0 = Ext0->getOperand(0), *V1 = Ext1->getOperand(0);
132   Value *VecBO =
133       Builder.CreateBinOp(cast<BinaryOperator>(&I)->getOpcode(), V0, V1);
134 
135   // All IR flags are safe to back-propagate because any potential poison
136   // created in unused vector elements is discarded by the extract.
137   if (auto *VecBOInst = dyn_cast<Instruction>(VecBO))
138     VecBOInst->copyIRFlags(&I);
139 
140   Value *Extract = Builder.CreateExtractElement(VecBO, Ext0->getOperand(1));
141   I.replaceAllUsesWith(Extract);
142 }
143 
144 /// Match an instruction with extracted vector operands.
145 static bool foldExtractExtract(Instruction &I, const TargetTransformInfo &TTI) {
146   // It is not safe to transform things like div, urem, etc. because we may
147   // create undefined behavior when executing those on unknown vector elements.
148   if (!isSafeToSpeculativelyExecute(&I))
149     return false;
150 
151   Instruction *Ext0, *Ext1;
152   CmpInst::Predicate Pred = CmpInst::BAD_ICMP_PREDICATE;
153   if (!match(&I, m_Cmp(Pred, m_Instruction(Ext0), m_Instruction(Ext1))) &&
154       !match(&I, m_BinOp(m_Instruction(Ext0), m_Instruction(Ext1))))
155     return false;
156 
157   Value *V0, *V1;
158   uint64_t C0, C1;
159   if (!match(Ext0, m_ExtractElement(m_Value(V0), m_ConstantInt(C0))) ||
160       !match(Ext1, m_ExtractElement(m_Value(V1), m_ConstantInt(C1))) ||
161       V0->getType() != V1->getType())
162     return false;
163 
164   // TODO: Handle C0 != C1 by shuffling 1 of the operands.
165   if (C0 != C1)
166     return false;
167 
168   if (isExtractExtractCheap(Ext0, Ext1, I.getOpcode(), TTI))
169     return false;
170 
171   if (Pred != CmpInst::BAD_ICMP_PREDICATE)
172     foldExtExtCmp(Ext0, Ext1, I, TTI);
173   else
174     foldExtExtBinop(Ext0, Ext1, I, TTI);
175 
176   return true;
177 }
178 
179 /// This is the entry point for all transforms. Pass manager differences are
180 /// handled in the callers of this function.
181 static bool runImpl(Function &F, const TargetTransformInfo &TTI,
182                     const DominatorTree &DT) {
183   if (DisableVectorCombine)
184     return false;
185 
186   bool MadeChange = false;
187   for (BasicBlock &BB : F) {
188     // Ignore unreachable basic blocks.
189     if (!DT.isReachableFromEntry(&BB))
190       continue;
191     // Do not delete instructions under here and invalidate the iterator.
192     // Walk the block backwards for efficiency. We're matching a chain of
193     // use->defs, so we're more likely to succeed by starting from the bottom.
194     // TODO: It could be more efficient to remove dead instructions
195     //       iteratively in this loop rather than waiting until the end.
196     for (Instruction &I : make_range(BB.rbegin(), BB.rend()))
197       MadeChange |= foldExtractExtract(I, TTI);
198   }
199 
200   // We're done with transforms, so remove dead instructions.
201   if (MadeChange)
202     for (BasicBlock &BB : F)
203       SimplifyInstructionsInBlock(&BB);
204 
205   return MadeChange;
206 }
207 
208 // Pass manager boilerplate below here.
209 
210 namespace {
211 class VectorCombineLegacyPass : public FunctionPass {
212 public:
213   static char ID;
214   VectorCombineLegacyPass() : FunctionPass(ID) {
215     initializeVectorCombineLegacyPassPass(*PassRegistry::getPassRegistry());
216   }
217 
218   void getAnalysisUsage(AnalysisUsage &AU) const override {
219     AU.addRequired<DominatorTreeWrapperPass>();
220     AU.addRequired<TargetTransformInfoWrapperPass>();
221     AU.setPreservesCFG();
222     AU.addPreserved<DominatorTreeWrapperPass>();
223     AU.addPreserved<GlobalsAAWrapperPass>();
224     FunctionPass::getAnalysisUsage(AU);
225   }
226 
227   bool runOnFunction(Function &F) override {
228     if (skipFunction(F))
229       return false;
230     auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
231     auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
232     return runImpl(F, TTI, DT);
233   }
234 };
235 } // namespace
236 
237 char VectorCombineLegacyPass::ID = 0;
238 INITIALIZE_PASS_BEGIN(VectorCombineLegacyPass, "vector-combine",
239                       "Optimize scalar/vector ops", false,
240                       false)
241 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
242 INITIALIZE_PASS_END(VectorCombineLegacyPass, "vector-combine",
243                     "Optimize scalar/vector ops", false, false)
244 Pass *llvm::createVectorCombinePass() {
245   return new VectorCombineLegacyPass();
246 }
247 
248 PreservedAnalyses VectorCombinePass::run(Function &F,
249                                          FunctionAnalysisManager &FAM) {
250   TargetTransformInfo &TTI = FAM.getResult<TargetIRAnalysis>(F);
251   DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F);
252   if (!runImpl(F, TTI, DT))
253     return PreservedAnalyses::all();
254   PreservedAnalyses PA;
255   PA.preserveSet<CFGAnalyses>();
256   PA.preserve<GlobalsAA>();
257   return PA;
258 }
259