1a17f03bdSSanjay Patel //===------- VectorCombine.cpp - Optimize partial vector operations -------===// 2a17f03bdSSanjay Patel // 3a17f03bdSSanjay Patel // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4a17f03bdSSanjay Patel // See https://llvm.org/LICENSE.txt for license information. 5a17f03bdSSanjay Patel // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6a17f03bdSSanjay Patel // 7a17f03bdSSanjay Patel //===----------------------------------------------------------------------===// 8a17f03bdSSanjay Patel // 9a17f03bdSSanjay Patel // This pass optimizes scalar/vector interactions using target cost models. The 10a17f03bdSSanjay Patel // transforms implemented here may not fit in traditional loop-based or SLP 11a17f03bdSSanjay Patel // vectorization passes. 12a17f03bdSSanjay Patel // 13a17f03bdSSanjay Patel //===----------------------------------------------------------------------===// 14a17f03bdSSanjay Patel 15a17f03bdSSanjay Patel #include "llvm/Transforms/Vectorize/VectorCombine.h" 16a17f03bdSSanjay Patel #include "llvm/ADT/Statistic.h" 17a17f03bdSSanjay Patel #include "llvm/Analysis/GlobalsModRef.h" 18a17f03bdSSanjay Patel #include "llvm/Analysis/TargetTransformInfo.h" 1919b62b79SSanjay Patel #include "llvm/Analysis/ValueTracking.h" 20a17f03bdSSanjay Patel #include "llvm/IR/Dominators.h" 21a17f03bdSSanjay Patel #include "llvm/IR/Function.h" 22a17f03bdSSanjay Patel #include "llvm/IR/IRBuilder.h" 23a17f03bdSSanjay Patel #include "llvm/IR/PatternMatch.h" 24a17f03bdSSanjay Patel #include "llvm/InitializePasses.h" 25a17f03bdSSanjay Patel #include "llvm/Pass.h" 2625c6544fSSanjay Patel #include "llvm/Support/CommandLine.h" 27a17f03bdSSanjay Patel #include "llvm/Transforms/Vectorize.h" 28a17f03bdSSanjay Patel #include "llvm/Transforms/Utils/Local.h" 29a17f03bdSSanjay Patel 30a17f03bdSSanjay Patel using namespace llvm; 31a17f03bdSSanjay Patel using namespace llvm::PatternMatch; 32a17f03bdSSanjay Patel 33a17f03bdSSanjay Patel #define DEBUG_TYPE "vector-combine" 34a17f03bdSSanjay Patel STATISTIC(NumVecCmp, "Number of vector compares formed"); 3519b62b79SSanjay Patel STATISTIC(NumVecBO, "Number of vector binops formed"); 36a17f03bdSSanjay Patel 3725c6544fSSanjay Patel static cl::opt<bool> DisableVectorCombine( 3825c6544fSSanjay Patel "disable-vector-combine", cl::init(false), cl::Hidden, 3925c6544fSSanjay Patel cl::desc("Disable all vector combine transforms")); 4025c6544fSSanjay Patel 4134e34855SSanjay Patel /// Compare the relative costs of extracts followed by scalar operation vs. 4234e34855SSanjay Patel /// vector operation followed by extract: 4334e34855SSanjay Patel /// opcode (extelt V0, C), (extelt V1, C) --> extelt (opcode V0, V1), C 4434e34855SSanjay Patel /// Unless the vector op is much more expensive than the scalar op, this 4534e34855SSanjay Patel /// eliminates an extract. 4634e34855SSanjay Patel static bool isExtractExtractCheap(Instruction *Ext0, Instruction *Ext1, 4734e34855SSanjay Patel unsigned Opcode, 4834e34855SSanjay Patel const TargetTransformInfo &TTI) { 49*4fa63fd4SAustin Kerbow assert(isa<ConstantInt>(Ext0->getOperand(1)) && 50*4fa63fd4SAustin Kerbow (cast<ConstantInt>(Ext0->getOperand(1))->getZExtValue() == 51*4fa63fd4SAustin Kerbow cast<ConstantInt>(Ext1->getOperand(1))->getZExtValue()) && 5234e34855SSanjay Patel "Expected same constant extract index"); 5334e34855SSanjay Patel 5434e34855SSanjay Patel Type *ScalarTy = Ext0->getType(); 5534e34855SSanjay Patel Type *VecTy = Ext0->getOperand(0)->getType(); 5634e34855SSanjay Patel int ScalarOpCost, VectorOpCost; 5734e34855SSanjay Patel 5834e34855SSanjay Patel // Get cost estimates for scalar and vector versions of the operation. 5934e34855SSanjay Patel bool IsBinOp = Instruction::isBinaryOp(Opcode); 6034e34855SSanjay Patel if (IsBinOp) { 6134e34855SSanjay Patel ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, ScalarTy); 6234e34855SSanjay Patel VectorOpCost = TTI.getArithmeticInstrCost(Opcode, VecTy); 6334e34855SSanjay Patel } else { 6434e34855SSanjay Patel assert((Opcode == Instruction::ICmp || Opcode == Instruction::FCmp) && 6534e34855SSanjay Patel "Expected a compare"); 6634e34855SSanjay Patel ScalarOpCost = TTI.getCmpSelInstrCost(Opcode, ScalarTy, 6734e34855SSanjay Patel CmpInst::makeCmpResultType(ScalarTy)); 6834e34855SSanjay Patel VectorOpCost = TTI.getCmpSelInstrCost(Opcode, VecTy, 6934e34855SSanjay Patel CmpInst::makeCmpResultType(VecTy)); 7034e34855SSanjay Patel } 7134e34855SSanjay Patel 7234e34855SSanjay Patel // Get cost estimate for the extract element. This cost will factor into 7334e34855SSanjay Patel // both sequences. 7434e34855SSanjay Patel unsigned ExtIndex = cast<ConstantInt>(Ext0->getOperand(1))->getZExtValue(); 7534e34855SSanjay Patel int ExtractCost = TTI.getVectorInstrCost(Instruction::ExtractElement, 7634e34855SSanjay Patel VecTy, ExtIndex); 7734e34855SSanjay Patel 7834e34855SSanjay Patel // Extra uses of the extracts mean that we include those costs in the 7934e34855SSanjay Patel // vector total because those instructions will not be eliminated. 80e9c79a7aSSanjay Patel int OldCost, NewCost; 8134e34855SSanjay Patel if (Ext0->getOperand(0) == Ext1->getOperand(0)) { 8234e34855SSanjay Patel // Handle a special case. If the 2 operands are identical, adjust the 8334e34855SSanjay Patel // formulas to account for that. The extra use charge allows for either the 8434e34855SSanjay Patel // CSE'd pattern or an unoptimized form with identical values: 8534e34855SSanjay Patel // opcode (extelt V, C), (extelt V, C) --> extelt (opcode V, V), C 8634e34855SSanjay Patel bool HasUseTax = Ext0 == Ext1 ? !Ext0->hasNUses(2) 8734e34855SSanjay Patel : !Ext0->hasOneUse() || !Ext1->hasOneUse(); 88e9c79a7aSSanjay Patel OldCost = ExtractCost + ScalarOpCost; 89e9c79a7aSSanjay Patel NewCost = VectorOpCost + ExtractCost + HasUseTax * ExtractCost; 9034e34855SSanjay Patel } else { 9134e34855SSanjay Patel // Handle the general case. Each extract is actually a different value: 9234e34855SSanjay Patel // opcode (extelt V0, C), (extelt V1, C) --> extelt (opcode V0, V1), C 93e9c79a7aSSanjay Patel OldCost = 2 * ExtractCost + ScalarOpCost; 94e9c79a7aSSanjay Patel NewCost = VectorOpCost + ExtractCost + !Ext0->hasOneUse() * ExtractCost + 9534e34855SSanjay Patel !Ext1->hasOneUse() * ExtractCost; 9634e34855SSanjay Patel } 9710ea01d8SSanjay Patel // Aggressively form a vector op if the cost is equal because the transform 9810ea01d8SSanjay Patel // may enable further optimization. 9910ea01d8SSanjay Patel // Codegen can reverse this transform (scalarize) if it was not profitable. 10010ea01d8SSanjay Patel return OldCost < NewCost; 10134e34855SSanjay Patel } 10234e34855SSanjay Patel 103fc445589SSanjay Patel /// Try to reduce extract element costs by converting scalar compares to vector 104fc445589SSanjay Patel /// compares followed by extract. 105e9c79a7aSSanjay Patel /// cmp (ext0 V0, C), (ext1 V1, C) 106e9c79a7aSSanjay Patel static void foldExtExtCmp(Instruction *Ext0, Instruction *Ext1, 107fc445589SSanjay Patel Instruction &I, const TargetTransformInfo &TTI) { 108fc445589SSanjay Patel assert(isa<CmpInst>(&I) && "Expected a compare"); 109a17f03bdSSanjay Patel 110a17f03bdSSanjay Patel // cmp Pred (extelt V0, C), (extelt V1, C) --> extelt (cmp Pred V0, V1), C 111a17f03bdSSanjay Patel ++NumVecCmp; 112a17f03bdSSanjay Patel IRBuilder<> Builder(&I); 113fc445589SSanjay Patel CmpInst::Predicate Pred = cast<CmpInst>(&I)->getPredicate(); 114e9c79a7aSSanjay Patel Value *V0 = Ext0->getOperand(0), *V1 = Ext1->getOperand(0); 11534e34855SSanjay Patel Value *VecCmp = 11634e34855SSanjay Patel Ext0->getType()->isFloatingPointTy() ? Builder.CreateFCmp(Pred, V0, V1) 117a17f03bdSSanjay Patel : Builder.CreateICmp(Pred, V0, V1); 118fc445589SSanjay Patel Value *Extract = Builder.CreateExtractElement(VecCmp, Ext0->getOperand(1)); 119fc445589SSanjay Patel I.replaceAllUsesWith(Extract); 120a17f03bdSSanjay Patel } 121a17f03bdSSanjay Patel 12219b62b79SSanjay Patel /// Try to reduce extract element costs by converting scalar binops to vector 12319b62b79SSanjay Patel /// binops followed by extract. 124e9c79a7aSSanjay Patel /// bo (ext0 V0, C), (ext1 V1, C) 125e9c79a7aSSanjay Patel static void foldExtExtBinop(Instruction *Ext0, Instruction *Ext1, 126fc445589SSanjay Patel Instruction &I, const TargetTransformInfo &TTI) { 127fc445589SSanjay Patel assert(isa<BinaryOperator>(&I) && "Expected a binary operator"); 12819b62b79SSanjay Patel 12934e34855SSanjay Patel // bo (extelt V0, C), (extelt V1, C) --> extelt (bo V0, V1), C 13019b62b79SSanjay Patel ++NumVecBO; 13119b62b79SSanjay Patel IRBuilder<> Builder(&I); 132e9c79a7aSSanjay Patel Value *V0 = Ext0->getOperand(0), *V1 = Ext1->getOperand(0); 133e9c79a7aSSanjay Patel Value *VecBO = 13434e34855SSanjay Patel Builder.CreateBinOp(cast<BinaryOperator>(&I)->getOpcode(), V0, V1); 135e9c79a7aSSanjay Patel 13619b62b79SSanjay Patel // All IR flags are safe to back-propagate because any potential poison 13719b62b79SSanjay Patel // created in unused vector elements is discarded by the extract. 138e9c79a7aSSanjay Patel if (auto *VecBOInst = dyn_cast<Instruction>(VecBO)) 13919b62b79SSanjay Patel VecBOInst->copyIRFlags(&I); 140e9c79a7aSSanjay Patel 141e9c79a7aSSanjay Patel Value *Extract = Builder.CreateExtractElement(VecBO, Ext0->getOperand(1)); 14219b62b79SSanjay Patel I.replaceAllUsesWith(Extract); 14319b62b79SSanjay Patel } 14419b62b79SSanjay Patel 145fc445589SSanjay Patel /// Match an instruction with extracted vector operands. 146fc445589SSanjay Patel static bool foldExtractExtract(Instruction &I, const TargetTransformInfo &TTI) { 147e9c79a7aSSanjay Patel // It is not safe to transform things like div, urem, etc. because we may 148e9c79a7aSSanjay Patel // create undefined behavior when executing those on unknown vector elements. 149e9c79a7aSSanjay Patel if (!isSafeToSpeculativelyExecute(&I)) 150e9c79a7aSSanjay Patel return false; 151e9c79a7aSSanjay Patel 152fc445589SSanjay Patel Instruction *Ext0, *Ext1; 153fc445589SSanjay Patel CmpInst::Predicate Pred = CmpInst::BAD_ICMP_PREDICATE; 154fc445589SSanjay Patel if (!match(&I, m_Cmp(Pred, m_Instruction(Ext0), m_Instruction(Ext1))) && 155fc445589SSanjay Patel !match(&I, m_BinOp(m_Instruction(Ext0), m_Instruction(Ext1)))) 156fc445589SSanjay Patel return false; 157fc445589SSanjay Patel 158fc445589SSanjay Patel Value *V0, *V1; 159fc445589SSanjay Patel uint64_t C0, C1; 160fc445589SSanjay Patel if (!match(Ext0, m_ExtractElement(m_Value(V0), m_ConstantInt(C0))) || 161fc445589SSanjay Patel !match(Ext1, m_ExtractElement(m_Value(V1), m_ConstantInt(C1))) || 162fc445589SSanjay Patel V0->getType() != V1->getType()) 163fc445589SSanjay Patel return false; 164fc445589SSanjay Patel 165e9c79a7aSSanjay Patel // TODO: Handle C0 != C1 by shuffling 1 of the operands. 166e9c79a7aSSanjay Patel if (C0 != C1) 167fc445589SSanjay Patel return false; 168e9c79a7aSSanjay Patel 169e9c79a7aSSanjay Patel if (isExtractExtractCheap(Ext0, Ext1, I.getOpcode(), TTI)) 170e9c79a7aSSanjay Patel return false; 171e9c79a7aSSanjay Patel 172e9c79a7aSSanjay Patel if (Pred != CmpInst::BAD_ICMP_PREDICATE) 173e9c79a7aSSanjay Patel foldExtExtCmp(Ext0, Ext1, I, TTI); 174e9c79a7aSSanjay Patel else 175e9c79a7aSSanjay Patel foldExtExtBinop(Ext0, Ext1, I, TTI); 176e9c79a7aSSanjay Patel 177e9c79a7aSSanjay Patel return true; 178fc445589SSanjay Patel } 179fc445589SSanjay Patel 180a17f03bdSSanjay Patel /// This is the entry point for all transforms. Pass manager differences are 181a17f03bdSSanjay Patel /// handled in the callers of this function. 182a17f03bdSSanjay Patel static bool runImpl(Function &F, const TargetTransformInfo &TTI, 183a17f03bdSSanjay Patel const DominatorTree &DT) { 18425c6544fSSanjay Patel if (DisableVectorCombine) 18525c6544fSSanjay Patel return false; 18625c6544fSSanjay Patel 187a17f03bdSSanjay Patel bool MadeChange = false; 188a17f03bdSSanjay Patel for (BasicBlock &BB : F) { 189a17f03bdSSanjay Patel // Ignore unreachable basic blocks. 190a17f03bdSSanjay Patel if (!DT.isReachableFromEntry(&BB)) 191a17f03bdSSanjay Patel continue; 192a17f03bdSSanjay Patel // Do not delete instructions under here and invalidate the iterator. 193a17f03bdSSanjay Patel // Walk the block backwards for efficiency. We're matching a chain of 194a17f03bdSSanjay Patel // use->defs, so we're more likely to succeed by starting from the bottom. 195a17f03bdSSanjay Patel // TODO: It could be more efficient to remove dead instructions 196a17f03bdSSanjay Patel // iteratively in this loop rather than waiting until the end. 197fc445589SSanjay Patel for (Instruction &I : make_range(BB.rbegin(), BB.rend())) 198fc445589SSanjay Patel MadeChange |= foldExtractExtract(I, TTI); 199a17f03bdSSanjay Patel } 200a17f03bdSSanjay Patel 201a17f03bdSSanjay Patel // We're done with transforms, so remove dead instructions. 202a17f03bdSSanjay Patel if (MadeChange) 203a17f03bdSSanjay Patel for (BasicBlock &BB : F) 204a17f03bdSSanjay Patel SimplifyInstructionsInBlock(&BB); 205a17f03bdSSanjay Patel 206a17f03bdSSanjay Patel return MadeChange; 207a17f03bdSSanjay Patel } 208a17f03bdSSanjay Patel 209a17f03bdSSanjay Patel // Pass manager boilerplate below here. 210a17f03bdSSanjay Patel 211a17f03bdSSanjay Patel namespace { 212a17f03bdSSanjay Patel class VectorCombineLegacyPass : public FunctionPass { 213a17f03bdSSanjay Patel public: 214a17f03bdSSanjay Patel static char ID; 215a17f03bdSSanjay Patel VectorCombineLegacyPass() : FunctionPass(ID) { 216a17f03bdSSanjay Patel initializeVectorCombineLegacyPassPass(*PassRegistry::getPassRegistry()); 217a17f03bdSSanjay Patel } 218a17f03bdSSanjay Patel 219a17f03bdSSanjay Patel void getAnalysisUsage(AnalysisUsage &AU) const override { 220a17f03bdSSanjay Patel AU.addRequired<DominatorTreeWrapperPass>(); 221a17f03bdSSanjay Patel AU.addRequired<TargetTransformInfoWrapperPass>(); 222a17f03bdSSanjay Patel AU.setPreservesCFG(); 223a17f03bdSSanjay Patel AU.addPreserved<DominatorTreeWrapperPass>(); 224a17f03bdSSanjay Patel AU.addPreserved<GlobalsAAWrapperPass>(); 225a17f03bdSSanjay Patel FunctionPass::getAnalysisUsage(AU); 226a17f03bdSSanjay Patel } 227a17f03bdSSanjay Patel 228a17f03bdSSanjay Patel bool runOnFunction(Function &F) override { 229a17f03bdSSanjay Patel if (skipFunction(F)) 230a17f03bdSSanjay Patel return false; 231a17f03bdSSanjay Patel auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 232a17f03bdSSanjay Patel auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); 233a17f03bdSSanjay Patel return runImpl(F, TTI, DT); 234a17f03bdSSanjay Patel } 235a17f03bdSSanjay Patel }; 236a17f03bdSSanjay Patel } // namespace 237a17f03bdSSanjay Patel 238a17f03bdSSanjay Patel char VectorCombineLegacyPass::ID = 0; 239a17f03bdSSanjay Patel INITIALIZE_PASS_BEGIN(VectorCombineLegacyPass, "vector-combine", 240a17f03bdSSanjay Patel "Optimize scalar/vector ops", false, 241a17f03bdSSanjay Patel false) 242a17f03bdSSanjay Patel INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) 243a17f03bdSSanjay Patel INITIALIZE_PASS_END(VectorCombineLegacyPass, "vector-combine", 244a17f03bdSSanjay Patel "Optimize scalar/vector ops", false, false) 245a17f03bdSSanjay Patel Pass *llvm::createVectorCombinePass() { 246a17f03bdSSanjay Patel return new VectorCombineLegacyPass(); 247a17f03bdSSanjay Patel } 248a17f03bdSSanjay Patel 249a17f03bdSSanjay Patel PreservedAnalyses VectorCombinePass::run(Function &F, 250a17f03bdSSanjay Patel FunctionAnalysisManager &FAM) { 251a17f03bdSSanjay Patel TargetTransformInfo &TTI = FAM.getResult<TargetIRAnalysis>(F); 252a17f03bdSSanjay Patel DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F); 253a17f03bdSSanjay Patel if (!runImpl(F, TTI, DT)) 254a17f03bdSSanjay Patel return PreservedAnalyses::all(); 255a17f03bdSSanjay Patel PreservedAnalyses PA; 256a17f03bdSSanjay Patel PA.preserveSet<CFGAnalyses>(); 257a17f03bdSSanjay Patel PA.preserve<GlobalsAA>(); 258a17f03bdSSanjay Patel return PA; 259a17f03bdSSanjay Patel } 260