1 //===------- VectorCombine.cpp - Optimize partial vector operations -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass optimizes scalar/vector interactions using target cost models. The
10 // transforms implemented here may not fit in traditional loop-based or SLP
11 // vectorization passes.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "llvm/Transforms/Vectorize/VectorCombine.h"
16 #include "llvm/ADT/Statistic.h"
17 #include "llvm/Analysis/GlobalsModRef.h"
18 #include "llvm/Analysis/TargetTransformInfo.h"
19 #include "llvm/Analysis/ValueTracking.h"
20 #include "llvm/IR/Dominators.h"
21 #include "llvm/IR/Function.h"
22 #include "llvm/IR/IRBuilder.h"
23 #include "llvm/IR/PatternMatch.h"
24 #include "llvm/InitializePasses.h"
25 #include "llvm/Pass.h"
26 #include "llvm/Transforms/Vectorize.h"
27 #include "llvm/Transforms/Utils/Local.h"
28 
29 using namespace llvm;
30 using namespace llvm::PatternMatch;
31 
32 #define DEBUG_TYPE "vector-combine"
33 STATISTIC(NumVecCmp, "Number of vector compares formed");
34 STATISTIC(NumVecBO, "Number of vector binops formed");
35 
36 static bool foldExtractCmp(Instruction &I, const TargetTransformInfo &TTI) {
37   // Match a cmp with extracted vector operands.
38   CmpInst::Predicate Pred;
39   Instruction *Ext0, *Ext1;
40   if (!match(&I, m_Cmp(Pred, m_Instruction(Ext0), m_Instruction(Ext1))))
41     return false;
42 
43   Value *V0, *V1;
44   ConstantInt *C;
45   if (!match(Ext0, m_ExtractElement(m_Value(V0), m_ConstantInt(C))) ||
46       !match(Ext1, m_ExtractElement(m_Value(V1), m_Specific(C))) ||
47       V0->getType() != V1->getType())
48     return false;
49 
50   Type *ScalarTy = Ext0->getType();
51   Type *VecTy = V0->getType();
52   bool IsFP = ScalarTy->isFloatingPointTy();
53   unsigned CmpOpcode = IsFP ? Instruction::FCmp : Instruction::ICmp;
54 
55   // Check if the existing scalar code or the vector alternative is cheaper.
56   // Extra uses of the extracts mean that we include those costs in the
57   // vector total because those instructions will not be eliminated.
58   // ((2 * extract) + scalar cmp) < (vector cmp + extract) ?
59   int ExtractCost = TTI.getVectorInstrCost(Instruction::ExtractElement,
60                                            VecTy, C->getZExtValue());
61   int ScalarCmpCost = TTI.getCmpSelInstrCost(CmpOpcode, ScalarTy, I.getType());
62   int VecCmpCost = TTI.getCmpSelInstrCost(CmpOpcode, VecTy,
63                                           CmpInst::makeCmpResultType(VecTy));
64 
65   int ScalarCost = 2 * ExtractCost + ScalarCmpCost;
66   int VecCost = VecCmpCost + ExtractCost +
67                 !Ext0->hasOneUse() * ExtractCost +
68                 !Ext1->hasOneUse() * ExtractCost;
69   if (ScalarCost < VecCost)
70     return false;
71 
72   // cmp Pred (extelt V0, C), (extelt V1, C) --> extelt (cmp Pred V0, V1), C
73   ++NumVecCmp;
74   IRBuilder<> Builder(&I);
75   Value *VecCmp = IsFP ? Builder.CreateFCmp(Pred, V0, V1)
76                        : Builder.CreateICmp(Pred, V0, V1);
77   Value *Ext = Builder.CreateExtractElement(VecCmp, C);
78   I.replaceAllUsesWith(Ext);
79   return true;
80 }
81 
82 /// Try to reduce extract element costs by converting scalar binops to vector
83 /// binops followed by extract.
84 static bool foldExtractBinop(Instruction &I, const TargetTransformInfo &TTI) {
85   // It is not safe to transform things like div, urem, etc. because we may
86   // create undefined behavior when executing those on unknown vector elements.
87   if (!isSafeToSpeculativelyExecute(&I))
88     return false;
89 
90   // Match a scalar binop with extracted vector operands:
91   // bo (extelt X, C0), (extelt Y, C1)
92   Instruction *Ext0, *Ext1;
93   if (!match(&I, m_BinOp(m_Instruction(Ext0), m_Instruction(Ext1))))
94     return false;
95 
96   Value *X, *Y;
97   uint64_t C0, C1;
98   if (!match(Ext0, m_ExtractElement(m_Value(X), m_ConstantInt(C0))) ||
99       !match(Ext1, m_ExtractElement(m_Value(Y), m_ConstantInt(C1))) ||
100       X->getType() != Y->getType())
101     return false;
102 
103   // Check if using a vector binop would be cheaper.
104   Instruction::BinaryOps BOpcode = cast<BinaryOperator>(I).getOpcode();
105   Type *ScalarTy = I.getType();
106   Type *VecTy = X->getType();
107   int ScalarBOCost = TTI.getArithmeticInstrCost(BOpcode, ScalarTy);
108   int VecBOCost = TTI.getArithmeticInstrCost(BOpcode, VecTy);
109   int Extract0Cost = TTI.getVectorInstrCost(Instruction::ExtractElement,
110                                             VecTy, C0);
111 
112   // Handle a special case - if the extract indexes are the same, the
113   // replacement sequence does not require a shuffle. Unless the vector binop is
114   // much more expensive than the scalar binop, this eliminates an extract.
115   // Extra uses of the extracts mean that we include those costs in the
116   // vector total because those instructions will not be eliminated.
117   if (C0 == C1) {
118     assert(Extract0Cost ==
119                TTI.getVectorInstrCost(Instruction::ExtractElement, VecTy, C1) &&
120            "Different costs for same extract?");
121     int ExtractCost = Extract0Cost;
122     if (X != Y) {
123       int ScalarCost = ExtractCost + ExtractCost + ScalarBOCost;
124       int VecCost = VecBOCost + ExtractCost +
125                     !Ext0->hasOneUse() * ExtractCost +
126                     !Ext1->hasOneUse() * ExtractCost;
127       if (ScalarCost <= VecCost)
128         return false;
129     } else {
130       // Handle an extra-special case. If the 2 binop operands are identical,
131       // adjust the formulas to account for that:
132       // bo (extelt X, C), (extelt X, C) --> extelt (bo X, X), C
133       // The extra use charge allows for either the CSE'd pattern or an
134       // unoptimized form with identical values.
135       bool HasUseTax = Ext0 == Ext1 ? !Ext0->hasNUses(2)
136                                     : !Ext0->hasOneUse() || !Ext1->hasOneUse();
137       int ScalarCost = ExtractCost + ScalarBOCost;
138       int VecCost = VecBOCost + ExtractCost + HasUseTax * ExtractCost;
139       if (ScalarCost <= VecCost)
140         return false;
141     }
142 
143     // bo (extelt X, C), (extelt Y, C) --> extelt (bo X, Y), C
144     ++NumVecBO;
145     IRBuilder<> Builder(&I);
146     Value *NewBO = Builder.CreateBinOp(BOpcode, X, Y);
147     if (auto *VecBOInst = dyn_cast<Instruction>(NewBO)) {
148       // All IR flags are safe to back-propagate because any potential poison
149       // created in unused vector elements is discarded by the extract.
150       VecBOInst->copyIRFlags(&I);
151     }
152     Value *Extract = Builder.CreateExtractElement(NewBO, Ext0->getOperand(1));
153     I.replaceAllUsesWith(Extract);
154     return true;
155   }
156 
157   // TODO: Handle C0 != C1 by shuffling 1 of the operands.
158   return false;
159 }
160 
161 /// This is the entry point for all transforms. Pass manager differences are
162 /// handled in the callers of this function.
163 static bool runImpl(Function &F, const TargetTransformInfo &TTI,
164                     const DominatorTree &DT) {
165   bool MadeChange = false;
166   for (BasicBlock &BB : F) {
167     // Ignore unreachable basic blocks.
168     if (!DT.isReachableFromEntry(&BB))
169       continue;
170     // Do not delete instructions under here and invalidate the iterator.
171     // Walk the block backwards for efficiency. We're matching a chain of
172     // use->defs, so we're more likely to succeed by starting from the bottom.
173     // TODO: It could be more efficient to remove dead instructions
174     //       iteratively in this loop rather than waiting until the end.
175     for (Instruction &I : make_range(BB.rbegin(), BB.rend())) {
176       MadeChange |= foldExtractCmp(I, TTI);
177       MadeChange |= foldExtractBinop(I, TTI);
178     }
179   }
180 
181   // We're done with transforms, so remove dead instructions.
182   if (MadeChange)
183     for (BasicBlock &BB : F)
184       SimplifyInstructionsInBlock(&BB);
185 
186   return MadeChange;
187 }
188 
189 // Pass manager boilerplate below here.
190 
191 namespace {
192 class VectorCombineLegacyPass : public FunctionPass {
193 public:
194   static char ID;
195   VectorCombineLegacyPass() : FunctionPass(ID) {
196     initializeVectorCombineLegacyPassPass(*PassRegistry::getPassRegistry());
197   }
198 
199   void getAnalysisUsage(AnalysisUsage &AU) const override {
200     AU.addRequired<DominatorTreeWrapperPass>();
201     AU.addRequired<TargetTransformInfoWrapperPass>();
202     AU.setPreservesCFG();
203     AU.addPreserved<DominatorTreeWrapperPass>();
204     AU.addPreserved<GlobalsAAWrapperPass>();
205     FunctionPass::getAnalysisUsage(AU);
206   }
207 
208   bool runOnFunction(Function &F) override {
209     if (skipFunction(F))
210       return false;
211     auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
212     auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
213     return runImpl(F, TTI, DT);
214   }
215 };
216 } // namespace
217 
218 char VectorCombineLegacyPass::ID = 0;
219 INITIALIZE_PASS_BEGIN(VectorCombineLegacyPass, "vector-combine",
220                       "Optimize scalar/vector ops", false,
221                       false)
222 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
223 INITIALIZE_PASS_END(VectorCombineLegacyPass, "vector-combine",
224                     "Optimize scalar/vector ops", false, false)
225 Pass *llvm::createVectorCombinePass() {
226   return new VectorCombineLegacyPass();
227 }
228 
229 PreservedAnalyses VectorCombinePass::run(Function &F,
230                                          FunctionAnalysisManager &FAM) {
231   TargetTransformInfo &TTI = FAM.getResult<TargetIRAnalysis>(F);
232   DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F);
233   if (!runImpl(F, TTI, DT))
234     return PreservedAnalyses::all();
235   PreservedAnalyses PA;
236   PA.preserveSet<CFGAnalyses>();
237   PA.preserve<GlobalsAA>();
238   return PA;
239 }
240