1a17f03bdSSanjay Patel //===------- VectorCombine.cpp - Optimize partial vector operations -------===//
2a17f03bdSSanjay Patel //
3a17f03bdSSanjay Patel // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4a17f03bdSSanjay Patel // See https://llvm.org/LICENSE.txt for license information.
5a17f03bdSSanjay Patel // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6a17f03bdSSanjay Patel //
7a17f03bdSSanjay Patel //===----------------------------------------------------------------------===//
8a17f03bdSSanjay Patel //
9a17f03bdSSanjay Patel // This pass optimizes scalar/vector interactions using target cost models. The
10a17f03bdSSanjay Patel // transforms implemented here may not fit in traditional loop-based or SLP
11a17f03bdSSanjay Patel // vectorization passes.
12a17f03bdSSanjay Patel //
13a17f03bdSSanjay Patel //===----------------------------------------------------------------------===//
14a17f03bdSSanjay Patel 
15a17f03bdSSanjay Patel #include "llvm/Transforms/Vectorize/VectorCombine.h"
16a17f03bdSSanjay Patel #include "llvm/ADT/Statistic.h"
17a17f03bdSSanjay Patel #include "llvm/Analysis/GlobalsModRef.h"
18a17f03bdSSanjay Patel #include "llvm/Analysis/TargetTransformInfo.h"
1919b62b79SSanjay Patel #include "llvm/Analysis/ValueTracking.h"
20a17f03bdSSanjay Patel #include "llvm/IR/Dominators.h"
21a17f03bdSSanjay Patel #include "llvm/IR/Function.h"
22a17f03bdSSanjay Patel #include "llvm/IR/IRBuilder.h"
23a17f03bdSSanjay Patel #include "llvm/IR/PatternMatch.h"
24a17f03bdSSanjay Patel #include "llvm/InitializePasses.h"
25a17f03bdSSanjay Patel #include "llvm/Pass.h"
26a17f03bdSSanjay Patel #include "llvm/Transforms/Vectorize.h"
27a17f03bdSSanjay Patel #include "llvm/Transforms/Utils/Local.h"
28a17f03bdSSanjay Patel 
29a17f03bdSSanjay Patel using namespace llvm;
30a17f03bdSSanjay Patel using namespace llvm::PatternMatch;
31a17f03bdSSanjay Patel 
32a17f03bdSSanjay Patel #define DEBUG_TYPE "vector-combine"
33a17f03bdSSanjay Patel STATISTIC(NumVecCmp, "Number of vector compares formed");
3419b62b79SSanjay Patel STATISTIC(NumVecBO, "Number of vector binops formed");
35a17f03bdSSanjay Patel 
36a17f03bdSSanjay Patel static bool foldExtractCmp(Instruction &I, const TargetTransformInfo &TTI) {
37a17f03bdSSanjay Patel   // Match a cmp with extracted vector operands.
38a17f03bdSSanjay Patel   CmpInst::Predicate Pred;
39a17f03bdSSanjay Patel   Instruction *Ext0, *Ext1;
40a17f03bdSSanjay Patel   if (!match(&I, m_Cmp(Pred, m_Instruction(Ext0), m_Instruction(Ext1))))
41a17f03bdSSanjay Patel     return false;
42a17f03bdSSanjay Patel 
43a17f03bdSSanjay Patel   Value *V0, *V1;
44a17f03bdSSanjay Patel   ConstantInt *C;
45a17f03bdSSanjay Patel   if (!match(Ext0, m_ExtractElement(m_Value(V0), m_ConstantInt(C))) ||
46a17f03bdSSanjay Patel       !match(Ext1, m_ExtractElement(m_Value(V1), m_Specific(C))) ||
47a17f03bdSSanjay Patel       V0->getType() != V1->getType())
48a17f03bdSSanjay Patel     return false;
49a17f03bdSSanjay Patel 
50a17f03bdSSanjay Patel   Type *ScalarTy = Ext0->getType();
51a17f03bdSSanjay Patel   Type *VecTy = V0->getType();
52a17f03bdSSanjay Patel   bool IsFP = ScalarTy->isFloatingPointTy();
53a17f03bdSSanjay Patel   unsigned CmpOpcode = IsFP ? Instruction::FCmp : Instruction::ICmp;
54a17f03bdSSanjay Patel 
55a17f03bdSSanjay Patel   // Check if the existing scalar code or the vector alternative is cheaper.
56a17f03bdSSanjay Patel   // Extra uses of the extracts mean that we include those costs in the
57a17f03bdSSanjay Patel   // vector total because those instructions will not be eliminated.
58a17f03bdSSanjay Patel   // ((2 * extract) + scalar cmp) < (vector cmp + extract) ?
59a17f03bdSSanjay Patel   int ExtractCost = TTI.getVectorInstrCost(Instruction::ExtractElement,
60a17f03bdSSanjay Patel                                            VecTy, C->getZExtValue());
61*62dd44d7SSanjay Patel   int ScalarCmpCost = TTI.getCmpSelInstrCost(CmpOpcode, ScalarTy, I.getType());
62*62dd44d7SSanjay Patel   int VecCmpCost = TTI.getCmpSelInstrCost(CmpOpcode, VecTy,
63*62dd44d7SSanjay Patel                                           CmpInst::makeCmpResultType(VecTy));
64a17f03bdSSanjay Patel 
65a17f03bdSSanjay Patel   int ScalarCost = 2 * ExtractCost + ScalarCmpCost;
66a17f03bdSSanjay Patel   int VecCost = VecCmpCost + ExtractCost +
67a17f03bdSSanjay Patel                 !Ext0->hasOneUse() * ExtractCost +
68a17f03bdSSanjay Patel                 !Ext1->hasOneUse() * ExtractCost;
69a17f03bdSSanjay Patel   if (ScalarCost < VecCost)
70a17f03bdSSanjay Patel     return false;
71a17f03bdSSanjay Patel 
72a17f03bdSSanjay Patel   // cmp Pred (extelt V0, C), (extelt V1, C) --> extelt (cmp Pred V0, V1), C
73a17f03bdSSanjay Patel   ++NumVecCmp;
74a17f03bdSSanjay Patel   IRBuilder<> Builder(&I);
75a17f03bdSSanjay Patel   Value *VecCmp = IsFP ? Builder.CreateFCmp(Pred, V0, V1)
76a17f03bdSSanjay Patel                        : Builder.CreateICmp(Pred, V0, V1);
77a17f03bdSSanjay Patel   Value *Ext = Builder.CreateExtractElement(VecCmp, C);
78a17f03bdSSanjay Patel   I.replaceAllUsesWith(Ext);
79a17f03bdSSanjay Patel   return true;
80a17f03bdSSanjay Patel }
81a17f03bdSSanjay Patel 
8219b62b79SSanjay Patel /// Try to reduce extract element costs by converting scalar binops to vector
8319b62b79SSanjay Patel /// binops followed by extract.
8419b62b79SSanjay Patel static bool foldExtractBinop(Instruction &I, const TargetTransformInfo &TTI) {
8519b62b79SSanjay Patel   // It is not safe to transform things like div, urem, etc. because we may
8619b62b79SSanjay Patel   // create undefined behavior when executing those on unknown vector elements.
8719b62b79SSanjay Patel   if (!isSafeToSpeculativelyExecute(&I))
8819b62b79SSanjay Patel     return false;
8919b62b79SSanjay Patel 
9019b62b79SSanjay Patel   // Match a scalar binop with extracted vector operands:
9119b62b79SSanjay Patel   // bo (extelt X, C0), (extelt Y, C1)
9219b62b79SSanjay Patel   Instruction *Ext0, *Ext1;
9319b62b79SSanjay Patel   if (!match(&I, m_BinOp(m_Instruction(Ext0), m_Instruction(Ext1))))
9419b62b79SSanjay Patel     return false;
9519b62b79SSanjay Patel 
9619b62b79SSanjay Patel   Value *X, *Y;
9719b62b79SSanjay Patel   uint64_t C0, C1;
9819b62b79SSanjay Patel   if (!match(Ext0, m_ExtractElement(m_Value(X), m_ConstantInt(C0))) ||
9919b62b79SSanjay Patel       !match(Ext1, m_ExtractElement(m_Value(Y), m_ConstantInt(C1))) ||
10019b62b79SSanjay Patel       X->getType() != Y->getType())
10119b62b79SSanjay Patel     return false;
10219b62b79SSanjay Patel 
10319b62b79SSanjay Patel   // Check if using a vector binop would be cheaper.
10419b62b79SSanjay Patel   Instruction::BinaryOps BOpcode = cast<BinaryOperator>(I).getOpcode();
10519b62b79SSanjay Patel   Type *ScalarTy = I.getType();
10619b62b79SSanjay Patel   Type *VecTy = X->getType();
10719b62b79SSanjay Patel   int ScalarBOCost = TTI.getArithmeticInstrCost(BOpcode, ScalarTy);
10819b62b79SSanjay Patel   int VecBOCost = TTI.getArithmeticInstrCost(BOpcode, VecTy);
10919b62b79SSanjay Patel   int Extract0Cost = TTI.getVectorInstrCost(Instruction::ExtractElement,
11019b62b79SSanjay Patel                                             VecTy, C0);
11119b62b79SSanjay Patel 
11219b62b79SSanjay Patel   // Handle a special case - if the extract indexes are the same, the
11319b62b79SSanjay Patel   // replacement sequence does not require a shuffle. Unless the vector binop is
11419b62b79SSanjay Patel   // much more expensive than the scalar binop, this eliminates an extract.
11519b62b79SSanjay Patel   // Extra uses of the extracts mean that we include those costs in the
11619b62b79SSanjay Patel   // vector total because those instructions will not be eliminated.
11719b62b79SSanjay Patel   if (C0 == C1) {
1181674f772SKadir Cetinkaya     assert(Extract0Cost ==
1191674f772SKadir Cetinkaya                TTI.getVectorInstrCost(Instruction::ExtractElement, VecTy, C1) &&
1201674f772SKadir Cetinkaya            "Different costs for same extract?");
12119b62b79SSanjay Patel     int ExtractCost = Extract0Cost;
12219b62b79SSanjay Patel     if (X != Y) {
12319b62b79SSanjay Patel       int ScalarCost = ExtractCost + ExtractCost + ScalarBOCost;
12419b62b79SSanjay Patel       int VecCost = VecBOCost + ExtractCost +
12519b62b79SSanjay Patel                     !Ext0->hasOneUse() * ExtractCost +
12619b62b79SSanjay Patel                     !Ext1->hasOneUse() * ExtractCost;
12719b62b79SSanjay Patel       if (ScalarCost <= VecCost)
12819b62b79SSanjay Patel         return false;
12919b62b79SSanjay Patel     } else {
13019b62b79SSanjay Patel       // Handle an extra-special case. If the 2 binop operands are identical,
13119b62b79SSanjay Patel       // adjust the formulas to account for that:
13219b62b79SSanjay Patel       // bo (extelt X, C), (extelt X, C) --> extelt (bo X, X), C
13319b62b79SSanjay Patel       // The extra use charge allows for either the CSE'd pattern or an
13419b62b79SSanjay Patel       // unoptimized form with identical values.
13519b62b79SSanjay Patel       bool HasUseTax = Ext0 == Ext1 ? !Ext0->hasNUses(2)
13619b62b79SSanjay Patel                                     : !Ext0->hasOneUse() || !Ext1->hasOneUse();
13719b62b79SSanjay Patel       int ScalarCost = ExtractCost + ScalarBOCost;
13819b62b79SSanjay Patel       int VecCost = VecBOCost + ExtractCost + HasUseTax * ExtractCost;
13919b62b79SSanjay Patel       if (ScalarCost <= VecCost)
14019b62b79SSanjay Patel         return false;
14119b62b79SSanjay Patel     }
14219b62b79SSanjay Patel 
14319b62b79SSanjay Patel     // bo (extelt X, C), (extelt Y, C) --> extelt (bo X, Y), C
14419b62b79SSanjay Patel     ++NumVecBO;
14519b62b79SSanjay Patel     IRBuilder<> Builder(&I);
14619b62b79SSanjay Patel     Value *NewBO = Builder.CreateBinOp(BOpcode, X, Y);
14719b62b79SSanjay Patel     if (auto *VecBOInst = dyn_cast<Instruction>(NewBO)) {
14819b62b79SSanjay Patel       // All IR flags are safe to back-propagate because any potential poison
14919b62b79SSanjay Patel       // created in unused vector elements is discarded by the extract.
15019b62b79SSanjay Patel       VecBOInst->copyIRFlags(&I);
15119b62b79SSanjay Patel     }
15219b62b79SSanjay Patel     Value *Extract = Builder.CreateExtractElement(NewBO, Ext0->getOperand(1));
15319b62b79SSanjay Patel     I.replaceAllUsesWith(Extract);
15419b62b79SSanjay Patel     return true;
15519b62b79SSanjay Patel   }
15619b62b79SSanjay Patel 
15719b62b79SSanjay Patel   // TODO: Handle C0 != C1 by shuffling 1 of the operands.
15819b62b79SSanjay Patel   return false;
15919b62b79SSanjay Patel }
16019b62b79SSanjay Patel 
161a17f03bdSSanjay Patel /// This is the entry point for all transforms. Pass manager differences are
162a17f03bdSSanjay Patel /// handled in the callers of this function.
163a17f03bdSSanjay Patel static bool runImpl(Function &F, const TargetTransformInfo &TTI,
164a17f03bdSSanjay Patel                     const DominatorTree &DT) {
165a17f03bdSSanjay Patel   bool MadeChange = false;
166a17f03bdSSanjay Patel   for (BasicBlock &BB : F) {
167a17f03bdSSanjay Patel     // Ignore unreachable basic blocks.
168a17f03bdSSanjay Patel     if (!DT.isReachableFromEntry(&BB))
169a17f03bdSSanjay Patel       continue;
170a17f03bdSSanjay Patel     // Do not delete instructions under here and invalidate the iterator.
171a17f03bdSSanjay Patel     // Walk the block backwards for efficiency. We're matching a chain of
172a17f03bdSSanjay Patel     // use->defs, so we're more likely to succeed by starting from the bottom.
173a17f03bdSSanjay Patel     // TODO: It could be more efficient to remove dead instructions
174a17f03bdSSanjay Patel     //       iteratively in this loop rather than waiting until the end.
175a17f03bdSSanjay Patel     for (Instruction &I : make_range(BB.rbegin(), BB.rend())) {
176a17f03bdSSanjay Patel       MadeChange |= foldExtractCmp(I, TTI);
17719b62b79SSanjay Patel       MadeChange |= foldExtractBinop(I, TTI);
178a17f03bdSSanjay Patel     }
179a17f03bdSSanjay Patel   }
180a17f03bdSSanjay Patel 
181a17f03bdSSanjay Patel   // We're done with transforms, so remove dead instructions.
182a17f03bdSSanjay Patel   if (MadeChange)
183a17f03bdSSanjay Patel     for (BasicBlock &BB : F)
184a17f03bdSSanjay Patel       SimplifyInstructionsInBlock(&BB);
185a17f03bdSSanjay Patel 
186a17f03bdSSanjay Patel   return MadeChange;
187a17f03bdSSanjay Patel }
188a17f03bdSSanjay Patel 
189a17f03bdSSanjay Patel // Pass manager boilerplate below here.
190a17f03bdSSanjay Patel 
191a17f03bdSSanjay Patel namespace {
192a17f03bdSSanjay Patel class VectorCombineLegacyPass : public FunctionPass {
193a17f03bdSSanjay Patel public:
194a17f03bdSSanjay Patel   static char ID;
195a17f03bdSSanjay Patel   VectorCombineLegacyPass() : FunctionPass(ID) {
196a17f03bdSSanjay Patel     initializeVectorCombineLegacyPassPass(*PassRegistry::getPassRegistry());
197a17f03bdSSanjay Patel   }
198a17f03bdSSanjay Patel 
199a17f03bdSSanjay Patel   void getAnalysisUsage(AnalysisUsage &AU) const override {
200a17f03bdSSanjay Patel     AU.addRequired<DominatorTreeWrapperPass>();
201a17f03bdSSanjay Patel     AU.addRequired<TargetTransformInfoWrapperPass>();
202a17f03bdSSanjay Patel     AU.setPreservesCFG();
203a17f03bdSSanjay Patel     AU.addPreserved<DominatorTreeWrapperPass>();
204a17f03bdSSanjay Patel     AU.addPreserved<GlobalsAAWrapperPass>();
205a17f03bdSSanjay Patel     FunctionPass::getAnalysisUsage(AU);
206a17f03bdSSanjay Patel   }
207a17f03bdSSanjay Patel 
208a17f03bdSSanjay Patel   bool runOnFunction(Function &F) override {
209a17f03bdSSanjay Patel     if (skipFunction(F))
210a17f03bdSSanjay Patel       return false;
211a17f03bdSSanjay Patel     auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
212a17f03bdSSanjay Patel     auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
213a17f03bdSSanjay Patel     return runImpl(F, TTI, DT);
214a17f03bdSSanjay Patel   }
215a17f03bdSSanjay Patel };
216a17f03bdSSanjay Patel } // namespace
217a17f03bdSSanjay Patel 
218a17f03bdSSanjay Patel char VectorCombineLegacyPass::ID = 0;
219a17f03bdSSanjay Patel INITIALIZE_PASS_BEGIN(VectorCombineLegacyPass, "vector-combine",
220a17f03bdSSanjay Patel                       "Optimize scalar/vector ops", false,
221a17f03bdSSanjay Patel                       false)
222a17f03bdSSanjay Patel INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
223a17f03bdSSanjay Patel INITIALIZE_PASS_END(VectorCombineLegacyPass, "vector-combine",
224a17f03bdSSanjay Patel                     "Optimize scalar/vector ops", false, false)
225a17f03bdSSanjay Patel Pass *llvm::createVectorCombinePass() {
226a17f03bdSSanjay Patel   return new VectorCombineLegacyPass();
227a17f03bdSSanjay Patel }
228a17f03bdSSanjay Patel 
229a17f03bdSSanjay Patel PreservedAnalyses VectorCombinePass::run(Function &F,
230a17f03bdSSanjay Patel                                          FunctionAnalysisManager &FAM) {
231a17f03bdSSanjay Patel   TargetTransformInfo &TTI = FAM.getResult<TargetIRAnalysis>(F);
232a17f03bdSSanjay Patel   DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F);
233a17f03bdSSanjay Patel   if (!runImpl(F, TTI, DT))
234a17f03bdSSanjay Patel     return PreservedAnalyses::all();
235a17f03bdSSanjay Patel   PreservedAnalyses PA;
236a17f03bdSSanjay Patel   PA.preserveSet<CFGAnalyses>();
237a17f03bdSSanjay Patel   PA.preserve<GlobalsAA>();
238a17f03bdSSanjay Patel   return PA;
239a17f03bdSSanjay Patel }
240