1 //===------- VectorCombine.cpp - Optimize partial vector operations -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass optimizes scalar/vector interactions using target cost models. The
10 // transforms implemented here may not fit in traditional loop-based or SLP
11 // vectorization passes.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "llvm/Transforms/Vectorize/VectorCombine.h"
16 #include "llvm/ADT/Statistic.h"
17 #include "llvm/Analysis/GlobalsModRef.h"
18 #include "llvm/Analysis/TargetTransformInfo.h"
19 #include "llvm/Analysis/ValueTracking.h"
20 #include "llvm/IR/Dominators.h"
21 #include "llvm/IR/Function.h"
22 #include "llvm/IR/IRBuilder.h"
23 #include "llvm/IR/PatternMatch.h"
24 #include "llvm/InitializePasses.h"
25 #include "llvm/Pass.h"
26 #include "llvm/Transforms/Vectorize.h"
27 #include "llvm/Transforms/Utils/Local.h"
28 
29 using namespace llvm;
30 using namespace llvm::PatternMatch;
31 
32 #define DEBUG_TYPE "vector-combine"
33 STATISTIC(NumVecCmp, "Number of vector compares formed");
34 STATISTIC(NumVecBO, "Number of vector binops formed");
35 
36 /// Compare the relative costs of extracts followed by scalar operation vs.
37 /// vector operation followed by extract:
38 /// opcode (extelt V0, C), (extelt V1, C) --> extelt (opcode V0, V1), C
39 /// Unless the vector op is much more expensive than the scalar op, this
40 /// eliminates an extract.
41 static bool isExtractExtractCheap(Instruction *Ext0, Instruction *Ext1,
42                                   unsigned Opcode,
43                                   const TargetTransformInfo &TTI) {
44   assert(Ext0->getOperand(1) == Ext1->getOperand(1) &&
45          isa<ConstantInt>(Ext0->getOperand(1)) &&
46          "Expected same constant extract index");
47 
48   Type *ScalarTy = Ext0->getType();
49   Type *VecTy = Ext0->getOperand(0)->getType();
50   int ScalarOpCost, VectorOpCost;
51 
52   // Get cost estimates for scalar and vector versions of the operation.
53   bool IsBinOp = Instruction::isBinaryOp(Opcode);
54   if (IsBinOp) {
55     ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, ScalarTy);
56     VectorOpCost = TTI.getArithmeticInstrCost(Opcode, VecTy);
57   } else {
58     assert((Opcode == Instruction::ICmp || Opcode == Instruction::FCmp) &&
59            "Expected a compare");
60     ScalarOpCost = TTI.getCmpSelInstrCost(Opcode, ScalarTy,
61                                           CmpInst::makeCmpResultType(ScalarTy));
62     VectorOpCost = TTI.getCmpSelInstrCost(Opcode, VecTy,
63                                           CmpInst::makeCmpResultType(VecTy));
64   }
65 
66   // Get cost estimate for the extract element. This cost will factor into
67   // both sequences.
68   unsigned ExtIndex = cast<ConstantInt>(Ext0->getOperand(1))->getZExtValue();
69   int ExtractCost = TTI.getVectorInstrCost(Instruction::ExtractElement,
70                                            VecTy, ExtIndex);
71 
72   // Extra uses of the extracts mean that we include those costs in the
73   // vector total because those instructions will not be eliminated.
74   int OldCost, NewCost;
75   if (Ext0->getOperand(0) == Ext1->getOperand(0)) {
76     // Handle a special case. If the 2 operands are identical, adjust the
77     // formulas to account for that. The extra use charge allows for either the
78     // CSE'd pattern or an unoptimized form with identical values:
79     // opcode (extelt V, C), (extelt V, C) --> extelt (opcode V, V), C
80     bool HasUseTax = Ext0 == Ext1 ? !Ext0->hasNUses(2)
81                                   : !Ext0->hasOneUse() || !Ext1->hasOneUse();
82     OldCost = ExtractCost + ScalarOpCost;
83     NewCost = VectorOpCost + ExtractCost + HasUseTax * ExtractCost;
84   } else {
85     // Handle the general case. Each extract is actually a different value:
86     // opcode (extelt V0, C), (extelt V1, C) --> extelt (opcode V0, V1), C
87     OldCost = 2 * ExtractCost + ScalarOpCost;
88     NewCost = VectorOpCost + ExtractCost + !Ext0->hasOneUse() * ExtractCost +
89               !Ext1->hasOneUse() * ExtractCost;
90   }
91   // TODO: The cost comparison should not differ based on opcode. Either we
92   //       want to be uniformly more or less aggressive in deciding if a vector
93   //       operation should replace the scalar operation.
94   return IsBinOp ? OldCost <= NewCost : OldCost < NewCost;
95 }
96 
97 /// Try to reduce extract element costs by converting scalar compares to vector
98 /// compares followed by extract.
99 /// cmp (ext0 V0, C), (ext1 V1, C)
100 static void foldExtExtCmp(Instruction *Ext0, Instruction *Ext1,
101                           Instruction &I, const TargetTransformInfo &TTI) {
102   assert(isa<CmpInst>(&I) && "Expected a compare");
103 
104   // cmp Pred (extelt V0, C), (extelt V1, C) --> extelt (cmp Pred V0, V1), C
105   ++NumVecCmp;
106   IRBuilder<> Builder(&I);
107   CmpInst::Predicate Pred = cast<CmpInst>(&I)->getPredicate();
108   Value *V0 = Ext0->getOperand(0), *V1 = Ext1->getOperand(0);
109   Value *VecCmp =
110       Ext0->getType()->isFloatingPointTy() ? Builder.CreateFCmp(Pred, V0, V1)
111                                            : Builder.CreateICmp(Pred, V0, V1);
112   Value *Extract = Builder.CreateExtractElement(VecCmp, Ext0->getOperand(1));
113   I.replaceAllUsesWith(Extract);
114 }
115 
116 /// Try to reduce extract element costs by converting scalar binops to vector
117 /// binops followed by extract.
118 /// bo (ext0 V0, C), (ext1 V1, C)
119 static void foldExtExtBinop(Instruction *Ext0, Instruction *Ext1,
120                             Instruction &I, const TargetTransformInfo &TTI) {
121   assert(isa<BinaryOperator>(&I) && "Expected a binary operator");
122 
123   // bo (extelt V0, C), (extelt V1, C) --> extelt (bo V0, V1), C
124   ++NumVecBO;
125   IRBuilder<> Builder(&I);
126   Value *V0 = Ext0->getOperand(0), *V1 = Ext1->getOperand(0);
127   Value *VecBO =
128       Builder.CreateBinOp(cast<BinaryOperator>(&I)->getOpcode(), V0, V1);
129 
130   // All IR flags are safe to back-propagate because any potential poison
131   // created in unused vector elements is discarded by the extract.
132   if (auto *VecBOInst = dyn_cast<Instruction>(VecBO))
133     VecBOInst->copyIRFlags(&I);
134 
135   Value *Extract = Builder.CreateExtractElement(VecBO, Ext0->getOperand(1));
136   I.replaceAllUsesWith(Extract);
137 }
138 
139 /// Match an instruction with extracted vector operands.
140 static bool foldExtractExtract(Instruction &I, const TargetTransformInfo &TTI) {
141   // It is not safe to transform things like div, urem, etc. because we may
142   // create undefined behavior when executing those on unknown vector elements.
143   if (!isSafeToSpeculativelyExecute(&I))
144     return false;
145 
146   Instruction *Ext0, *Ext1;
147   CmpInst::Predicate Pred = CmpInst::BAD_ICMP_PREDICATE;
148   if (!match(&I, m_Cmp(Pred, m_Instruction(Ext0), m_Instruction(Ext1))) &&
149       !match(&I, m_BinOp(m_Instruction(Ext0), m_Instruction(Ext1))))
150     return false;
151 
152   Value *V0, *V1;
153   uint64_t C0, C1;
154   if (!match(Ext0, m_ExtractElement(m_Value(V0), m_ConstantInt(C0))) ||
155       !match(Ext1, m_ExtractElement(m_Value(V1), m_ConstantInt(C1))) ||
156       V0->getType() != V1->getType())
157     return false;
158 
159   // TODO: Handle C0 != C1 by shuffling 1 of the operands.
160   if (C0 != C1)
161     return false;
162 
163   if (isExtractExtractCheap(Ext0, Ext1, I.getOpcode(), TTI))
164     return false;
165 
166   if (Pred != CmpInst::BAD_ICMP_PREDICATE)
167     foldExtExtCmp(Ext0, Ext1, I, TTI);
168   else
169     foldExtExtBinop(Ext0, Ext1, I, TTI);
170 
171   return true;
172 }
173 
174 /// This is the entry point for all transforms. Pass manager differences are
175 /// handled in the callers of this function.
176 static bool runImpl(Function &F, const TargetTransformInfo &TTI,
177                     const DominatorTree &DT) {
178   bool MadeChange = false;
179   for (BasicBlock &BB : F) {
180     // Ignore unreachable basic blocks.
181     if (!DT.isReachableFromEntry(&BB))
182       continue;
183     // Do not delete instructions under here and invalidate the iterator.
184     // Walk the block backwards for efficiency. We're matching a chain of
185     // use->defs, so we're more likely to succeed by starting from the bottom.
186     // TODO: It could be more efficient to remove dead instructions
187     //       iteratively in this loop rather than waiting until the end.
188     for (Instruction &I : make_range(BB.rbegin(), BB.rend()))
189       MadeChange |= foldExtractExtract(I, TTI);
190   }
191 
192   // We're done with transforms, so remove dead instructions.
193   if (MadeChange)
194     for (BasicBlock &BB : F)
195       SimplifyInstructionsInBlock(&BB);
196 
197   return MadeChange;
198 }
199 
200 // Pass manager boilerplate below here.
201 
202 namespace {
203 class VectorCombineLegacyPass : public FunctionPass {
204 public:
205   static char ID;
206   VectorCombineLegacyPass() : FunctionPass(ID) {
207     initializeVectorCombineLegacyPassPass(*PassRegistry::getPassRegistry());
208   }
209 
210   void getAnalysisUsage(AnalysisUsage &AU) const override {
211     AU.addRequired<DominatorTreeWrapperPass>();
212     AU.addRequired<TargetTransformInfoWrapperPass>();
213     AU.setPreservesCFG();
214     AU.addPreserved<DominatorTreeWrapperPass>();
215     AU.addPreserved<GlobalsAAWrapperPass>();
216     FunctionPass::getAnalysisUsage(AU);
217   }
218 
219   bool runOnFunction(Function &F) override {
220     if (skipFunction(F))
221       return false;
222     auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
223     auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
224     return runImpl(F, TTI, DT);
225   }
226 };
227 } // namespace
228 
229 char VectorCombineLegacyPass::ID = 0;
230 INITIALIZE_PASS_BEGIN(VectorCombineLegacyPass, "vector-combine",
231                       "Optimize scalar/vector ops", false,
232                       false)
233 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
234 INITIALIZE_PASS_END(VectorCombineLegacyPass, "vector-combine",
235                     "Optimize scalar/vector ops", false, false)
236 Pass *llvm::createVectorCombinePass() {
237   return new VectorCombineLegacyPass();
238 }
239 
240 PreservedAnalyses VectorCombinePass::run(Function &F,
241                                          FunctionAnalysisManager &FAM) {
242   TargetTransformInfo &TTI = FAM.getResult<TargetIRAnalysis>(F);
243   DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F);
244   if (!runImpl(F, TTI, DT))
245     return PreservedAnalyses::all();
246   PreservedAnalyses PA;
247   PA.preserveSet<CFGAnalyses>();
248   PA.preserve<GlobalsAA>();
249   return PA;
250 }
251