1 //===------- VectorCombine.cpp - Optimize partial vector operations -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass optimizes scalar/vector interactions using target cost models. The
10 // transforms implemented here may not fit in traditional loop-based or SLP
11 // vectorization passes.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "llvm/Transforms/Vectorize/VectorCombine.h"
16 #include "llvm/ADT/Statistic.h"
17 #include "llvm/Analysis/GlobalsModRef.h"
18 #include "llvm/Analysis/TargetTransformInfo.h"
19 #include "llvm/Analysis/ValueTracking.h"
20 #include "llvm/IR/Dominators.h"
21 #include "llvm/IR/Function.h"
22 #include "llvm/IR/IRBuilder.h"
23 #include "llvm/IR/PatternMatch.h"
24 #include "llvm/InitializePasses.h"
25 #include "llvm/Pass.h"
26 #include "llvm/Support/CommandLine.h"
27 #include "llvm/Transforms/Vectorize.h"
28 #include "llvm/Transforms/Utils/Local.h"
29 
30 using namespace llvm;
31 using namespace llvm::PatternMatch;
32 
33 #define DEBUG_TYPE "vector-combine"
34 STATISTIC(NumVecCmp, "Number of vector compares formed");
35 STATISTIC(NumVecBO, "Number of vector binops formed");
36 
37 static cl::opt<bool> DisableVectorCombine(
38     "disable-vector-combine", cl::init(false), cl::Hidden,
39     cl::desc("Disable all vector combine transforms"));
40 
41 static cl::opt<bool> DisableBinopExtractShuffle(
42     "disable-binop-extract-shuffle", cl::init(false), cl::Hidden,
43     cl::desc("Disable binop extract to shuffle transforms"));
44 
45 
46 /// Compare the relative costs of 2 extracts followed by scalar operation vs.
47 /// vector operation(s) followed by extract. Return true if the existing
48 /// instructions are cheaper than a vector alternative. Otherwise, return false
49 /// and if one of the extracts should be transformed to a shufflevector, set
50 /// \p ConvertToShuffle to that extract instruction.
51 static bool isExtractExtractCheap(Instruction *Ext0, Instruction *Ext1,
52                                   unsigned Opcode,
53                                   const TargetTransformInfo &TTI,
54                                   Instruction *&ConvertToShuffle) {
55   assert(isa<ConstantInt>(Ext0->getOperand(1)) &&
56          isa<ConstantInt>(Ext1->getOperand(1)) &&
57          "Expected constant extract indexes");
58   Type *ScalarTy = Ext0->getType();
59   Type *VecTy = Ext0->getOperand(0)->getType();
60   int ScalarOpCost, VectorOpCost;
61 
62   // Get cost estimates for scalar and vector versions of the operation.
63   bool IsBinOp = Instruction::isBinaryOp(Opcode);
64   if (IsBinOp) {
65     ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, ScalarTy);
66     VectorOpCost = TTI.getArithmeticInstrCost(Opcode, VecTy);
67   } else {
68     assert((Opcode == Instruction::ICmp || Opcode == Instruction::FCmp) &&
69            "Expected a compare");
70     ScalarOpCost = TTI.getCmpSelInstrCost(Opcode, ScalarTy,
71                                           CmpInst::makeCmpResultType(ScalarTy));
72     VectorOpCost = TTI.getCmpSelInstrCost(Opcode, VecTy,
73                                           CmpInst::makeCmpResultType(VecTy));
74   }
75 
76   // Get cost estimates for the extract elements. These costs will factor into
77   // both sequences.
78   unsigned Ext0Index = cast<ConstantInt>(Ext0->getOperand(1))->getZExtValue();
79   unsigned Ext1Index = cast<ConstantInt>(Ext1->getOperand(1))->getZExtValue();
80 
81   int Extract0Cost = TTI.getVectorInstrCost(Instruction::ExtractElement,
82                                             VecTy, Ext0Index);
83   int Extract1Cost = TTI.getVectorInstrCost(Instruction::ExtractElement,
84                                             VecTy, Ext1Index);
85 
86   // A more expensive extract will always be replaced by a splat shuffle.
87   // For example, if Ext0 is more expensive:
88   // opcode (extelt V0, Ext0), (ext V1, Ext1) -->
89   // extelt (opcode (splat V0, Ext0), V1), Ext1
90   // TODO: Evaluate whether that always results in lowest cost. Alternatively,
91   //       check the cost of creating a broadcast shuffle and shuffling both
92   //       operands to element 0.
93   int CheapExtractCost = std::min(Extract0Cost, Extract1Cost);
94 
95   // Extra uses of the extracts mean that we include those costs in the
96   // vector total because those instructions will not be eliminated.
97   int OldCost, NewCost;
98   if (Ext0->getOperand(0) == Ext1->getOperand(0) && Ext0Index == Ext1Index) {
99     // Handle a special case. If the 2 extracts are identical, adjust the
100     // formulas to account for that. The extra use charge allows for either the
101     // CSE'd pattern or an unoptimized form with identical values:
102     // opcode (extelt V, C), (extelt V, C) --> extelt (opcode V, V), C
103     bool HasUseTax = Ext0 == Ext1 ? !Ext0->hasNUses(2)
104                                   : !Ext0->hasOneUse() || !Ext1->hasOneUse();
105     OldCost = CheapExtractCost + ScalarOpCost;
106     NewCost = VectorOpCost + CheapExtractCost + HasUseTax * CheapExtractCost;
107   } else {
108     // Handle the general case. Each extract is actually a different value:
109     // opcode (extelt V0, C0), (extelt V1, C1) --> extelt (opcode V0, V1), C
110     OldCost = Extract0Cost + Extract1Cost + ScalarOpCost;
111     NewCost = VectorOpCost + CheapExtractCost +
112               !Ext0->hasOneUse() * Extract0Cost +
113               !Ext1->hasOneUse() * Extract1Cost;
114   }
115 
116   if (Ext0Index == Ext1Index) {
117     // If the extract indexes are identical, no shuffle is needed.
118     ConvertToShuffle = nullptr;
119   } else {
120     if (IsBinOp && DisableBinopExtractShuffle)
121       return true;
122 
123     // If we are extracting from 2 different indexes, then one operand must be
124     // shuffled before performing the vector operation. The shuffle mask is
125     // undefined except for 1 lane that is being translated to the remaining
126     // extraction lane. Therefore, it is a splat shuffle. Ex:
127     // ShufMask = { undef, undef, 0, undef }
128     // TODO: The cost model has an option for a "broadcast" shuffle
129     //       (splat-from-element-0), but no option for a more general splat.
130     NewCost +=
131         TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, VecTy);
132 
133     // The more expensive extract will be replaced by a shuffle. If the extracts
134     // have the same cost, replace the extract with the higher index.
135     if (Extract0Cost > Extract1Cost)
136       ConvertToShuffle = Ext0;
137     else if (Extract1Cost > Extract0Cost)
138       ConvertToShuffle = Ext1;
139     else
140       ConvertToShuffle = Ext0Index > Ext1Index ? Ext0 : Ext1;
141   }
142 
143   // Aggressively form a vector op if the cost is equal because the transform
144   // may enable further optimization.
145   // Codegen can reverse this transform (scalarize) if it was not profitable.
146   return OldCost < NewCost;
147 }
148 
149 /// Try to reduce extract element costs by converting scalar compares to vector
150 /// compares followed by extract.
151 /// cmp (ext0 V0, C), (ext1 V1, C)
152 static void foldExtExtCmp(Instruction *Ext0, Instruction *Ext1,
153                           Instruction &I, const TargetTransformInfo &TTI) {
154   assert(isa<CmpInst>(&I) && "Expected a compare");
155 
156   // cmp Pred (extelt V0, C), (extelt V1, C) --> extelt (cmp Pred V0, V1), C
157   ++NumVecCmp;
158   IRBuilder<> Builder(&I);
159   CmpInst::Predicate Pred = cast<CmpInst>(&I)->getPredicate();
160   Value *V0 = Ext0->getOperand(0), *V1 = Ext1->getOperand(0);
161   Value *VecCmp =
162       Ext0->getType()->isFloatingPointTy() ? Builder.CreateFCmp(Pred, V0, V1)
163                                            : Builder.CreateICmp(Pred, V0, V1);
164   Value *Extract = Builder.CreateExtractElement(VecCmp, Ext0->getOperand(1));
165   I.replaceAllUsesWith(Extract);
166 }
167 
168 /// Try to reduce extract element costs by converting scalar binops to vector
169 /// binops followed by extract.
170 /// bo (ext0 V0, C), (ext1 V1, C)
171 static void foldExtExtBinop(Instruction *Ext0, Instruction *Ext1,
172                             Instruction &I, const TargetTransformInfo &TTI) {
173   assert(isa<BinaryOperator>(&I) && "Expected a binary operator");
174 
175   // bo (extelt V0, C), (extelt V1, C) --> extelt (bo V0, V1), C
176   ++NumVecBO;
177   IRBuilder<> Builder(&I);
178   Value *V0 = Ext0->getOperand(0), *V1 = Ext1->getOperand(0);
179   Value *VecBO =
180       Builder.CreateBinOp(cast<BinaryOperator>(&I)->getOpcode(), V0, V1);
181 
182   // All IR flags are safe to back-propagate because any potential poison
183   // created in unused vector elements is discarded by the extract.
184   if (auto *VecBOInst = dyn_cast<Instruction>(VecBO))
185     VecBOInst->copyIRFlags(&I);
186 
187   Value *Extract = Builder.CreateExtractElement(VecBO, Ext0->getOperand(1));
188   I.replaceAllUsesWith(Extract);
189 }
190 
191 /// Match an instruction with extracted vector operands.
192 static bool foldExtractExtract(Instruction &I, const TargetTransformInfo &TTI) {
193   // It is not safe to transform things like div, urem, etc. because we may
194   // create undefined behavior when executing those on unknown vector elements.
195   if (!isSafeToSpeculativelyExecute(&I))
196     return false;
197 
198   Instruction *Ext0, *Ext1;
199   CmpInst::Predicate Pred = CmpInst::BAD_ICMP_PREDICATE;
200   if (!match(&I, m_Cmp(Pred, m_Instruction(Ext0), m_Instruction(Ext1))) &&
201       !match(&I, m_BinOp(m_Instruction(Ext0), m_Instruction(Ext1))))
202     return false;
203 
204   Value *V0, *V1;
205   uint64_t C0, C1;
206   if (!match(Ext0, m_ExtractElement(m_Value(V0), m_ConstantInt(C0))) ||
207       !match(Ext1, m_ExtractElement(m_Value(V1), m_ConstantInt(C1))) ||
208       V0->getType() != V1->getType())
209     return false;
210 
211   Instruction *ConvertToShuffle;
212   if (isExtractExtractCheap(Ext0, Ext1, I.getOpcode(), TTI, ConvertToShuffle))
213     return false;
214 
215   if (ConvertToShuffle) {
216     // The shuffle mask is undefined except for 1 lane that is being translated
217     // to the cheap extraction lane. Example:
218     // ShufMask = { 2, undef, undef, undef }
219     uint64_t SplatIndex = ConvertToShuffle == Ext0 ? C0 : C1;
220     uint64_t CheapExtIndex = ConvertToShuffle == Ext0 ? C1 : C0;
221     Type *VecTy = V0->getType();
222     Type *I32Ty = IntegerType::getInt32Ty(I.getContext());
223     UndefValue *Undef = UndefValue::get(I32Ty);
224     SmallVector<Constant *, 32> ShufMask(VecTy->getVectorNumElements(), Undef);
225     ShufMask[CheapExtIndex] = ConstantInt::get(I32Ty, SplatIndex);
226     IRBuilder<> Builder(ConvertToShuffle);
227 
228     // extelt X, C --> extelt (splat X), C'
229     Value *Shuf = Builder.CreateShuffleVector(ConvertToShuffle->getOperand(0),
230                                               UndefValue::get(VecTy),
231                                               ConstantVector::get(ShufMask));
232     Value *NewExt = Builder.CreateExtractElement(Shuf, CheapExtIndex);
233     if (ConvertToShuffle == Ext0)
234       Ext0 = cast<Instruction>(NewExt);
235     else
236       Ext1 = cast<Instruction>(NewExt);
237   }
238 
239   if (Pred != CmpInst::BAD_ICMP_PREDICATE)
240     foldExtExtCmp(Ext0, Ext1, I, TTI);
241   else
242     foldExtExtBinop(Ext0, Ext1, I, TTI);
243 
244   return true;
245 }
246 
247 /// This is the entry point for all transforms. Pass manager differences are
248 /// handled in the callers of this function.
249 static bool runImpl(Function &F, const TargetTransformInfo &TTI,
250                     const DominatorTree &DT) {
251   if (DisableVectorCombine)
252     return false;
253 
254   bool MadeChange = false;
255   for (BasicBlock &BB : F) {
256     // Ignore unreachable basic blocks.
257     if (!DT.isReachableFromEntry(&BB))
258       continue;
259     // Do not delete instructions under here and invalidate the iterator.
260     // Walk the block backwards for efficiency. We're matching a chain of
261     // use->defs, so we're more likely to succeed by starting from the bottom.
262     // TODO: It could be more efficient to remove dead instructions
263     //       iteratively in this loop rather than waiting until the end.
264     for (Instruction &I : make_range(BB.rbegin(), BB.rend()))
265       MadeChange |= foldExtractExtract(I, TTI);
266   }
267 
268   // We're done with transforms, so remove dead instructions.
269   if (MadeChange)
270     for (BasicBlock &BB : F)
271       SimplifyInstructionsInBlock(&BB);
272 
273   return MadeChange;
274 }
275 
276 // Pass manager boilerplate below here.
277 
278 namespace {
279 class VectorCombineLegacyPass : public FunctionPass {
280 public:
281   static char ID;
282   VectorCombineLegacyPass() : FunctionPass(ID) {
283     initializeVectorCombineLegacyPassPass(*PassRegistry::getPassRegistry());
284   }
285 
286   void getAnalysisUsage(AnalysisUsage &AU) const override {
287     AU.addRequired<DominatorTreeWrapperPass>();
288     AU.addRequired<TargetTransformInfoWrapperPass>();
289     AU.setPreservesCFG();
290     AU.addPreserved<DominatorTreeWrapperPass>();
291     AU.addPreserved<GlobalsAAWrapperPass>();
292     FunctionPass::getAnalysisUsage(AU);
293   }
294 
295   bool runOnFunction(Function &F) override {
296     if (skipFunction(F))
297       return false;
298     auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
299     auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
300     return runImpl(F, TTI, DT);
301   }
302 };
303 } // namespace
304 
305 char VectorCombineLegacyPass::ID = 0;
306 INITIALIZE_PASS_BEGIN(VectorCombineLegacyPass, "vector-combine",
307                       "Optimize scalar/vector ops", false,
308                       false)
309 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
310 INITIALIZE_PASS_END(VectorCombineLegacyPass, "vector-combine",
311                     "Optimize scalar/vector ops", false, false)
312 Pass *llvm::createVectorCombinePass() {
313   return new VectorCombineLegacyPass();
314 }
315 
316 PreservedAnalyses VectorCombinePass::run(Function &F,
317                                          FunctionAnalysisManager &FAM) {
318   TargetTransformInfo &TTI = FAM.getResult<TargetIRAnalysis>(F);
319   DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F);
320   if (!runImpl(F, TTI, DT))
321     return PreservedAnalyses::all();
322   PreservedAnalyses PA;
323   PA.preserveSet<CFGAnalyses>();
324   PA.preserve<GlobalsAA>();
325   return PA;
326 }
327