1a17f03bdSSanjay Patel //===------- VectorCombine.cpp - Optimize partial vector operations -------===//
2a17f03bdSSanjay Patel //
3a17f03bdSSanjay Patel // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4a17f03bdSSanjay Patel // See https://llvm.org/LICENSE.txt for license information.
5a17f03bdSSanjay Patel // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6a17f03bdSSanjay Patel //
7a17f03bdSSanjay Patel //===----------------------------------------------------------------------===//
8a17f03bdSSanjay Patel //
9a17f03bdSSanjay Patel // This pass optimizes scalar/vector interactions using target cost models. The
10a17f03bdSSanjay Patel // transforms implemented here may not fit in traditional loop-based or SLP
11a17f03bdSSanjay Patel // vectorization passes.
12a17f03bdSSanjay Patel //
13a17f03bdSSanjay Patel //===----------------------------------------------------------------------===//
14a17f03bdSSanjay Patel 
15a17f03bdSSanjay Patel #include "llvm/Transforms/Vectorize/VectorCombine.h"
16a17f03bdSSanjay Patel #include "llvm/ADT/Statistic.h"
17a17f03bdSSanjay Patel #include "llvm/Analysis/GlobalsModRef.h"
18a17f03bdSSanjay Patel #include "llvm/Analysis/TargetTransformInfo.h"
1919b62b79SSanjay Patel #include "llvm/Analysis/ValueTracking.h"
20a17f03bdSSanjay Patel #include "llvm/IR/Dominators.h"
21a17f03bdSSanjay Patel #include "llvm/IR/Function.h"
22a17f03bdSSanjay Patel #include "llvm/IR/IRBuilder.h"
23a17f03bdSSanjay Patel #include "llvm/IR/PatternMatch.h"
24a17f03bdSSanjay Patel #include "llvm/InitializePasses.h"
25a17f03bdSSanjay Patel #include "llvm/Pass.h"
26a17f03bdSSanjay Patel #include "llvm/Transforms/Vectorize.h"
27a17f03bdSSanjay Patel #include "llvm/Transforms/Utils/Local.h"
28a17f03bdSSanjay Patel 
29a17f03bdSSanjay Patel using namespace llvm;
30a17f03bdSSanjay Patel using namespace llvm::PatternMatch;
31a17f03bdSSanjay Patel 
32a17f03bdSSanjay Patel #define DEBUG_TYPE "vector-combine"
33a17f03bdSSanjay Patel STATISTIC(NumVecCmp, "Number of vector compares formed");
3419b62b79SSanjay Patel STATISTIC(NumVecBO, "Number of vector binops formed");
35a17f03bdSSanjay Patel 
3634e34855SSanjay Patel /// Compare the relative costs of extracts followed by scalar operation vs.
3734e34855SSanjay Patel /// vector operation followed by extract:
3834e34855SSanjay Patel /// opcode (extelt V0, C), (extelt V1, C) --> extelt (opcode V0, V1), C
3934e34855SSanjay Patel /// Unless the vector op is much more expensive than the scalar op, this
4034e34855SSanjay Patel /// eliminates an extract.
4134e34855SSanjay Patel static bool isExtractExtractCheap(Instruction *Ext0, Instruction *Ext1,
4234e34855SSanjay Patel                                   unsigned Opcode,
4334e34855SSanjay Patel                                   const TargetTransformInfo &TTI) {
4434e34855SSanjay Patel   assert(Ext0->getOperand(1) == Ext1->getOperand(1) &&
4534e34855SSanjay Patel          isa<ConstantInt>(Ext0->getOperand(1)) &&
4634e34855SSanjay Patel          "Expected same constant extract index");
4734e34855SSanjay Patel 
4834e34855SSanjay Patel   Type *ScalarTy = Ext0->getType();
4934e34855SSanjay Patel   Type *VecTy = Ext0->getOperand(0)->getType();
5034e34855SSanjay Patel   int ScalarOpCost, VectorOpCost;
5134e34855SSanjay Patel 
5234e34855SSanjay Patel   // Get cost estimates for scalar and vector versions of the operation.
5334e34855SSanjay Patel   bool IsBinOp = Instruction::isBinaryOp(Opcode);
5434e34855SSanjay Patel   if (IsBinOp) {
5534e34855SSanjay Patel     ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, ScalarTy);
5634e34855SSanjay Patel     VectorOpCost = TTI.getArithmeticInstrCost(Opcode, VecTy);
5734e34855SSanjay Patel   } else {
5834e34855SSanjay Patel     assert((Opcode == Instruction::ICmp || Opcode == Instruction::FCmp) &&
5934e34855SSanjay Patel            "Expected a compare");
6034e34855SSanjay Patel     ScalarOpCost = TTI.getCmpSelInstrCost(Opcode, ScalarTy,
6134e34855SSanjay Patel                                           CmpInst::makeCmpResultType(ScalarTy));
6234e34855SSanjay Patel     VectorOpCost = TTI.getCmpSelInstrCost(Opcode, VecTy,
6334e34855SSanjay Patel                                           CmpInst::makeCmpResultType(VecTy));
6434e34855SSanjay Patel   }
6534e34855SSanjay Patel 
6634e34855SSanjay Patel   // Get cost estimate for the extract element. This cost will factor into
6734e34855SSanjay Patel   // both sequences.
6834e34855SSanjay Patel   unsigned ExtIndex = cast<ConstantInt>(Ext0->getOperand(1))->getZExtValue();
6934e34855SSanjay Patel   int ExtractCost = TTI.getVectorInstrCost(Instruction::ExtractElement,
7034e34855SSanjay Patel                                            VecTy, ExtIndex);
7134e34855SSanjay Patel 
7234e34855SSanjay Patel   // Extra uses of the extracts mean that we include those costs in the
7334e34855SSanjay Patel   // vector total because those instructions will not be eliminated.
74e9c79a7aSSanjay Patel   int OldCost, NewCost;
7534e34855SSanjay Patel   if (Ext0->getOperand(0) == Ext1->getOperand(0)) {
7634e34855SSanjay Patel     // Handle a special case. If the 2 operands are identical, adjust the
7734e34855SSanjay Patel     // formulas to account for that. The extra use charge allows for either the
7834e34855SSanjay Patel     // CSE'd pattern or an unoptimized form with identical values:
7934e34855SSanjay Patel     // opcode (extelt V, C), (extelt V, C) --> extelt (opcode V, V), C
8034e34855SSanjay Patel     bool HasUseTax = Ext0 == Ext1 ? !Ext0->hasNUses(2)
8134e34855SSanjay Patel                                   : !Ext0->hasOneUse() || !Ext1->hasOneUse();
82e9c79a7aSSanjay Patel     OldCost = ExtractCost + ScalarOpCost;
83e9c79a7aSSanjay Patel     NewCost = VectorOpCost + ExtractCost + HasUseTax * ExtractCost;
8434e34855SSanjay Patel   } else {
8534e34855SSanjay Patel     // Handle the general case. Each extract is actually a different value:
8634e34855SSanjay Patel     // opcode (extelt V0, C), (extelt V1, C) --> extelt (opcode V0, V1), C
87e9c79a7aSSanjay Patel     OldCost = 2 * ExtractCost + ScalarOpCost;
88e9c79a7aSSanjay Patel     NewCost = VectorOpCost + ExtractCost + !Ext0->hasOneUse() * ExtractCost +
8934e34855SSanjay Patel               !Ext1->hasOneUse() * ExtractCost;
9034e34855SSanjay Patel   }
91*10ea01d8SSanjay Patel   // Aggressively form a vector op if the cost is equal because the transform
92*10ea01d8SSanjay Patel   // may enable further optimization.
93*10ea01d8SSanjay Patel   // Codegen can reverse this transform (scalarize) if it was not profitable.
94*10ea01d8SSanjay Patel   return OldCost < NewCost;
9534e34855SSanjay Patel }
9634e34855SSanjay Patel 
97fc445589SSanjay Patel /// Try to reduce extract element costs by converting scalar compares to vector
98fc445589SSanjay Patel /// compares followed by extract.
99e9c79a7aSSanjay Patel /// cmp (ext0 V0, C), (ext1 V1, C)
100e9c79a7aSSanjay Patel static void foldExtExtCmp(Instruction *Ext0, Instruction *Ext1,
101fc445589SSanjay Patel                           Instruction &I, const TargetTransformInfo &TTI) {
102fc445589SSanjay Patel   assert(isa<CmpInst>(&I) && "Expected a compare");
103a17f03bdSSanjay Patel 
104a17f03bdSSanjay Patel   // cmp Pred (extelt V0, C), (extelt V1, C) --> extelt (cmp Pred V0, V1), C
105a17f03bdSSanjay Patel   ++NumVecCmp;
106a17f03bdSSanjay Patel   IRBuilder<> Builder(&I);
107fc445589SSanjay Patel   CmpInst::Predicate Pred = cast<CmpInst>(&I)->getPredicate();
108e9c79a7aSSanjay Patel   Value *V0 = Ext0->getOperand(0), *V1 = Ext1->getOperand(0);
10934e34855SSanjay Patel   Value *VecCmp =
11034e34855SSanjay Patel       Ext0->getType()->isFloatingPointTy() ? Builder.CreateFCmp(Pred, V0, V1)
111a17f03bdSSanjay Patel                                            : Builder.CreateICmp(Pred, V0, V1);
112fc445589SSanjay Patel   Value *Extract = Builder.CreateExtractElement(VecCmp, Ext0->getOperand(1));
113fc445589SSanjay Patel   I.replaceAllUsesWith(Extract);
114a17f03bdSSanjay Patel }
115a17f03bdSSanjay Patel 
11619b62b79SSanjay Patel /// Try to reduce extract element costs by converting scalar binops to vector
11719b62b79SSanjay Patel /// binops followed by extract.
118e9c79a7aSSanjay Patel /// bo (ext0 V0, C), (ext1 V1, C)
119e9c79a7aSSanjay Patel static void foldExtExtBinop(Instruction *Ext0, Instruction *Ext1,
120fc445589SSanjay Patel                             Instruction &I, const TargetTransformInfo &TTI) {
121fc445589SSanjay Patel   assert(isa<BinaryOperator>(&I) && "Expected a binary operator");
12219b62b79SSanjay Patel 
12334e34855SSanjay Patel   // bo (extelt V0, C), (extelt V1, C) --> extelt (bo V0, V1), C
12419b62b79SSanjay Patel   ++NumVecBO;
12519b62b79SSanjay Patel   IRBuilder<> Builder(&I);
126e9c79a7aSSanjay Patel   Value *V0 = Ext0->getOperand(0), *V1 = Ext1->getOperand(0);
127e9c79a7aSSanjay Patel   Value *VecBO =
12834e34855SSanjay Patel       Builder.CreateBinOp(cast<BinaryOperator>(&I)->getOpcode(), V0, V1);
129e9c79a7aSSanjay Patel 
13019b62b79SSanjay Patel   // All IR flags are safe to back-propagate because any potential poison
13119b62b79SSanjay Patel   // created in unused vector elements is discarded by the extract.
132e9c79a7aSSanjay Patel   if (auto *VecBOInst = dyn_cast<Instruction>(VecBO))
13319b62b79SSanjay Patel     VecBOInst->copyIRFlags(&I);
134e9c79a7aSSanjay Patel 
135e9c79a7aSSanjay Patel   Value *Extract = Builder.CreateExtractElement(VecBO, Ext0->getOperand(1));
13619b62b79SSanjay Patel   I.replaceAllUsesWith(Extract);
13719b62b79SSanjay Patel }
13819b62b79SSanjay Patel 
139fc445589SSanjay Patel /// Match an instruction with extracted vector operands.
140fc445589SSanjay Patel static bool foldExtractExtract(Instruction &I, const TargetTransformInfo &TTI) {
141e9c79a7aSSanjay Patel   // It is not safe to transform things like div, urem, etc. because we may
142e9c79a7aSSanjay Patel   // create undefined behavior when executing those on unknown vector elements.
143e9c79a7aSSanjay Patel   if (!isSafeToSpeculativelyExecute(&I))
144e9c79a7aSSanjay Patel     return false;
145e9c79a7aSSanjay Patel 
146fc445589SSanjay Patel   Instruction *Ext0, *Ext1;
147fc445589SSanjay Patel   CmpInst::Predicate Pred = CmpInst::BAD_ICMP_PREDICATE;
148fc445589SSanjay Patel   if (!match(&I, m_Cmp(Pred, m_Instruction(Ext0), m_Instruction(Ext1))) &&
149fc445589SSanjay Patel       !match(&I, m_BinOp(m_Instruction(Ext0), m_Instruction(Ext1))))
150fc445589SSanjay Patel     return false;
151fc445589SSanjay Patel 
152fc445589SSanjay Patel   Value *V0, *V1;
153fc445589SSanjay Patel   uint64_t C0, C1;
154fc445589SSanjay Patel   if (!match(Ext0, m_ExtractElement(m_Value(V0), m_ConstantInt(C0))) ||
155fc445589SSanjay Patel       !match(Ext1, m_ExtractElement(m_Value(V1), m_ConstantInt(C1))) ||
156fc445589SSanjay Patel       V0->getType() != V1->getType())
157fc445589SSanjay Patel     return false;
158fc445589SSanjay Patel 
159e9c79a7aSSanjay Patel   // TODO: Handle C0 != C1 by shuffling 1 of the operands.
160e9c79a7aSSanjay Patel   if (C0 != C1)
161fc445589SSanjay Patel     return false;
162e9c79a7aSSanjay Patel 
163e9c79a7aSSanjay Patel   if (isExtractExtractCheap(Ext0, Ext1, I.getOpcode(), TTI))
164e9c79a7aSSanjay Patel     return false;
165e9c79a7aSSanjay Patel 
166e9c79a7aSSanjay Patel   if (Pred != CmpInst::BAD_ICMP_PREDICATE)
167e9c79a7aSSanjay Patel     foldExtExtCmp(Ext0, Ext1, I, TTI);
168e9c79a7aSSanjay Patel   else
169e9c79a7aSSanjay Patel     foldExtExtBinop(Ext0, Ext1, I, TTI);
170e9c79a7aSSanjay Patel 
171e9c79a7aSSanjay Patel   return true;
172fc445589SSanjay Patel }
173fc445589SSanjay Patel 
174a17f03bdSSanjay Patel /// This is the entry point for all transforms. Pass manager differences are
175a17f03bdSSanjay Patel /// handled in the callers of this function.
176a17f03bdSSanjay Patel static bool runImpl(Function &F, const TargetTransformInfo &TTI,
177a17f03bdSSanjay Patel                     const DominatorTree &DT) {
178a17f03bdSSanjay Patel   bool MadeChange = false;
179a17f03bdSSanjay Patel   for (BasicBlock &BB : F) {
180a17f03bdSSanjay Patel     // Ignore unreachable basic blocks.
181a17f03bdSSanjay Patel     if (!DT.isReachableFromEntry(&BB))
182a17f03bdSSanjay Patel       continue;
183a17f03bdSSanjay Patel     // Do not delete instructions under here and invalidate the iterator.
184a17f03bdSSanjay Patel     // Walk the block backwards for efficiency. We're matching a chain of
185a17f03bdSSanjay Patel     // use->defs, so we're more likely to succeed by starting from the bottom.
186a17f03bdSSanjay Patel     // TODO: It could be more efficient to remove dead instructions
187a17f03bdSSanjay Patel     //       iteratively in this loop rather than waiting until the end.
188fc445589SSanjay Patel     for (Instruction &I : make_range(BB.rbegin(), BB.rend()))
189fc445589SSanjay Patel       MadeChange |= foldExtractExtract(I, TTI);
190a17f03bdSSanjay Patel   }
191a17f03bdSSanjay Patel 
192a17f03bdSSanjay Patel   // We're done with transforms, so remove dead instructions.
193a17f03bdSSanjay Patel   if (MadeChange)
194a17f03bdSSanjay Patel     for (BasicBlock &BB : F)
195a17f03bdSSanjay Patel       SimplifyInstructionsInBlock(&BB);
196a17f03bdSSanjay Patel 
197a17f03bdSSanjay Patel   return MadeChange;
198a17f03bdSSanjay Patel }
199a17f03bdSSanjay Patel 
200a17f03bdSSanjay Patel // Pass manager boilerplate below here.
201a17f03bdSSanjay Patel 
202a17f03bdSSanjay Patel namespace {
203a17f03bdSSanjay Patel class VectorCombineLegacyPass : public FunctionPass {
204a17f03bdSSanjay Patel public:
205a17f03bdSSanjay Patel   static char ID;
206a17f03bdSSanjay Patel   VectorCombineLegacyPass() : FunctionPass(ID) {
207a17f03bdSSanjay Patel     initializeVectorCombineLegacyPassPass(*PassRegistry::getPassRegistry());
208a17f03bdSSanjay Patel   }
209a17f03bdSSanjay Patel 
210a17f03bdSSanjay Patel   void getAnalysisUsage(AnalysisUsage &AU) const override {
211a17f03bdSSanjay Patel     AU.addRequired<DominatorTreeWrapperPass>();
212a17f03bdSSanjay Patel     AU.addRequired<TargetTransformInfoWrapperPass>();
213a17f03bdSSanjay Patel     AU.setPreservesCFG();
214a17f03bdSSanjay Patel     AU.addPreserved<DominatorTreeWrapperPass>();
215a17f03bdSSanjay Patel     AU.addPreserved<GlobalsAAWrapperPass>();
216a17f03bdSSanjay Patel     FunctionPass::getAnalysisUsage(AU);
217a17f03bdSSanjay Patel   }
218a17f03bdSSanjay Patel 
219a17f03bdSSanjay Patel   bool runOnFunction(Function &F) override {
220a17f03bdSSanjay Patel     if (skipFunction(F))
221a17f03bdSSanjay Patel       return false;
222a17f03bdSSanjay Patel     auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
223a17f03bdSSanjay Patel     auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
224a17f03bdSSanjay Patel     return runImpl(F, TTI, DT);
225a17f03bdSSanjay Patel   }
226a17f03bdSSanjay Patel };
227a17f03bdSSanjay Patel } // namespace
228a17f03bdSSanjay Patel 
229a17f03bdSSanjay Patel char VectorCombineLegacyPass::ID = 0;
230a17f03bdSSanjay Patel INITIALIZE_PASS_BEGIN(VectorCombineLegacyPass, "vector-combine",
231a17f03bdSSanjay Patel                       "Optimize scalar/vector ops", false,
232a17f03bdSSanjay Patel                       false)
233a17f03bdSSanjay Patel INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
234a17f03bdSSanjay Patel INITIALIZE_PASS_END(VectorCombineLegacyPass, "vector-combine",
235a17f03bdSSanjay Patel                     "Optimize scalar/vector ops", false, false)
236a17f03bdSSanjay Patel Pass *llvm::createVectorCombinePass() {
237a17f03bdSSanjay Patel   return new VectorCombineLegacyPass();
238a17f03bdSSanjay Patel }
239a17f03bdSSanjay Patel 
240a17f03bdSSanjay Patel PreservedAnalyses VectorCombinePass::run(Function &F,
241a17f03bdSSanjay Patel                                          FunctionAnalysisManager &FAM) {
242a17f03bdSSanjay Patel   TargetTransformInfo &TTI = FAM.getResult<TargetIRAnalysis>(F);
243a17f03bdSSanjay Patel   DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F);
244a17f03bdSSanjay Patel   if (!runImpl(F, TTI, DT))
245a17f03bdSSanjay Patel     return PreservedAnalyses::all();
246a17f03bdSSanjay Patel   PreservedAnalyses PA;
247a17f03bdSSanjay Patel   PA.preserveSet<CFGAnalyses>();
248a17f03bdSSanjay Patel   PA.preserve<GlobalsAA>();
249a17f03bdSSanjay Patel   return PA;
250a17f03bdSSanjay Patel }
251