1*a17f03bdSSanjay Patel //===------- VectorCombine.cpp - Optimize partial vector operations -------===//
2*a17f03bdSSanjay Patel //
3*a17f03bdSSanjay Patel // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*a17f03bdSSanjay Patel // See https://llvm.org/LICENSE.txt for license information.
5*a17f03bdSSanjay Patel // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*a17f03bdSSanjay Patel //
7*a17f03bdSSanjay Patel //===----------------------------------------------------------------------===//
8*a17f03bdSSanjay Patel //
9*a17f03bdSSanjay Patel // This pass optimizes scalar/vector interactions using target cost models. The
10*a17f03bdSSanjay Patel // transforms implemented here may not fit in traditional loop-based or SLP
11*a17f03bdSSanjay Patel // vectorization passes.
12*a17f03bdSSanjay Patel //
13*a17f03bdSSanjay Patel //===----------------------------------------------------------------------===//
14*a17f03bdSSanjay Patel 
15*a17f03bdSSanjay Patel #include "llvm/Transforms/Vectorize/VectorCombine.h"
16*a17f03bdSSanjay Patel #include "llvm/ADT/Statistic.h"
17*a17f03bdSSanjay Patel #include "llvm/Analysis/GlobalsModRef.h"
18*a17f03bdSSanjay Patel #include "llvm/Analysis/TargetTransformInfo.h"
19*a17f03bdSSanjay Patel #include "llvm/IR/Dominators.h"
20*a17f03bdSSanjay Patel #include "llvm/IR/Function.h"
21*a17f03bdSSanjay Patel #include "llvm/IR/IRBuilder.h"
22*a17f03bdSSanjay Patel #include "llvm/IR/PatternMatch.h"
23*a17f03bdSSanjay Patel #include "llvm/InitializePasses.h"
24*a17f03bdSSanjay Patel #include "llvm/Pass.h"
25*a17f03bdSSanjay Patel #include "llvm/Support/DebugCounter.h"
26*a17f03bdSSanjay Patel #include "llvm/Transforms/Vectorize.h"
27*a17f03bdSSanjay Patel #include "llvm/Transforms/Utils/Local.h"
28*a17f03bdSSanjay Patel 
29*a17f03bdSSanjay Patel using namespace llvm;
30*a17f03bdSSanjay Patel using namespace llvm::PatternMatch;
31*a17f03bdSSanjay Patel 
32*a17f03bdSSanjay Patel #define DEBUG_TYPE "vector-combine"
33*a17f03bdSSanjay Patel STATISTIC(NumVecCmp, "Number of vector compares formed");
34*a17f03bdSSanjay Patel DEBUG_COUNTER(VecCombineCounter, "vector-combine-transform",
35*a17f03bdSSanjay Patel               "Controls transformations in vector-combine pass");
36*a17f03bdSSanjay Patel 
37*a17f03bdSSanjay Patel static bool foldExtractCmp(Instruction &I, const TargetTransformInfo &TTI) {
38*a17f03bdSSanjay Patel   // Match a cmp with extracted vector operands.
39*a17f03bdSSanjay Patel   CmpInst::Predicate Pred;
40*a17f03bdSSanjay Patel   Instruction *Ext0, *Ext1;
41*a17f03bdSSanjay Patel   if (!match(&I, m_Cmp(Pred, m_Instruction(Ext0), m_Instruction(Ext1))))
42*a17f03bdSSanjay Patel     return false;
43*a17f03bdSSanjay Patel 
44*a17f03bdSSanjay Patel   Value *V0, *V1;
45*a17f03bdSSanjay Patel   ConstantInt *C;
46*a17f03bdSSanjay Patel   if (!match(Ext0, m_ExtractElement(m_Value(V0), m_ConstantInt(C))) ||
47*a17f03bdSSanjay Patel       !match(Ext1, m_ExtractElement(m_Value(V1), m_Specific(C))) ||
48*a17f03bdSSanjay Patel       V0->getType() != V1->getType())
49*a17f03bdSSanjay Patel     return false;
50*a17f03bdSSanjay Patel 
51*a17f03bdSSanjay Patel   Type *ScalarTy = Ext0->getType();
52*a17f03bdSSanjay Patel   Type *VecTy = V0->getType();
53*a17f03bdSSanjay Patel   bool IsFP = ScalarTy->isFloatingPointTy();
54*a17f03bdSSanjay Patel   unsigned CmpOpcode = IsFP ? Instruction::FCmp : Instruction::ICmp;
55*a17f03bdSSanjay Patel 
56*a17f03bdSSanjay Patel   // Check if the existing scalar code or the vector alternative is cheaper.
57*a17f03bdSSanjay Patel   // Extra uses of the extracts mean that we include those costs in the
58*a17f03bdSSanjay Patel   // vector total because those instructions will not be eliminated.
59*a17f03bdSSanjay Patel   // ((2 * extract) + scalar cmp) < (vector cmp + extract) ?
60*a17f03bdSSanjay Patel   int ExtractCost = TTI.getVectorInstrCost(Instruction::ExtractElement,
61*a17f03bdSSanjay Patel                                            VecTy, C->getZExtValue());
62*a17f03bdSSanjay Patel   int ScalarCmpCost = TTI.getOperationCost(CmpOpcode, ScalarTy);
63*a17f03bdSSanjay Patel   int VecCmpCost = TTI.getOperationCost(CmpOpcode, VecTy);
64*a17f03bdSSanjay Patel 
65*a17f03bdSSanjay Patel   int ScalarCost = 2 * ExtractCost + ScalarCmpCost;
66*a17f03bdSSanjay Patel   int VecCost = VecCmpCost + ExtractCost +
67*a17f03bdSSanjay Patel                 !Ext0->hasOneUse() * ExtractCost +
68*a17f03bdSSanjay Patel                 !Ext1->hasOneUse() * ExtractCost;
69*a17f03bdSSanjay Patel   if (ScalarCost < VecCost)
70*a17f03bdSSanjay Patel     return false;
71*a17f03bdSSanjay Patel 
72*a17f03bdSSanjay Patel   // cmp Pred (extelt V0, C), (extelt V1, C) --> extelt (cmp Pred V0, V1), C
73*a17f03bdSSanjay Patel   ++NumVecCmp;
74*a17f03bdSSanjay Patel   IRBuilder<> Builder(&I);
75*a17f03bdSSanjay Patel   Value *VecCmp = IsFP ? Builder.CreateFCmp(Pred, V0, V1)
76*a17f03bdSSanjay Patel                        : Builder.CreateICmp(Pred, V0, V1);
77*a17f03bdSSanjay Patel   Value *Ext = Builder.CreateExtractElement(VecCmp, C);
78*a17f03bdSSanjay Patel   I.replaceAllUsesWith(Ext);
79*a17f03bdSSanjay Patel   return true;
80*a17f03bdSSanjay Patel }
81*a17f03bdSSanjay Patel 
82*a17f03bdSSanjay Patel /// This is the entry point for all transforms. Pass manager differences are
83*a17f03bdSSanjay Patel /// handled in the callers of this function.
84*a17f03bdSSanjay Patel static bool runImpl(Function &F, const TargetTransformInfo &TTI,
85*a17f03bdSSanjay Patel                     const DominatorTree &DT) {
86*a17f03bdSSanjay Patel   bool MadeChange = false;
87*a17f03bdSSanjay Patel   for (BasicBlock &BB : F) {
88*a17f03bdSSanjay Patel     // Ignore unreachable basic blocks.
89*a17f03bdSSanjay Patel     if (!DT.isReachableFromEntry(&BB))
90*a17f03bdSSanjay Patel       continue;
91*a17f03bdSSanjay Patel     // Do not delete instructions under here and invalidate the iterator.
92*a17f03bdSSanjay Patel     // Walk the block backwards for efficiency. We're matching a chain of
93*a17f03bdSSanjay Patel     // use->defs, so we're more likely to succeed by starting from the bottom.
94*a17f03bdSSanjay Patel     // TODO: It could be more efficient to remove dead instructions
95*a17f03bdSSanjay Patel     //       iteratively in this loop rather than waiting until the end.
96*a17f03bdSSanjay Patel     for (Instruction &I : make_range(BB.rbegin(), BB.rend())) {
97*a17f03bdSSanjay Patel       MadeChange |= foldExtractCmp(I, TTI);
98*a17f03bdSSanjay Patel       // TODO: More transforms go here.
99*a17f03bdSSanjay Patel     }
100*a17f03bdSSanjay Patel   }
101*a17f03bdSSanjay Patel 
102*a17f03bdSSanjay Patel   // We're done with transforms, so remove dead instructions.
103*a17f03bdSSanjay Patel   if (MadeChange)
104*a17f03bdSSanjay Patel     for (BasicBlock &BB : F)
105*a17f03bdSSanjay Patel       SimplifyInstructionsInBlock(&BB);
106*a17f03bdSSanjay Patel 
107*a17f03bdSSanjay Patel   return MadeChange;
108*a17f03bdSSanjay Patel }
109*a17f03bdSSanjay Patel 
110*a17f03bdSSanjay Patel // Pass manager boilerplate below here.
111*a17f03bdSSanjay Patel 
112*a17f03bdSSanjay Patel namespace {
113*a17f03bdSSanjay Patel class VectorCombineLegacyPass : public FunctionPass {
114*a17f03bdSSanjay Patel public:
115*a17f03bdSSanjay Patel   static char ID;
116*a17f03bdSSanjay Patel   VectorCombineLegacyPass() : FunctionPass(ID) {
117*a17f03bdSSanjay Patel     initializeVectorCombineLegacyPassPass(*PassRegistry::getPassRegistry());
118*a17f03bdSSanjay Patel   }
119*a17f03bdSSanjay Patel 
120*a17f03bdSSanjay Patel   void getAnalysisUsage(AnalysisUsage &AU) const override {
121*a17f03bdSSanjay Patel     AU.addRequired<DominatorTreeWrapperPass>();
122*a17f03bdSSanjay Patel     AU.addRequired<TargetTransformInfoWrapperPass>();
123*a17f03bdSSanjay Patel     AU.setPreservesCFG();
124*a17f03bdSSanjay Patel     AU.addPreserved<DominatorTreeWrapperPass>();
125*a17f03bdSSanjay Patel     AU.addPreserved<GlobalsAAWrapperPass>();
126*a17f03bdSSanjay Patel     FunctionPass::getAnalysisUsage(AU);
127*a17f03bdSSanjay Patel   }
128*a17f03bdSSanjay Patel 
129*a17f03bdSSanjay Patel   bool runOnFunction(Function &F) override {
130*a17f03bdSSanjay Patel     if (skipFunction(F))
131*a17f03bdSSanjay Patel       return false;
132*a17f03bdSSanjay Patel     auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
133*a17f03bdSSanjay Patel     auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
134*a17f03bdSSanjay Patel     return runImpl(F, TTI, DT);
135*a17f03bdSSanjay Patel   }
136*a17f03bdSSanjay Patel };
137*a17f03bdSSanjay Patel } // namespace
138*a17f03bdSSanjay Patel 
139*a17f03bdSSanjay Patel char VectorCombineLegacyPass::ID = 0;
140*a17f03bdSSanjay Patel INITIALIZE_PASS_BEGIN(VectorCombineLegacyPass, "vector-combine",
141*a17f03bdSSanjay Patel                       "Optimize scalar/vector ops", false,
142*a17f03bdSSanjay Patel                       false)
143*a17f03bdSSanjay Patel INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
144*a17f03bdSSanjay Patel INITIALIZE_PASS_END(VectorCombineLegacyPass, "vector-combine",
145*a17f03bdSSanjay Patel                     "Optimize scalar/vector ops", false, false)
146*a17f03bdSSanjay Patel Pass *llvm::createVectorCombinePass() {
147*a17f03bdSSanjay Patel   return new VectorCombineLegacyPass();
148*a17f03bdSSanjay Patel }
149*a17f03bdSSanjay Patel 
150*a17f03bdSSanjay Patel PreservedAnalyses VectorCombinePass::run(Function &F,
151*a17f03bdSSanjay Patel                                          FunctionAnalysisManager &FAM) {
152*a17f03bdSSanjay Patel   TargetTransformInfo &TTI = FAM.getResult<TargetIRAnalysis>(F);
153*a17f03bdSSanjay Patel   DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F);
154*a17f03bdSSanjay Patel   if (!runImpl(F, TTI, DT))
155*a17f03bdSSanjay Patel     return PreservedAnalyses::all();
156*a17f03bdSSanjay Patel   PreservedAnalyses PA;
157*a17f03bdSSanjay Patel   PA.preserveSet<CFGAnalyses>();
158*a17f03bdSSanjay Patel   PA.preserve<GlobalsAA>();
159*a17f03bdSSanjay Patel   return PA;
160*a17f03bdSSanjay Patel }
161