1 //===-- X86PartialReduction.cpp -------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This pass looks for add instructions used by a horizontal reduction to see 10 // if we might be able to use pmaddwd or psadbw. Some cases of this require 11 // cross basic block knowledge and can't be done in SelectionDAG. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "X86.h" 16 #include "llvm/Analysis/ValueTracking.h" 17 #include "llvm/CodeGen/TargetPassConfig.h" 18 #include "llvm/IR/Constants.h" 19 #include "llvm/IR/Instructions.h" 20 #include "llvm/IR/IntrinsicsX86.h" 21 #include "llvm/IR/IRBuilder.h" 22 #include "llvm/IR/Operator.h" 23 #include "llvm/Pass.h" 24 #include "X86TargetMachine.h" 25 26 using namespace llvm; 27 28 #define DEBUG_TYPE "x86-partial-reduction" 29 30 namespace { 31 32 class X86PartialReduction : public FunctionPass { 33 const DataLayout *DL; 34 const X86Subtarget *ST; 35 36 public: 37 static char ID; // Pass identification, replacement for typeid. 38 39 X86PartialReduction() : FunctionPass(ID) { } 40 41 bool runOnFunction(Function &Fn) override; 42 43 void getAnalysisUsage(AnalysisUsage &AU) const override { 44 AU.setPreservesCFG(); 45 } 46 47 StringRef getPassName() const override { 48 return "X86 Partial Reduction"; 49 } 50 51 private: 52 bool tryMAddPattern(BinaryOperator *BO); 53 bool tryMAddReplacement(Value *Op, BinaryOperator *Add); 54 55 bool trySADPattern(BinaryOperator *BO); 56 bool trySADReplacement(Value *Op, BinaryOperator *Add); 57 }; 58 } 59 60 FunctionPass *llvm::createX86PartialReductionPass() { 61 return new X86PartialReduction(); 62 } 63 64 char X86PartialReduction::ID = 0; 65 66 INITIALIZE_PASS(X86PartialReduction, DEBUG_TYPE, 67 "X86 Partial Reduction", false, false) 68 69 static bool isVectorReductionOp(const BinaryOperator &BO) { 70 if (!BO.getType()->isVectorTy()) 71 return false; 72 73 unsigned Opcode = BO.getOpcode(); 74 75 switch (Opcode) { 76 case Instruction::Add: 77 case Instruction::Mul: 78 case Instruction::And: 79 case Instruction::Or: 80 case Instruction::Xor: 81 break; 82 case Instruction::FAdd: 83 case Instruction::FMul: 84 if (auto *FPOp = dyn_cast<FPMathOperator>(&BO)) 85 if (FPOp->getFastMathFlags().isFast()) 86 break; 87 LLVM_FALLTHROUGH; 88 default: 89 return false; 90 } 91 92 unsigned ElemNum = cast<VectorType>(BO.getType())->getNumElements(); 93 // Ensure the reduction size is a power of 2. 94 if (!isPowerOf2_32(ElemNum)) 95 return false; 96 97 unsigned ElemNumToReduce = ElemNum; 98 99 // Do DFS search on the def-use chain from the given instruction. We only 100 // allow four kinds of operations during the search until we reach the 101 // instruction that extracts the first element from the vector: 102 // 103 // 1. The reduction operation of the same opcode as the given instruction. 104 // 105 // 2. PHI node. 106 // 107 // 3. ShuffleVector instruction together with a reduction operation that 108 // does a partial reduction. 109 // 110 // 4. ExtractElement that extracts the first element from the vector, and we 111 // stop searching the def-use chain here. 112 // 113 // 3 & 4 above perform a reduction on all elements of the vector. We push defs 114 // from 1-3 to the stack to continue the DFS. The given instruction is not 115 // a reduction operation if we meet any other instructions other than those 116 // listed above. 117 118 SmallVector<const User *, 16> UsersToVisit{&BO}; 119 SmallPtrSet<const User *, 16> Visited; 120 bool ReduxExtracted = false; 121 122 while (!UsersToVisit.empty()) { 123 auto User = UsersToVisit.back(); 124 UsersToVisit.pop_back(); 125 if (!Visited.insert(User).second) 126 continue; 127 128 for (const auto *U : User->users()) { 129 auto *Inst = dyn_cast<Instruction>(U); 130 if (!Inst) 131 return false; 132 133 if (Inst->getOpcode() == Opcode || isa<PHINode>(U)) { 134 if (auto *FPOp = dyn_cast<FPMathOperator>(Inst)) 135 if (!isa<PHINode>(FPOp) && !FPOp->getFastMathFlags().isFast()) 136 return false; 137 UsersToVisit.push_back(U); 138 } else if (auto *ShufInst = dyn_cast<ShuffleVectorInst>(U)) { 139 // Detect the following pattern: A ShuffleVector instruction together 140 // with a reduction that do partial reduction on the first and second 141 // ElemNumToReduce / 2 elements, and store the result in 142 // ElemNumToReduce / 2 elements in another vector. 143 144 unsigned ResultElements = ShufInst->getType()->getNumElements(); 145 if (ResultElements < ElemNum) 146 return false; 147 148 if (ElemNumToReduce == 1) 149 return false; 150 if (!isa<UndefValue>(U->getOperand(1))) 151 return false; 152 for (unsigned i = 0; i < ElemNumToReduce / 2; ++i) 153 if (ShufInst->getMaskValue(i) != int(i + ElemNumToReduce / 2)) 154 return false; 155 for (unsigned i = ElemNumToReduce / 2; i < ElemNum; ++i) 156 if (ShufInst->getMaskValue(i) != -1) 157 return false; 158 159 // There is only one user of this ShuffleVector instruction, which 160 // must be a reduction operation. 161 if (!U->hasOneUse()) 162 return false; 163 164 auto *U2 = dyn_cast<BinaryOperator>(*U->user_begin()); 165 if (!U2 || U2->getOpcode() != Opcode) 166 return false; 167 168 // Check operands of the reduction operation. 169 if ((U2->getOperand(0) == U->getOperand(0) && U2->getOperand(1) == U) || 170 (U2->getOperand(1) == U->getOperand(0) && U2->getOperand(0) == U)) { 171 UsersToVisit.push_back(U2); 172 ElemNumToReduce /= 2; 173 } else 174 return false; 175 } else if (isa<ExtractElementInst>(U)) { 176 // At this moment we should have reduced all elements in the vector. 177 if (ElemNumToReduce != 1) 178 return false; 179 180 auto *Val = dyn_cast<ConstantInt>(U->getOperand(1)); 181 if (!Val || !Val->isZero()) 182 return false; 183 184 ReduxExtracted = true; 185 } else 186 return false; 187 } 188 } 189 return ReduxExtracted; 190 } 191 192 bool X86PartialReduction::tryMAddReplacement(Value *Op, BinaryOperator *Add) { 193 BasicBlock *BB = Add->getParent(); 194 195 auto *BO = dyn_cast<BinaryOperator>(Op); 196 if (!BO || BO->getOpcode() != Instruction::Mul || !BO->hasOneUse() || 197 BO->getParent() != BB) 198 return false; 199 200 Value *LHS = BO->getOperand(0); 201 Value *RHS = BO->getOperand(1); 202 203 // LHS and RHS should be only used once or if they are the same then only 204 // used twice. Only check this when SSE4.1 is enabled and we have zext/sext 205 // instructions, otherwise we use punpck to emulate zero extend in stages. The 206 // trunc/ we need to do likely won't introduce new instructions in that case. 207 if (ST->hasSSE41()) { 208 if (LHS == RHS) { 209 if (!isa<Constant>(LHS) && !LHS->hasNUses(2)) 210 return false; 211 } else { 212 if (!isa<Constant>(LHS) && !LHS->hasOneUse()) 213 return false; 214 if (!isa<Constant>(RHS) && !RHS->hasOneUse()) 215 return false; 216 } 217 } 218 219 auto CanShrinkOp = [&](Value *Op) { 220 auto IsFreeTruncation = [&](Value *Op) { 221 if (auto *Cast = dyn_cast<CastInst>(Op)) { 222 if (Cast->getParent() == BB && 223 (Cast->getOpcode() == Instruction::SExt || 224 Cast->getOpcode() == Instruction::ZExt) && 225 Cast->getOperand(0)->getType()->getScalarSizeInBits() <= 16) 226 return true; 227 } 228 229 return isa<Constant>(Op); 230 }; 231 232 // If the operation can be freely truncated and has enough sign bits we 233 // can shrink. 234 if (IsFreeTruncation(Op) && 235 ComputeNumSignBits(Op, *DL, 0, nullptr, BO) > 16) 236 return true; 237 238 // SelectionDAG has limited support for truncating through an add or sub if 239 // the inputs are freely truncatable. 240 if (auto *BO = dyn_cast<BinaryOperator>(Op)) { 241 if (BO->getParent() == BB && 242 IsFreeTruncation(BO->getOperand(0)) && 243 IsFreeTruncation(BO->getOperand(1)) && 244 ComputeNumSignBits(Op, *DL, 0, nullptr, BO) > 16) 245 return true; 246 } 247 248 return false; 249 }; 250 251 // Both Ops need to be shrinkable. 252 if (!CanShrinkOp(LHS) && !CanShrinkOp(RHS)) 253 return false; 254 255 IRBuilder<> Builder(Add); 256 257 auto *MulTy = cast<VectorType>(Op->getType()); 258 unsigned NumElts = MulTy->getNumElements(); 259 260 // Extract even elements and odd elements and add them together. This will 261 // be pattern matched by SelectionDAG to pmaddwd. This instruction will be 262 // half the original width. 263 SmallVector<int, 16> EvenMask(NumElts / 2); 264 SmallVector<int, 16> OddMask(NumElts / 2); 265 for (int i = 0, e = NumElts / 2; i != e; ++i) { 266 EvenMask[i] = i * 2; 267 OddMask[i] = i * 2 + 1; 268 } 269 Value *EvenElts = Builder.CreateShuffleVector(BO, BO, EvenMask); 270 Value *OddElts = Builder.CreateShuffleVector(BO, BO, OddMask); 271 Value *MAdd = Builder.CreateAdd(EvenElts, OddElts); 272 273 // Concatenate zeroes to extend back to the original type. 274 SmallVector<int, 32> ConcatMask(NumElts); 275 std::iota(ConcatMask.begin(), ConcatMask.end(), 0); 276 Value *Zero = Constant::getNullValue(MAdd->getType()); 277 Value *Concat = Builder.CreateShuffleVector(MAdd, Zero, ConcatMask); 278 279 // Replaces the use of mul in the original Add with the pmaddwd and zeroes. 280 Add->replaceUsesOfWith(BO, Concat); 281 Add->setHasNoSignedWrap(false); 282 Add->setHasNoUnsignedWrap(false); 283 284 return true; 285 } 286 287 // Try to replace operans of this add with pmaddwd patterns. 288 bool X86PartialReduction::tryMAddPattern(BinaryOperator *BO) { 289 if (!ST->hasSSE2()) 290 return false; 291 292 // Need at least 8 elements. 293 if (cast<VectorType>(BO->getType())->getNumElements() < 8) 294 return false; 295 296 // Element type should be i32. 297 if (!cast<VectorType>(BO->getType())->getElementType()->isIntegerTy(32)) 298 return false; 299 300 bool Changed = false; 301 Changed |= tryMAddReplacement(BO->getOperand(0), BO); 302 Changed |= tryMAddReplacement(BO->getOperand(1), BO); 303 return Changed; 304 } 305 306 bool X86PartialReduction::trySADReplacement(Value *Op, BinaryOperator *Add) { 307 // Operand should be a select. 308 auto *SI = dyn_cast<SelectInst>(Op); 309 if (!SI) 310 return false; 311 312 // Select needs to implement absolute value. 313 Value *LHS, *RHS; 314 auto SPR = matchSelectPattern(SI, LHS, RHS); 315 if (SPR.Flavor != SPF_ABS) 316 return false; 317 318 // Need a subtract of two values. 319 auto *Sub = dyn_cast<BinaryOperator>(LHS); 320 if (!Sub || Sub->getOpcode() != Instruction::Sub) 321 return false; 322 323 // Look for zero extend from i8. 324 auto getZeroExtendedVal = [](Value *Op) -> Value * { 325 if (auto *ZExt = dyn_cast<ZExtInst>(Op)) 326 if (cast<VectorType>(ZExt->getOperand(0)->getType()) 327 ->getElementType() 328 ->isIntegerTy(8)) 329 return ZExt->getOperand(0); 330 331 return nullptr; 332 }; 333 334 // Both operands of the subtract should be extends from vXi8. 335 Value *Op0 = getZeroExtendedVal(Sub->getOperand(0)); 336 Value *Op1 = getZeroExtendedVal(Sub->getOperand(1)); 337 if (!Op0 || !Op1) 338 return false; 339 340 IRBuilder<> Builder(Add); 341 342 auto *OpTy = cast<VectorType>(Op->getType()); 343 unsigned NumElts = OpTy->getNumElements(); 344 345 unsigned IntrinsicNumElts; 346 Intrinsic::ID IID; 347 if (ST->hasBWI() && NumElts >= 64) { 348 IID = Intrinsic::x86_avx512_psad_bw_512; 349 IntrinsicNumElts = 64; 350 } else if (ST->hasAVX2() && NumElts >= 32) { 351 IID = Intrinsic::x86_avx2_psad_bw; 352 IntrinsicNumElts = 32; 353 } else { 354 IID = Intrinsic::x86_sse2_psad_bw; 355 IntrinsicNumElts = 16; 356 } 357 358 Function *PSADBWFn = Intrinsic::getDeclaration(Add->getModule(), IID); 359 360 if (NumElts < 16) { 361 // Pad input with zeroes. 362 SmallVector<int, 32> ConcatMask(16); 363 for (unsigned i = 0; i != NumElts; ++i) 364 ConcatMask[i] = i; 365 for (unsigned i = NumElts; i != 16; ++i) 366 ConcatMask[i] = (i % NumElts) + NumElts; 367 368 Value *Zero = Constant::getNullValue(Op0->getType()); 369 Op0 = Builder.CreateShuffleVector(Op0, Zero, ConcatMask); 370 Op1 = Builder.CreateShuffleVector(Op1, Zero, ConcatMask); 371 NumElts = 16; 372 } 373 374 // Intrinsics produce vXi64 and need to be casted to vXi32. 375 Type *I32Ty = VectorType::get(Builder.getInt32Ty(), IntrinsicNumElts / 4); 376 377 assert(NumElts % IntrinsicNumElts == 0 && "Unexpected number of elements!"); 378 unsigned NumSplits = NumElts / IntrinsicNumElts; 379 380 // First collect the pieces we need. 381 SmallVector<Value *, 4> Ops(NumSplits); 382 for (unsigned i = 0; i != NumSplits; ++i) { 383 SmallVector<int, 64> ExtractMask(IntrinsicNumElts); 384 std::iota(ExtractMask.begin(), ExtractMask.end(), i * IntrinsicNumElts); 385 Value *ExtractOp0 = Builder.CreateShuffleVector(Op0, Op0, ExtractMask); 386 Value *ExtractOp1 = Builder.CreateShuffleVector(Op1, Op0, ExtractMask); 387 Ops[i] = Builder.CreateCall(PSADBWFn, {ExtractOp0, ExtractOp1}); 388 Ops[i] = Builder.CreateBitCast(Ops[i], I32Ty); 389 } 390 391 assert(isPowerOf2_32(NumSplits) && "Expected power of 2 splits"); 392 unsigned Stages = Log2_32(NumSplits); 393 for (unsigned s = Stages; s > 0; --s) { 394 unsigned NumConcatElts = 395 cast<VectorType>(Ops[0]->getType())->getNumElements() * 2; 396 for (unsigned i = 0; i != 1U << (s - 1); ++i) { 397 SmallVector<int, 64> ConcatMask(NumConcatElts); 398 std::iota(ConcatMask.begin(), ConcatMask.end(), 0); 399 Ops[i] = Builder.CreateShuffleVector(Ops[i*2], Ops[i*2+1], ConcatMask); 400 } 401 } 402 403 // At this point the final value should be in Ops[0]. Now we need to adjust 404 // it to the final original type. 405 NumElts = cast<VectorType>(OpTy)->getNumElements(); 406 if (NumElts == 2) { 407 // Extract down to 2 elements. 408 Ops[0] = Builder.CreateShuffleVector(Ops[0], Ops[0], ArrayRef<int>{0, 1}); 409 } else if (NumElts >= 8) { 410 SmallVector<int, 32> ConcatMask(NumElts); 411 unsigned SubElts = cast<VectorType>(Ops[0]->getType())->getNumElements(); 412 for (unsigned i = 0; i != SubElts; ++i) 413 ConcatMask[i] = i; 414 for (unsigned i = SubElts; i != NumElts; ++i) 415 ConcatMask[i] = (i % SubElts) + SubElts; 416 417 Value *Zero = Constant::getNullValue(Ops[0]->getType()); 418 Ops[0] = Builder.CreateShuffleVector(Ops[0], Zero, ConcatMask); 419 } 420 421 // Replaces the uses of Op in Add with the new sequence. 422 Add->replaceUsesOfWith(Op, Ops[0]); 423 Add->setHasNoSignedWrap(false); 424 Add->setHasNoUnsignedWrap(false); 425 426 return true; 427 } 428 429 bool X86PartialReduction::trySADPattern(BinaryOperator *BO) { 430 if (!ST->hasSSE2()) 431 return false; 432 433 // TODO: There's nothing special about i32, any integer type above i16 should 434 // work just as well. 435 if (!cast<VectorType>(BO->getType())->getElementType()->isIntegerTy(32)) 436 return false; 437 438 bool Changed = false; 439 Changed |= trySADReplacement(BO->getOperand(0), BO); 440 Changed |= trySADReplacement(BO->getOperand(1), BO); 441 return Changed; 442 } 443 444 bool X86PartialReduction::runOnFunction(Function &F) { 445 if (skipFunction(F)) 446 return false; 447 448 auto *TPC = getAnalysisIfAvailable<TargetPassConfig>(); 449 if (!TPC) 450 return false; 451 452 auto &TM = TPC->getTM<X86TargetMachine>(); 453 ST = TM.getSubtargetImpl(F); 454 455 DL = &F.getParent()->getDataLayout(); 456 457 bool MadeChange = false; 458 for (auto &BB : F) { 459 for (auto &I : BB) { 460 auto *BO = dyn_cast<BinaryOperator>(&I); 461 if (!BO) 462 continue; 463 464 if (!isVectorReductionOp(*BO)) 465 continue; 466 467 if (BO->getOpcode() == Instruction::Add) { 468 if (tryMAddPattern(BO)) { 469 MadeChange = true; 470 continue; 471 } 472 if (trySADPattern(BO)) { 473 MadeChange = true; 474 continue; 475 } 476 } 477 } 478 } 479 480 return MadeChange; 481 } 482