1a17f03bdSSanjay Patel //===------- VectorCombine.cpp - Optimize partial vector operations -------===//
2a17f03bdSSanjay Patel //
3a17f03bdSSanjay Patel // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4a17f03bdSSanjay Patel // See https://llvm.org/LICENSE.txt for license information.
5a17f03bdSSanjay Patel // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6a17f03bdSSanjay Patel //
7a17f03bdSSanjay Patel //===----------------------------------------------------------------------===//
8a17f03bdSSanjay Patel //
9a17f03bdSSanjay Patel // This pass optimizes scalar/vector interactions using target cost models. The
10a17f03bdSSanjay Patel // transforms implemented here may not fit in traditional loop-based or SLP
11a17f03bdSSanjay Patel // vectorization passes.
12a17f03bdSSanjay Patel //
13a17f03bdSSanjay Patel //===----------------------------------------------------------------------===//
14a17f03bdSSanjay Patel 
15a17f03bdSSanjay Patel #include "llvm/Transforms/Vectorize/VectorCombine.h"
16a17f03bdSSanjay Patel #include "llvm/ADT/Statistic.h"
17a17f03bdSSanjay Patel #include "llvm/Analysis/GlobalsModRef.h"
18a17f03bdSSanjay Patel #include "llvm/Analysis/TargetTransformInfo.h"
1919b62b79SSanjay Patel #include "llvm/Analysis/ValueTracking.h"
20a17f03bdSSanjay Patel #include "llvm/IR/Dominators.h"
21a17f03bdSSanjay Patel #include "llvm/IR/Function.h"
22a17f03bdSSanjay Patel #include "llvm/IR/IRBuilder.h"
23a17f03bdSSanjay Patel #include "llvm/IR/PatternMatch.h"
24a17f03bdSSanjay Patel #include "llvm/InitializePasses.h"
25a17f03bdSSanjay Patel #include "llvm/Pass.h"
2625c6544fSSanjay Patel #include "llvm/Support/CommandLine.h"
27a17f03bdSSanjay Patel #include "llvm/Transforms/Vectorize.h"
28a17f03bdSSanjay Patel #include "llvm/Transforms/Utils/Local.h"
29a17f03bdSSanjay Patel 
30a17f03bdSSanjay Patel using namespace llvm;
31a17f03bdSSanjay Patel using namespace llvm::PatternMatch;
32a17f03bdSSanjay Patel 
33a17f03bdSSanjay Patel #define DEBUG_TYPE "vector-combine"
34a17f03bdSSanjay Patel STATISTIC(NumVecCmp, "Number of vector compares formed");
3519b62b79SSanjay Patel STATISTIC(NumVecBO, "Number of vector binops formed");
36a17f03bdSSanjay Patel 
3725c6544fSSanjay Patel static cl::opt<bool> DisableVectorCombine(
3825c6544fSSanjay Patel     "disable-vector-combine", cl::init(false), cl::Hidden,
3925c6544fSSanjay Patel     cl::desc("Disable all vector combine transforms"));
4025c6544fSSanjay Patel 
4134e34855SSanjay Patel /// Compare the relative costs of extracts followed by scalar operation vs.
4234e34855SSanjay Patel /// vector operation followed by extract:
4334e34855SSanjay Patel /// opcode (extelt V0, C), (extelt V1, C) --> extelt (opcode V0, V1), C
4434e34855SSanjay Patel /// Unless the vector op is much more expensive than the scalar op, this
4534e34855SSanjay Patel /// eliminates an extract.
4634e34855SSanjay Patel static bool isExtractExtractCheap(Instruction *Ext0, Instruction *Ext1,
4734e34855SSanjay Patel                                   unsigned Opcode,
4834e34855SSanjay Patel                                   const TargetTransformInfo &TTI) {
49*4fa63fd4SAustin Kerbow   assert(isa<ConstantInt>(Ext0->getOperand(1)) &&
50*4fa63fd4SAustin Kerbow          (cast<ConstantInt>(Ext0->getOperand(1))->getZExtValue() ==
51*4fa63fd4SAustin Kerbow           cast<ConstantInt>(Ext1->getOperand(1))->getZExtValue()) &&
5234e34855SSanjay Patel          "Expected same constant extract index");
5334e34855SSanjay Patel 
5434e34855SSanjay Patel   Type *ScalarTy = Ext0->getType();
5534e34855SSanjay Patel   Type *VecTy = Ext0->getOperand(0)->getType();
5634e34855SSanjay Patel   int ScalarOpCost, VectorOpCost;
5734e34855SSanjay Patel 
5834e34855SSanjay Patel   // Get cost estimates for scalar and vector versions of the operation.
5934e34855SSanjay Patel   bool IsBinOp = Instruction::isBinaryOp(Opcode);
6034e34855SSanjay Patel   if (IsBinOp) {
6134e34855SSanjay Patel     ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, ScalarTy);
6234e34855SSanjay Patel     VectorOpCost = TTI.getArithmeticInstrCost(Opcode, VecTy);
6334e34855SSanjay Patel   } else {
6434e34855SSanjay Patel     assert((Opcode == Instruction::ICmp || Opcode == Instruction::FCmp) &&
6534e34855SSanjay Patel            "Expected a compare");
6634e34855SSanjay Patel     ScalarOpCost = TTI.getCmpSelInstrCost(Opcode, ScalarTy,
6734e34855SSanjay Patel                                           CmpInst::makeCmpResultType(ScalarTy));
6834e34855SSanjay Patel     VectorOpCost = TTI.getCmpSelInstrCost(Opcode, VecTy,
6934e34855SSanjay Patel                                           CmpInst::makeCmpResultType(VecTy));
7034e34855SSanjay Patel   }
7134e34855SSanjay Patel 
7234e34855SSanjay Patel   // Get cost estimate for the extract element. This cost will factor into
7334e34855SSanjay Patel   // both sequences.
7434e34855SSanjay Patel   unsigned ExtIndex = cast<ConstantInt>(Ext0->getOperand(1))->getZExtValue();
7534e34855SSanjay Patel   int ExtractCost = TTI.getVectorInstrCost(Instruction::ExtractElement,
7634e34855SSanjay Patel                                            VecTy, ExtIndex);
7734e34855SSanjay Patel 
7834e34855SSanjay Patel   // Extra uses of the extracts mean that we include those costs in the
7934e34855SSanjay Patel   // vector total because those instructions will not be eliminated.
80e9c79a7aSSanjay Patel   int OldCost, NewCost;
8134e34855SSanjay Patel   if (Ext0->getOperand(0) == Ext1->getOperand(0)) {
8234e34855SSanjay Patel     // Handle a special case. If the 2 operands are identical, adjust the
8334e34855SSanjay Patel     // formulas to account for that. The extra use charge allows for either the
8434e34855SSanjay Patel     // CSE'd pattern or an unoptimized form with identical values:
8534e34855SSanjay Patel     // opcode (extelt V, C), (extelt V, C) --> extelt (opcode V, V), C
8634e34855SSanjay Patel     bool HasUseTax = Ext0 == Ext1 ? !Ext0->hasNUses(2)
8734e34855SSanjay Patel                                   : !Ext0->hasOneUse() || !Ext1->hasOneUse();
88e9c79a7aSSanjay Patel     OldCost = ExtractCost + ScalarOpCost;
89e9c79a7aSSanjay Patel     NewCost = VectorOpCost + ExtractCost + HasUseTax * ExtractCost;
9034e34855SSanjay Patel   } else {
9134e34855SSanjay Patel     // Handle the general case. Each extract is actually a different value:
9234e34855SSanjay Patel     // opcode (extelt V0, C), (extelt V1, C) --> extelt (opcode V0, V1), C
93e9c79a7aSSanjay Patel     OldCost = 2 * ExtractCost + ScalarOpCost;
94e9c79a7aSSanjay Patel     NewCost = VectorOpCost + ExtractCost + !Ext0->hasOneUse() * ExtractCost +
9534e34855SSanjay Patel               !Ext1->hasOneUse() * ExtractCost;
9634e34855SSanjay Patel   }
9710ea01d8SSanjay Patel   // Aggressively form a vector op if the cost is equal because the transform
9810ea01d8SSanjay Patel   // may enable further optimization.
9910ea01d8SSanjay Patel   // Codegen can reverse this transform (scalarize) if it was not profitable.
10010ea01d8SSanjay Patel   return OldCost < NewCost;
10134e34855SSanjay Patel }
10234e34855SSanjay Patel 
103fc445589SSanjay Patel /// Try to reduce extract element costs by converting scalar compares to vector
104fc445589SSanjay Patel /// compares followed by extract.
105e9c79a7aSSanjay Patel /// cmp (ext0 V0, C), (ext1 V1, C)
106e9c79a7aSSanjay Patel static void foldExtExtCmp(Instruction *Ext0, Instruction *Ext1,
107fc445589SSanjay Patel                           Instruction &I, const TargetTransformInfo &TTI) {
108fc445589SSanjay Patel   assert(isa<CmpInst>(&I) && "Expected a compare");
109a17f03bdSSanjay Patel 
110a17f03bdSSanjay Patel   // cmp Pred (extelt V0, C), (extelt V1, C) --> extelt (cmp Pred V0, V1), C
111a17f03bdSSanjay Patel   ++NumVecCmp;
112a17f03bdSSanjay Patel   IRBuilder<> Builder(&I);
113fc445589SSanjay Patel   CmpInst::Predicate Pred = cast<CmpInst>(&I)->getPredicate();
114e9c79a7aSSanjay Patel   Value *V0 = Ext0->getOperand(0), *V1 = Ext1->getOperand(0);
11534e34855SSanjay Patel   Value *VecCmp =
11634e34855SSanjay Patel       Ext0->getType()->isFloatingPointTy() ? Builder.CreateFCmp(Pred, V0, V1)
117a17f03bdSSanjay Patel                                            : Builder.CreateICmp(Pred, V0, V1);
118fc445589SSanjay Patel   Value *Extract = Builder.CreateExtractElement(VecCmp, Ext0->getOperand(1));
119fc445589SSanjay Patel   I.replaceAllUsesWith(Extract);
120a17f03bdSSanjay Patel }
121a17f03bdSSanjay Patel 
12219b62b79SSanjay Patel /// Try to reduce extract element costs by converting scalar binops to vector
12319b62b79SSanjay Patel /// binops followed by extract.
124e9c79a7aSSanjay Patel /// bo (ext0 V0, C), (ext1 V1, C)
125e9c79a7aSSanjay Patel static void foldExtExtBinop(Instruction *Ext0, Instruction *Ext1,
126fc445589SSanjay Patel                             Instruction &I, const TargetTransformInfo &TTI) {
127fc445589SSanjay Patel   assert(isa<BinaryOperator>(&I) && "Expected a binary operator");
12819b62b79SSanjay Patel 
12934e34855SSanjay Patel   // bo (extelt V0, C), (extelt V1, C) --> extelt (bo V0, V1), C
13019b62b79SSanjay Patel   ++NumVecBO;
13119b62b79SSanjay Patel   IRBuilder<> Builder(&I);
132e9c79a7aSSanjay Patel   Value *V0 = Ext0->getOperand(0), *V1 = Ext1->getOperand(0);
133e9c79a7aSSanjay Patel   Value *VecBO =
13434e34855SSanjay Patel       Builder.CreateBinOp(cast<BinaryOperator>(&I)->getOpcode(), V0, V1);
135e9c79a7aSSanjay Patel 
13619b62b79SSanjay Patel   // All IR flags are safe to back-propagate because any potential poison
13719b62b79SSanjay Patel   // created in unused vector elements is discarded by the extract.
138e9c79a7aSSanjay Patel   if (auto *VecBOInst = dyn_cast<Instruction>(VecBO))
13919b62b79SSanjay Patel     VecBOInst->copyIRFlags(&I);
140e9c79a7aSSanjay Patel 
141e9c79a7aSSanjay Patel   Value *Extract = Builder.CreateExtractElement(VecBO, Ext0->getOperand(1));
14219b62b79SSanjay Patel   I.replaceAllUsesWith(Extract);
14319b62b79SSanjay Patel }
14419b62b79SSanjay Patel 
145fc445589SSanjay Patel /// Match an instruction with extracted vector operands.
146fc445589SSanjay Patel static bool foldExtractExtract(Instruction &I, const TargetTransformInfo &TTI) {
147e9c79a7aSSanjay Patel   // It is not safe to transform things like div, urem, etc. because we may
148e9c79a7aSSanjay Patel   // create undefined behavior when executing those on unknown vector elements.
149e9c79a7aSSanjay Patel   if (!isSafeToSpeculativelyExecute(&I))
150e9c79a7aSSanjay Patel     return false;
151e9c79a7aSSanjay Patel 
152fc445589SSanjay Patel   Instruction *Ext0, *Ext1;
153fc445589SSanjay Patel   CmpInst::Predicate Pred = CmpInst::BAD_ICMP_PREDICATE;
154fc445589SSanjay Patel   if (!match(&I, m_Cmp(Pred, m_Instruction(Ext0), m_Instruction(Ext1))) &&
155fc445589SSanjay Patel       !match(&I, m_BinOp(m_Instruction(Ext0), m_Instruction(Ext1))))
156fc445589SSanjay Patel     return false;
157fc445589SSanjay Patel 
158fc445589SSanjay Patel   Value *V0, *V1;
159fc445589SSanjay Patel   uint64_t C0, C1;
160fc445589SSanjay Patel   if (!match(Ext0, m_ExtractElement(m_Value(V0), m_ConstantInt(C0))) ||
161fc445589SSanjay Patel       !match(Ext1, m_ExtractElement(m_Value(V1), m_ConstantInt(C1))) ||
162fc445589SSanjay Patel       V0->getType() != V1->getType())
163fc445589SSanjay Patel     return false;
164fc445589SSanjay Patel 
165e9c79a7aSSanjay Patel   // TODO: Handle C0 != C1 by shuffling 1 of the operands.
166e9c79a7aSSanjay Patel   if (C0 != C1)
167fc445589SSanjay Patel     return false;
168e9c79a7aSSanjay Patel 
169e9c79a7aSSanjay Patel   if (isExtractExtractCheap(Ext0, Ext1, I.getOpcode(), TTI))
170e9c79a7aSSanjay Patel     return false;
171e9c79a7aSSanjay Patel 
172e9c79a7aSSanjay Patel   if (Pred != CmpInst::BAD_ICMP_PREDICATE)
173e9c79a7aSSanjay Patel     foldExtExtCmp(Ext0, Ext1, I, TTI);
174e9c79a7aSSanjay Patel   else
175e9c79a7aSSanjay Patel     foldExtExtBinop(Ext0, Ext1, I, TTI);
176e9c79a7aSSanjay Patel 
177e9c79a7aSSanjay Patel   return true;
178fc445589SSanjay Patel }
179fc445589SSanjay Patel 
180a17f03bdSSanjay Patel /// This is the entry point for all transforms. Pass manager differences are
181a17f03bdSSanjay Patel /// handled in the callers of this function.
182a17f03bdSSanjay Patel static bool runImpl(Function &F, const TargetTransformInfo &TTI,
183a17f03bdSSanjay Patel                     const DominatorTree &DT) {
18425c6544fSSanjay Patel   if (DisableVectorCombine)
18525c6544fSSanjay Patel     return false;
18625c6544fSSanjay Patel 
187a17f03bdSSanjay Patel   bool MadeChange = false;
188a17f03bdSSanjay Patel   for (BasicBlock &BB : F) {
189a17f03bdSSanjay Patel     // Ignore unreachable basic blocks.
190a17f03bdSSanjay Patel     if (!DT.isReachableFromEntry(&BB))
191a17f03bdSSanjay Patel       continue;
192a17f03bdSSanjay Patel     // Do not delete instructions under here and invalidate the iterator.
193a17f03bdSSanjay Patel     // Walk the block backwards for efficiency. We're matching a chain of
194a17f03bdSSanjay Patel     // use->defs, so we're more likely to succeed by starting from the bottom.
195a17f03bdSSanjay Patel     // TODO: It could be more efficient to remove dead instructions
196a17f03bdSSanjay Patel     //       iteratively in this loop rather than waiting until the end.
197fc445589SSanjay Patel     for (Instruction &I : make_range(BB.rbegin(), BB.rend()))
198fc445589SSanjay Patel       MadeChange |= foldExtractExtract(I, TTI);
199a17f03bdSSanjay Patel   }
200a17f03bdSSanjay Patel 
201a17f03bdSSanjay Patel   // We're done with transforms, so remove dead instructions.
202a17f03bdSSanjay Patel   if (MadeChange)
203a17f03bdSSanjay Patel     for (BasicBlock &BB : F)
204a17f03bdSSanjay Patel       SimplifyInstructionsInBlock(&BB);
205a17f03bdSSanjay Patel 
206a17f03bdSSanjay Patel   return MadeChange;
207a17f03bdSSanjay Patel }
208a17f03bdSSanjay Patel 
209a17f03bdSSanjay Patel // Pass manager boilerplate below here.
210a17f03bdSSanjay Patel 
211a17f03bdSSanjay Patel namespace {
212a17f03bdSSanjay Patel class VectorCombineLegacyPass : public FunctionPass {
213a17f03bdSSanjay Patel public:
214a17f03bdSSanjay Patel   static char ID;
215a17f03bdSSanjay Patel   VectorCombineLegacyPass() : FunctionPass(ID) {
216a17f03bdSSanjay Patel     initializeVectorCombineLegacyPassPass(*PassRegistry::getPassRegistry());
217a17f03bdSSanjay Patel   }
218a17f03bdSSanjay Patel 
219a17f03bdSSanjay Patel   void getAnalysisUsage(AnalysisUsage &AU) const override {
220a17f03bdSSanjay Patel     AU.addRequired<DominatorTreeWrapperPass>();
221a17f03bdSSanjay Patel     AU.addRequired<TargetTransformInfoWrapperPass>();
222a17f03bdSSanjay Patel     AU.setPreservesCFG();
223a17f03bdSSanjay Patel     AU.addPreserved<DominatorTreeWrapperPass>();
224a17f03bdSSanjay Patel     AU.addPreserved<GlobalsAAWrapperPass>();
225a17f03bdSSanjay Patel     FunctionPass::getAnalysisUsage(AU);
226a17f03bdSSanjay Patel   }
227a17f03bdSSanjay Patel 
228a17f03bdSSanjay Patel   bool runOnFunction(Function &F) override {
229a17f03bdSSanjay Patel     if (skipFunction(F))
230a17f03bdSSanjay Patel       return false;
231a17f03bdSSanjay Patel     auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
232a17f03bdSSanjay Patel     auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
233a17f03bdSSanjay Patel     return runImpl(F, TTI, DT);
234a17f03bdSSanjay Patel   }
235a17f03bdSSanjay Patel };
236a17f03bdSSanjay Patel } // namespace
237a17f03bdSSanjay Patel 
238a17f03bdSSanjay Patel char VectorCombineLegacyPass::ID = 0;
239a17f03bdSSanjay Patel INITIALIZE_PASS_BEGIN(VectorCombineLegacyPass, "vector-combine",
240a17f03bdSSanjay Patel                       "Optimize scalar/vector ops", false,
241a17f03bdSSanjay Patel                       false)
242a17f03bdSSanjay Patel INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
243a17f03bdSSanjay Patel INITIALIZE_PASS_END(VectorCombineLegacyPass, "vector-combine",
244a17f03bdSSanjay Patel                     "Optimize scalar/vector ops", false, false)
245a17f03bdSSanjay Patel Pass *llvm::createVectorCombinePass() {
246a17f03bdSSanjay Patel   return new VectorCombineLegacyPass();
247a17f03bdSSanjay Patel }
248a17f03bdSSanjay Patel 
249a17f03bdSSanjay Patel PreservedAnalyses VectorCombinePass::run(Function &F,
250a17f03bdSSanjay Patel                                          FunctionAnalysisManager &FAM) {
251a17f03bdSSanjay Patel   TargetTransformInfo &TTI = FAM.getResult<TargetIRAnalysis>(F);
252a17f03bdSSanjay Patel   DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F);
253a17f03bdSSanjay Patel   if (!runImpl(F, TTI, DT))
254a17f03bdSSanjay Patel     return PreservedAnalyses::all();
255a17f03bdSSanjay Patel   PreservedAnalyses PA;
256a17f03bdSSanjay Patel   PA.preserveSet<CFGAnalyses>();
257a17f03bdSSanjay Patel   PA.preserve<GlobalsAA>();
258a17f03bdSSanjay Patel   return PA;
259a17f03bdSSanjay Patel }
260