1 //===------- VectorCombine.cpp - Optimize partial vector operations -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass optimizes scalar/vector interactions using target cost models. The
10 // transforms implemented here may not fit in traditional loop-based or SLP
11 // vectorization passes.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "llvm/Transforms/Vectorize/VectorCombine.h"
16 #include "llvm/ADT/Statistic.h"
17 #include "llvm/Analysis/GlobalsModRef.h"
18 #include "llvm/Analysis/TargetTransformInfo.h"
19 #include "llvm/Analysis/ValueTracking.h"
20 #include "llvm/IR/Dominators.h"
21 #include "llvm/IR/Function.h"
22 #include "llvm/IR/IRBuilder.h"
23 #include "llvm/IR/PatternMatch.h"
24 #include "llvm/InitializePasses.h"
25 #include "llvm/Pass.h"
26 #include "llvm/Support/CommandLine.h"
27 #include "llvm/Transforms/Vectorize.h"
28 #include "llvm/Transforms/Utils/Local.h"
29 
30 using namespace llvm;
31 using namespace llvm::PatternMatch;
32 
33 #define DEBUG_TYPE "vector-combine"
34 STATISTIC(NumVecCmp, "Number of vector compares formed");
35 STATISTIC(NumVecBO, "Number of vector binops formed");
36 
37 static cl::opt<bool> DisableVectorCombine(
38     "disable-vector-combine", cl::init(false), cl::Hidden,
39     cl::desc("Disable all vector combine transforms"));
40 
41 /// Compare the relative costs of extracts followed by scalar operation vs.
42 /// vector operation followed by extract:
43 /// opcode (extelt V0, C), (extelt V1, C) --> extelt (opcode V0, V1), C
44 /// Unless the vector op is much more expensive than the scalar op, this
45 /// eliminates an extract.
46 static bool isExtractExtractCheap(Instruction *Ext0, Instruction *Ext1,
47                                   unsigned Opcode,
48                                   const TargetTransformInfo &TTI) {
49   assert(isa<ConstantInt>(Ext0->getOperand(1)) &&
50          (cast<ConstantInt>(Ext0->getOperand(1))->getZExtValue() ==
51           cast<ConstantInt>(Ext1->getOperand(1))->getZExtValue()) &&
52          "Expected same constant extract index");
53 
54   Type *ScalarTy = Ext0->getType();
55   Type *VecTy = Ext0->getOperand(0)->getType();
56   int ScalarOpCost, VectorOpCost;
57 
58   // Get cost estimates for scalar and vector versions of the operation.
59   bool IsBinOp = Instruction::isBinaryOp(Opcode);
60   if (IsBinOp) {
61     ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, ScalarTy);
62     VectorOpCost = TTI.getArithmeticInstrCost(Opcode, VecTy);
63   } else {
64     assert((Opcode == Instruction::ICmp || Opcode == Instruction::FCmp) &&
65            "Expected a compare");
66     ScalarOpCost = TTI.getCmpSelInstrCost(Opcode, ScalarTy,
67                                           CmpInst::makeCmpResultType(ScalarTy));
68     VectorOpCost = TTI.getCmpSelInstrCost(Opcode, VecTy,
69                                           CmpInst::makeCmpResultType(VecTy));
70   }
71 
72   // Get cost estimate for the extract element. This cost will factor into
73   // both sequences.
74   unsigned ExtIndex = cast<ConstantInt>(Ext0->getOperand(1))->getZExtValue();
75   int ExtractCost = TTI.getVectorInstrCost(Instruction::ExtractElement,
76                                            VecTy, ExtIndex);
77 
78   // Extra uses of the extracts mean that we include those costs in the
79   // vector total because those instructions will not be eliminated.
80   int OldCost, NewCost;
81   if (Ext0->getOperand(0) == Ext1->getOperand(0)) {
82     // Handle a special case. If the 2 operands are identical, adjust the
83     // formulas to account for that. The extra use charge allows for either the
84     // CSE'd pattern or an unoptimized form with identical values:
85     // opcode (extelt V, C), (extelt V, C) --> extelt (opcode V, V), C
86     bool HasUseTax = Ext0 == Ext1 ? !Ext0->hasNUses(2)
87                                   : !Ext0->hasOneUse() || !Ext1->hasOneUse();
88     OldCost = ExtractCost + ScalarOpCost;
89     NewCost = VectorOpCost + ExtractCost + HasUseTax * ExtractCost;
90   } else {
91     // Handle the general case. Each extract is actually a different value:
92     // opcode (extelt V0, C), (extelt V1, C) --> extelt (opcode V0, V1), C
93     OldCost = 2 * ExtractCost + ScalarOpCost;
94     NewCost = VectorOpCost + ExtractCost + !Ext0->hasOneUse() * ExtractCost +
95               !Ext1->hasOneUse() * ExtractCost;
96   }
97   // Aggressively form a vector op if the cost is equal because the transform
98   // may enable further optimization.
99   // Codegen can reverse this transform (scalarize) if it was not profitable.
100   return OldCost < NewCost;
101 }
102 
103 /// Try to reduce extract element costs by converting scalar compares to vector
104 /// compares followed by extract.
105 /// cmp (ext0 V0, C), (ext1 V1, C)
106 static void foldExtExtCmp(Instruction *Ext0, Instruction *Ext1,
107                           Instruction &I, const TargetTransformInfo &TTI) {
108   assert(isa<CmpInst>(&I) && "Expected a compare");
109 
110   // cmp Pred (extelt V0, C), (extelt V1, C) --> extelt (cmp Pred V0, V1), C
111   ++NumVecCmp;
112   IRBuilder<> Builder(&I);
113   CmpInst::Predicate Pred = cast<CmpInst>(&I)->getPredicate();
114   Value *V0 = Ext0->getOperand(0), *V1 = Ext1->getOperand(0);
115   Value *VecCmp =
116       Ext0->getType()->isFloatingPointTy() ? Builder.CreateFCmp(Pred, V0, V1)
117                                            : Builder.CreateICmp(Pred, V0, V1);
118   Value *Extract = Builder.CreateExtractElement(VecCmp, Ext0->getOperand(1));
119   I.replaceAllUsesWith(Extract);
120 }
121 
122 /// Try to reduce extract element costs by converting scalar binops to vector
123 /// binops followed by extract.
124 /// bo (ext0 V0, C), (ext1 V1, C)
125 static void foldExtExtBinop(Instruction *Ext0, Instruction *Ext1,
126                             Instruction &I, const TargetTransformInfo &TTI) {
127   assert(isa<BinaryOperator>(&I) && "Expected a binary operator");
128 
129   // bo (extelt V0, C), (extelt V1, C) --> extelt (bo V0, V1), C
130   ++NumVecBO;
131   IRBuilder<> Builder(&I);
132   Value *V0 = Ext0->getOperand(0), *V1 = Ext1->getOperand(0);
133   Value *VecBO =
134       Builder.CreateBinOp(cast<BinaryOperator>(&I)->getOpcode(), V0, V1);
135 
136   // All IR flags are safe to back-propagate because any potential poison
137   // created in unused vector elements is discarded by the extract.
138   if (auto *VecBOInst = dyn_cast<Instruction>(VecBO))
139     VecBOInst->copyIRFlags(&I);
140 
141   Value *Extract = Builder.CreateExtractElement(VecBO, Ext0->getOperand(1));
142   I.replaceAllUsesWith(Extract);
143 }
144 
145 /// Match an instruction with extracted vector operands.
146 static bool foldExtractExtract(Instruction &I, const TargetTransformInfo &TTI) {
147   // It is not safe to transform things like div, urem, etc. because we may
148   // create undefined behavior when executing those on unknown vector elements.
149   if (!isSafeToSpeculativelyExecute(&I))
150     return false;
151 
152   Instruction *Ext0, *Ext1;
153   CmpInst::Predicate Pred = CmpInst::BAD_ICMP_PREDICATE;
154   if (!match(&I, m_Cmp(Pred, m_Instruction(Ext0), m_Instruction(Ext1))) &&
155       !match(&I, m_BinOp(m_Instruction(Ext0), m_Instruction(Ext1))))
156     return false;
157 
158   Value *V0, *V1;
159   uint64_t C0, C1;
160   if (!match(Ext0, m_ExtractElement(m_Value(V0), m_ConstantInt(C0))) ||
161       !match(Ext1, m_ExtractElement(m_Value(V1), m_ConstantInt(C1))) ||
162       V0->getType() != V1->getType())
163     return false;
164 
165   // TODO: Handle C0 != C1 by shuffling 1 of the operands.
166   if (C0 != C1)
167     return false;
168 
169   if (isExtractExtractCheap(Ext0, Ext1, I.getOpcode(), TTI))
170     return false;
171 
172   if (Pred != CmpInst::BAD_ICMP_PREDICATE)
173     foldExtExtCmp(Ext0, Ext1, I, TTI);
174   else
175     foldExtExtBinop(Ext0, Ext1, I, TTI);
176 
177   return true;
178 }
179 
180 /// This is the entry point for all transforms. Pass manager differences are
181 /// handled in the callers of this function.
182 static bool runImpl(Function &F, const TargetTransformInfo &TTI,
183                     const DominatorTree &DT) {
184   if (DisableVectorCombine)
185     return false;
186 
187   bool MadeChange = false;
188   for (BasicBlock &BB : F) {
189     // Ignore unreachable basic blocks.
190     if (!DT.isReachableFromEntry(&BB))
191       continue;
192     // Do not delete instructions under here and invalidate the iterator.
193     // Walk the block backwards for efficiency. We're matching a chain of
194     // use->defs, so we're more likely to succeed by starting from the bottom.
195     // TODO: It could be more efficient to remove dead instructions
196     //       iteratively in this loop rather than waiting until the end.
197     for (Instruction &I : make_range(BB.rbegin(), BB.rend()))
198       MadeChange |= foldExtractExtract(I, TTI);
199   }
200 
201   // We're done with transforms, so remove dead instructions.
202   if (MadeChange)
203     for (BasicBlock &BB : F)
204       SimplifyInstructionsInBlock(&BB);
205 
206   return MadeChange;
207 }
208 
209 // Pass manager boilerplate below here.
210 
211 namespace {
212 class VectorCombineLegacyPass : public FunctionPass {
213 public:
214   static char ID;
215   VectorCombineLegacyPass() : FunctionPass(ID) {
216     initializeVectorCombineLegacyPassPass(*PassRegistry::getPassRegistry());
217   }
218 
219   void getAnalysisUsage(AnalysisUsage &AU) const override {
220     AU.addRequired<DominatorTreeWrapperPass>();
221     AU.addRequired<TargetTransformInfoWrapperPass>();
222     AU.setPreservesCFG();
223     AU.addPreserved<DominatorTreeWrapperPass>();
224     AU.addPreserved<GlobalsAAWrapperPass>();
225     FunctionPass::getAnalysisUsage(AU);
226   }
227 
228   bool runOnFunction(Function &F) override {
229     if (skipFunction(F))
230       return false;
231     auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
232     auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
233     return runImpl(F, TTI, DT);
234   }
235 };
236 } // namespace
237 
238 char VectorCombineLegacyPass::ID = 0;
239 INITIALIZE_PASS_BEGIN(VectorCombineLegacyPass, "vector-combine",
240                       "Optimize scalar/vector ops", false,
241                       false)
242 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
243 INITIALIZE_PASS_END(VectorCombineLegacyPass, "vector-combine",
244                     "Optimize scalar/vector ops", false, false)
245 Pass *llvm::createVectorCombinePass() {
246   return new VectorCombineLegacyPass();
247 }
248 
249 PreservedAnalyses VectorCombinePass::run(Function &F,
250                                          FunctionAnalysisManager &FAM) {
251   TargetTransformInfo &TTI = FAM.getResult<TargetIRAnalysis>(F);
252   DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F);
253   if (!runImpl(F, TTI, DT))
254     return PreservedAnalyses::all();
255   PreservedAnalyses PA;
256   PA.preserveSet<CFGAnalyses>();
257   PA.preserve<GlobalsAA>();
258   return PA;
259 }
260