1 //===------- VectorCombine.cpp - Optimize partial vector operations -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This pass optimizes scalar/vector interactions using target cost models. The 10 // transforms implemented here may not fit in traditional loop-based or SLP 11 // vectorization passes. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "llvm/Transforms/Vectorize/VectorCombine.h" 16 #include "llvm/ADT/Statistic.h" 17 #include "llvm/Analysis/GlobalsModRef.h" 18 #include "llvm/Analysis/TargetTransformInfo.h" 19 #include "llvm/IR/Dominators.h" 20 #include "llvm/IR/Function.h" 21 #include "llvm/IR/IRBuilder.h" 22 #include "llvm/IR/PatternMatch.h" 23 #include "llvm/InitializePasses.h" 24 #include "llvm/Pass.h" 25 #include "llvm/Transforms/Vectorize.h" 26 #include "llvm/Transforms/Utils/Local.h" 27 28 using namespace llvm; 29 using namespace llvm::PatternMatch; 30 31 #define DEBUG_TYPE "vector-combine" 32 STATISTIC(NumVecCmp, "Number of vector compares formed"); 33 34 static bool foldExtractCmp(Instruction &I, const TargetTransformInfo &TTI) { 35 // Match a cmp with extracted vector operands. 36 CmpInst::Predicate Pred; 37 Instruction *Ext0, *Ext1; 38 if (!match(&I, m_Cmp(Pred, m_Instruction(Ext0), m_Instruction(Ext1)))) 39 return false; 40 41 Value *V0, *V1; 42 ConstantInt *C; 43 if (!match(Ext0, m_ExtractElement(m_Value(V0), m_ConstantInt(C))) || 44 !match(Ext1, m_ExtractElement(m_Value(V1), m_Specific(C))) || 45 V0->getType() != V1->getType()) 46 return false; 47 48 Type *ScalarTy = Ext0->getType(); 49 Type *VecTy = V0->getType(); 50 bool IsFP = ScalarTy->isFloatingPointTy(); 51 unsigned CmpOpcode = IsFP ? Instruction::FCmp : Instruction::ICmp; 52 53 // Check if the existing scalar code or the vector alternative is cheaper. 54 // Extra uses of the extracts mean that we include those costs in the 55 // vector total because those instructions will not be eliminated. 56 // ((2 * extract) + scalar cmp) < (vector cmp + extract) ? 57 int ExtractCost = TTI.getVectorInstrCost(Instruction::ExtractElement, 58 VecTy, C->getZExtValue()); 59 int ScalarCmpCost = TTI.getOperationCost(CmpOpcode, ScalarTy); 60 int VecCmpCost = TTI.getOperationCost(CmpOpcode, VecTy); 61 62 int ScalarCost = 2 * ExtractCost + ScalarCmpCost; 63 int VecCost = VecCmpCost + ExtractCost + 64 !Ext0->hasOneUse() * ExtractCost + 65 !Ext1->hasOneUse() * ExtractCost; 66 if (ScalarCost < VecCost) 67 return false; 68 69 // cmp Pred (extelt V0, C), (extelt V1, C) --> extelt (cmp Pred V0, V1), C 70 ++NumVecCmp; 71 IRBuilder<> Builder(&I); 72 Value *VecCmp = IsFP ? Builder.CreateFCmp(Pred, V0, V1) 73 : Builder.CreateICmp(Pred, V0, V1); 74 Value *Ext = Builder.CreateExtractElement(VecCmp, C); 75 I.replaceAllUsesWith(Ext); 76 return true; 77 } 78 79 /// This is the entry point for all transforms. Pass manager differences are 80 /// handled in the callers of this function. 81 static bool runImpl(Function &F, const TargetTransformInfo &TTI, 82 const DominatorTree &DT) { 83 bool MadeChange = false; 84 for (BasicBlock &BB : F) { 85 // Ignore unreachable basic blocks. 86 if (!DT.isReachableFromEntry(&BB)) 87 continue; 88 // Do not delete instructions under here and invalidate the iterator. 89 // Walk the block backwards for efficiency. We're matching a chain of 90 // use->defs, so we're more likely to succeed by starting from the bottom. 91 // TODO: It could be more efficient to remove dead instructions 92 // iteratively in this loop rather than waiting until the end. 93 for (Instruction &I : make_range(BB.rbegin(), BB.rend())) { 94 MadeChange |= foldExtractCmp(I, TTI); 95 // TODO: More transforms go here. 96 } 97 } 98 99 // We're done with transforms, so remove dead instructions. 100 if (MadeChange) 101 for (BasicBlock &BB : F) 102 SimplifyInstructionsInBlock(&BB); 103 104 return MadeChange; 105 } 106 107 // Pass manager boilerplate below here. 108 109 namespace { 110 class VectorCombineLegacyPass : public FunctionPass { 111 public: 112 static char ID; 113 VectorCombineLegacyPass() : FunctionPass(ID) { 114 initializeVectorCombineLegacyPassPass(*PassRegistry::getPassRegistry()); 115 } 116 117 void getAnalysisUsage(AnalysisUsage &AU) const override { 118 AU.addRequired<DominatorTreeWrapperPass>(); 119 AU.addRequired<TargetTransformInfoWrapperPass>(); 120 AU.setPreservesCFG(); 121 AU.addPreserved<DominatorTreeWrapperPass>(); 122 AU.addPreserved<GlobalsAAWrapperPass>(); 123 FunctionPass::getAnalysisUsage(AU); 124 } 125 126 bool runOnFunction(Function &F) override { 127 if (skipFunction(F)) 128 return false; 129 auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 130 auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); 131 return runImpl(F, TTI, DT); 132 } 133 }; 134 } // namespace 135 136 char VectorCombineLegacyPass::ID = 0; 137 INITIALIZE_PASS_BEGIN(VectorCombineLegacyPass, "vector-combine", 138 "Optimize scalar/vector ops", false, 139 false) 140 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) 141 INITIALIZE_PASS_END(VectorCombineLegacyPass, "vector-combine", 142 "Optimize scalar/vector ops", false, false) 143 Pass *llvm::createVectorCombinePass() { 144 return new VectorCombineLegacyPass(); 145 } 146 147 PreservedAnalyses VectorCombinePass::run(Function &F, 148 FunctionAnalysisManager &FAM) { 149 TargetTransformInfo &TTI = FAM.getResult<TargetIRAnalysis>(F); 150 DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F); 151 if (!runImpl(F, TTI, DT)) 152 return PreservedAnalyses::all(); 153 PreservedAnalyses PA; 154 PA.preserveSet<CFGAnalyses>(); 155 PA.preserve<GlobalsAA>(); 156 return PA; 157 } 158