1 //===------- VectorCombine.cpp - Optimize partial vector operations -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass optimizes scalar/vector interactions using target cost models. The
10 // transforms implemented here may not fit in traditional loop-based or SLP
11 // vectorization passes.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "llvm/Transforms/Vectorize/VectorCombine.h"
16 #include "llvm/ADT/Statistic.h"
17 #include "llvm/Analysis/GlobalsModRef.h"
18 #include "llvm/Analysis/TargetTransformInfo.h"
19 #include "llvm/IR/Dominators.h"
20 #include "llvm/IR/Function.h"
21 #include "llvm/IR/IRBuilder.h"
22 #include "llvm/IR/PatternMatch.h"
23 #include "llvm/InitializePasses.h"
24 #include "llvm/Pass.h"
25 #include "llvm/Transforms/Vectorize.h"
26 #include "llvm/Transforms/Utils/Local.h"
27 
28 using namespace llvm;
29 using namespace llvm::PatternMatch;
30 
31 #define DEBUG_TYPE "vector-combine"
32 STATISTIC(NumVecCmp, "Number of vector compares formed");
33 
34 static bool foldExtractCmp(Instruction &I, const TargetTransformInfo &TTI) {
35   // Match a cmp with extracted vector operands.
36   CmpInst::Predicate Pred;
37   Instruction *Ext0, *Ext1;
38   if (!match(&I, m_Cmp(Pred, m_Instruction(Ext0), m_Instruction(Ext1))))
39     return false;
40 
41   Value *V0, *V1;
42   ConstantInt *C;
43   if (!match(Ext0, m_ExtractElement(m_Value(V0), m_ConstantInt(C))) ||
44       !match(Ext1, m_ExtractElement(m_Value(V1), m_Specific(C))) ||
45       V0->getType() != V1->getType())
46     return false;
47 
48   Type *ScalarTy = Ext0->getType();
49   Type *VecTy = V0->getType();
50   bool IsFP = ScalarTy->isFloatingPointTy();
51   unsigned CmpOpcode = IsFP ? Instruction::FCmp : Instruction::ICmp;
52 
53   // Check if the existing scalar code or the vector alternative is cheaper.
54   // Extra uses of the extracts mean that we include those costs in the
55   // vector total because those instructions will not be eliminated.
56   // ((2 * extract) + scalar cmp) < (vector cmp + extract) ?
57   int ExtractCost = TTI.getVectorInstrCost(Instruction::ExtractElement,
58                                            VecTy, C->getZExtValue());
59   int ScalarCmpCost = TTI.getOperationCost(CmpOpcode, ScalarTy);
60   int VecCmpCost = TTI.getOperationCost(CmpOpcode, VecTy);
61 
62   int ScalarCost = 2 * ExtractCost + ScalarCmpCost;
63   int VecCost = VecCmpCost + ExtractCost +
64                 !Ext0->hasOneUse() * ExtractCost +
65                 !Ext1->hasOneUse() * ExtractCost;
66   if (ScalarCost < VecCost)
67     return false;
68 
69   // cmp Pred (extelt V0, C), (extelt V1, C) --> extelt (cmp Pred V0, V1), C
70   ++NumVecCmp;
71   IRBuilder<> Builder(&I);
72   Value *VecCmp = IsFP ? Builder.CreateFCmp(Pred, V0, V1)
73                        : Builder.CreateICmp(Pred, V0, V1);
74   Value *Ext = Builder.CreateExtractElement(VecCmp, C);
75   I.replaceAllUsesWith(Ext);
76   return true;
77 }
78 
79 /// This is the entry point for all transforms. Pass manager differences are
80 /// handled in the callers of this function.
81 static bool runImpl(Function &F, const TargetTransformInfo &TTI,
82                     const DominatorTree &DT) {
83   bool MadeChange = false;
84   for (BasicBlock &BB : F) {
85     // Ignore unreachable basic blocks.
86     if (!DT.isReachableFromEntry(&BB))
87       continue;
88     // Do not delete instructions under here and invalidate the iterator.
89     // Walk the block backwards for efficiency. We're matching a chain of
90     // use->defs, so we're more likely to succeed by starting from the bottom.
91     // TODO: It could be more efficient to remove dead instructions
92     //       iteratively in this loop rather than waiting until the end.
93     for (Instruction &I : make_range(BB.rbegin(), BB.rend())) {
94       MadeChange |= foldExtractCmp(I, TTI);
95       // TODO: More transforms go here.
96     }
97   }
98 
99   // We're done with transforms, so remove dead instructions.
100   if (MadeChange)
101     for (BasicBlock &BB : F)
102       SimplifyInstructionsInBlock(&BB);
103 
104   return MadeChange;
105 }
106 
107 // Pass manager boilerplate below here.
108 
109 namespace {
110 class VectorCombineLegacyPass : public FunctionPass {
111 public:
112   static char ID;
113   VectorCombineLegacyPass() : FunctionPass(ID) {
114     initializeVectorCombineLegacyPassPass(*PassRegistry::getPassRegistry());
115   }
116 
117   void getAnalysisUsage(AnalysisUsage &AU) const override {
118     AU.addRequired<DominatorTreeWrapperPass>();
119     AU.addRequired<TargetTransformInfoWrapperPass>();
120     AU.setPreservesCFG();
121     AU.addPreserved<DominatorTreeWrapperPass>();
122     AU.addPreserved<GlobalsAAWrapperPass>();
123     FunctionPass::getAnalysisUsage(AU);
124   }
125 
126   bool runOnFunction(Function &F) override {
127     if (skipFunction(F))
128       return false;
129     auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
130     auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
131     return runImpl(F, TTI, DT);
132   }
133 };
134 } // namespace
135 
136 char VectorCombineLegacyPass::ID = 0;
137 INITIALIZE_PASS_BEGIN(VectorCombineLegacyPass, "vector-combine",
138                       "Optimize scalar/vector ops", false,
139                       false)
140 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
141 INITIALIZE_PASS_END(VectorCombineLegacyPass, "vector-combine",
142                     "Optimize scalar/vector ops", false, false)
143 Pass *llvm::createVectorCombinePass() {
144   return new VectorCombineLegacyPass();
145 }
146 
147 PreservedAnalyses VectorCombinePass::run(Function &F,
148                                          FunctionAnalysisManager &FAM) {
149   TargetTransformInfo &TTI = FAM.getResult<TargetIRAnalysis>(F);
150   DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F);
151   if (!runImpl(F, TTI, DT))
152     return PreservedAnalyses::all();
153   PreservedAnalyses PA;
154   PA.preserveSet<CFGAnalyses>();
155   PA.preserve<GlobalsAA>();
156   return PA;
157 }
158