1*a17f03bdSSanjay Patel //===------- VectorCombine.cpp - Optimize partial vector operations -------===// 2*a17f03bdSSanjay Patel // 3*a17f03bdSSanjay Patel // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4*a17f03bdSSanjay Patel // See https://llvm.org/LICENSE.txt for license information. 5*a17f03bdSSanjay Patel // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6*a17f03bdSSanjay Patel // 7*a17f03bdSSanjay Patel //===----------------------------------------------------------------------===// 8*a17f03bdSSanjay Patel // 9*a17f03bdSSanjay Patel // This pass optimizes scalar/vector interactions using target cost models. The 10*a17f03bdSSanjay Patel // transforms implemented here may not fit in traditional loop-based or SLP 11*a17f03bdSSanjay Patel // vectorization passes. 12*a17f03bdSSanjay Patel // 13*a17f03bdSSanjay Patel //===----------------------------------------------------------------------===// 14*a17f03bdSSanjay Patel 15*a17f03bdSSanjay Patel #include "llvm/Transforms/Vectorize/VectorCombine.h" 16*a17f03bdSSanjay Patel #include "llvm/ADT/Statistic.h" 17*a17f03bdSSanjay Patel #include "llvm/Analysis/GlobalsModRef.h" 18*a17f03bdSSanjay Patel #include "llvm/Analysis/TargetTransformInfo.h" 19*a17f03bdSSanjay Patel #include "llvm/IR/Dominators.h" 20*a17f03bdSSanjay Patel #include "llvm/IR/Function.h" 21*a17f03bdSSanjay Patel #include "llvm/IR/IRBuilder.h" 22*a17f03bdSSanjay Patel #include "llvm/IR/PatternMatch.h" 23*a17f03bdSSanjay Patel #include "llvm/InitializePasses.h" 24*a17f03bdSSanjay Patel #include "llvm/Pass.h" 25*a17f03bdSSanjay Patel #include "llvm/Support/DebugCounter.h" 26*a17f03bdSSanjay Patel #include "llvm/Transforms/Vectorize.h" 27*a17f03bdSSanjay Patel #include "llvm/Transforms/Utils/Local.h" 28*a17f03bdSSanjay Patel 29*a17f03bdSSanjay Patel using namespace llvm; 30*a17f03bdSSanjay Patel using namespace llvm::PatternMatch; 31*a17f03bdSSanjay Patel 32*a17f03bdSSanjay Patel #define DEBUG_TYPE "vector-combine" 33*a17f03bdSSanjay Patel STATISTIC(NumVecCmp, "Number of vector compares formed"); 34*a17f03bdSSanjay Patel DEBUG_COUNTER(VecCombineCounter, "vector-combine-transform", 35*a17f03bdSSanjay Patel "Controls transformations in vector-combine pass"); 36*a17f03bdSSanjay Patel 37*a17f03bdSSanjay Patel static bool foldExtractCmp(Instruction &I, const TargetTransformInfo &TTI) { 38*a17f03bdSSanjay Patel // Match a cmp with extracted vector operands. 39*a17f03bdSSanjay Patel CmpInst::Predicate Pred; 40*a17f03bdSSanjay Patel Instruction *Ext0, *Ext1; 41*a17f03bdSSanjay Patel if (!match(&I, m_Cmp(Pred, m_Instruction(Ext0), m_Instruction(Ext1)))) 42*a17f03bdSSanjay Patel return false; 43*a17f03bdSSanjay Patel 44*a17f03bdSSanjay Patel Value *V0, *V1; 45*a17f03bdSSanjay Patel ConstantInt *C; 46*a17f03bdSSanjay Patel if (!match(Ext0, m_ExtractElement(m_Value(V0), m_ConstantInt(C))) || 47*a17f03bdSSanjay Patel !match(Ext1, m_ExtractElement(m_Value(V1), m_Specific(C))) || 48*a17f03bdSSanjay Patel V0->getType() != V1->getType()) 49*a17f03bdSSanjay Patel return false; 50*a17f03bdSSanjay Patel 51*a17f03bdSSanjay Patel Type *ScalarTy = Ext0->getType(); 52*a17f03bdSSanjay Patel Type *VecTy = V0->getType(); 53*a17f03bdSSanjay Patel bool IsFP = ScalarTy->isFloatingPointTy(); 54*a17f03bdSSanjay Patel unsigned CmpOpcode = IsFP ? Instruction::FCmp : Instruction::ICmp; 55*a17f03bdSSanjay Patel 56*a17f03bdSSanjay Patel // Check if the existing scalar code or the vector alternative is cheaper. 57*a17f03bdSSanjay Patel // Extra uses of the extracts mean that we include those costs in the 58*a17f03bdSSanjay Patel // vector total because those instructions will not be eliminated. 59*a17f03bdSSanjay Patel // ((2 * extract) + scalar cmp) < (vector cmp + extract) ? 60*a17f03bdSSanjay Patel int ExtractCost = TTI.getVectorInstrCost(Instruction::ExtractElement, 61*a17f03bdSSanjay Patel VecTy, C->getZExtValue()); 62*a17f03bdSSanjay Patel int ScalarCmpCost = TTI.getOperationCost(CmpOpcode, ScalarTy); 63*a17f03bdSSanjay Patel int VecCmpCost = TTI.getOperationCost(CmpOpcode, VecTy); 64*a17f03bdSSanjay Patel 65*a17f03bdSSanjay Patel int ScalarCost = 2 * ExtractCost + ScalarCmpCost; 66*a17f03bdSSanjay Patel int VecCost = VecCmpCost + ExtractCost + 67*a17f03bdSSanjay Patel !Ext0->hasOneUse() * ExtractCost + 68*a17f03bdSSanjay Patel !Ext1->hasOneUse() * ExtractCost; 69*a17f03bdSSanjay Patel if (ScalarCost < VecCost) 70*a17f03bdSSanjay Patel return false; 71*a17f03bdSSanjay Patel 72*a17f03bdSSanjay Patel // cmp Pred (extelt V0, C), (extelt V1, C) --> extelt (cmp Pred V0, V1), C 73*a17f03bdSSanjay Patel ++NumVecCmp; 74*a17f03bdSSanjay Patel IRBuilder<> Builder(&I); 75*a17f03bdSSanjay Patel Value *VecCmp = IsFP ? Builder.CreateFCmp(Pred, V0, V1) 76*a17f03bdSSanjay Patel : Builder.CreateICmp(Pred, V0, V1); 77*a17f03bdSSanjay Patel Value *Ext = Builder.CreateExtractElement(VecCmp, C); 78*a17f03bdSSanjay Patel I.replaceAllUsesWith(Ext); 79*a17f03bdSSanjay Patel return true; 80*a17f03bdSSanjay Patel } 81*a17f03bdSSanjay Patel 82*a17f03bdSSanjay Patel /// This is the entry point for all transforms. Pass manager differences are 83*a17f03bdSSanjay Patel /// handled in the callers of this function. 84*a17f03bdSSanjay Patel static bool runImpl(Function &F, const TargetTransformInfo &TTI, 85*a17f03bdSSanjay Patel const DominatorTree &DT) { 86*a17f03bdSSanjay Patel bool MadeChange = false; 87*a17f03bdSSanjay Patel for (BasicBlock &BB : F) { 88*a17f03bdSSanjay Patel // Ignore unreachable basic blocks. 89*a17f03bdSSanjay Patel if (!DT.isReachableFromEntry(&BB)) 90*a17f03bdSSanjay Patel continue; 91*a17f03bdSSanjay Patel // Do not delete instructions under here and invalidate the iterator. 92*a17f03bdSSanjay Patel // Walk the block backwards for efficiency. We're matching a chain of 93*a17f03bdSSanjay Patel // use->defs, so we're more likely to succeed by starting from the bottom. 94*a17f03bdSSanjay Patel // TODO: It could be more efficient to remove dead instructions 95*a17f03bdSSanjay Patel // iteratively in this loop rather than waiting until the end. 96*a17f03bdSSanjay Patel for (Instruction &I : make_range(BB.rbegin(), BB.rend())) { 97*a17f03bdSSanjay Patel MadeChange |= foldExtractCmp(I, TTI); 98*a17f03bdSSanjay Patel // TODO: More transforms go here. 99*a17f03bdSSanjay Patel } 100*a17f03bdSSanjay Patel } 101*a17f03bdSSanjay Patel 102*a17f03bdSSanjay Patel // We're done with transforms, so remove dead instructions. 103*a17f03bdSSanjay Patel if (MadeChange) 104*a17f03bdSSanjay Patel for (BasicBlock &BB : F) 105*a17f03bdSSanjay Patel SimplifyInstructionsInBlock(&BB); 106*a17f03bdSSanjay Patel 107*a17f03bdSSanjay Patel return MadeChange; 108*a17f03bdSSanjay Patel } 109*a17f03bdSSanjay Patel 110*a17f03bdSSanjay Patel // Pass manager boilerplate below here. 111*a17f03bdSSanjay Patel 112*a17f03bdSSanjay Patel namespace { 113*a17f03bdSSanjay Patel class VectorCombineLegacyPass : public FunctionPass { 114*a17f03bdSSanjay Patel public: 115*a17f03bdSSanjay Patel static char ID; 116*a17f03bdSSanjay Patel VectorCombineLegacyPass() : FunctionPass(ID) { 117*a17f03bdSSanjay Patel initializeVectorCombineLegacyPassPass(*PassRegistry::getPassRegistry()); 118*a17f03bdSSanjay Patel } 119*a17f03bdSSanjay Patel 120*a17f03bdSSanjay Patel void getAnalysisUsage(AnalysisUsage &AU) const override { 121*a17f03bdSSanjay Patel AU.addRequired<DominatorTreeWrapperPass>(); 122*a17f03bdSSanjay Patel AU.addRequired<TargetTransformInfoWrapperPass>(); 123*a17f03bdSSanjay Patel AU.setPreservesCFG(); 124*a17f03bdSSanjay Patel AU.addPreserved<DominatorTreeWrapperPass>(); 125*a17f03bdSSanjay Patel AU.addPreserved<GlobalsAAWrapperPass>(); 126*a17f03bdSSanjay Patel FunctionPass::getAnalysisUsage(AU); 127*a17f03bdSSanjay Patel } 128*a17f03bdSSanjay Patel 129*a17f03bdSSanjay Patel bool runOnFunction(Function &F) override { 130*a17f03bdSSanjay Patel if (skipFunction(F)) 131*a17f03bdSSanjay Patel return false; 132*a17f03bdSSanjay Patel auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 133*a17f03bdSSanjay Patel auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); 134*a17f03bdSSanjay Patel return runImpl(F, TTI, DT); 135*a17f03bdSSanjay Patel } 136*a17f03bdSSanjay Patel }; 137*a17f03bdSSanjay Patel } // namespace 138*a17f03bdSSanjay Patel 139*a17f03bdSSanjay Patel char VectorCombineLegacyPass::ID = 0; 140*a17f03bdSSanjay Patel INITIALIZE_PASS_BEGIN(VectorCombineLegacyPass, "vector-combine", 141*a17f03bdSSanjay Patel "Optimize scalar/vector ops", false, 142*a17f03bdSSanjay Patel false) 143*a17f03bdSSanjay Patel INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) 144*a17f03bdSSanjay Patel INITIALIZE_PASS_END(VectorCombineLegacyPass, "vector-combine", 145*a17f03bdSSanjay Patel "Optimize scalar/vector ops", false, false) 146*a17f03bdSSanjay Patel Pass *llvm::createVectorCombinePass() { 147*a17f03bdSSanjay Patel return new VectorCombineLegacyPass(); 148*a17f03bdSSanjay Patel } 149*a17f03bdSSanjay Patel 150*a17f03bdSSanjay Patel PreservedAnalyses VectorCombinePass::run(Function &F, 151*a17f03bdSSanjay Patel FunctionAnalysisManager &FAM) { 152*a17f03bdSSanjay Patel TargetTransformInfo &TTI = FAM.getResult<TargetIRAnalysis>(F); 153*a17f03bdSSanjay Patel DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F); 154*a17f03bdSSanjay Patel if (!runImpl(F, TTI, DT)) 155*a17f03bdSSanjay Patel return PreservedAnalyses::all(); 156*a17f03bdSSanjay Patel PreservedAnalyses PA; 157*a17f03bdSSanjay Patel PA.preserveSet<CFGAnalyses>(); 158*a17f03bdSSanjay Patel PA.preserve<GlobalsAA>(); 159*a17f03bdSSanjay Patel return PA; 160*a17f03bdSSanjay Patel } 161