1a17f03bdSSanjay Patel //===------- VectorCombine.cpp - Optimize partial vector operations -------===// 2a17f03bdSSanjay Patel // 3a17f03bdSSanjay Patel // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4a17f03bdSSanjay Patel // See https://llvm.org/LICENSE.txt for license information. 5a17f03bdSSanjay Patel // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6a17f03bdSSanjay Patel // 7a17f03bdSSanjay Patel //===----------------------------------------------------------------------===// 8a17f03bdSSanjay Patel // 9a17f03bdSSanjay Patel // This pass optimizes scalar/vector interactions using target cost models. The 10a17f03bdSSanjay Patel // transforms implemented here may not fit in traditional loop-based or SLP 11a17f03bdSSanjay Patel // vectorization passes. 12a17f03bdSSanjay Patel // 13a17f03bdSSanjay Patel //===----------------------------------------------------------------------===// 14a17f03bdSSanjay Patel 15a17f03bdSSanjay Patel #include "llvm/Transforms/Vectorize/VectorCombine.h" 16a17f03bdSSanjay Patel #include "llvm/ADT/Statistic.h" 17a17f03bdSSanjay Patel #include "llvm/Analysis/GlobalsModRef.h" 18a17f03bdSSanjay Patel #include "llvm/Analysis/TargetTransformInfo.h" 1919b62b79SSanjay Patel #include "llvm/Analysis/ValueTracking.h" 20a17f03bdSSanjay Patel #include "llvm/IR/Dominators.h" 21a17f03bdSSanjay Patel #include "llvm/IR/Function.h" 22a17f03bdSSanjay Patel #include "llvm/IR/IRBuilder.h" 23a17f03bdSSanjay Patel #include "llvm/IR/PatternMatch.h" 24a17f03bdSSanjay Patel #include "llvm/InitializePasses.h" 25a17f03bdSSanjay Patel #include "llvm/Pass.h" 26a17f03bdSSanjay Patel #include "llvm/Transforms/Vectorize.h" 27a17f03bdSSanjay Patel #include "llvm/Transforms/Utils/Local.h" 28a17f03bdSSanjay Patel 29a17f03bdSSanjay Patel using namespace llvm; 30a17f03bdSSanjay Patel using namespace llvm::PatternMatch; 31a17f03bdSSanjay Patel 32a17f03bdSSanjay Patel #define DEBUG_TYPE "vector-combine" 33a17f03bdSSanjay Patel STATISTIC(NumVecCmp, "Number of vector compares formed"); 3419b62b79SSanjay Patel STATISTIC(NumVecBO, "Number of vector binops formed"); 35a17f03bdSSanjay Patel 3634e34855SSanjay Patel /// Compare the relative costs of extracts followed by scalar operation vs. 3734e34855SSanjay Patel /// vector operation followed by extract: 3834e34855SSanjay Patel /// opcode (extelt V0, C), (extelt V1, C) --> extelt (opcode V0, V1), C 3934e34855SSanjay Patel /// Unless the vector op is much more expensive than the scalar op, this 4034e34855SSanjay Patel /// eliminates an extract. 4134e34855SSanjay Patel static bool isExtractExtractCheap(Instruction *Ext0, Instruction *Ext1, 4234e34855SSanjay Patel unsigned Opcode, 4334e34855SSanjay Patel const TargetTransformInfo &TTI) { 4434e34855SSanjay Patel assert(Ext0->getOperand(1) == Ext1->getOperand(1) && 4534e34855SSanjay Patel isa<ConstantInt>(Ext0->getOperand(1)) && 4634e34855SSanjay Patel "Expected same constant extract index"); 4734e34855SSanjay Patel 4834e34855SSanjay Patel Type *ScalarTy = Ext0->getType(); 4934e34855SSanjay Patel Type *VecTy = Ext0->getOperand(0)->getType(); 5034e34855SSanjay Patel int ScalarOpCost, VectorOpCost; 5134e34855SSanjay Patel 5234e34855SSanjay Patel // Get cost estimates for scalar and vector versions of the operation. 5334e34855SSanjay Patel bool IsBinOp = Instruction::isBinaryOp(Opcode); 5434e34855SSanjay Patel if (IsBinOp) { 5534e34855SSanjay Patel ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, ScalarTy); 5634e34855SSanjay Patel VectorOpCost = TTI.getArithmeticInstrCost(Opcode, VecTy); 5734e34855SSanjay Patel } else { 5834e34855SSanjay Patel assert((Opcode == Instruction::ICmp || Opcode == Instruction::FCmp) && 5934e34855SSanjay Patel "Expected a compare"); 6034e34855SSanjay Patel ScalarOpCost = TTI.getCmpSelInstrCost(Opcode, ScalarTy, 6134e34855SSanjay Patel CmpInst::makeCmpResultType(ScalarTy)); 6234e34855SSanjay Patel VectorOpCost = TTI.getCmpSelInstrCost(Opcode, VecTy, 6334e34855SSanjay Patel CmpInst::makeCmpResultType(VecTy)); 6434e34855SSanjay Patel } 6534e34855SSanjay Patel 6634e34855SSanjay Patel // Get cost estimate for the extract element. This cost will factor into 6734e34855SSanjay Patel // both sequences. 6834e34855SSanjay Patel unsigned ExtIndex = cast<ConstantInt>(Ext0->getOperand(1))->getZExtValue(); 6934e34855SSanjay Patel int ExtractCost = TTI.getVectorInstrCost(Instruction::ExtractElement, 7034e34855SSanjay Patel VecTy, ExtIndex); 7134e34855SSanjay Patel 7234e34855SSanjay Patel // Extra uses of the extracts mean that we include those costs in the 7334e34855SSanjay Patel // vector total because those instructions will not be eliminated. 74e9c79a7aSSanjay Patel int OldCost, NewCost; 7534e34855SSanjay Patel if (Ext0->getOperand(0) == Ext1->getOperand(0)) { 7634e34855SSanjay Patel // Handle a special case. If the 2 operands are identical, adjust the 7734e34855SSanjay Patel // formulas to account for that. The extra use charge allows for either the 7834e34855SSanjay Patel // CSE'd pattern or an unoptimized form with identical values: 7934e34855SSanjay Patel // opcode (extelt V, C), (extelt V, C) --> extelt (opcode V, V), C 8034e34855SSanjay Patel bool HasUseTax = Ext0 == Ext1 ? !Ext0->hasNUses(2) 8134e34855SSanjay Patel : !Ext0->hasOneUse() || !Ext1->hasOneUse(); 82e9c79a7aSSanjay Patel OldCost = ExtractCost + ScalarOpCost; 83e9c79a7aSSanjay Patel NewCost = VectorOpCost + ExtractCost + HasUseTax * ExtractCost; 8434e34855SSanjay Patel } else { 8534e34855SSanjay Patel // Handle the general case. Each extract is actually a different value: 8634e34855SSanjay Patel // opcode (extelt V0, C), (extelt V1, C) --> extelt (opcode V0, V1), C 87e9c79a7aSSanjay Patel OldCost = 2 * ExtractCost + ScalarOpCost; 88e9c79a7aSSanjay Patel NewCost = VectorOpCost + ExtractCost + !Ext0->hasOneUse() * ExtractCost + 8934e34855SSanjay Patel !Ext1->hasOneUse() * ExtractCost; 9034e34855SSanjay Patel } 91*10ea01d8SSanjay Patel // Aggressively form a vector op if the cost is equal because the transform 92*10ea01d8SSanjay Patel // may enable further optimization. 93*10ea01d8SSanjay Patel // Codegen can reverse this transform (scalarize) if it was not profitable. 94*10ea01d8SSanjay Patel return OldCost < NewCost; 9534e34855SSanjay Patel } 9634e34855SSanjay Patel 97fc445589SSanjay Patel /// Try to reduce extract element costs by converting scalar compares to vector 98fc445589SSanjay Patel /// compares followed by extract. 99e9c79a7aSSanjay Patel /// cmp (ext0 V0, C), (ext1 V1, C) 100e9c79a7aSSanjay Patel static void foldExtExtCmp(Instruction *Ext0, Instruction *Ext1, 101fc445589SSanjay Patel Instruction &I, const TargetTransformInfo &TTI) { 102fc445589SSanjay Patel assert(isa<CmpInst>(&I) && "Expected a compare"); 103a17f03bdSSanjay Patel 104a17f03bdSSanjay Patel // cmp Pred (extelt V0, C), (extelt V1, C) --> extelt (cmp Pred V0, V1), C 105a17f03bdSSanjay Patel ++NumVecCmp; 106a17f03bdSSanjay Patel IRBuilder<> Builder(&I); 107fc445589SSanjay Patel CmpInst::Predicate Pred = cast<CmpInst>(&I)->getPredicate(); 108e9c79a7aSSanjay Patel Value *V0 = Ext0->getOperand(0), *V1 = Ext1->getOperand(0); 10934e34855SSanjay Patel Value *VecCmp = 11034e34855SSanjay Patel Ext0->getType()->isFloatingPointTy() ? Builder.CreateFCmp(Pred, V0, V1) 111a17f03bdSSanjay Patel : Builder.CreateICmp(Pred, V0, V1); 112fc445589SSanjay Patel Value *Extract = Builder.CreateExtractElement(VecCmp, Ext0->getOperand(1)); 113fc445589SSanjay Patel I.replaceAllUsesWith(Extract); 114a17f03bdSSanjay Patel } 115a17f03bdSSanjay Patel 11619b62b79SSanjay Patel /// Try to reduce extract element costs by converting scalar binops to vector 11719b62b79SSanjay Patel /// binops followed by extract. 118e9c79a7aSSanjay Patel /// bo (ext0 V0, C), (ext1 V1, C) 119e9c79a7aSSanjay Patel static void foldExtExtBinop(Instruction *Ext0, Instruction *Ext1, 120fc445589SSanjay Patel Instruction &I, const TargetTransformInfo &TTI) { 121fc445589SSanjay Patel assert(isa<BinaryOperator>(&I) && "Expected a binary operator"); 12219b62b79SSanjay Patel 12334e34855SSanjay Patel // bo (extelt V0, C), (extelt V1, C) --> extelt (bo V0, V1), C 12419b62b79SSanjay Patel ++NumVecBO; 12519b62b79SSanjay Patel IRBuilder<> Builder(&I); 126e9c79a7aSSanjay Patel Value *V0 = Ext0->getOperand(0), *V1 = Ext1->getOperand(0); 127e9c79a7aSSanjay Patel Value *VecBO = 12834e34855SSanjay Patel Builder.CreateBinOp(cast<BinaryOperator>(&I)->getOpcode(), V0, V1); 129e9c79a7aSSanjay Patel 13019b62b79SSanjay Patel // All IR flags are safe to back-propagate because any potential poison 13119b62b79SSanjay Patel // created in unused vector elements is discarded by the extract. 132e9c79a7aSSanjay Patel if (auto *VecBOInst = dyn_cast<Instruction>(VecBO)) 13319b62b79SSanjay Patel VecBOInst->copyIRFlags(&I); 134e9c79a7aSSanjay Patel 135e9c79a7aSSanjay Patel Value *Extract = Builder.CreateExtractElement(VecBO, Ext0->getOperand(1)); 13619b62b79SSanjay Patel I.replaceAllUsesWith(Extract); 13719b62b79SSanjay Patel } 13819b62b79SSanjay Patel 139fc445589SSanjay Patel /// Match an instruction with extracted vector operands. 140fc445589SSanjay Patel static bool foldExtractExtract(Instruction &I, const TargetTransformInfo &TTI) { 141e9c79a7aSSanjay Patel // It is not safe to transform things like div, urem, etc. because we may 142e9c79a7aSSanjay Patel // create undefined behavior when executing those on unknown vector elements. 143e9c79a7aSSanjay Patel if (!isSafeToSpeculativelyExecute(&I)) 144e9c79a7aSSanjay Patel return false; 145e9c79a7aSSanjay Patel 146fc445589SSanjay Patel Instruction *Ext0, *Ext1; 147fc445589SSanjay Patel CmpInst::Predicate Pred = CmpInst::BAD_ICMP_PREDICATE; 148fc445589SSanjay Patel if (!match(&I, m_Cmp(Pred, m_Instruction(Ext0), m_Instruction(Ext1))) && 149fc445589SSanjay Patel !match(&I, m_BinOp(m_Instruction(Ext0), m_Instruction(Ext1)))) 150fc445589SSanjay Patel return false; 151fc445589SSanjay Patel 152fc445589SSanjay Patel Value *V0, *V1; 153fc445589SSanjay Patel uint64_t C0, C1; 154fc445589SSanjay Patel if (!match(Ext0, m_ExtractElement(m_Value(V0), m_ConstantInt(C0))) || 155fc445589SSanjay Patel !match(Ext1, m_ExtractElement(m_Value(V1), m_ConstantInt(C1))) || 156fc445589SSanjay Patel V0->getType() != V1->getType()) 157fc445589SSanjay Patel return false; 158fc445589SSanjay Patel 159e9c79a7aSSanjay Patel // TODO: Handle C0 != C1 by shuffling 1 of the operands. 160e9c79a7aSSanjay Patel if (C0 != C1) 161fc445589SSanjay Patel return false; 162e9c79a7aSSanjay Patel 163e9c79a7aSSanjay Patel if (isExtractExtractCheap(Ext0, Ext1, I.getOpcode(), TTI)) 164e9c79a7aSSanjay Patel return false; 165e9c79a7aSSanjay Patel 166e9c79a7aSSanjay Patel if (Pred != CmpInst::BAD_ICMP_PREDICATE) 167e9c79a7aSSanjay Patel foldExtExtCmp(Ext0, Ext1, I, TTI); 168e9c79a7aSSanjay Patel else 169e9c79a7aSSanjay Patel foldExtExtBinop(Ext0, Ext1, I, TTI); 170e9c79a7aSSanjay Patel 171e9c79a7aSSanjay Patel return true; 172fc445589SSanjay Patel } 173fc445589SSanjay Patel 174a17f03bdSSanjay Patel /// This is the entry point for all transforms. Pass manager differences are 175a17f03bdSSanjay Patel /// handled in the callers of this function. 176a17f03bdSSanjay Patel static bool runImpl(Function &F, const TargetTransformInfo &TTI, 177a17f03bdSSanjay Patel const DominatorTree &DT) { 178a17f03bdSSanjay Patel bool MadeChange = false; 179a17f03bdSSanjay Patel for (BasicBlock &BB : F) { 180a17f03bdSSanjay Patel // Ignore unreachable basic blocks. 181a17f03bdSSanjay Patel if (!DT.isReachableFromEntry(&BB)) 182a17f03bdSSanjay Patel continue; 183a17f03bdSSanjay Patel // Do not delete instructions under here and invalidate the iterator. 184a17f03bdSSanjay Patel // Walk the block backwards for efficiency. We're matching a chain of 185a17f03bdSSanjay Patel // use->defs, so we're more likely to succeed by starting from the bottom. 186a17f03bdSSanjay Patel // TODO: It could be more efficient to remove dead instructions 187a17f03bdSSanjay Patel // iteratively in this loop rather than waiting until the end. 188fc445589SSanjay Patel for (Instruction &I : make_range(BB.rbegin(), BB.rend())) 189fc445589SSanjay Patel MadeChange |= foldExtractExtract(I, TTI); 190a17f03bdSSanjay Patel } 191a17f03bdSSanjay Patel 192a17f03bdSSanjay Patel // We're done with transforms, so remove dead instructions. 193a17f03bdSSanjay Patel if (MadeChange) 194a17f03bdSSanjay Patel for (BasicBlock &BB : F) 195a17f03bdSSanjay Patel SimplifyInstructionsInBlock(&BB); 196a17f03bdSSanjay Patel 197a17f03bdSSanjay Patel return MadeChange; 198a17f03bdSSanjay Patel } 199a17f03bdSSanjay Patel 200a17f03bdSSanjay Patel // Pass manager boilerplate below here. 201a17f03bdSSanjay Patel 202a17f03bdSSanjay Patel namespace { 203a17f03bdSSanjay Patel class VectorCombineLegacyPass : public FunctionPass { 204a17f03bdSSanjay Patel public: 205a17f03bdSSanjay Patel static char ID; 206a17f03bdSSanjay Patel VectorCombineLegacyPass() : FunctionPass(ID) { 207a17f03bdSSanjay Patel initializeVectorCombineLegacyPassPass(*PassRegistry::getPassRegistry()); 208a17f03bdSSanjay Patel } 209a17f03bdSSanjay Patel 210a17f03bdSSanjay Patel void getAnalysisUsage(AnalysisUsage &AU) const override { 211a17f03bdSSanjay Patel AU.addRequired<DominatorTreeWrapperPass>(); 212a17f03bdSSanjay Patel AU.addRequired<TargetTransformInfoWrapperPass>(); 213a17f03bdSSanjay Patel AU.setPreservesCFG(); 214a17f03bdSSanjay Patel AU.addPreserved<DominatorTreeWrapperPass>(); 215a17f03bdSSanjay Patel AU.addPreserved<GlobalsAAWrapperPass>(); 216a17f03bdSSanjay Patel FunctionPass::getAnalysisUsage(AU); 217a17f03bdSSanjay Patel } 218a17f03bdSSanjay Patel 219a17f03bdSSanjay Patel bool runOnFunction(Function &F) override { 220a17f03bdSSanjay Patel if (skipFunction(F)) 221a17f03bdSSanjay Patel return false; 222a17f03bdSSanjay Patel auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 223a17f03bdSSanjay Patel auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); 224a17f03bdSSanjay Patel return runImpl(F, TTI, DT); 225a17f03bdSSanjay Patel } 226a17f03bdSSanjay Patel }; 227a17f03bdSSanjay Patel } // namespace 228a17f03bdSSanjay Patel 229a17f03bdSSanjay Patel char VectorCombineLegacyPass::ID = 0; 230a17f03bdSSanjay Patel INITIALIZE_PASS_BEGIN(VectorCombineLegacyPass, "vector-combine", 231a17f03bdSSanjay Patel "Optimize scalar/vector ops", false, 232a17f03bdSSanjay Patel false) 233a17f03bdSSanjay Patel INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) 234a17f03bdSSanjay Patel INITIALIZE_PASS_END(VectorCombineLegacyPass, "vector-combine", 235a17f03bdSSanjay Patel "Optimize scalar/vector ops", false, false) 236a17f03bdSSanjay Patel Pass *llvm::createVectorCombinePass() { 237a17f03bdSSanjay Patel return new VectorCombineLegacyPass(); 238a17f03bdSSanjay Patel } 239a17f03bdSSanjay Patel 240a17f03bdSSanjay Patel PreservedAnalyses VectorCombinePass::run(Function &F, 241a17f03bdSSanjay Patel FunctionAnalysisManager &FAM) { 242a17f03bdSSanjay Patel TargetTransformInfo &TTI = FAM.getResult<TargetIRAnalysis>(F); 243a17f03bdSSanjay Patel DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F); 244a17f03bdSSanjay Patel if (!runImpl(F, TTI, DT)) 245a17f03bdSSanjay Patel return PreservedAnalyses::all(); 246a17f03bdSSanjay Patel PreservedAnalyses PA; 247a17f03bdSSanjay Patel PA.preserveSet<CFGAnalyses>(); 248a17f03bdSSanjay Patel PA.preserve<GlobalsAA>(); 249a17f03bdSSanjay Patel return PA; 250a17f03bdSSanjay Patel } 251