1 //===-- X86PartialReduction.cpp -------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This pass looks for add instructions used by a horizontal reduction to see 10 // if we might be able to use pmaddwd or psadbw. Some cases of this require 11 // cross basic block knowledge and can't be done in SelectionDAG. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "X86.h" 16 #include "llvm/Analysis/ValueTracking.h" 17 #include "llvm/CodeGen/TargetPassConfig.h" 18 #include "llvm/IR/Constants.h" 19 #include "llvm/IR/Instructions.h" 20 #include "llvm/IR/IntrinsicsX86.h" 21 #include "llvm/IR/IRBuilder.h" 22 #include "llvm/IR/Operator.h" 23 #include "llvm/Pass.h" 24 #include "X86TargetMachine.h" 25 26 using namespace llvm; 27 28 #define DEBUG_TYPE "x86-partial-reduction" 29 30 namespace { 31 32 class X86PartialReduction : public FunctionPass { 33 const DataLayout *DL; 34 const X86Subtarget *ST; 35 36 public: 37 static char ID; // Pass identification, replacement for typeid. 38 39 X86PartialReduction() : FunctionPass(ID) { } 40 41 bool runOnFunction(Function &Fn) override; 42 43 void getAnalysisUsage(AnalysisUsage &AU) const override { 44 AU.setPreservesCFG(); 45 } 46 47 StringRef getPassName() const override { 48 return "X86 Partial Reduction"; 49 } 50 51 private: 52 bool tryMAddPattern(BinaryOperator *BO); 53 bool tryMAddReplacement(Value *Op, BinaryOperator *Add); 54 55 bool trySADPattern(BinaryOperator *BO); 56 bool trySADReplacement(Value *Op, BinaryOperator *Add); 57 }; 58 } 59 60 FunctionPass *llvm::createX86PartialReductionPass() { 61 return new X86PartialReduction(); 62 } 63 64 char X86PartialReduction::ID = 0; 65 66 INITIALIZE_PASS(X86PartialReduction, DEBUG_TYPE, 67 "X86 Partial Reduction", false, false) 68 69 static bool isVectorReductionOp(const BinaryOperator &BO) { 70 if (!BO.getType()->isVectorTy()) 71 return false; 72 73 unsigned Opcode = BO.getOpcode(); 74 75 switch (Opcode) { 76 case Instruction::Add: 77 case Instruction::Mul: 78 case Instruction::And: 79 case Instruction::Or: 80 case Instruction::Xor: 81 break; 82 case Instruction::FAdd: 83 case Instruction::FMul: 84 if (auto *FPOp = dyn_cast<FPMathOperator>(&BO)) 85 if (FPOp->getFastMathFlags().isFast()) 86 break; 87 LLVM_FALLTHROUGH; 88 default: 89 return false; 90 } 91 92 unsigned ElemNum = BO.getType()->getVectorNumElements(); 93 // Ensure the reduction size is a power of 2. 94 if (!isPowerOf2_32(ElemNum)) 95 return false; 96 97 unsigned ElemNumToReduce = ElemNum; 98 99 // Do DFS search on the def-use chain from the given instruction. We only 100 // allow four kinds of operations during the search until we reach the 101 // instruction that extracts the first element from the vector: 102 // 103 // 1. The reduction operation of the same opcode as the given instruction. 104 // 105 // 2. PHI node. 106 // 107 // 3. ShuffleVector instruction together with a reduction operation that 108 // does a partial reduction. 109 // 110 // 4. ExtractElement that extracts the first element from the vector, and we 111 // stop searching the def-use chain here. 112 // 113 // 3 & 4 above perform a reduction on all elements of the vector. We push defs 114 // from 1-3 to the stack to continue the DFS. The given instruction is not 115 // a reduction operation if we meet any other instructions other than those 116 // listed above. 117 118 SmallVector<const User *, 16> UsersToVisit{&BO}; 119 SmallPtrSet<const User *, 16> Visited; 120 bool ReduxExtracted = false; 121 122 while (!UsersToVisit.empty()) { 123 auto User = UsersToVisit.back(); 124 UsersToVisit.pop_back(); 125 if (!Visited.insert(User).second) 126 continue; 127 128 for (const auto *U : User->users()) { 129 auto *Inst = dyn_cast<Instruction>(U); 130 if (!Inst) 131 return false; 132 133 if (Inst->getOpcode() == Opcode || isa<PHINode>(U)) { 134 if (auto *FPOp = dyn_cast<FPMathOperator>(Inst)) 135 if (!isa<PHINode>(FPOp) && !FPOp->getFastMathFlags().isFast()) 136 return false; 137 UsersToVisit.push_back(U); 138 } else if (auto *ShufInst = dyn_cast<ShuffleVectorInst>(U)) { 139 // Detect the following pattern: A ShuffleVector instruction together 140 // with a reduction that do partial reduction on the first and second 141 // ElemNumToReduce / 2 elements, and store the result in 142 // ElemNumToReduce / 2 elements in another vector. 143 144 unsigned ResultElements = ShufInst->getType()->getVectorNumElements(); 145 if (ResultElements < ElemNum) 146 return false; 147 148 if (ElemNumToReduce == 1) 149 return false; 150 if (!isa<UndefValue>(U->getOperand(1))) 151 return false; 152 for (unsigned i = 0; i < ElemNumToReduce / 2; ++i) 153 if (ShufInst->getMaskValue(i) != int(i + ElemNumToReduce / 2)) 154 return false; 155 for (unsigned i = ElemNumToReduce / 2; i < ElemNum; ++i) 156 if (ShufInst->getMaskValue(i) != -1) 157 return false; 158 159 // There is only one user of this ShuffleVector instruction, which 160 // must be a reduction operation. 161 if (!U->hasOneUse()) 162 return false; 163 164 auto *U2 = dyn_cast<BinaryOperator>(*U->user_begin()); 165 if (!U2 || U2->getOpcode() != Opcode) 166 return false; 167 168 // Check operands of the reduction operation. 169 if ((U2->getOperand(0) == U->getOperand(0) && U2->getOperand(1) == U) || 170 (U2->getOperand(1) == U->getOperand(0) && U2->getOperand(0) == U)) { 171 UsersToVisit.push_back(U2); 172 ElemNumToReduce /= 2; 173 } else 174 return false; 175 } else if (isa<ExtractElementInst>(U)) { 176 // At this moment we should have reduced all elements in the vector. 177 if (ElemNumToReduce != 1) 178 return false; 179 180 auto *Val = dyn_cast<ConstantInt>(U->getOperand(1)); 181 if (!Val || !Val->isZero()) 182 return false; 183 184 ReduxExtracted = true; 185 } else 186 return false; 187 } 188 } 189 return ReduxExtracted; 190 } 191 192 bool X86PartialReduction::tryMAddReplacement(Value *Op, BinaryOperator *Add) { 193 BasicBlock *BB = Add->getParent(); 194 195 auto *BO = dyn_cast<BinaryOperator>(Op); 196 if (!BO || BO->getOpcode() != Instruction::Mul || !BO->hasOneUse() || 197 BO->getParent() != BB) 198 return false; 199 200 Value *LHS = BO->getOperand(0); 201 Value *RHS = BO->getOperand(1); 202 203 // LHS and RHS should be only used once or if they are the same then only 204 // used twice. Only check this when SSE4.1 is enabled and we have zext/sext 205 // instructions, otherwise we use punpck to emulate zero extend in stages. The 206 // trunc/ we need to do likely won't introduce new instructions in that case. 207 if (ST->hasSSE41()) { 208 if (LHS == RHS) { 209 if (!isa<Constant>(LHS) && !LHS->hasNUses(2)) 210 return false; 211 } else { 212 if (!isa<Constant>(LHS) && !LHS->hasOneUse()) 213 return false; 214 if (!isa<Constant>(RHS) && !RHS->hasOneUse()) 215 return false; 216 } 217 } 218 219 auto canShrinkOp = [&](Value *Op) { 220 if (isa<Constant>(Op) && ComputeNumSignBits(Op, *DL, 0, nullptr, BO) > 16) 221 return true; 222 if (auto *Cast = dyn_cast<CastInst>(Op)) { 223 if (Cast->getParent() == BB && 224 (Cast->getOpcode() == Instruction::SExt || 225 Cast->getOpcode() == Instruction::ZExt) && 226 ComputeNumSignBits(Op, *DL, 0, nullptr, BO) > 16) 227 return true; 228 } 229 230 return false; 231 }; 232 233 // Both Ops need to be shrinkable. 234 if (!canShrinkOp(LHS) && !canShrinkOp(RHS)) 235 return false; 236 237 IRBuilder<> Builder(Add); 238 239 Type *MulTy = Op->getType(); 240 unsigned NumElts = MulTy->getVectorNumElements(); 241 242 // Extract even elements and odd elements and add them together. This will 243 // be pattern matched by SelectionDAG to pmaddwd. This instruction will be 244 // half the original width. 245 SmallVector<uint32_t, 16> EvenMask(NumElts / 2); 246 SmallVector<uint32_t, 16> OddMask(NumElts / 2); 247 for (int i = 0, e = NumElts / 2; i != e; ++i) { 248 EvenMask[i] = i * 2; 249 OddMask[i] = i * 2 + 1; 250 } 251 Value *EvenElts = Builder.CreateShuffleVector(BO, BO, EvenMask); 252 Value *OddElts = Builder.CreateShuffleVector(BO, BO, OddMask); 253 Value *MAdd = Builder.CreateAdd(EvenElts, OddElts); 254 255 // Concatenate zeroes to extend back to the original type. 256 SmallVector<uint32_t, 32> ConcatMask(NumElts); 257 std::iota(ConcatMask.begin(), ConcatMask.end(), 0); 258 Value *Zero = Constant::getNullValue(MAdd->getType()); 259 Value *Concat = Builder.CreateShuffleVector(MAdd, Zero, ConcatMask); 260 261 // Replaces the use of mul in the original Add with the pmaddwd and zeroes. 262 Add->replaceUsesOfWith(BO, Concat); 263 Add->setHasNoSignedWrap(false); 264 Add->setHasNoUnsignedWrap(false); 265 266 return true; 267 } 268 269 // Try to replace operans of this add with pmaddwd patterns. 270 bool X86PartialReduction::tryMAddPattern(BinaryOperator *BO) { 271 if (!ST->hasSSE2()) 272 return false; 273 274 // Need at least 8 elements. 275 if (BO->getType()->getVectorNumElements() < 8) 276 return false; 277 278 // Element type should be i32. 279 if (!BO->getType()->getVectorElementType()->isIntegerTy(32)) 280 return false; 281 282 bool Changed = false; 283 Changed |= tryMAddReplacement(BO->getOperand(0), BO); 284 Changed |= tryMAddReplacement(BO->getOperand(1), BO); 285 return Changed; 286 } 287 288 bool X86PartialReduction::trySADReplacement(Value *Op, BinaryOperator *Add) { 289 // Operand should be a select. 290 auto *SI = dyn_cast<SelectInst>(Op); 291 if (!SI) 292 return false; 293 294 // Select needs to implement absolute value. 295 Value *LHS, *RHS; 296 auto SPR = matchSelectPattern(SI, LHS, RHS); 297 if (SPR.Flavor != SPF_ABS) 298 return false; 299 300 // Need a subtract of two values. 301 auto *Sub = dyn_cast<BinaryOperator>(LHS); 302 if (!Sub || Sub->getOpcode() != Instruction::Sub) 303 return false; 304 305 // Look for zero extend from i8. 306 auto getZeroExtendedVal = [](Value *Op) -> Value * { 307 if (auto *ZExt = dyn_cast<ZExtInst>(Op)) 308 if (ZExt->getOperand(0)->getType()->getVectorElementType()->isIntegerTy(8)) 309 return ZExt->getOperand(0); 310 311 return nullptr; 312 }; 313 314 // Both operands of the subtract should be extends from vXi8. 315 Value *Op0 = getZeroExtendedVal(Sub->getOperand(0)); 316 Value *Op1 = getZeroExtendedVal(Sub->getOperand(1)); 317 if (!Op0 || !Op1) 318 return false; 319 320 IRBuilder<> Builder(Add); 321 322 Type *OpTy = Op->getType(); 323 unsigned NumElts = OpTy->getVectorNumElements(); 324 325 unsigned IntrinsicNumElts; 326 Intrinsic::ID IID; 327 if (ST->hasBWI() && NumElts >= 64) { 328 IID = Intrinsic::x86_avx512_psad_bw_512; 329 IntrinsicNumElts = 64; 330 } else if (ST->hasAVX2() && NumElts >= 32) { 331 IID = Intrinsic::x86_avx2_psad_bw; 332 IntrinsicNumElts = 32; 333 } else { 334 IID = Intrinsic::x86_sse2_psad_bw; 335 IntrinsicNumElts = 16; 336 } 337 338 Function *PSADBWFn = Intrinsic::getDeclaration(Add->getModule(), IID); 339 340 if (NumElts < 16) { 341 // Pad input with zeroes. 342 SmallVector<uint32_t, 32> ConcatMask(16); 343 for (unsigned i = 0; i != NumElts; ++i) 344 ConcatMask[i] = i; 345 for (unsigned i = NumElts; i != 16; ++i) 346 ConcatMask[i] = (i % NumElts) + NumElts; 347 348 Value *Zero = Constant::getNullValue(Op0->getType()); 349 Op0 = Builder.CreateShuffleVector(Op0, Zero, ConcatMask); 350 Op1 = Builder.CreateShuffleVector(Op1, Zero, ConcatMask); 351 NumElts = 16; 352 } 353 354 // Intrinsics produce vXi64 and need to be casted to vXi32. 355 Type *I32Ty = VectorType::get(Builder.getInt32Ty(), IntrinsicNumElts / 4); 356 357 assert(NumElts % IntrinsicNumElts == 0 && "Unexpected number of elements!"); 358 unsigned NumSplits = NumElts / IntrinsicNumElts; 359 360 // First collect the pieces we need. 361 SmallVector<Value *, 4> Ops(NumSplits); 362 for (unsigned i = 0; i != NumSplits; ++i) { 363 SmallVector<uint32_t, 64> ExtractMask(IntrinsicNumElts); 364 std::iota(ExtractMask.begin(), ExtractMask.end(), i * IntrinsicNumElts); 365 Value *ExtractOp0 = Builder.CreateShuffleVector(Op0, Op0, ExtractMask); 366 Value *ExtractOp1 = Builder.CreateShuffleVector(Op1, Op0, ExtractMask); 367 Ops[i] = Builder.CreateCall(PSADBWFn, {ExtractOp0, ExtractOp1}); 368 Ops[i] = Builder.CreateBitCast(Ops[i], I32Ty); 369 } 370 371 assert(isPowerOf2_32(NumSplits) && "Expected power of 2 splits"); 372 unsigned Stages = Log2_32(NumSplits); 373 for (unsigned s = Stages; s > 0; --s) { 374 unsigned NumConcatElts = Ops[0]->getType()->getVectorNumElements() * 2; 375 for (unsigned i = 0; i != 1U << (s - 1); ++i) { 376 SmallVector<uint32_t, 64> ConcatMask(NumConcatElts); 377 std::iota(ConcatMask.begin(), ConcatMask.end(), 0); 378 Ops[i] = Builder.CreateShuffleVector(Ops[i*2], Ops[i*2+1], ConcatMask); 379 } 380 } 381 382 // At this point the final value should be in Ops[0]. Now we need to adjust 383 // it to the final original type. 384 NumElts = OpTy->getVectorNumElements(); 385 if (NumElts == 2) { 386 // Extract down to 2 elements. 387 Ops[0] = Builder.CreateShuffleVector(Ops[0], Ops[0], ArrayRef<int>{0, 1}); 388 } else if (NumElts >= 8) { 389 SmallVector<uint32_t, 32> ConcatMask(NumElts); 390 unsigned SubElts = Ops[0]->getType()->getVectorNumElements(); 391 for (unsigned i = 0; i != SubElts; ++i) 392 ConcatMask[i] = i; 393 for (unsigned i = SubElts; i != NumElts; ++i) 394 ConcatMask[i] = (i % SubElts) + SubElts; 395 396 Value *Zero = Constant::getNullValue(Ops[0]->getType()); 397 Ops[0] = Builder.CreateShuffleVector(Ops[0], Zero, ConcatMask); 398 } 399 400 // Replaces the uses of Op in Add with the new sequence. 401 Add->replaceUsesOfWith(Op, Ops[0]); 402 Add->setHasNoSignedWrap(false); 403 Add->setHasNoUnsignedWrap(false); 404 405 return false; 406 } 407 408 bool X86PartialReduction::trySADPattern(BinaryOperator *BO) { 409 if (!ST->hasSSE2()) 410 return false; 411 412 // TODO: There's nothing special about i32, any integer type above i16 should 413 // work just as well. 414 if (!BO->getType()->getVectorElementType()->isIntegerTy(32)) 415 return false; 416 417 bool Changed = false; 418 Changed |= trySADReplacement(BO->getOperand(0), BO); 419 Changed |= trySADReplacement(BO->getOperand(1), BO); 420 return Changed; 421 } 422 423 bool X86PartialReduction::runOnFunction(Function &F) { 424 if (skipFunction(F)) 425 return false; 426 427 auto *TPC = getAnalysisIfAvailable<TargetPassConfig>(); 428 if (!TPC) 429 return false; 430 431 auto &TM = TPC->getTM<X86TargetMachine>(); 432 ST = TM.getSubtargetImpl(F); 433 434 DL = &F.getParent()->getDataLayout(); 435 436 bool MadeChange = false; 437 for (auto &BB : F) { 438 for (auto &I : BB) { 439 auto *BO = dyn_cast<BinaryOperator>(&I); 440 if (!BO) 441 continue; 442 443 if (!isVectorReductionOp(*BO)) 444 continue; 445 446 if (BO->getOpcode() == Instruction::Add) { 447 if (tryMAddPattern(BO)) { 448 MadeChange = true; 449 continue; 450 } 451 if (trySADPattern(BO)) { 452 MadeChange = true; 453 continue; 454 } 455 } 456 } 457 } 458 459 return MadeChange; 460 } 461