1a17f03bdSSanjay Patel //===------- VectorCombine.cpp - Optimize partial vector operations -------===//
2a17f03bdSSanjay Patel //
3a17f03bdSSanjay Patel // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4a17f03bdSSanjay Patel // See https://llvm.org/LICENSE.txt for license information.
5a17f03bdSSanjay Patel // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6a17f03bdSSanjay Patel //
7a17f03bdSSanjay Patel //===----------------------------------------------------------------------===//
8a17f03bdSSanjay Patel //
9a17f03bdSSanjay Patel // This pass optimizes scalar/vector interactions using target cost models. The
10a17f03bdSSanjay Patel // transforms implemented here may not fit in traditional loop-based or SLP
11a17f03bdSSanjay Patel // vectorization passes.
12a17f03bdSSanjay Patel //
13a17f03bdSSanjay Patel //===----------------------------------------------------------------------===//
14a17f03bdSSanjay Patel 
15a17f03bdSSanjay Patel #include "llvm/Transforms/Vectorize/VectorCombine.h"
16a17f03bdSSanjay Patel #include "llvm/ADT/Statistic.h"
17a17f03bdSSanjay Patel #include "llvm/Analysis/GlobalsModRef.h"
18a17f03bdSSanjay Patel #include "llvm/Analysis/TargetTransformInfo.h"
1919b62b79SSanjay Patel #include "llvm/Analysis/ValueTracking.h"
20a17f03bdSSanjay Patel #include "llvm/IR/Dominators.h"
21a17f03bdSSanjay Patel #include "llvm/IR/Function.h"
22a17f03bdSSanjay Patel #include "llvm/IR/IRBuilder.h"
23a17f03bdSSanjay Patel #include "llvm/IR/PatternMatch.h"
24a17f03bdSSanjay Patel #include "llvm/InitializePasses.h"
25a17f03bdSSanjay Patel #include "llvm/Pass.h"
26a17f03bdSSanjay Patel #include "llvm/Transforms/Vectorize.h"
27a17f03bdSSanjay Patel #include "llvm/Transforms/Utils/Local.h"
28a17f03bdSSanjay Patel 
29a17f03bdSSanjay Patel using namespace llvm;
30a17f03bdSSanjay Patel using namespace llvm::PatternMatch;
31a17f03bdSSanjay Patel 
32a17f03bdSSanjay Patel #define DEBUG_TYPE "vector-combine"
33a17f03bdSSanjay Patel STATISTIC(NumVecCmp, "Number of vector compares formed");
3419b62b79SSanjay Patel STATISTIC(NumVecBO, "Number of vector binops formed");
35a17f03bdSSanjay Patel 
36*34e34855SSanjay Patel /// Compare the relative costs of extracts followed by scalar operation vs.
37*34e34855SSanjay Patel /// vector operation followed by extract:
38*34e34855SSanjay Patel /// opcode (extelt V0, C), (extelt V1, C) --> extelt (opcode V0, V1), C
39*34e34855SSanjay Patel /// Unless the vector op is much more expensive than the scalar op, this
40*34e34855SSanjay Patel /// eliminates an extract.
41*34e34855SSanjay Patel static bool isExtractExtractCheap(Instruction *Ext0, Instruction *Ext1,
42*34e34855SSanjay Patel                                   unsigned Opcode,
43*34e34855SSanjay Patel                                   const TargetTransformInfo &TTI) {
44*34e34855SSanjay Patel   assert(Ext0->getOperand(1) == Ext1->getOperand(1) &&
45*34e34855SSanjay Patel          isa<ConstantInt>(Ext0->getOperand(1)) &&
46*34e34855SSanjay Patel          "Expected same constant extract index");
47*34e34855SSanjay Patel 
48*34e34855SSanjay Patel   Type *ScalarTy = Ext0->getType();
49*34e34855SSanjay Patel   Type *VecTy = Ext0->getOperand(0)->getType();
50*34e34855SSanjay Patel   int ScalarOpCost, VectorOpCost;
51*34e34855SSanjay Patel 
52*34e34855SSanjay Patel   // Get cost estimates for scalar and vector versions of the operation.
53*34e34855SSanjay Patel   bool IsBinOp = Instruction::isBinaryOp(Opcode);
54*34e34855SSanjay Patel   if (IsBinOp) {
55*34e34855SSanjay Patel     ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, ScalarTy);
56*34e34855SSanjay Patel     VectorOpCost = TTI.getArithmeticInstrCost(Opcode, VecTy);
57*34e34855SSanjay Patel   } else {
58*34e34855SSanjay Patel     assert((Opcode == Instruction::ICmp || Opcode == Instruction::FCmp) &&
59*34e34855SSanjay Patel            "Expected a compare");
60*34e34855SSanjay Patel     ScalarOpCost = TTI.getCmpSelInstrCost(Opcode, ScalarTy,
61*34e34855SSanjay Patel                                           CmpInst::makeCmpResultType(ScalarTy));
62*34e34855SSanjay Patel     VectorOpCost = TTI.getCmpSelInstrCost(Opcode, VecTy,
63*34e34855SSanjay Patel                                           CmpInst::makeCmpResultType(VecTy));
64*34e34855SSanjay Patel   }
65*34e34855SSanjay Patel 
66*34e34855SSanjay Patel   // Get cost estimate for the extract element. This cost will factor into
67*34e34855SSanjay Patel   // both sequences.
68*34e34855SSanjay Patel   unsigned ExtIndex = cast<ConstantInt>(Ext0->getOperand(1))->getZExtValue();
69*34e34855SSanjay Patel   int ExtractCost = TTI.getVectorInstrCost(Instruction::ExtractElement,
70*34e34855SSanjay Patel                                            VecTy, ExtIndex);
71*34e34855SSanjay Patel 
72*34e34855SSanjay Patel   // Extra uses of the extracts mean that we include those costs in the
73*34e34855SSanjay Patel   // vector total because those instructions will not be eliminated.
74*34e34855SSanjay Patel   int ScalarCost, VectorCost;
75*34e34855SSanjay Patel   if (Ext0->getOperand(0) == Ext1->getOperand(0)) {
76*34e34855SSanjay Patel     // Handle a special case. If the 2 operands are identical, adjust the
77*34e34855SSanjay Patel     // formulas to account for that. The extra use charge allows for either the
78*34e34855SSanjay Patel     // CSE'd pattern or an unoptimized form with identical values:
79*34e34855SSanjay Patel     // opcode (extelt V, C), (extelt V, C) --> extelt (opcode V, V), C
80*34e34855SSanjay Patel     bool HasUseTax = Ext0 == Ext1 ? !Ext0->hasNUses(2)
81*34e34855SSanjay Patel                                   : !Ext0->hasOneUse() || !Ext1->hasOneUse();
82*34e34855SSanjay Patel     ScalarCost = ExtractCost + ScalarOpCost;
83*34e34855SSanjay Patel     VectorCost = VectorOpCost + ExtractCost + HasUseTax * ExtractCost;
84*34e34855SSanjay Patel   } else {
85*34e34855SSanjay Patel     // Handle the general case. Each extract is actually a different value:
86*34e34855SSanjay Patel     // opcode (extelt V0, C), (extelt V1, C) --> extelt (opcode V0, V1), C
87*34e34855SSanjay Patel     ScalarCost = 2 * ExtractCost + ScalarOpCost;
88*34e34855SSanjay Patel     VectorCost = VectorOpCost + ExtractCost +
89*34e34855SSanjay Patel                  !Ext0->hasOneUse() * ExtractCost +
90*34e34855SSanjay Patel                  !Ext1->hasOneUse() * ExtractCost;
91*34e34855SSanjay Patel   }
92*34e34855SSanjay Patel   // TODO: The cost comparison should not differ based on opcode. Either we
93*34e34855SSanjay Patel   //       want to be uniformly more or less aggressive in deciding if a vector
94*34e34855SSanjay Patel   //       operation should replace the scalar operation.
95*34e34855SSanjay Patel   return IsBinOp ? ScalarCost <= VectorCost : ScalarCost < VectorCost;
96*34e34855SSanjay Patel }
97*34e34855SSanjay Patel 
98fc445589SSanjay Patel /// Try to reduce extract element costs by converting scalar compares to vector
99fc445589SSanjay Patel /// compares followed by extract.
100fc445589SSanjay Patel /// cmp (ext0 V0, C0), (ext1 V1, C1)
101fc445589SSanjay Patel static bool foldExtExtCmp(Instruction *Ext0, Value *V0, uint64_t C0,
102fc445589SSanjay Patel                           Instruction *Ext1, Value *V1, uint64_t C1,
103fc445589SSanjay Patel                           Instruction &I, const TargetTransformInfo &TTI) {
104fc445589SSanjay Patel   assert(isa<CmpInst>(&I) && "Expected a compare");
105a17f03bdSSanjay Patel 
106fc445589SSanjay Patel   // TODO: Handle C0 != C1 by shuffling 1 of the operands.
107fc445589SSanjay Patel   if (C0 != C1)
108fc445589SSanjay Patel     return false;
109fc445589SSanjay Patel 
110*34e34855SSanjay Patel   if (isExtractExtractCheap(Ext0, Ext1, I.getOpcode(), TTI))
111a17f03bdSSanjay Patel     return false;
112a17f03bdSSanjay Patel 
113a17f03bdSSanjay Patel   // cmp Pred (extelt V0, C), (extelt V1, C) --> extelt (cmp Pred V0, V1), C
114a17f03bdSSanjay Patel   ++NumVecCmp;
115a17f03bdSSanjay Patel   IRBuilder<> Builder(&I);
116fc445589SSanjay Patel   CmpInst::Predicate Pred = cast<CmpInst>(&I)->getPredicate();
117*34e34855SSanjay Patel   Value *VecCmp =
118*34e34855SSanjay Patel       Ext0->getType()->isFloatingPointTy() ? Builder.CreateFCmp(Pred, V0, V1)
119a17f03bdSSanjay Patel                                            : Builder.CreateICmp(Pred, V0, V1);
120fc445589SSanjay Patel   Value *Extract = Builder.CreateExtractElement(VecCmp, Ext0->getOperand(1));
121fc445589SSanjay Patel   I.replaceAllUsesWith(Extract);
122a17f03bdSSanjay Patel   return true;
123a17f03bdSSanjay Patel }
124a17f03bdSSanjay Patel 
12519b62b79SSanjay Patel /// Try to reduce extract element costs by converting scalar binops to vector
12619b62b79SSanjay Patel /// binops followed by extract.
127fc445589SSanjay Patel /// bo (ext0 V0, C0), (ext1 V1, C1)
128fc445589SSanjay Patel static bool foldExtExtBinop(Instruction *Ext0, Value *V0, uint64_t C0,
129fc445589SSanjay Patel                             Instruction *Ext1, Value *V1, uint64_t C1,
130fc445589SSanjay Patel                             Instruction &I, const TargetTransformInfo &TTI) {
131fc445589SSanjay Patel   assert(isa<BinaryOperator>(&I) && "Expected a binary operator");
13219b62b79SSanjay Patel 
133*34e34855SSanjay Patel   // TODO: Handle C0 != C1 by shuffling 1 of the operands.
134*34e34855SSanjay Patel   if (C0 != C1)
13519b62b79SSanjay Patel     return false;
13619b62b79SSanjay Patel 
137*34e34855SSanjay Patel   if (isExtractExtractCheap(Ext0, Ext1, I.getOpcode(), TTI))
138*34e34855SSanjay Patel     return false;
139*34e34855SSanjay Patel 
140*34e34855SSanjay Patel   // bo (extelt V0, C), (extelt V1, C) --> extelt (bo V0, V1), C
14119b62b79SSanjay Patel   ++NumVecBO;
14219b62b79SSanjay Patel   IRBuilder<> Builder(&I);
143*34e34855SSanjay Patel   Value *NewBO =
144*34e34855SSanjay Patel       Builder.CreateBinOp(cast<BinaryOperator>(&I)->getOpcode(), V0, V1);
14519b62b79SSanjay Patel   if (auto *VecBOInst = dyn_cast<Instruction>(NewBO)) {
14619b62b79SSanjay Patel     // All IR flags are safe to back-propagate because any potential poison
14719b62b79SSanjay Patel     // created in unused vector elements is discarded by the extract.
14819b62b79SSanjay Patel     VecBOInst->copyIRFlags(&I);
14919b62b79SSanjay Patel   }
15019b62b79SSanjay Patel   Value *Extract = Builder.CreateExtractElement(NewBO, Ext0->getOperand(1));
15119b62b79SSanjay Patel   I.replaceAllUsesWith(Extract);
15219b62b79SSanjay Patel   return true;
15319b62b79SSanjay Patel }
15419b62b79SSanjay Patel 
155fc445589SSanjay Patel /// Match an instruction with extracted vector operands.
156fc445589SSanjay Patel static bool foldExtractExtract(Instruction &I, const TargetTransformInfo &TTI) {
157fc445589SSanjay Patel   Instruction *Ext0, *Ext1;
158fc445589SSanjay Patel   CmpInst::Predicate Pred = CmpInst::BAD_ICMP_PREDICATE;
159fc445589SSanjay Patel   if (!match(&I, m_Cmp(Pred, m_Instruction(Ext0), m_Instruction(Ext1))) &&
160fc445589SSanjay Patel       !match(&I, m_BinOp(m_Instruction(Ext0), m_Instruction(Ext1))))
161fc445589SSanjay Patel     return false;
162fc445589SSanjay Patel 
163fc445589SSanjay Patel   Value *V0, *V1;
164fc445589SSanjay Patel   uint64_t C0, C1;
165fc445589SSanjay Patel   if (!match(Ext0, m_ExtractElement(m_Value(V0), m_ConstantInt(C0))) ||
166fc445589SSanjay Patel       !match(Ext1, m_ExtractElement(m_Value(V1), m_ConstantInt(C1))) ||
167fc445589SSanjay Patel       V0->getType() != V1->getType())
168fc445589SSanjay Patel     return false;
169fc445589SSanjay Patel 
170fc445589SSanjay Patel   if (Pred != CmpInst::BAD_ICMP_PREDICATE)
171fc445589SSanjay Patel     return foldExtExtCmp(Ext0, V0, C0, Ext1, V1, C1, I, TTI);
172fc445589SSanjay Patel 
173fc445589SSanjay Patel   // It is not safe to transform things like div, urem, etc. because we may
174fc445589SSanjay Patel   // create undefined behavior when executing those on unknown vector elements.
175fc445589SSanjay Patel   if (isSafeToSpeculativelyExecute(&I))
176fc445589SSanjay Patel     return foldExtExtBinop(Ext0, V0, C0, Ext1, V1, C1, I, TTI);
177fc445589SSanjay Patel 
178fc445589SSanjay Patel   return false;
179fc445589SSanjay Patel }
180fc445589SSanjay Patel 
181a17f03bdSSanjay Patel /// This is the entry point for all transforms. Pass manager differences are
182a17f03bdSSanjay Patel /// handled in the callers of this function.
183a17f03bdSSanjay Patel static bool runImpl(Function &F, const TargetTransformInfo &TTI,
184a17f03bdSSanjay Patel                     const DominatorTree &DT) {
185a17f03bdSSanjay Patel   bool MadeChange = false;
186a17f03bdSSanjay Patel   for (BasicBlock &BB : F) {
187a17f03bdSSanjay Patel     // Ignore unreachable basic blocks.
188a17f03bdSSanjay Patel     if (!DT.isReachableFromEntry(&BB))
189a17f03bdSSanjay Patel       continue;
190a17f03bdSSanjay Patel     // Do not delete instructions under here and invalidate the iterator.
191a17f03bdSSanjay Patel     // Walk the block backwards for efficiency. We're matching a chain of
192a17f03bdSSanjay Patel     // use->defs, so we're more likely to succeed by starting from the bottom.
193a17f03bdSSanjay Patel     // TODO: It could be more efficient to remove dead instructions
194a17f03bdSSanjay Patel     //       iteratively in this loop rather than waiting until the end.
195fc445589SSanjay Patel     for (Instruction &I : make_range(BB.rbegin(), BB.rend()))
196fc445589SSanjay Patel       MadeChange |= foldExtractExtract(I, TTI);
197a17f03bdSSanjay Patel   }
198a17f03bdSSanjay Patel 
199a17f03bdSSanjay Patel   // We're done with transforms, so remove dead instructions.
200a17f03bdSSanjay Patel   if (MadeChange)
201a17f03bdSSanjay Patel     for (BasicBlock &BB : F)
202a17f03bdSSanjay Patel       SimplifyInstructionsInBlock(&BB);
203a17f03bdSSanjay Patel 
204a17f03bdSSanjay Patel   return MadeChange;
205a17f03bdSSanjay Patel }
206a17f03bdSSanjay Patel 
207a17f03bdSSanjay Patel // Pass manager boilerplate below here.
208a17f03bdSSanjay Patel 
209a17f03bdSSanjay Patel namespace {
210a17f03bdSSanjay Patel class VectorCombineLegacyPass : public FunctionPass {
211a17f03bdSSanjay Patel public:
212a17f03bdSSanjay Patel   static char ID;
213a17f03bdSSanjay Patel   VectorCombineLegacyPass() : FunctionPass(ID) {
214a17f03bdSSanjay Patel     initializeVectorCombineLegacyPassPass(*PassRegistry::getPassRegistry());
215a17f03bdSSanjay Patel   }
216a17f03bdSSanjay Patel 
217a17f03bdSSanjay Patel   void getAnalysisUsage(AnalysisUsage &AU) const override {
218a17f03bdSSanjay Patel     AU.addRequired<DominatorTreeWrapperPass>();
219a17f03bdSSanjay Patel     AU.addRequired<TargetTransformInfoWrapperPass>();
220a17f03bdSSanjay Patel     AU.setPreservesCFG();
221a17f03bdSSanjay Patel     AU.addPreserved<DominatorTreeWrapperPass>();
222a17f03bdSSanjay Patel     AU.addPreserved<GlobalsAAWrapperPass>();
223a17f03bdSSanjay Patel     FunctionPass::getAnalysisUsage(AU);
224a17f03bdSSanjay Patel   }
225a17f03bdSSanjay Patel 
226a17f03bdSSanjay Patel   bool runOnFunction(Function &F) override {
227a17f03bdSSanjay Patel     if (skipFunction(F))
228a17f03bdSSanjay Patel       return false;
229a17f03bdSSanjay Patel     auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
230a17f03bdSSanjay Patel     auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
231a17f03bdSSanjay Patel     return runImpl(F, TTI, DT);
232a17f03bdSSanjay Patel   }
233a17f03bdSSanjay Patel };
234a17f03bdSSanjay Patel } // namespace
235a17f03bdSSanjay Patel 
236a17f03bdSSanjay Patel char VectorCombineLegacyPass::ID = 0;
237a17f03bdSSanjay Patel INITIALIZE_PASS_BEGIN(VectorCombineLegacyPass, "vector-combine",
238a17f03bdSSanjay Patel                       "Optimize scalar/vector ops", false,
239a17f03bdSSanjay Patel                       false)
240a17f03bdSSanjay Patel INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
241a17f03bdSSanjay Patel INITIALIZE_PASS_END(VectorCombineLegacyPass, "vector-combine",
242a17f03bdSSanjay Patel                     "Optimize scalar/vector ops", false, false)
243a17f03bdSSanjay Patel Pass *llvm::createVectorCombinePass() {
244a17f03bdSSanjay Patel   return new VectorCombineLegacyPass();
245a17f03bdSSanjay Patel }
246a17f03bdSSanjay Patel 
247a17f03bdSSanjay Patel PreservedAnalyses VectorCombinePass::run(Function &F,
248a17f03bdSSanjay Patel                                          FunctionAnalysisManager &FAM) {
249a17f03bdSSanjay Patel   TargetTransformInfo &TTI = FAM.getResult<TargetIRAnalysis>(F);
250a17f03bdSSanjay Patel   DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F);
251a17f03bdSSanjay Patel   if (!runImpl(F, TTI, DT))
252a17f03bdSSanjay Patel     return PreservedAnalyses::all();
253a17f03bdSSanjay Patel   PreservedAnalyses PA;
254a17f03bdSSanjay Patel   PA.preserveSet<CFGAnalyses>();
255a17f03bdSSanjay Patel   PA.preserve<GlobalsAA>();
256a17f03bdSSanjay Patel   return PA;
257a17f03bdSSanjay Patel }
258