1 //===------- VectorCombine.cpp - Optimize partial vector operations -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass optimizes scalar/vector interactions using target cost models. The
10 // transforms implemented here may not fit in traditional loop-based or SLP
11 // vectorization passes.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "llvm/Transforms/Vectorize/VectorCombine.h"
16 #include "llvm/ADT/Statistic.h"
17 #include "llvm/Analysis/GlobalsModRef.h"
18 #include "llvm/Analysis/TargetTransformInfo.h"
19 #include "llvm/IR/Dominators.h"
20 #include "llvm/IR/Function.h"
21 #include "llvm/IR/IRBuilder.h"
22 #include "llvm/IR/PatternMatch.h"
23 #include "llvm/InitializePasses.h"
24 #include "llvm/Pass.h"
25 #include "llvm/Support/DebugCounter.h"
26 #include "llvm/Transforms/Vectorize.h"
27 #include "llvm/Transforms/Utils/Local.h"
28 
29 using namespace llvm;
30 using namespace llvm::PatternMatch;
31 
32 #define DEBUG_TYPE "vector-combine"
33 STATISTIC(NumVecCmp, "Number of vector compares formed");
34 DEBUG_COUNTER(VecCombineCounter, "vector-combine-transform",
35               "Controls transformations in vector-combine pass");
36 
37 static bool foldExtractCmp(Instruction &I, const TargetTransformInfo &TTI) {
38   // Match a cmp with extracted vector operands.
39   CmpInst::Predicate Pred;
40   Instruction *Ext0, *Ext1;
41   if (!match(&I, m_Cmp(Pred, m_Instruction(Ext0), m_Instruction(Ext1))))
42     return false;
43 
44   Value *V0, *V1;
45   ConstantInt *C;
46   if (!match(Ext0, m_ExtractElement(m_Value(V0), m_ConstantInt(C))) ||
47       !match(Ext1, m_ExtractElement(m_Value(V1), m_Specific(C))) ||
48       V0->getType() != V1->getType())
49     return false;
50 
51   Type *ScalarTy = Ext0->getType();
52   Type *VecTy = V0->getType();
53   bool IsFP = ScalarTy->isFloatingPointTy();
54   unsigned CmpOpcode = IsFP ? Instruction::FCmp : Instruction::ICmp;
55 
56   // Check if the existing scalar code or the vector alternative is cheaper.
57   // Extra uses of the extracts mean that we include those costs in the
58   // vector total because those instructions will not be eliminated.
59   // ((2 * extract) + scalar cmp) < (vector cmp + extract) ?
60   int ExtractCost = TTI.getVectorInstrCost(Instruction::ExtractElement,
61                                            VecTy, C->getZExtValue());
62   int ScalarCmpCost = TTI.getOperationCost(CmpOpcode, ScalarTy);
63   int VecCmpCost = TTI.getOperationCost(CmpOpcode, VecTy);
64 
65   int ScalarCost = 2 * ExtractCost + ScalarCmpCost;
66   int VecCost = VecCmpCost + ExtractCost +
67                 !Ext0->hasOneUse() * ExtractCost +
68                 !Ext1->hasOneUse() * ExtractCost;
69   if (ScalarCost < VecCost)
70     return false;
71 
72   // cmp Pred (extelt V0, C), (extelt V1, C) --> extelt (cmp Pred V0, V1), C
73   ++NumVecCmp;
74   IRBuilder<> Builder(&I);
75   Value *VecCmp = IsFP ? Builder.CreateFCmp(Pred, V0, V1)
76                        : Builder.CreateICmp(Pred, V0, V1);
77   Value *Ext = Builder.CreateExtractElement(VecCmp, C);
78   I.replaceAllUsesWith(Ext);
79   return true;
80 }
81 
82 /// This is the entry point for all transforms. Pass manager differences are
83 /// handled in the callers of this function.
84 static bool runImpl(Function &F, const TargetTransformInfo &TTI,
85                     const DominatorTree &DT) {
86   bool MadeChange = false;
87   for (BasicBlock &BB : F) {
88     // Ignore unreachable basic blocks.
89     if (!DT.isReachableFromEntry(&BB))
90       continue;
91     // Do not delete instructions under here and invalidate the iterator.
92     // Walk the block backwards for efficiency. We're matching a chain of
93     // use->defs, so we're more likely to succeed by starting from the bottom.
94     // TODO: It could be more efficient to remove dead instructions
95     //       iteratively in this loop rather than waiting until the end.
96     for (Instruction &I : make_range(BB.rbegin(), BB.rend())) {
97       MadeChange |= foldExtractCmp(I, TTI);
98       // TODO: More transforms go here.
99     }
100   }
101 
102   // We're done with transforms, so remove dead instructions.
103   if (MadeChange)
104     for (BasicBlock &BB : F)
105       SimplifyInstructionsInBlock(&BB);
106 
107   return MadeChange;
108 }
109 
110 // Pass manager boilerplate below here.
111 
112 namespace {
113 class VectorCombineLegacyPass : public FunctionPass {
114 public:
115   static char ID;
116   VectorCombineLegacyPass() : FunctionPass(ID) {
117     initializeVectorCombineLegacyPassPass(*PassRegistry::getPassRegistry());
118   }
119 
120   void getAnalysisUsage(AnalysisUsage &AU) const override {
121     AU.addRequired<DominatorTreeWrapperPass>();
122     AU.addRequired<TargetTransformInfoWrapperPass>();
123     AU.setPreservesCFG();
124     AU.addPreserved<DominatorTreeWrapperPass>();
125     AU.addPreserved<GlobalsAAWrapperPass>();
126     FunctionPass::getAnalysisUsage(AU);
127   }
128 
129   bool runOnFunction(Function &F) override {
130     if (skipFunction(F))
131       return false;
132     auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
133     auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
134     return runImpl(F, TTI, DT);
135   }
136 };
137 } // namespace
138 
139 char VectorCombineLegacyPass::ID = 0;
140 INITIALIZE_PASS_BEGIN(VectorCombineLegacyPass, "vector-combine",
141                       "Optimize scalar/vector ops", false,
142                       false)
143 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
144 INITIALIZE_PASS_END(VectorCombineLegacyPass, "vector-combine",
145                     "Optimize scalar/vector ops", false, false)
146 Pass *llvm::createVectorCombinePass() {
147   return new VectorCombineLegacyPass();
148 }
149 
150 PreservedAnalyses VectorCombinePass::run(Function &F,
151                                          FunctionAnalysisManager &FAM) {
152   TargetTransformInfo &TTI = FAM.getResult<TargetIRAnalysis>(F);
153   DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F);
154   if (!runImpl(F, TTI, DT))
155     return PreservedAnalyses::all();
156   PreservedAnalyses PA;
157   PA.preserveSet<CFGAnalyses>();
158   PA.preserve<GlobalsAA>();
159   return PA;
160 }
161