1a17f03bdSSanjay Patel //===------- VectorCombine.cpp - Optimize partial vector operations -------===// 2a17f03bdSSanjay Patel // 3a17f03bdSSanjay Patel // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4a17f03bdSSanjay Patel // See https://llvm.org/LICENSE.txt for license information. 5a17f03bdSSanjay Patel // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6a17f03bdSSanjay Patel // 7a17f03bdSSanjay Patel //===----------------------------------------------------------------------===// 8a17f03bdSSanjay Patel // 9a17f03bdSSanjay Patel // This pass optimizes scalar/vector interactions using target cost models. The 10a17f03bdSSanjay Patel // transforms implemented here may not fit in traditional loop-based or SLP 11a17f03bdSSanjay Patel // vectorization passes. 12a17f03bdSSanjay Patel // 13a17f03bdSSanjay Patel //===----------------------------------------------------------------------===// 14a17f03bdSSanjay Patel 15a17f03bdSSanjay Patel #include "llvm/Transforms/Vectorize/VectorCombine.h" 16a17f03bdSSanjay Patel #include "llvm/ADT/Statistic.h" 17a17f03bdSSanjay Patel #include "llvm/Analysis/GlobalsModRef.h" 18a17f03bdSSanjay Patel #include "llvm/Analysis/TargetTransformInfo.h" 1919b62b79SSanjay Patel #include "llvm/Analysis/ValueTracking.h" 20a17f03bdSSanjay Patel #include "llvm/IR/Dominators.h" 21a17f03bdSSanjay Patel #include "llvm/IR/Function.h" 22a17f03bdSSanjay Patel #include "llvm/IR/IRBuilder.h" 23a17f03bdSSanjay Patel #include "llvm/IR/PatternMatch.h" 24a17f03bdSSanjay Patel #include "llvm/InitializePasses.h" 25a17f03bdSSanjay Patel #include "llvm/Pass.h" 26a17f03bdSSanjay Patel #include "llvm/Transforms/Vectorize.h" 27a17f03bdSSanjay Patel #include "llvm/Transforms/Utils/Local.h" 28a17f03bdSSanjay Patel 29a17f03bdSSanjay Patel using namespace llvm; 30a17f03bdSSanjay Patel using namespace llvm::PatternMatch; 31a17f03bdSSanjay Patel 32a17f03bdSSanjay Patel #define DEBUG_TYPE "vector-combine" 33a17f03bdSSanjay Patel STATISTIC(NumVecCmp, "Number of vector compares formed"); 3419b62b79SSanjay Patel STATISTIC(NumVecBO, "Number of vector binops formed"); 35a17f03bdSSanjay Patel 36*34e34855SSanjay Patel /// Compare the relative costs of extracts followed by scalar operation vs. 37*34e34855SSanjay Patel /// vector operation followed by extract: 38*34e34855SSanjay Patel /// opcode (extelt V0, C), (extelt V1, C) --> extelt (opcode V0, V1), C 39*34e34855SSanjay Patel /// Unless the vector op is much more expensive than the scalar op, this 40*34e34855SSanjay Patel /// eliminates an extract. 41*34e34855SSanjay Patel static bool isExtractExtractCheap(Instruction *Ext0, Instruction *Ext1, 42*34e34855SSanjay Patel unsigned Opcode, 43*34e34855SSanjay Patel const TargetTransformInfo &TTI) { 44*34e34855SSanjay Patel assert(Ext0->getOperand(1) == Ext1->getOperand(1) && 45*34e34855SSanjay Patel isa<ConstantInt>(Ext0->getOperand(1)) && 46*34e34855SSanjay Patel "Expected same constant extract index"); 47*34e34855SSanjay Patel 48*34e34855SSanjay Patel Type *ScalarTy = Ext0->getType(); 49*34e34855SSanjay Patel Type *VecTy = Ext0->getOperand(0)->getType(); 50*34e34855SSanjay Patel int ScalarOpCost, VectorOpCost; 51*34e34855SSanjay Patel 52*34e34855SSanjay Patel // Get cost estimates for scalar and vector versions of the operation. 53*34e34855SSanjay Patel bool IsBinOp = Instruction::isBinaryOp(Opcode); 54*34e34855SSanjay Patel if (IsBinOp) { 55*34e34855SSanjay Patel ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, ScalarTy); 56*34e34855SSanjay Patel VectorOpCost = TTI.getArithmeticInstrCost(Opcode, VecTy); 57*34e34855SSanjay Patel } else { 58*34e34855SSanjay Patel assert((Opcode == Instruction::ICmp || Opcode == Instruction::FCmp) && 59*34e34855SSanjay Patel "Expected a compare"); 60*34e34855SSanjay Patel ScalarOpCost = TTI.getCmpSelInstrCost(Opcode, ScalarTy, 61*34e34855SSanjay Patel CmpInst::makeCmpResultType(ScalarTy)); 62*34e34855SSanjay Patel VectorOpCost = TTI.getCmpSelInstrCost(Opcode, VecTy, 63*34e34855SSanjay Patel CmpInst::makeCmpResultType(VecTy)); 64*34e34855SSanjay Patel } 65*34e34855SSanjay Patel 66*34e34855SSanjay Patel // Get cost estimate for the extract element. This cost will factor into 67*34e34855SSanjay Patel // both sequences. 68*34e34855SSanjay Patel unsigned ExtIndex = cast<ConstantInt>(Ext0->getOperand(1))->getZExtValue(); 69*34e34855SSanjay Patel int ExtractCost = TTI.getVectorInstrCost(Instruction::ExtractElement, 70*34e34855SSanjay Patel VecTy, ExtIndex); 71*34e34855SSanjay Patel 72*34e34855SSanjay Patel // Extra uses of the extracts mean that we include those costs in the 73*34e34855SSanjay Patel // vector total because those instructions will not be eliminated. 74*34e34855SSanjay Patel int ScalarCost, VectorCost; 75*34e34855SSanjay Patel if (Ext0->getOperand(0) == Ext1->getOperand(0)) { 76*34e34855SSanjay Patel // Handle a special case. If the 2 operands are identical, adjust the 77*34e34855SSanjay Patel // formulas to account for that. The extra use charge allows for either the 78*34e34855SSanjay Patel // CSE'd pattern or an unoptimized form with identical values: 79*34e34855SSanjay Patel // opcode (extelt V, C), (extelt V, C) --> extelt (opcode V, V), C 80*34e34855SSanjay Patel bool HasUseTax = Ext0 == Ext1 ? !Ext0->hasNUses(2) 81*34e34855SSanjay Patel : !Ext0->hasOneUse() || !Ext1->hasOneUse(); 82*34e34855SSanjay Patel ScalarCost = ExtractCost + ScalarOpCost; 83*34e34855SSanjay Patel VectorCost = VectorOpCost + ExtractCost + HasUseTax * ExtractCost; 84*34e34855SSanjay Patel } else { 85*34e34855SSanjay Patel // Handle the general case. Each extract is actually a different value: 86*34e34855SSanjay Patel // opcode (extelt V0, C), (extelt V1, C) --> extelt (opcode V0, V1), C 87*34e34855SSanjay Patel ScalarCost = 2 * ExtractCost + ScalarOpCost; 88*34e34855SSanjay Patel VectorCost = VectorOpCost + ExtractCost + 89*34e34855SSanjay Patel !Ext0->hasOneUse() * ExtractCost + 90*34e34855SSanjay Patel !Ext1->hasOneUse() * ExtractCost; 91*34e34855SSanjay Patel } 92*34e34855SSanjay Patel // TODO: The cost comparison should not differ based on opcode. Either we 93*34e34855SSanjay Patel // want to be uniformly more or less aggressive in deciding if a vector 94*34e34855SSanjay Patel // operation should replace the scalar operation. 95*34e34855SSanjay Patel return IsBinOp ? ScalarCost <= VectorCost : ScalarCost < VectorCost; 96*34e34855SSanjay Patel } 97*34e34855SSanjay Patel 98fc445589SSanjay Patel /// Try to reduce extract element costs by converting scalar compares to vector 99fc445589SSanjay Patel /// compares followed by extract. 100fc445589SSanjay Patel /// cmp (ext0 V0, C0), (ext1 V1, C1) 101fc445589SSanjay Patel static bool foldExtExtCmp(Instruction *Ext0, Value *V0, uint64_t C0, 102fc445589SSanjay Patel Instruction *Ext1, Value *V1, uint64_t C1, 103fc445589SSanjay Patel Instruction &I, const TargetTransformInfo &TTI) { 104fc445589SSanjay Patel assert(isa<CmpInst>(&I) && "Expected a compare"); 105a17f03bdSSanjay Patel 106fc445589SSanjay Patel // TODO: Handle C0 != C1 by shuffling 1 of the operands. 107fc445589SSanjay Patel if (C0 != C1) 108fc445589SSanjay Patel return false; 109fc445589SSanjay Patel 110*34e34855SSanjay Patel if (isExtractExtractCheap(Ext0, Ext1, I.getOpcode(), TTI)) 111a17f03bdSSanjay Patel return false; 112a17f03bdSSanjay Patel 113a17f03bdSSanjay Patel // cmp Pred (extelt V0, C), (extelt V1, C) --> extelt (cmp Pred V0, V1), C 114a17f03bdSSanjay Patel ++NumVecCmp; 115a17f03bdSSanjay Patel IRBuilder<> Builder(&I); 116fc445589SSanjay Patel CmpInst::Predicate Pred = cast<CmpInst>(&I)->getPredicate(); 117*34e34855SSanjay Patel Value *VecCmp = 118*34e34855SSanjay Patel Ext0->getType()->isFloatingPointTy() ? Builder.CreateFCmp(Pred, V0, V1) 119a17f03bdSSanjay Patel : Builder.CreateICmp(Pred, V0, V1); 120fc445589SSanjay Patel Value *Extract = Builder.CreateExtractElement(VecCmp, Ext0->getOperand(1)); 121fc445589SSanjay Patel I.replaceAllUsesWith(Extract); 122a17f03bdSSanjay Patel return true; 123a17f03bdSSanjay Patel } 124a17f03bdSSanjay Patel 12519b62b79SSanjay Patel /// Try to reduce extract element costs by converting scalar binops to vector 12619b62b79SSanjay Patel /// binops followed by extract. 127fc445589SSanjay Patel /// bo (ext0 V0, C0), (ext1 V1, C1) 128fc445589SSanjay Patel static bool foldExtExtBinop(Instruction *Ext0, Value *V0, uint64_t C0, 129fc445589SSanjay Patel Instruction *Ext1, Value *V1, uint64_t C1, 130fc445589SSanjay Patel Instruction &I, const TargetTransformInfo &TTI) { 131fc445589SSanjay Patel assert(isa<BinaryOperator>(&I) && "Expected a binary operator"); 13219b62b79SSanjay Patel 133*34e34855SSanjay Patel // TODO: Handle C0 != C1 by shuffling 1 of the operands. 134*34e34855SSanjay Patel if (C0 != C1) 13519b62b79SSanjay Patel return false; 13619b62b79SSanjay Patel 137*34e34855SSanjay Patel if (isExtractExtractCheap(Ext0, Ext1, I.getOpcode(), TTI)) 138*34e34855SSanjay Patel return false; 139*34e34855SSanjay Patel 140*34e34855SSanjay Patel // bo (extelt V0, C), (extelt V1, C) --> extelt (bo V0, V1), C 14119b62b79SSanjay Patel ++NumVecBO; 14219b62b79SSanjay Patel IRBuilder<> Builder(&I); 143*34e34855SSanjay Patel Value *NewBO = 144*34e34855SSanjay Patel Builder.CreateBinOp(cast<BinaryOperator>(&I)->getOpcode(), V0, V1); 14519b62b79SSanjay Patel if (auto *VecBOInst = dyn_cast<Instruction>(NewBO)) { 14619b62b79SSanjay Patel // All IR flags are safe to back-propagate because any potential poison 14719b62b79SSanjay Patel // created in unused vector elements is discarded by the extract. 14819b62b79SSanjay Patel VecBOInst->copyIRFlags(&I); 14919b62b79SSanjay Patel } 15019b62b79SSanjay Patel Value *Extract = Builder.CreateExtractElement(NewBO, Ext0->getOperand(1)); 15119b62b79SSanjay Patel I.replaceAllUsesWith(Extract); 15219b62b79SSanjay Patel return true; 15319b62b79SSanjay Patel } 15419b62b79SSanjay Patel 155fc445589SSanjay Patel /// Match an instruction with extracted vector operands. 156fc445589SSanjay Patel static bool foldExtractExtract(Instruction &I, const TargetTransformInfo &TTI) { 157fc445589SSanjay Patel Instruction *Ext0, *Ext1; 158fc445589SSanjay Patel CmpInst::Predicate Pred = CmpInst::BAD_ICMP_PREDICATE; 159fc445589SSanjay Patel if (!match(&I, m_Cmp(Pred, m_Instruction(Ext0), m_Instruction(Ext1))) && 160fc445589SSanjay Patel !match(&I, m_BinOp(m_Instruction(Ext0), m_Instruction(Ext1)))) 161fc445589SSanjay Patel return false; 162fc445589SSanjay Patel 163fc445589SSanjay Patel Value *V0, *V1; 164fc445589SSanjay Patel uint64_t C0, C1; 165fc445589SSanjay Patel if (!match(Ext0, m_ExtractElement(m_Value(V0), m_ConstantInt(C0))) || 166fc445589SSanjay Patel !match(Ext1, m_ExtractElement(m_Value(V1), m_ConstantInt(C1))) || 167fc445589SSanjay Patel V0->getType() != V1->getType()) 168fc445589SSanjay Patel return false; 169fc445589SSanjay Patel 170fc445589SSanjay Patel if (Pred != CmpInst::BAD_ICMP_PREDICATE) 171fc445589SSanjay Patel return foldExtExtCmp(Ext0, V0, C0, Ext1, V1, C1, I, TTI); 172fc445589SSanjay Patel 173fc445589SSanjay Patel // It is not safe to transform things like div, urem, etc. because we may 174fc445589SSanjay Patel // create undefined behavior when executing those on unknown vector elements. 175fc445589SSanjay Patel if (isSafeToSpeculativelyExecute(&I)) 176fc445589SSanjay Patel return foldExtExtBinop(Ext0, V0, C0, Ext1, V1, C1, I, TTI); 177fc445589SSanjay Patel 178fc445589SSanjay Patel return false; 179fc445589SSanjay Patel } 180fc445589SSanjay Patel 181a17f03bdSSanjay Patel /// This is the entry point for all transforms. Pass manager differences are 182a17f03bdSSanjay Patel /// handled in the callers of this function. 183a17f03bdSSanjay Patel static bool runImpl(Function &F, const TargetTransformInfo &TTI, 184a17f03bdSSanjay Patel const DominatorTree &DT) { 185a17f03bdSSanjay Patel bool MadeChange = false; 186a17f03bdSSanjay Patel for (BasicBlock &BB : F) { 187a17f03bdSSanjay Patel // Ignore unreachable basic blocks. 188a17f03bdSSanjay Patel if (!DT.isReachableFromEntry(&BB)) 189a17f03bdSSanjay Patel continue; 190a17f03bdSSanjay Patel // Do not delete instructions under here and invalidate the iterator. 191a17f03bdSSanjay Patel // Walk the block backwards for efficiency. We're matching a chain of 192a17f03bdSSanjay Patel // use->defs, so we're more likely to succeed by starting from the bottom. 193a17f03bdSSanjay Patel // TODO: It could be more efficient to remove dead instructions 194a17f03bdSSanjay Patel // iteratively in this loop rather than waiting until the end. 195fc445589SSanjay Patel for (Instruction &I : make_range(BB.rbegin(), BB.rend())) 196fc445589SSanjay Patel MadeChange |= foldExtractExtract(I, TTI); 197a17f03bdSSanjay Patel } 198a17f03bdSSanjay Patel 199a17f03bdSSanjay Patel // We're done with transforms, so remove dead instructions. 200a17f03bdSSanjay Patel if (MadeChange) 201a17f03bdSSanjay Patel for (BasicBlock &BB : F) 202a17f03bdSSanjay Patel SimplifyInstructionsInBlock(&BB); 203a17f03bdSSanjay Patel 204a17f03bdSSanjay Patel return MadeChange; 205a17f03bdSSanjay Patel } 206a17f03bdSSanjay Patel 207a17f03bdSSanjay Patel // Pass manager boilerplate below here. 208a17f03bdSSanjay Patel 209a17f03bdSSanjay Patel namespace { 210a17f03bdSSanjay Patel class VectorCombineLegacyPass : public FunctionPass { 211a17f03bdSSanjay Patel public: 212a17f03bdSSanjay Patel static char ID; 213a17f03bdSSanjay Patel VectorCombineLegacyPass() : FunctionPass(ID) { 214a17f03bdSSanjay Patel initializeVectorCombineLegacyPassPass(*PassRegistry::getPassRegistry()); 215a17f03bdSSanjay Patel } 216a17f03bdSSanjay Patel 217a17f03bdSSanjay Patel void getAnalysisUsage(AnalysisUsage &AU) const override { 218a17f03bdSSanjay Patel AU.addRequired<DominatorTreeWrapperPass>(); 219a17f03bdSSanjay Patel AU.addRequired<TargetTransformInfoWrapperPass>(); 220a17f03bdSSanjay Patel AU.setPreservesCFG(); 221a17f03bdSSanjay Patel AU.addPreserved<DominatorTreeWrapperPass>(); 222a17f03bdSSanjay Patel AU.addPreserved<GlobalsAAWrapperPass>(); 223a17f03bdSSanjay Patel FunctionPass::getAnalysisUsage(AU); 224a17f03bdSSanjay Patel } 225a17f03bdSSanjay Patel 226a17f03bdSSanjay Patel bool runOnFunction(Function &F) override { 227a17f03bdSSanjay Patel if (skipFunction(F)) 228a17f03bdSSanjay Patel return false; 229a17f03bdSSanjay Patel auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 230a17f03bdSSanjay Patel auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); 231a17f03bdSSanjay Patel return runImpl(F, TTI, DT); 232a17f03bdSSanjay Patel } 233a17f03bdSSanjay Patel }; 234a17f03bdSSanjay Patel } // namespace 235a17f03bdSSanjay Patel 236a17f03bdSSanjay Patel char VectorCombineLegacyPass::ID = 0; 237a17f03bdSSanjay Patel INITIALIZE_PASS_BEGIN(VectorCombineLegacyPass, "vector-combine", 238a17f03bdSSanjay Patel "Optimize scalar/vector ops", false, 239a17f03bdSSanjay Patel false) 240a17f03bdSSanjay Patel INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) 241a17f03bdSSanjay Patel INITIALIZE_PASS_END(VectorCombineLegacyPass, "vector-combine", 242a17f03bdSSanjay Patel "Optimize scalar/vector ops", false, false) 243a17f03bdSSanjay Patel Pass *llvm::createVectorCombinePass() { 244a17f03bdSSanjay Patel return new VectorCombineLegacyPass(); 245a17f03bdSSanjay Patel } 246a17f03bdSSanjay Patel 247a17f03bdSSanjay Patel PreservedAnalyses VectorCombinePass::run(Function &F, 248a17f03bdSSanjay Patel FunctionAnalysisManager &FAM) { 249a17f03bdSSanjay Patel TargetTransformInfo &TTI = FAM.getResult<TargetIRAnalysis>(F); 250a17f03bdSSanjay Patel DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F); 251a17f03bdSSanjay Patel if (!runImpl(F, TTI, DT)) 252a17f03bdSSanjay Patel return PreservedAnalyses::all(); 253a17f03bdSSanjay Patel PreservedAnalyses PA; 254a17f03bdSSanjay Patel PA.preserveSet<CFGAnalyses>(); 255a17f03bdSSanjay Patel PA.preserve<GlobalsAA>(); 256a17f03bdSSanjay Patel return PA; 257a17f03bdSSanjay Patel } 258