1 //===-- X86PartialReduction.cpp -------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass looks for add instructions used by a horizontal reduction to see
10 // if we might be able to use pmaddwd or psadbw. Some cases of this require
11 // cross basic block knowledge and can't be done in SelectionDAG.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "X86.h"
16 #include "X86TargetMachine.h"
17 #include "llvm/Analysis/ValueTracking.h"
18 #include "llvm/CodeGen/TargetPassConfig.h"
19 #include "llvm/IR/Constants.h"
20 #include "llvm/IR/IRBuilder.h"
21 #include "llvm/IR/Instructions.h"
22 #include "llvm/IR/IntrinsicsX86.h"
23 #include "llvm/IR/Operator.h"
24 #include "llvm/Pass.h"
25 #include "llvm/Support/KnownBits.h"
26 
27 using namespace llvm;
28 
29 #define DEBUG_TYPE "x86-partial-reduction"
30 
31 namespace {
32 
33 class X86PartialReduction : public FunctionPass {
34   const DataLayout *DL;
35   const X86Subtarget *ST;
36 
37 public:
38   static char ID; // Pass identification, replacement for typeid.
39 
40   X86PartialReduction() : FunctionPass(ID) { }
41 
42   bool runOnFunction(Function &Fn) override;
43 
44   void getAnalysisUsage(AnalysisUsage &AU) const override {
45     AU.setPreservesCFG();
46   }
47 
48   StringRef getPassName() const override {
49     return "X86 Partial Reduction";
50   }
51 
52 private:
53   bool tryMAddReplacement(Instruction *Op, bool ReduceInOneBB);
54   bool trySADReplacement(Instruction *Op);
55 };
56 }
57 
58 FunctionPass *llvm::createX86PartialReductionPass() {
59   return new X86PartialReduction();
60 }
61 
62 char X86PartialReduction::ID = 0;
63 
64 INITIALIZE_PASS(X86PartialReduction, DEBUG_TYPE,
65                 "X86 Partial Reduction", false, false)
66 
67 // This function should be aligned with detectExtMul() in X86ISelLowering.cpp.
68 static bool matchVPDPBUSDPattern(const X86Subtarget *ST, BinaryOperator *Mul,
69                                  const DataLayout *DL) {
70   if (!ST->hasVNNI() && !ST->hasAVXVNNI())
71     return false;
72 
73   Value *LHS = Mul->getOperand(0);
74   Value *RHS = Mul->getOperand(1);
75 
76   if (isa<SExtInst>(LHS))
77     std::swap(LHS, RHS);
78 
79   if (!isa<ZExtInst>(LHS))
80     return false;
81 
82   auto IsFreeTruncation = [&](Value *Op) {
83     if (auto *Cast = dyn_cast<CastInst>(Op)) {
84       if (Cast->getParent() == Mul->getParent() &&
85           (Cast->getOpcode() == Instruction::SExt ||
86            Cast->getOpcode() == Instruction::ZExt) &&
87           Cast->getOperand(0)->getType()->getScalarSizeInBits() <= 8)
88         return true;
89     }
90     // TODO: Support constant in ISel.
91     return false;
92   };
93 
94   // (dpbusd (zext a), (sext, b)). Since the first operand should be unsigned
95   // value, we need to check LHS is zero extended value. RHS should be signed
96   // value, so we just check the signed bits.
97   if ((IsFreeTruncation(LHS) &&
98        computeKnownBits(LHS, *DL).countMaxActiveBits() <= 8) &&
99       (IsFreeTruncation(RHS) && ComputeMaxSignificantBits(RHS, *DL) <= 8))
100     return true;
101 
102   return false;
103 }
104 
105 bool X86PartialReduction::tryMAddReplacement(Instruction *Op,
106                                              bool ReduceInOneBB) {
107   if (!ST->hasSSE2())
108     return false;
109 
110   // Need at least 8 elements.
111   if (cast<FixedVectorType>(Op->getType())->getNumElements() < 8)
112     return false;
113 
114   // Element type should be i32.
115   if (!cast<VectorType>(Op->getType())->getElementType()->isIntegerTy(32))
116     return false;
117 
118   auto *Mul = dyn_cast<BinaryOperator>(Op);
119   if (!Mul || Mul->getOpcode() != Instruction::Mul)
120     return false;
121 
122   Value *LHS = Mul->getOperand(0);
123   Value *RHS = Mul->getOperand(1);
124 
125   // If the target support VNNI, leave it to ISel to combine reduce operation
126   // to VNNI instruction.
127   // TODO: we can support transforming reduce to VNNI intrinsic for across block
128   // in this pass.
129   if (ReduceInOneBB && matchVPDPBUSDPattern(ST, Mul, DL))
130     return false;
131 
132   // LHS and RHS should be only used once or if they are the same then only
133   // used twice. Only check this when SSE4.1 is enabled and we have zext/sext
134   // instructions, otherwise we use punpck to emulate zero extend in stages. The
135   // trunc/ we need to do likely won't introduce new instructions in that case.
136   if (ST->hasSSE41()) {
137     if (LHS == RHS) {
138       if (!isa<Constant>(LHS) && !LHS->hasNUses(2))
139         return false;
140     } else {
141       if (!isa<Constant>(LHS) && !LHS->hasOneUse())
142         return false;
143       if (!isa<Constant>(RHS) && !RHS->hasOneUse())
144         return false;
145     }
146   }
147 
148   auto CanShrinkOp = [&](Value *Op) {
149     auto IsFreeTruncation = [&](Value *Op) {
150       if (auto *Cast = dyn_cast<CastInst>(Op)) {
151         if (Cast->getParent() == Mul->getParent() &&
152             (Cast->getOpcode() == Instruction::SExt ||
153              Cast->getOpcode() == Instruction::ZExt) &&
154             Cast->getOperand(0)->getType()->getScalarSizeInBits() <= 16)
155           return true;
156       }
157 
158       return isa<Constant>(Op);
159     };
160 
161     // If the operation can be freely truncated and has enough sign bits we
162     // can shrink.
163     if (IsFreeTruncation(Op) &&
164         ComputeNumSignBits(Op, *DL, 0, nullptr, Mul) > 16)
165       return true;
166 
167     // SelectionDAG has limited support for truncating through an add or sub if
168     // the inputs are freely truncatable.
169     if (auto *BO = dyn_cast<BinaryOperator>(Op)) {
170       if (BO->getParent() == Mul->getParent() &&
171           IsFreeTruncation(BO->getOperand(0)) &&
172           IsFreeTruncation(BO->getOperand(1)) &&
173           ComputeNumSignBits(Op, *DL, 0, nullptr, Mul) > 16)
174         return true;
175     }
176 
177     return false;
178   };
179 
180   // Both Ops need to be shrinkable.
181   if (!CanShrinkOp(LHS) && !CanShrinkOp(RHS))
182     return false;
183 
184   IRBuilder<> Builder(Mul);
185 
186   auto *MulTy = cast<FixedVectorType>(Op->getType());
187   unsigned NumElts = MulTy->getNumElements();
188 
189   // Extract even elements and odd elements and add them together. This will
190   // be pattern matched by SelectionDAG to pmaddwd. This instruction will be
191   // half the original width.
192   SmallVector<int, 16> EvenMask(NumElts / 2);
193   SmallVector<int, 16> OddMask(NumElts / 2);
194   for (int i = 0, e = NumElts / 2; i != e; ++i) {
195     EvenMask[i] = i * 2;
196     OddMask[i] = i * 2 + 1;
197   }
198   // Creating a new mul so the replaceAllUsesWith below doesn't replace the
199   // uses in the shuffles we're creating.
200   Value *NewMul = Builder.CreateMul(Mul->getOperand(0), Mul->getOperand(1));
201   Value *EvenElts = Builder.CreateShuffleVector(NewMul, NewMul, EvenMask);
202   Value *OddElts = Builder.CreateShuffleVector(NewMul, NewMul, OddMask);
203   Value *MAdd = Builder.CreateAdd(EvenElts, OddElts);
204 
205   // Concatenate zeroes to extend back to the original type.
206   SmallVector<int, 32> ConcatMask(NumElts);
207   std::iota(ConcatMask.begin(), ConcatMask.end(), 0);
208   Value *Zero = Constant::getNullValue(MAdd->getType());
209   Value *Concat = Builder.CreateShuffleVector(MAdd, Zero, ConcatMask);
210 
211   Mul->replaceAllUsesWith(Concat);
212   Mul->eraseFromParent();
213 
214   return true;
215 }
216 
217 bool X86PartialReduction::trySADReplacement(Instruction *Op) {
218   if (!ST->hasSSE2())
219     return false;
220 
221   // TODO: There's nothing special about i32, any integer type above i16 should
222   // work just as well.
223   if (!cast<VectorType>(Op->getType())->getElementType()->isIntegerTy(32))
224     return false;
225 
226   // Operand should be a select.
227   auto *SI = dyn_cast<SelectInst>(Op);
228   if (!SI)
229     return false;
230 
231   // Select needs to implement absolute value.
232   Value *LHS, *RHS;
233   auto SPR = matchSelectPattern(SI, LHS, RHS);
234   if (SPR.Flavor != SPF_ABS)
235     return false;
236 
237   // Need a subtract of two values.
238   auto *Sub = dyn_cast<BinaryOperator>(LHS);
239   if (!Sub || Sub->getOpcode() != Instruction::Sub)
240     return false;
241 
242   // Look for zero extend from i8.
243   auto getZeroExtendedVal = [](Value *Op) -> Value * {
244     if (auto *ZExt = dyn_cast<ZExtInst>(Op))
245       if (cast<VectorType>(ZExt->getOperand(0)->getType())
246               ->getElementType()
247               ->isIntegerTy(8))
248         return ZExt->getOperand(0);
249 
250     return nullptr;
251   };
252 
253   // Both operands of the subtract should be extends from vXi8.
254   Value *Op0 = getZeroExtendedVal(Sub->getOperand(0));
255   Value *Op1 = getZeroExtendedVal(Sub->getOperand(1));
256   if (!Op0 || !Op1)
257     return false;
258 
259   IRBuilder<> Builder(SI);
260 
261   auto *OpTy = cast<FixedVectorType>(Op->getType());
262   unsigned NumElts = OpTy->getNumElements();
263 
264   unsigned IntrinsicNumElts;
265   Intrinsic::ID IID;
266   if (ST->hasBWI() && NumElts >= 64) {
267     IID = Intrinsic::x86_avx512_psad_bw_512;
268     IntrinsicNumElts = 64;
269   } else if (ST->hasAVX2() && NumElts >= 32) {
270     IID = Intrinsic::x86_avx2_psad_bw;
271     IntrinsicNumElts = 32;
272   } else {
273     IID = Intrinsic::x86_sse2_psad_bw;
274     IntrinsicNumElts = 16;
275   }
276 
277   Function *PSADBWFn = Intrinsic::getDeclaration(SI->getModule(), IID);
278 
279   if (NumElts < 16) {
280     // Pad input with zeroes.
281     SmallVector<int, 32> ConcatMask(16);
282     for (unsigned i = 0; i != NumElts; ++i)
283       ConcatMask[i] = i;
284     for (unsigned i = NumElts; i != 16; ++i)
285       ConcatMask[i] = (i % NumElts) + NumElts;
286 
287     Value *Zero = Constant::getNullValue(Op0->getType());
288     Op0 = Builder.CreateShuffleVector(Op0, Zero, ConcatMask);
289     Op1 = Builder.CreateShuffleVector(Op1, Zero, ConcatMask);
290     NumElts = 16;
291   }
292 
293   // Intrinsics produce vXi64 and need to be casted to vXi32.
294   auto *I32Ty =
295       FixedVectorType::get(Builder.getInt32Ty(), IntrinsicNumElts / 4);
296 
297   assert(NumElts % IntrinsicNumElts == 0 && "Unexpected number of elements!");
298   unsigned NumSplits = NumElts / IntrinsicNumElts;
299 
300   // First collect the pieces we need.
301   SmallVector<Value *, 4> Ops(NumSplits);
302   for (unsigned i = 0; i != NumSplits; ++i) {
303     SmallVector<int, 64> ExtractMask(IntrinsicNumElts);
304     std::iota(ExtractMask.begin(), ExtractMask.end(), i * IntrinsicNumElts);
305     Value *ExtractOp0 = Builder.CreateShuffleVector(Op0, Op0, ExtractMask);
306     Value *ExtractOp1 = Builder.CreateShuffleVector(Op1, Op0, ExtractMask);
307     Ops[i] = Builder.CreateCall(PSADBWFn, {ExtractOp0, ExtractOp1});
308     Ops[i] = Builder.CreateBitCast(Ops[i], I32Ty);
309   }
310 
311   assert(isPowerOf2_32(NumSplits) && "Expected power of 2 splits");
312   unsigned Stages = Log2_32(NumSplits);
313   for (unsigned s = Stages; s > 0; --s) {
314     unsigned NumConcatElts =
315         cast<FixedVectorType>(Ops[0]->getType())->getNumElements() * 2;
316     for (unsigned i = 0; i != 1U << (s - 1); ++i) {
317       SmallVector<int, 64> ConcatMask(NumConcatElts);
318       std::iota(ConcatMask.begin(), ConcatMask.end(), 0);
319       Ops[i] = Builder.CreateShuffleVector(Ops[i*2], Ops[i*2+1], ConcatMask);
320     }
321   }
322 
323   // At this point the final value should be in Ops[0]. Now we need to adjust
324   // it to the final original type.
325   NumElts = cast<FixedVectorType>(OpTy)->getNumElements();
326   if (NumElts == 2) {
327     // Extract down to 2 elements.
328     Ops[0] = Builder.CreateShuffleVector(Ops[0], Ops[0], ArrayRef<int>{0, 1});
329   } else if (NumElts >= 8) {
330     SmallVector<int, 32> ConcatMask(NumElts);
331     unsigned SubElts =
332         cast<FixedVectorType>(Ops[0]->getType())->getNumElements();
333     for (unsigned i = 0; i != SubElts; ++i)
334       ConcatMask[i] = i;
335     for (unsigned i = SubElts; i != NumElts; ++i)
336       ConcatMask[i] = (i % SubElts) + SubElts;
337 
338     Value *Zero = Constant::getNullValue(Ops[0]->getType());
339     Ops[0] = Builder.CreateShuffleVector(Ops[0], Zero, ConcatMask);
340   }
341 
342   SI->replaceAllUsesWith(Ops[0]);
343   SI->eraseFromParent();
344 
345   return true;
346 }
347 
348 // Walk backwards from the ExtractElementInst and determine if it is the end of
349 // a horizontal reduction. Return the input to the reduction if we find one.
350 static Value *matchAddReduction(const ExtractElementInst &EE,
351                                 bool &ReduceInOneBB) {
352   ReduceInOneBB = true;
353   // Make sure we're extracting index 0.
354   auto *Index = dyn_cast<ConstantInt>(EE.getIndexOperand());
355   if (!Index || !Index->isNullValue())
356     return nullptr;
357 
358   const auto *BO = dyn_cast<BinaryOperator>(EE.getVectorOperand());
359   if (!BO || BO->getOpcode() != Instruction::Add || !BO->hasOneUse())
360     return nullptr;
361   if (EE.getParent() != BO->getParent())
362     ReduceInOneBB = false;
363 
364   unsigned NumElems = cast<FixedVectorType>(BO->getType())->getNumElements();
365   // Ensure the reduction size is a power of 2.
366   if (!isPowerOf2_32(NumElems))
367     return nullptr;
368 
369   const Value *Op = BO;
370   unsigned Stages = Log2_32(NumElems);
371   for (unsigned i = 0; i != Stages; ++i) {
372     const auto *BO = dyn_cast<BinaryOperator>(Op);
373     if (!BO || BO->getOpcode() != Instruction::Add)
374       return nullptr;
375     if (EE.getParent() != BO->getParent())
376       ReduceInOneBB = false;
377 
378     // If this isn't the first add, then it should only have 2 users, the
379     // shuffle and another add which we checked in the previous iteration.
380     if (i != 0 && !BO->hasNUses(2))
381       return nullptr;
382 
383     Value *LHS = BO->getOperand(0);
384     Value *RHS = BO->getOperand(1);
385 
386     auto *Shuffle = dyn_cast<ShuffleVectorInst>(LHS);
387     if (Shuffle) {
388       Op = RHS;
389     } else {
390       Shuffle = dyn_cast<ShuffleVectorInst>(RHS);
391       Op = LHS;
392     }
393 
394     // The first operand of the shuffle should be the same as the other operand
395     // of the bin op.
396     if (!Shuffle || Shuffle->getOperand(0) != Op)
397       return nullptr;
398 
399     // Verify the shuffle has the expected (at this stage of the pyramid) mask.
400     unsigned MaskEnd = 1 << i;
401     for (unsigned Index = 0; Index < MaskEnd; ++Index)
402       if (Shuffle->getMaskValue(Index) != (int)(MaskEnd + Index))
403         return nullptr;
404   }
405 
406   return const_cast<Value *>(Op);
407 }
408 
409 // See if this BO is reachable from this Phi by walking forward through single
410 // use BinaryOperators with the same opcode. If we get back then we know we've
411 // found a loop and it is safe to step through this Add to find more leaves.
412 static bool isReachableFromPHI(PHINode *Phi, BinaryOperator *BO) {
413   // The PHI itself should only have one use.
414   if (!Phi->hasOneUse())
415     return false;
416 
417   Instruction *U = cast<Instruction>(*Phi->user_begin());
418   if (U == BO)
419     return true;
420 
421   while (U->hasOneUse() && U->getOpcode() == BO->getOpcode())
422     U = cast<Instruction>(*U->user_begin());
423 
424   return U == BO;
425 }
426 
427 // Collect all the leaves of the tree of adds that feeds into the horizontal
428 // reduction. Root is the Value that is used by the horizontal reduction.
429 // We look through single use phis, single use adds, or adds that are used by
430 // a phi that forms a loop with the add.
431 static void collectLeaves(Value *Root, SmallVectorImpl<Instruction *> &Leaves) {
432   SmallPtrSet<Value *, 8> Visited;
433   SmallVector<Value *, 8> Worklist;
434   Worklist.push_back(Root);
435 
436   while (!Worklist.empty()) {
437     Value *V = Worklist.pop_back_val();
438      if (!Visited.insert(V).second)
439        continue;
440 
441     if (auto *PN = dyn_cast<PHINode>(V)) {
442       // PHI node should have single use unless it is the root node, then it
443       // has 2 uses.
444       if (!PN->hasNUses(PN == Root ? 2 : 1))
445         break;
446 
447       // Push incoming values to the worklist.
448       append_range(Worklist, PN->incoming_values());
449 
450       continue;
451     }
452 
453     if (auto *BO = dyn_cast<BinaryOperator>(V)) {
454       if (BO->getOpcode() == Instruction::Add) {
455         // Simple case. Single use, just push its operands to the worklist.
456         if (BO->hasNUses(BO == Root ? 2 : 1)) {
457           append_range(Worklist, BO->operands());
458           continue;
459         }
460 
461         // If there is additional use, make sure it is an unvisited phi that
462         // gets us back to this node.
463         if (BO->hasNUses(BO == Root ? 3 : 2)) {
464           PHINode *PN = nullptr;
465           for (auto *U : Root->users())
466             if (auto *P = dyn_cast<PHINode>(U))
467               if (!Visited.count(P))
468                 PN = P;
469 
470           // If we didn't find a 2-input PHI then this isn't a case we can
471           // handle.
472           if (!PN || PN->getNumIncomingValues() != 2)
473             continue;
474 
475           // Walk forward from this phi to see if it reaches back to this add.
476           if (!isReachableFromPHI(PN, BO))
477             continue;
478 
479           // The phi forms a loop with this Add, push its operands.
480           append_range(Worklist, BO->operands());
481         }
482       }
483     }
484 
485     // Not an add or phi, make it a leaf.
486     if (auto *I = dyn_cast<Instruction>(V)) {
487       if (!V->hasNUses(I == Root ? 2 : 1))
488         continue;
489 
490       // Add this as a leaf.
491       Leaves.push_back(I);
492     }
493   }
494 }
495 
496 bool X86PartialReduction::runOnFunction(Function &F) {
497   if (skipFunction(F))
498     return false;
499 
500   auto *TPC = getAnalysisIfAvailable<TargetPassConfig>();
501   if (!TPC)
502     return false;
503 
504   auto &TM = TPC->getTM<X86TargetMachine>();
505   ST = TM.getSubtargetImpl(F);
506 
507   DL = &F.getParent()->getDataLayout();
508 
509   bool MadeChange = false;
510   for (auto &BB : F) {
511     for (auto &I : BB) {
512       auto *EE = dyn_cast<ExtractElementInst>(&I);
513       if (!EE)
514         continue;
515 
516       bool ReduceInOneBB;
517       // First find a reduction tree.
518       // FIXME: Do we need to handle other opcodes than Add?
519       Value *Root = matchAddReduction(*EE, ReduceInOneBB);
520       if (!Root)
521         continue;
522 
523       SmallVector<Instruction *, 8> Leaves;
524       collectLeaves(Root, Leaves);
525 
526       for (Instruction *I : Leaves) {
527         if (tryMAddReplacement(I, ReduceInOneBB)) {
528           MadeChange = true;
529           continue;
530         }
531 
532         // Don't do SAD matching on the root node. SelectionDAG already
533         // has support for that and currently generates better code.
534         if (I != Root && trySADReplacement(I))
535           MadeChange = true;
536       }
537     }
538   }
539 
540   return MadeChange;
541 }
542