1a17f03bdSSanjay Patel //===------- VectorCombine.cpp - Optimize partial vector operations -------===//
2a17f03bdSSanjay Patel //
3a17f03bdSSanjay Patel // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4a17f03bdSSanjay Patel // See https://llvm.org/LICENSE.txt for license information.
5a17f03bdSSanjay Patel // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6a17f03bdSSanjay Patel //
7a17f03bdSSanjay Patel //===----------------------------------------------------------------------===//
8a17f03bdSSanjay Patel //
9a17f03bdSSanjay Patel // This pass optimizes scalar/vector interactions using target cost models. The
10a17f03bdSSanjay Patel // transforms implemented here may not fit in traditional loop-based or SLP
11a17f03bdSSanjay Patel // vectorization passes.
12a17f03bdSSanjay Patel //
13a17f03bdSSanjay Patel //===----------------------------------------------------------------------===//
14a17f03bdSSanjay Patel 
15a17f03bdSSanjay Patel #include "llvm/Transforms/Vectorize/VectorCombine.h"
16a17f03bdSSanjay Patel #include "llvm/ADT/Statistic.h"
17a17f03bdSSanjay Patel #include "llvm/Analysis/GlobalsModRef.h"
18a17f03bdSSanjay Patel #include "llvm/Analysis/TargetTransformInfo.h"
1919b62b79SSanjay Patel #include "llvm/Analysis/ValueTracking.h"
20a17f03bdSSanjay Patel #include "llvm/IR/Dominators.h"
21a17f03bdSSanjay Patel #include "llvm/IR/Function.h"
22a17f03bdSSanjay Patel #include "llvm/IR/IRBuilder.h"
23a17f03bdSSanjay Patel #include "llvm/IR/PatternMatch.h"
24a17f03bdSSanjay Patel #include "llvm/InitializePasses.h"
25a17f03bdSSanjay Patel #include "llvm/Pass.h"
26a17f03bdSSanjay Patel #include "llvm/Transforms/Vectorize.h"
27a17f03bdSSanjay Patel #include "llvm/Transforms/Utils/Local.h"
28a17f03bdSSanjay Patel 
29a17f03bdSSanjay Patel using namespace llvm;
30a17f03bdSSanjay Patel using namespace llvm::PatternMatch;
31a17f03bdSSanjay Patel 
32a17f03bdSSanjay Patel #define DEBUG_TYPE "vector-combine"
33a17f03bdSSanjay Patel STATISTIC(NumVecCmp, "Number of vector compares formed");
3419b62b79SSanjay Patel STATISTIC(NumVecBO, "Number of vector binops formed");
35a17f03bdSSanjay Patel 
36*fc445589SSanjay Patel /// Try to reduce extract element costs by converting scalar compares to vector
37*fc445589SSanjay Patel /// compares followed by extract.
38*fc445589SSanjay Patel /// cmp (ext0 V0, C0), (ext1 V1, C1)
39*fc445589SSanjay Patel static bool foldExtExtCmp(Instruction *Ext0, Value *V0, uint64_t C0,
40*fc445589SSanjay Patel                           Instruction *Ext1, Value *V1, uint64_t C1,
41*fc445589SSanjay Patel                           Instruction &I, const TargetTransformInfo &TTI) {
42*fc445589SSanjay Patel   assert(isa<CmpInst>(&I) && "Expected a compare");
43a17f03bdSSanjay Patel   Type *ScalarTy = Ext0->getType();
44a17f03bdSSanjay Patel   Type *VecTy = V0->getType();
45a17f03bdSSanjay Patel   bool IsFP = ScalarTy->isFloatingPointTy();
46a17f03bdSSanjay Patel   unsigned CmpOpcode = IsFP ? Instruction::FCmp : Instruction::ICmp;
47a17f03bdSSanjay Patel 
48*fc445589SSanjay Patel   // TODO: Handle C0 != C1 by shuffling 1 of the operands.
49*fc445589SSanjay Patel   if (C0 != C1)
50*fc445589SSanjay Patel     return false;
51*fc445589SSanjay Patel 
52a17f03bdSSanjay Patel   // Check if the existing scalar code or the vector alternative is cheaper.
53a17f03bdSSanjay Patel   // Extra uses of the extracts mean that we include those costs in the
54a17f03bdSSanjay Patel   // vector total because those instructions will not be eliminated.
55a17f03bdSSanjay Patel   // ((2 * extract) + scalar cmp) < (vector cmp + extract) ?
56*fc445589SSanjay Patel   int ExtractCost =
57*fc445589SSanjay Patel       TTI.getVectorInstrCost(Instruction::ExtractElement, VecTy, C0);
5862dd44d7SSanjay Patel   int ScalarCmpCost = TTI.getCmpSelInstrCost(CmpOpcode, ScalarTy, I.getType());
5962dd44d7SSanjay Patel   int VecCmpCost = TTI.getCmpSelInstrCost(CmpOpcode, VecTy,
6062dd44d7SSanjay Patel                                           CmpInst::makeCmpResultType(VecTy));
61a17f03bdSSanjay Patel 
62a17f03bdSSanjay Patel   int ScalarCost = 2 * ExtractCost + ScalarCmpCost;
63a17f03bdSSanjay Patel   int VecCost = VecCmpCost + ExtractCost +
64a17f03bdSSanjay Patel                 !Ext0->hasOneUse() * ExtractCost +
65a17f03bdSSanjay Patel                 !Ext1->hasOneUse() * ExtractCost;
66a17f03bdSSanjay Patel   if (ScalarCost < VecCost)
67a17f03bdSSanjay Patel     return false;
68a17f03bdSSanjay Patel 
69a17f03bdSSanjay Patel   // cmp Pred (extelt V0, C), (extelt V1, C) --> extelt (cmp Pred V0, V1), C
70a17f03bdSSanjay Patel   ++NumVecCmp;
71a17f03bdSSanjay Patel   IRBuilder<> Builder(&I);
72*fc445589SSanjay Patel   CmpInst::Predicate Pred = cast<CmpInst>(&I)->getPredicate();
73a17f03bdSSanjay Patel   Value *VecCmp = IsFP ? Builder.CreateFCmp(Pred, V0, V1)
74a17f03bdSSanjay Patel                        : Builder.CreateICmp(Pred, V0, V1);
75*fc445589SSanjay Patel   Value *Extract = Builder.CreateExtractElement(VecCmp, Ext0->getOperand(1));
76*fc445589SSanjay Patel   I.replaceAllUsesWith(Extract);
77a17f03bdSSanjay Patel   return true;
78a17f03bdSSanjay Patel }
79a17f03bdSSanjay Patel 
8019b62b79SSanjay Patel /// Try to reduce extract element costs by converting scalar binops to vector
8119b62b79SSanjay Patel /// binops followed by extract.
82*fc445589SSanjay Patel /// bo (ext0 V0, C0), (ext1 V1, C1)
83*fc445589SSanjay Patel static bool foldExtExtBinop(Instruction *Ext0, Value *V0, uint64_t C0,
84*fc445589SSanjay Patel                             Instruction *Ext1, Value *V1, uint64_t C1,
85*fc445589SSanjay Patel                             Instruction &I, const TargetTransformInfo &TTI) {
86*fc445589SSanjay Patel   assert(isa<BinaryOperator>(&I) && "Expected a binary operator");
87*fc445589SSanjay Patel   Type *ScalarTy = Ext0->getType();
88*fc445589SSanjay Patel   Type *VecTy = V0->getType();
89*fc445589SSanjay Patel   Instruction::BinaryOps BOpcode = cast<BinaryOperator>(I).getOpcode();
9019b62b79SSanjay Patel 
9119b62b79SSanjay Patel   // Check if using a vector binop would be cheaper.
9219b62b79SSanjay Patel   int ScalarBOCost = TTI.getArithmeticInstrCost(BOpcode, ScalarTy);
9319b62b79SSanjay Patel   int VecBOCost = TTI.getArithmeticInstrCost(BOpcode, VecTy);
9419b62b79SSanjay Patel   int Extract0Cost = TTI.getVectorInstrCost(Instruction::ExtractElement,
9519b62b79SSanjay Patel                                             VecTy, C0);
9619b62b79SSanjay Patel 
9719b62b79SSanjay Patel   // Handle a special case - if the extract indexes are the same, the
9819b62b79SSanjay Patel   // replacement sequence does not require a shuffle. Unless the vector binop is
9919b62b79SSanjay Patel   // much more expensive than the scalar binop, this eliminates an extract.
10019b62b79SSanjay Patel   // Extra uses of the extracts mean that we include those costs in the
10119b62b79SSanjay Patel   // vector total because those instructions will not be eliminated.
10219b62b79SSanjay Patel   if (C0 == C1) {
1031674f772SKadir Cetinkaya     assert(Extract0Cost ==
1041674f772SKadir Cetinkaya                TTI.getVectorInstrCost(Instruction::ExtractElement, VecTy, C1) &&
1051674f772SKadir Cetinkaya            "Different costs for same extract?");
10619b62b79SSanjay Patel     int ExtractCost = Extract0Cost;
107*fc445589SSanjay Patel     if (V0 != V1) {
10819b62b79SSanjay Patel       int ScalarCost = ExtractCost + ExtractCost + ScalarBOCost;
10919b62b79SSanjay Patel       int VecCost = VecBOCost + ExtractCost +
11019b62b79SSanjay Patel                     !Ext0->hasOneUse() * ExtractCost +
11119b62b79SSanjay Patel                     !Ext1->hasOneUse() * ExtractCost;
11219b62b79SSanjay Patel       if (ScalarCost <= VecCost)
11319b62b79SSanjay Patel         return false;
11419b62b79SSanjay Patel     } else {
11519b62b79SSanjay Patel       // Handle an extra-special case. If the 2 binop operands are identical,
11619b62b79SSanjay Patel       // adjust the formulas to account for that:
117*fc445589SSanjay Patel       // bo (extelt V, C), (extelt V, C) --> extelt (bo V, V), C
11819b62b79SSanjay Patel       // The extra use charge allows for either the CSE'd pattern or an
11919b62b79SSanjay Patel       // unoptimized form with identical values.
12019b62b79SSanjay Patel       bool HasUseTax = Ext0 == Ext1 ? !Ext0->hasNUses(2)
12119b62b79SSanjay Patel                                     : !Ext0->hasOneUse() || !Ext1->hasOneUse();
12219b62b79SSanjay Patel       int ScalarCost = ExtractCost + ScalarBOCost;
12319b62b79SSanjay Patel       int VecCost = VecBOCost + ExtractCost + HasUseTax * ExtractCost;
12419b62b79SSanjay Patel       if (ScalarCost <= VecCost)
12519b62b79SSanjay Patel         return false;
12619b62b79SSanjay Patel     }
12719b62b79SSanjay Patel 
12819b62b79SSanjay Patel     // bo (extelt X, C), (extelt Y, C) --> extelt (bo X, Y), C
12919b62b79SSanjay Patel     ++NumVecBO;
13019b62b79SSanjay Patel     IRBuilder<> Builder(&I);
131*fc445589SSanjay Patel     Value *NewBO = Builder.CreateBinOp(BOpcode, V0, V1);
13219b62b79SSanjay Patel     if (auto *VecBOInst = dyn_cast<Instruction>(NewBO)) {
13319b62b79SSanjay Patel       // All IR flags are safe to back-propagate because any potential poison
13419b62b79SSanjay Patel       // created in unused vector elements is discarded by the extract.
13519b62b79SSanjay Patel       VecBOInst->copyIRFlags(&I);
13619b62b79SSanjay Patel     }
13719b62b79SSanjay Patel     Value *Extract = Builder.CreateExtractElement(NewBO, Ext0->getOperand(1));
13819b62b79SSanjay Patel     I.replaceAllUsesWith(Extract);
13919b62b79SSanjay Patel     return true;
14019b62b79SSanjay Patel   }
14119b62b79SSanjay Patel 
14219b62b79SSanjay Patel   // TODO: Handle C0 != C1 by shuffling 1 of the operands.
14319b62b79SSanjay Patel   return false;
14419b62b79SSanjay Patel }
14519b62b79SSanjay Patel 
146*fc445589SSanjay Patel /// Match an instruction with extracted vector operands.
147*fc445589SSanjay Patel static bool foldExtractExtract(Instruction &I, const TargetTransformInfo &TTI) {
148*fc445589SSanjay Patel   Instruction *Ext0, *Ext1;
149*fc445589SSanjay Patel   CmpInst::Predicate Pred = CmpInst::BAD_ICMP_PREDICATE;
150*fc445589SSanjay Patel   if (!match(&I, m_Cmp(Pred, m_Instruction(Ext0), m_Instruction(Ext1))) &&
151*fc445589SSanjay Patel       !match(&I, m_BinOp(m_Instruction(Ext0), m_Instruction(Ext1))))
152*fc445589SSanjay Patel     return false;
153*fc445589SSanjay Patel 
154*fc445589SSanjay Patel   Value *V0, *V1;
155*fc445589SSanjay Patel   uint64_t C0, C1;
156*fc445589SSanjay Patel   if (!match(Ext0, m_ExtractElement(m_Value(V0), m_ConstantInt(C0))) ||
157*fc445589SSanjay Patel       !match(Ext1, m_ExtractElement(m_Value(V1), m_ConstantInt(C1))) ||
158*fc445589SSanjay Patel       V0->getType() != V1->getType())
159*fc445589SSanjay Patel     return false;
160*fc445589SSanjay Patel 
161*fc445589SSanjay Patel   if (Pred != CmpInst::BAD_ICMP_PREDICATE)
162*fc445589SSanjay Patel     return foldExtExtCmp(Ext0, V0, C0, Ext1, V1, C1, I, TTI);
163*fc445589SSanjay Patel 
164*fc445589SSanjay Patel   // It is not safe to transform things like div, urem, etc. because we may
165*fc445589SSanjay Patel   // create undefined behavior when executing those on unknown vector elements.
166*fc445589SSanjay Patel   if (isSafeToSpeculativelyExecute(&I))
167*fc445589SSanjay Patel     return foldExtExtBinop(Ext0, V0, C0, Ext1, V1, C1, I, TTI);
168*fc445589SSanjay Patel 
169*fc445589SSanjay Patel   return false;
170*fc445589SSanjay Patel }
171*fc445589SSanjay Patel 
172a17f03bdSSanjay Patel /// This is the entry point for all transforms. Pass manager differences are
173a17f03bdSSanjay Patel /// handled in the callers of this function.
174a17f03bdSSanjay Patel static bool runImpl(Function &F, const TargetTransformInfo &TTI,
175a17f03bdSSanjay Patel                     const DominatorTree &DT) {
176a17f03bdSSanjay Patel   bool MadeChange = false;
177a17f03bdSSanjay Patel   for (BasicBlock &BB : F) {
178a17f03bdSSanjay Patel     // Ignore unreachable basic blocks.
179a17f03bdSSanjay Patel     if (!DT.isReachableFromEntry(&BB))
180a17f03bdSSanjay Patel       continue;
181a17f03bdSSanjay Patel     // Do not delete instructions under here and invalidate the iterator.
182a17f03bdSSanjay Patel     // Walk the block backwards for efficiency. We're matching a chain of
183a17f03bdSSanjay Patel     // use->defs, so we're more likely to succeed by starting from the bottom.
184a17f03bdSSanjay Patel     // TODO: It could be more efficient to remove dead instructions
185a17f03bdSSanjay Patel     //       iteratively in this loop rather than waiting until the end.
186*fc445589SSanjay Patel     for (Instruction &I : make_range(BB.rbegin(), BB.rend()))
187*fc445589SSanjay Patel       MadeChange |= foldExtractExtract(I, TTI);
188a17f03bdSSanjay Patel   }
189a17f03bdSSanjay Patel 
190a17f03bdSSanjay Patel   // We're done with transforms, so remove dead instructions.
191a17f03bdSSanjay Patel   if (MadeChange)
192a17f03bdSSanjay Patel     for (BasicBlock &BB : F)
193a17f03bdSSanjay Patel       SimplifyInstructionsInBlock(&BB);
194a17f03bdSSanjay Patel 
195a17f03bdSSanjay Patel   return MadeChange;
196a17f03bdSSanjay Patel }
197a17f03bdSSanjay Patel 
198a17f03bdSSanjay Patel // Pass manager boilerplate below here.
199a17f03bdSSanjay Patel 
200a17f03bdSSanjay Patel namespace {
201a17f03bdSSanjay Patel class VectorCombineLegacyPass : public FunctionPass {
202a17f03bdSSanjay Patel public:
203a17f03bdSSanjay Patel   static char ID;
204a17f03bdSSanjay Patel   VectorCombineLegacyPass() : FunctionPass(ID) {
205a17f03bdSSanjay Patel     initializeVectorCombineLegacyPassPass(*PassRegistry::getPassRegistry());
206a17f03bdSSanjay Patel   }
207a17f03bdSSanjay Patel 
208a17f03bdSSanjay Patel   void getAnalysisUsage(AnalysisUsage &AU) const override {
209a17f03bdSSanjay Patel     AU.addRequired<DominatorTreeWrapperPass>();
210a17f03bdSSanjay Patel     AU.addRequired<TargetTransformInfoWrapperPass>();
211a17f03bdSSanjay Patel     AU.setPreservesCFG();
212a17f03bdSSanjay Patel     AU.addPreserved<DominatorTreeWrapperPass>();
213a17f03bdSSanjay Patel     AU.addPreserved<GlobalsAAWrapperPass>();
214a17f03bdSSanjay Patel     FunctionPass::getAnalysisUsage(AU);
215a17f03bdSSanjay Patel   }
216a17f03bdSSanjay Patel 
217a17f03bdSSanjay Patel   bool runOnFunction(Function &F) override {
218a17f03bdSSanjay Patel     if (skipFunction(F))
219a17f03bdSSanjay Patel       return false;
220a17f03bdSSanjay Patel     auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
221a17f03bdSSanjay Patel     auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
222a17f03bdSSanjay Patel     return runImpl(F, TTI, DT);
223a17f03bdSSanjay Patel   }
224a17f03bdSSanjay Patel };
225a17f03bdSSanjay Patel } // namespace
226a17f03bdSSanjay Patel 
227a17f03bdSSanjay Patel char VectorCombineLegacyPass::ID = 0;
228a17f03bdSSanjay Patel INITIALIZE_PASS_BEGIN(VectorCombineLegacyPass, "vector-combine",
229a17f03bdSSanjay Patel                       "Optimize scalar/vector ops", false,
230a17f03bdSSanjay Patel                       false)
231a17f03bdSSanjay Patel INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
232a17f03bdSSanjay Patel INITIALIZE_PASS_END(VectorCombineLegacyPass, "vector-combine",
233a17f03bdSSanjay Patel                     "Optimize scalar/vector ops", false, false)
234a17f03bdSSanjay Patel Pass *llvm::createVectorCombinePass() {
235a17f03bdSSanjay Patel   return new VectorCombineLegacyPass();
236a17f03bdSSanjay Patel }
237a17f03bdSSanjay Patel 
238a17f03bdSSanjay Patel PreservedAnalyses VectorCombinePass::run(Function &F,
239a17f03bdSSanjay Patel                                          FunctionAnalysisManager &FAM) {
240a17f03bdSSanjay Patel   TargetTransformInfo &TTI = FAM.getResult<TargetIRAnalysis>(F);
241a17f03bdSSanjay Patel   DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F);
242a17f03bdSSanjay Patel   if (!runImpl(F, TTI, DT))
243a17f03bdSSanjay Patel     return PreservedAnalyses::all();
244a17f03bdSSanjay Patel   PreservedAnalyses PA;
245a17f03bdSSanjay Patel   PA.preserveSet<CFGAnalyses>();
246a17f03bdSSanjay Patel   PA.preserve<GlobalsAA>();
247a17f03bdSSanjay Patel   return PA;
248a17f03bdSSanjay Patel }
249