1 //===------- VectorCombine.cpp - Optimize partial vector operations -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This pass optimizes scalar/vector interactions using target cost models. The 10 // transforms implemented here may not fit in traditional loop-based or SLP 11 // vectorization passes. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "llvm/Transforms/Vectorize/VectorCombine.h" 16 #include "llvm/ADT/Statistic.h" 17 #include "llvm/Analysis/GlobalsModRef.h" 18 #include "llvm/Analysis/TargetTransformInfo.h" 19 #include "llvm/IR/Dominators.h" 20 #include "llvm/IR/Function.h" 21 #include "llvm/IR/IRBuilder.h" 22 #include "llvm/IR/PatternMatch.h" 23 #include "llvm/InitializePasses.h" 24 #include "llvm/Pass.h" 25 #include "llvm/Support/DebugCounter.h" 26 #include "llvm/Transforms/Vectorize.h" 27 #include "llvm/Transforms/Utils/Local.h" 28 29 using namespace llvm; 30 using namespace llvm::PatternMatch; 31 32 #define DEBUG_TYPE "vector-combine" 33 STATISTIC(NumVecCmp, "Number of vector compares formed"); 34 DEBUG_COUNTER(VecCombineCounter, "vector-combine-transform", 35 "Controls transformations in vector-combine pass"); 36 37 static bool foldExtractCmp(Instruction &I, const TargetTransformInfo &TTI) { 38 // Match a cmp with extracted vector operands. 39 CmpInst::Predicate Pred; 40 Instruction *Ext0, *Ext1; 41 if (!match(&I, m_Cmp(Pred, m_Instruction(Ext0), m_Instruction(Ext1)))) 42 return false; 43 44 Value *V0, *V1; 45 ConstantInt *C; 46 if (!match(Ext0, m_ExtractElement(m_Value(V0), m_ConstantInt(C))) || 47 !match(Ext1, m_ExtractElement(m_Value(V1), m_Specific(C))) || 48 V0->getType() != V1->getType()) 49 return false; 50 51 Type *ScalarTy = Ext0->getType(); 52 Type *VecTy = V0->getType(); 53 bool IsFP = ScalarTy->isFloatingPointTy(); 54 unsigned CmpOpcode = IsFP ? Instruction::FCmp : Instruction::ICmp; 55 56 // Check if the existing scalar code or the vector alternative is cheaper. 57 // Extra uses of the extracts mean that we include those costs in the 58 // vector total because those instructions will not be eliminated. 59 // ((2 * extract) + scalar cmp) < (vector cmp + extract) ? 60 int ExtractCost = TTI.getVectorInstrCost(Instruction::ExtractElement, 61 VecTy, C->getZExtValue()); 62 int ScalarCmpCost = TTI.getOperationCost(CmpOpcode, ScalarTy); 63 int VecCmpCost = TTI.getOperationCost(CmpOpcode, VecTy); 64 65 int ScalarCost = 2 * ExtractCost + ScalarCmpCost; 66 int VecCost = VecCmpCost + ExtractCost + 67 !Ext0->hasOneUse() * ExtractCost + 68 !Ext1->hasOneUse() * ExtractCost; 69 if (ScalarCost < VecCost) 70 return false; 71 72 // cmp Pred (extelt V0, C), (extelt V1, C) --> extelt (cmp Pred V0, V1), C 73 ++NumVecCmp; 74 IRBuilder<> Builder(&I); 75 Value *VecCmp = IsFP ? Builder.CreateFCmp(Pred, V0, V1) 76 : Builder.CreateICmp(Pred, V0, V1); 77 Value *Ext = Builder.CreateExtractElement(VecCmp, C); 78 I.replaceAllUsesWith(Ext); 79 return true; 80 } 81 82 /// This is the entry point for all transforms. Pass manager differences are 83 /// handled in the callers of this function. 84 static bool runImpl(Function &F, const TargetTransformInfo &TTI, 85 const DominatorTree &DT) { 86 bool MadeChange = false; 87 for (BasicBlock &BB : F) { 88 // Ignore unreachable basic blocks. 89 if (!DT.isReachableFromEntry(&BB)) 90 continue; 91 // Do not delete instructions under here and invalidate the iterator. 92 // Walk the block backwards for efficiency. We're matching a chain of 93 // use->defs, so we're more likely to succeed by starting from the bottom. 94 // TODO: It could be more efficient to remove dead instructions 95 // iteratively in this loop rather than waiting until the end. 96 for (Instruction &I : make_range(BB.rbegin(), BB.rend())) { 97 MadeChange |= foldExtractCmp(I, TTI); 98 // TODO: More transforms go here. 99 } 100 } 101 102 // We're done with transforms, so remove dead instructions. 103 if (MadeChange) 104 for (BasicBlock &BB : F) 105 SimplifyInstructionsInBlock(&BB); 106 107 return MadeChange; 108 } 109 110 // Pass manager boilerplate below here. 111 112 namespace { 113 class VectorCombineLegacyPass : public FunctionPass { 114 public: 115 static char ID; 116 VectorCombineLegacyPass() : FunctionPass(ID) { 117 initializeVectorCombineLegacyPassPass(*PassRegistry::getPassRegistry()); 118 } 119 120 void getAnalysisUsage(AnalysisUsage &AU) const override { 121 AU.addRequired<DominatorTreeWrapperPass>(); 122 AU.addRequired<TargetTransformInfoWrapperPass>(); 123 AU.setPreservesCFG(); 124 AU.addPreserved<DominatorTreeWrapperPass>(); 125 AU.addPreserved<GlobalsAAWrapperPass>(); 126 FunctionPass::getAnalysisUsage(AU); 127 } 128 129 bool runOnFunction(Function &F) override { 130 if (skipFunction(F)) 131 return false; 132 auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 133 auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); 134 return runImpl(F, TTI, DT); 135 } 136 }; 137 } // namespace 138 139 char VectorCombineLegacyPass::ID = 0; 140 INITIALIZE_PASS_BEGIN(VectorCombineLegacyPass, "vector-combine", 141 "Optimize scalar/vector ops", false, 142 false) 143 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) 144 INITIALIZE_PASS_END(VectorCombineLegacyPass, "vector-combine", 145 "Optimize scalar/vector ops", false, false) 146 Pass *llvm::createVectorCombinePass() { 147 return new VectorCombineLegacyPass(); 148 } 149 150 PreservedAnalyses VectorCombinePass::run(Function &F, 151 FunctionAnalysisManager &FAM) { 152 TargetTransformInfo &TTI = FAM.getResult<TargetIRAnalysis>(F); 153 DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F); 154 if (!runImpl(F, TTI, DT)) 155 return PreservedAnalyses::all(); 156 PreservedAnalyses PA; 157 PA.preserveSet<CFGAnalyses>(); 158 PA.preserve<GlobalsAA>(); 159 return PA; 160 } 161