1 //===- ParallelDSP.cpp - Parallel DSP Pass --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file
10 /// Armv6 introduced instructions to perform 32-bit SIMD operations. The
11 /// purpose of this pass is do some IR pattern matching to create ACLE
12 /// DSP intrinsics, which map on these 32-bit SIMD operations.
13 /// This pass runs only when unaligned accesses is supported/enabled.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "llvm/ADT/Statistic.h"
18 #include "llvm/ADT/SmallPtrSet.h"
19 #include "llvm/Analysis/AliasAnalysis.h"
20 #include "llvm/Analysis/LoopAccessAnalysis.h"
21 #include "llvm/Analysis/LoopPass.h"
22 #include "llvm/Analysis/LoopInfo.h"
23 #include "llvm/IR/Instructions.h"
24 #include "llvm/IR/NoFolder.h"
25 #include "llvm/Transforms/Scalar.h"
26 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
27 #include "llvm/Transforms/Utils/LoopUtils.h"
28 #include "llvm/Pass.h"
29 #include "llvm/PassRegistry.h"
30 #include "llvm/PassSupport.h"
31 #include "llvm/Support/Debug.h"
32 #include "llvm/IR/PatternMatch.h"
33 #include "llvm/CodeGen/TargetPassConfig.h"
34 #include "ARM.h"
35 #include "ARMSubtarget.h"
36 
37 using namespace llvm;
38 using namespace PatternMatch;
39 
40 #define DEBUG_TYPE "arm-parallel-dsp"
41 
42 STATISTIC(NumSMLAD , "Number of smlad instructions generated");
43 
44 static cl::opt<bool>
45 DisableParallelDSP("disable-arm-parallel-dsp", cl::Hidden, cl::init(false),
46                    cl::desc("Disable the ARM Parallel DSP pass"));
47 
48 namespace {
49   struct OpChain;
50   struct BinOpChain;
51   struct Reduction;
52 
53   using OpChainList     = SmallVector<std::unique_ptr<OpChain>, 8>;
54   using ReductionList   = SmallVector<Reduction, 8>;
55   using ValueList       = SmallVector<Value*, 8>;
56   using MemInstList     = SmallVector<LoadInst*, 8>;
57   using PMACPair        = std::pair<BinOpChain*,BinOpChain*>;
58   using PMACPairList    = SmallVector<PMACPair, 8>;
59   using Instructions    = SmallVector<Instruction*,16>;
60   using MemLocList      = SmallVector<MemoryLocation, 4>;
61 
62   struct OpChain {
63     Instruction   *Root;
64     ValueList     AllValues;
65     MemInstList   VecLd;    // List of all load instructions.
66     MemLocList    MemLocs;  // All memory locations read by this tree.
67     bool          ReadOnly = true;
68 
69     OpChain(Instruction *I, ValueList &vl) : Root(I), AllValues(vl) { }
70     virtual ~OpChain() = default;
71 
72     void SetMemoryLocations() {
73       const auto Size = LocationSize::unknown();
74       for (auto *V : AllValues) {
75         if (auto *I = dyn_cast<Instruction>(V)) {
76           if (I->mayWriteToMemory())
77             ReadOnly = false;
78           if (auto *Ld = dyn_cast<LoadInst>(V))
79             MemLocs.push_back(MemoryLocation(Ld->getPointerOperand(), Size));
80         }
81       }
82     }
83 
84     unsigned size() const { return AllValues.size(); }
85   };
86 
87   // 'BinOpChain' and 'Reduction' are just some bookkeeping data structures.
88   // 'Reduction' contains the phi-node and accumulator statement from where we
89   // start pattern matching, and 'BinOpChain' the multiplication
90   // instructions that are candidates for parallel execution.
91   struct BinOpChain : public OpChain {
92     ValueList     LHS;      // List of all (narrow) left hand operands.
93     ValueList     RHS;      // List of all (narrow) right hand operands.
94     bool Exchange = false;
95 
96     BinOpChain(Instruction *I, ValueList &lhs, ValueList &rhs) :
97       OpChain(I, lhs), LHS(lhs), RHS(rhs) {
98         for (auto *V : RHS)
99           AllValues.push_back(V);
100       }
101 
102     bool AreSymmetrical(BinOpChain *Other);
103   };
104 
105   struct Reduction {
106     PHINode         *Phi;             // The Phi-node from where we start
107                                       // pattern matching.
108     Instruction     *AccIntAdd;       // The accumulating integer add statement,
109                                       // i.e, the reduction statement.
110     OpChainList     MACCandidates;    // The MAC candidates associated with
111                                       // this reduction statement.
112     PMACPairList    PMACPairs;
113     Reduction (PHINode *P, Instruction *Acc) : Phi(P), AccIntAdd(Acc) { };
114   };
115 
116   class WidenedLoad {
117     LoadInst *NewLd = nullptr;
118     SmallVector<LoadInst*, 4> Loads;
119 
120   public:
121     WidenedLoad(SmallVectorImpl<LoadInst*> &Lds, LoadInst *Wide)
122       : NewLd(Wide) {
123       for (auto *I : Lds)
124         Loads.push_back(I);
125     }
126     LoadInst *getLoad() {
127       return NewLd;
128     }
129   };
130 
131   class ARMParallelDSP : public LoopPass {
132     ScalarEvolution   *SE;
133     AliasAnalysis     *AA;
134     TargetLibraryInfo *TLI;
135     DominatorTree     *DT;
136     LoopInfo          *LI;
137     Loop              *L;
138     const DataLayout  *DL;
139     Module            *M;
140     std::map<LoadInst*, LoadInst*> LoadPairs;
141     std::map<LoadInst*, std::unique_ptr<WidenedLoad>> WideLoads;
142 
143     bool RecordSequentialLoads(BasicBlock *BB);
144     bool InsertParallelMACs(Reduction &Reduction);
145     bool AreSequentialLoads(LoadInst *Ld0, LoadInst *Ld1, MemInstList &VecMem);
146     LoadInst* CreateLoadIns(IRBuilder<NoFolder> &IRB,
147                             SmallVectorImpl<LoadInst*> &Loads,
148                             IntegerType *LoadTy);
149     void CreateParallelMACPairs(Reduction &R);
150     Instruction *CreateSMLADCall(SmallVectorImpl<LoadInst*> &VecLd0,
151                                  SmallVectorImpl<LoadInst*> &VecLd1,
152                                  Instruction *Acc, bool Exchange,
153                                  Instruction *InsertAfter);
154 
155     /// Try to match and generate: SMLAD, SMLADX - Signed Multiply Accumulate
156     /// Dual performs two signed 16x16-bit multiplications. It adds the
157     /// products to a 32-bit accumulate operand. Optionally, the instruction can
158     /// exchange the halfwords of the second operand before performing the
159     /// arithmetic.
160     bool MatchSMLAD(Function &F);
161 
162   public:
163     static char ID;
164 
165     ARMParallelDSP() : LoopPass(ID) { }
166 
167     void getAnalysisUsage(AnalysisUsage &AU) const override {
168       LoopPass::getAnalysisUsage(AU);
169       AU.addRequired<AssumptionCacheTracker>();
170       AU.addRequired<ScalarEvolutionWrapperPass>();
171       AU.addRequired<AAResultsWrapperPass>();
172       AU.addRequired<TargetLibraryInfoWrapperPass>();
173       AU.addRequired<LoopInfoWrapperPass>();
174       AU.addRequired<DominatorTreeWrapperPass>();
175       AU.addRequired<TargetPassConfig>();
176       AU.addPreserved<LoopInfoWrapperPass>();
177       AU.setPreservesCFG();
178     }
179 
180     bool runOnLoop(Loop *TheLoop, LPPassManager &) override {
181       if (DisableParallelDSP)
182         return false;
183       L = TheLoop;
184       SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
185       AA = &getAnalysis<AAResultsWrapperPass>().getAAResults();
186       TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
187       DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
188       LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
189       auto &TPC = getAnalysis<TargetPassConfig>();
190 
191       BasicBlock *Header = TheLoop->getHeader();
192       if (!Header)
193         return false;
194 
195       // TODO: We assume the loop header and latch to be the same block.
196       // This is not a fundamental restriction, but lifting this would just
197       // require more work to do the transformation and then patch up the CFG.
198       if (Header != TheLoop->getLoopLatch()) {
199         LLVM_DEBUG(dbgs() << "The loop header is not the loop latch: not "
200                              "running pass ARMParallelDSP\n");
201         return false;
202       }
203 
204       // We need a preheader as getIncomingValueForBlock assumes there is one.
205       if (!TheLoop->getLoopPreheader()) {
206         LLVM_DEBUG(dbgs() << "No preheader found, bailing out\n");
207         return false;
208       }
209 
210       Function &F = *Header->getParent();
211       M = F.getParent();
212       DL = &M->getDataLayout();
213 
214       auto &TM = TPC.getTM<TargetMachine>();
215       auto *ST = &TM.getSubtarget<ARMSubtarget>(F);
216 
217       if (!ST->allowsUnalignedMem()) {
218         LLVM_DEBUG(dbgs() << "Unaligned memory access not supported: not "
219                              "running pass ARMParallelDSP\n");
220         return false;
221       }
222 
223       if (!ST->hasDSP()) {
224         LLVM_DEBUG(dbgs() << "DSP extension not enabled: not running pass "
225                              "ARMParallelDSP\n");
226         return false;
227       }
228 
229       if (!ST->isLittle()) {
230         LLVM_DEBUG(dbgs() << "Only supporting little endian: not running pass "
231                              "ARMParallelDSP\n");
232         return false;
233       }
234 
235       LoopAccessInfo LAI(L, SE, TLI, AA, DT, LI);
236 
237       LLVM_DEBUG(dbgs() << "\n== Parallel DSP pass ==\n");
238       LLVM_DEBUG(dbgs() << " - " << F.getName() << "\n\n");
239 
240       if (!RecordSequentialLoads(Header)) {
241         LLVM_DEBUG(dbgs() << " - No sequential loads found.\n");
242         return false;
243       }
244 
245       bool Changes = MatchSMLAD(F);
246       return Changes;
247     }
248   };
249 }
250 
251 // MaxBitwidth: the maximum supported bitwidth of the elements in the DSP
252 // instructions, which is set to 16. So here we should collect all i8 and i16
253 // narrow operations.
254 // TODO: we currently only collect i16, and will support i8 later, so that's
255 // why we check that types are equal to MaxBitWidth, and not <= MaxBitWidth.
256 template<unsigned MaxBitWidth>
257 static bool IsNarrowSequence(Value *V, ValueList &VL) {
258   ConstantInt *CInt;
259 
260   if (match(V, m_ConstantInt(CInt))) {
261     // TODO: if a constant is used, it needs to fit within the bit width.
262     return false;
263   }
264 
265   auto *I = dyn_cast<Instruction>(V);
266   if (!I)
267    return false;
268 
269   Value *Val, *LHS, *RHS;
270   if (match(V, m_Trunc(m_Value(Val)))) {
271     if (cast<TruncInst>(I)->getDestTy()->getIntegerBitWidth() == MaxBitWidth)
272       return IsNarrowSequence<MaxBitWidth>(Val, VL);
273   } else if (match(V, m_Add(m_Value(LHS), m_Value(RHS)))) {
274     // TODO: we need to implement sadd16/sadd8 for this, which enables to
275     // also do the rewrite for smlad8.ll, but it is unsupported for now.
276     return false;
277   } else if (match(V, m_ZExtOrSExt(m_Value(Val)))) {
278     if (cast<CastInst>(I)->getSrcTy()->getIntegerBitWidth() != MaxBitWidth)
279       return false;
280 
281     if (match(Val, m_Load(m_Value()))) {
282       VL.push_back(Val);
283       VL.push_back(I);
284       return true;
285     }
286   }
287   return false;
288 }
289 
290 template<typename MemInst>
291 static bool AreSequentialAccesses(MemInst *MemOp0, MemInst *MemOp1,
292                                   const DataLayout &DL, ScalarEvolution &SE) {
293   if (isConsecutiveAccess(MemOp0, MemOp1, DL, SE))
294     return true;
295   return false;
296 }
297 
298 bool ARMParallelDSP::AreSequentialLoads(LoadInst *Ld0, LoadInst *Ld1,
299                                         MemInstList &VecMem) {
300   if (!Ld0 || !Ld1)
301     return false;
302 
303   if (!LoadPairs.count(Ld0) || LoadPairs[Ld0] != Ld1)
304     return false;
305 
306   LLVM_DEBUG(dbgs() << "Loads are sequential and valid:\n";
307     dbgs() << "Ld0:"; Ld0->dump();
308     dbgs() << "Ld1:"; Ld1->dump();
309   );
310 
311   VecMem.clear();
312   VecMem.push_back(Ld0);
313   VecMem.push_back(Ld1);
314   return true;
315 }
316 
317 /// Iterate through the block and record base, offset pairs of loads as well as
318 /// maximal sequences of sequential loads.
319 bool ARMParallelDSP::RecordSequentialLoads(BasicBlock *BB) {
320   SmallVector<LoadInst*, 8> Loads;
321   for (auto &I : *BB) {
322     auto *Ld = dyn_cast<LoadInst>(&I);
323     if (!Ld || !Ld->isSimple() ||
324         !Ld->hasOneUse() || !isa<SExtInst>(Ld->user_back()))
325       continue;
326     Loads.push_back(Ld);
327   }
328 
329   for (auto *Ld0 : Loads) {
330     for (auto *Ld1 : Loads) {
331       if (Ld0 == Ld1)
332         continue;
333 
334       if (AreSequentialAccesses<LoadInst>(Ld0, Ld1, *DL, *SE)) {
335         LoadPairs[Ld0] = Ld1;
336         break;
337       }
338     }
339   }
340 
341   LLVM_DEBUG(if (!LoadPairs.empty()) {
342                dbgs() << "Consecutive load pairs:\n";
343                for (auto &MapIt : LoadPairs) {
344                  LLVM_DEBUG(dbgs() << *MapIt.first << ", "
345                             << *MapIt.second << "\n");
346                }
347              });
348   return LoadPairs.size() > 1;
349 }
350 
351 void ARMParallelDSP::CreateParallelMACPairs(Reduction &R) {
352   OpChainList &Candidates = R.MACCandidates;
353   PMACPairList &PMACPairs = R.PMACPairs;
354   const unsigned Elems = Candidates.size();
355 
356   if (Elems < 2)
357     return;
358 
359   auto CanPair = [&](BinOpChain *PMul0, BinOpChain *PMul1) {
360     if (!PMul0->AreSymmetrical(PMul1))
361       return false;
362 
363     // The first elements of each vector should be loads with sexts. If we
364     // find that its two pairs of consecutive loads, then these can be
365     // transformed into two wider loads and the users can be replaced with
366     // DSP intrinsics.
367     for (unsigned x = 0; x < PMul0->LHS.size(); x += 2) {
368       auto *Ld0 = dyn_cast<LoadInst>(PMul0->LHS[x]);
369       auto *Ld1 = dyn_cast<LoadInst>(PMul1->LHS[x]);
370       auto *Ld2 = dyn_cast<LoadInst>(PMul0->RHS[x]);
371       auto *Ld3 = dyn_cast<LoadInst>(PMul1->RHS[x]);
372 
373       if (!Ld0 || !Ld1 || !Ld2 || !Ld3)
374         return false;
375 
376       LLVM_DEBUG(dbgs() << "Loads:\n"
377                  << " - " << *Ld0 << "\n"
378                  << " - " << *Ld1 << "\n"
379                  << " - " << *Ld2 << "\n"
380                  << " - " << *Ld3 << "\n");
381 
382       if (AreSequentialLoads(Ld0, Ld1, PMul0->VecLd)) {
383         if (AreSequentialLoads(Ld2, Ld3, PMul1->VecLd)) {
384           LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n");
385           PMACPairs.push_back(std::make_pair(PMul0, PMul1));
386           return true;
387         } else if (AreSequentialLoads(Ld3, Ld2, PMul1->VecLd)) {
388           LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n");
389           LLVM_DEBUG(dbgs() << "    exchanging Ld2 and Ld3\n");
390           PMul1->Exchange = true;
391           PMACPairs.push_back(std::make_pair(PMul0, PMul1));
392           return true;
393         }
394       } else if (AreSequentialLoads(Ld1, Ld0, PMul0->VecLd) &&
395                  AreSequentialLoads(Ld2, Ld3, PMul1->VecLd)) {
396         LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n");
397         LLVM_DEBUG(dbgs() << "    exchanging Ld0 and Ld1\n");
398         LLVM_DEBUG(dbgs() << "    and swapping muls\n");
399         PMul0->Exchange = true;
400         // Only the second operand can be exchanged, so swap the muls.
401         PMACPairs.push_back(std::make_pair(PMul1, PMul0));
402         return true;
403       }
404     }
405     return false;
406   };
407 
408   SmallPtrSet<const Instruction*, 4> Paired;
409   for (unsigned i = 0; i < Elems; ++i) {
410     BinOpChain *PMul0 = static_cast<BinOpChain*>(Candidates[i].get());
411     if (Paired.count(PMul0->Root))
412       continue;
413 
414     for (unsigned j = 0; j < Elems; ++j) {
415       if (i == j)
416         continue;
417 
418       BinOpChain *PMul1 = static_cast<BinOpChain*>(Candidates[j].get());
419       if (Paired.count(PMul1->Root))
420         continue;
421 
422       const Instruction *Mul0 = PMul0->Root;
423       const Instruction *Mul1 = PMul1->Root;
424       if (Mul0 == Mul1)
425         continue;
426 
427       assert(PMul0 != PMul1 && "expected different chains");
428 
429       if (CanPair(PMul0, PMul1)) {
430         Paired.insert(Mul0);
431         Paired.insert(Mul1);
432         break;
433       }
434     }
435   }
436 }
437 
438 bool ARMParallelDSP::InsertParallelMACs(Reduction &Reduction) {
439   Instruction *Acc = Reduction.Phi;
440   Instruction *InsertAfter = Reduction.AccIntAdd;
441 
442   for (auto &Pair : Reduction.PMACPairs) {
443     BinOpChain *PMul0 = Pair.first;
444     BinOpChain *PMul1 = Pair.second;
445     LLVM_DEBUG(dbgs() << "Found parallel MACs!!\n";
446                dbgs() << "- "; PMul0->Root->dump();
447                dbgs() << "- "; PMul1->Root->dump());
448 
449     Acc = CreateSMLADCall(PMul0->VecLd, PMul1->VecLd, Acc, PMul1->Exchange,
450                           InsertAfter);
451     InsertAfter = Acc;
452   }
453 
454   if (Acc != Reduction.Phi) {
455     LLVM_DEBUG(dbgs() << "Replace Accumulate: "; Acc->dump());
456     Reduction.AccIntAdd->replaceAllUsesWith(Acc);
457     return true;
458   }
459   return false;
460 }
461 
462 static void MatchReductions(Function &F, Loop *TheLoop, BasicBlock *Header,
463                             ReductionList &Reductions) {
464   RecurrenceDescriptor RecDesc;
465   const bool HasFnNoNaNAttr =
466     F.getFnAttribute("no-nans-fp-math").getValueAsString() == "true";
467   const BasicBlock *Latch = TheLoop->getLoopLatch();
468 
469   for (PHINode &Phi : Header->phis()) {
470     const auto *Ty = Phi.getType();
471     if (!Ty->isIntegerTy(32) && !Ty->isIntegerTy(64))
472       continue;
473 
474     const bool IsReduction =
475       RecurrenceDescriptor::AddReductionVar(&Phi,
476                                             RecurrenceDescriptor::RK_IntegerAdd,
477                                             TheLoop, HasFnNoNaNAttr, RecDesc);
478     if (!IsReduction)
479       continue;
480 
481     Instruction *Acc = dyn_cast<Instruction>(Phi.getIncomingValueForBlock(Latch));
482     if (!Acc)
483       continue;
484 
485     Reductions.push_back(Reduction(&Phi, Acc));
486   }
487 
488   LLVM_DEBUG(
489     dbgs() << "\nAccumulating integer additions (reductions) found:\n";
490     for (auto &R : Reductions) {
491       dbgs() << "-  "; R.Phi->dump();
492       dbgs() << "-> "; R.AccIntAdd->dump();
493     }
494   );
495 }
496 
497 static void AddMACCandidate(OpChainList &Candidates,
498                             Instruction *Mul,
499                             Value *MulOp0, Value *MulOp1) {
500   assert(Mul->getOpcode() == Instruction::Mul &&
501          "expected mul instruction");
502   ValueList LHS;
503   ValueList RHS;
504   if (IsNarrowSequence<16>(MulOp0, LHS) &&
505       IsNarrowSequence<16>(MulOp1, RHS)) {
506     Candidates.push_back(make_unique<BinOpChain>(Mul, LHS, RHS));
507   }
508 }
509 
510 static void MatchParallelMACSequences(Reduction &R,
511                                       OpChainList &Candidates) {
512   Instruction *Acc = R.AccIntAdd;
513   LLVM_DEBUG(dbgs() << "\n- Analysing:\t" << *Acc << "\n");
514 
515   // Returns false to signal the search should be stopped.
516   std::function<bool(Value*)> Match =
517     [&Candidates, &Match](Value *V) -> bool {
518 
519     auto *I = dyn_cast<Instruction>(V);
520     if (!I)
521       return false;
522 
523     switch (I->getOpcode()) {
524     case Instruction::Add:
525       if (Match(I->getOperand(0)) || (Match(I->getOperand(1))))
526         return true;
527       break;
528     case Instruction::Mul: {
529       Value *MulOp0 = I->getOperand(0);
530       Value *MulOp1 = I->getOperand(1);
531       if (isa<SExtInst>(MulOp0) && isa<SExtInst>(MulOp1))
532         AddMACCandidate(Candidates, I, MulOp0, MulOp1);
533       return false;
534     }
535     case Instruction::SExt:
536       return Match(I->getOperand(0));
537     }
538     return false;
539   };
540 
541   while (Match (Acc));
542   LLVM_DEBUG(dbgs() << "Finished matching MAC sequences, found "
543              << Candidates.size() << " candidates.\n");
544 }
545 
546 // Collects all instructions that are not part of the MAC chains, which is the
547 // set of instructions that can potentially alias with the MAC operands.
548 static void AliasCandidates(BasicBlock *Header, Instructions &Reads,
549                             Instructions &Writes) {
550   for (auto &I : *Header) {
551     if (I.mayReadFromMemory())
552       Reads.push_back(&I);
553     if (I.mayWriteToMemory())
554       Writes.push_back(&I);
555   }
556 }
557 
558 // Check whether statements in the basic block that write to memory alias with
559 // the memory locations accessed by the MAC-chains.
560 // TODO: we need the read statements when we accept more complicated chains.
561 static bool AreAliased(AliasAnalysis *AA, Instructions &Reads,
562                        Instructions &Writes, OpChainList &MACCandidates) {
563   LLVM_DEBUG(dbgs() << "Alias checks:\n");
564   for (auto &MAC : MACCandidates) {
565     LLVM_DEBUG(dbgs() << "mul: "; MAC->Root->dump());
566 
567     // At the moment, we allow only simple chains that only consist of reads,
568     // accumulate their result with an integer add, and thus that don't write
569     // memory, and simply bail if they do.
570     if (!MAC->ReadOnly)
571       return true;
572 
573     // Now for all writes in the basic block, check that they don't alias with
574     // the memory locations accessed by our MAC-chain:
575     for (auto *I : Writes) {
576       LLVM_DEBUG(dbgs() << "- "; I->dump());
577       assert(MAC->MemLocs.size() >= 2 && "expecting at least 2 memlocs");
578       for (auto &MemLoc : MAC->MemLocs) {
579         if (isModOrRefSet(intersectModRef(AA->getModRefInfo(I, MemLoc),
580                                           ModRefInfo::ModRef))) {
581           LLVM_DEBUG(dbgs() << "Yes, aliases found\n");
582           return true;
583         }
584       }
585     }
586   }
587 
588   LLVM_DEBUG(dbgs() << "OK: no aliases found!\n");
589   return false;
590 }
591 
592 static bool CheckMACMemory(OpChainList &Candidates) {
593   for (auto &C : Candidates) {
594     // A mul has 2 operands, and a narrow op consist of sext and a load; thus
595     // we expect at least 4 items in this operand value list.
596     if (C->size() < 4) {
597       LLVM_DEBUG(dbgs() << "Operand list too short.\n");
598       return false;
599     }
600     C->SetMemoryLocations();
601     ValueList &LHS = static_cast<BinOpChain*>(C.get())->LHS;
602     ValueList &RHS = static_cast<BinOpChain*>(C.get())->RHS;
603 
604     // Use +=2 to skip over the expected extend instructions.
605     for (unsigned i = 0, e = LHS.size(); i < e; i += 2) {
606       if (!isa<LoadInst>(LHS[i]) || !isa<LoadInst>(RHS[i]))
607         return false;
608     }
609   }
610   return true;
611 }
612 
613 // Loop Pass that needs to identify integer add/sub reductions of 16-bit vector
614 // multiplications.
615 // To use SMLAD:
616 // 1) we first need to find integer add reduction PHIs,
617 // 2) then from the PHI, look for this pattern:
618 //
619 // acc0 = phi i32 [0, %entry], [%acc1, %loop.body]
620 // ld0 = load i16
621 // sext0 = sext i16 %ld0 to i32
622 // ld1 = load i16
623 // sext1 = sext i16 %ld1 to i32
624 // mul0 = mul %sext0, %sext1
625 // ld2 = load i16
626 // sext2 = sext i16 %ld2 to i32
627 // ld3 = load i16
628 // sext3 = sext i16 %ld3 to i32
629 // mul1 = mul i32 %sext2, %sext3
630 // add0 = add i32 %mul0, %acc0
631 // acc1 = add i32 %add0, %mul1
632 //
633 // Which can be selected to:
634 //
635 // ldr.h r0
636 // ldr.h r1
637 // smlad r2, r0, r1, r2
638 //
639 // If constants are used instead of loads, these will need to be hoisted
640 // out and into a register.
641 //
642 // If loop invariants are used instead of loads, these need to be packed
643 // before the loop begins.
644 //
645 bool ARMParallelDSP::MatchSMLAD(Function &F) {
646   BasicBlock *Header = L->getHeader();
647   LLVM_DEBUG(dbgs() << "= Matching SMLAD =\n";
648              dbgs() << "Header block:\n"; Header->dump();
649              dbgs() << "Loop info:\n\n"; L->dump());
650 
651   bool Changed = false;
652   ReductionList Reductions;
653   MatchReductions(F, L, Header, Reductions);
654 
655   for (auto &R : Reductions) {
656     OpChainList MACCandidates;
657     MatchParallelMACSequences(R, MACCandidates);
658     if (!CheckMACMemory(MACCandidates))
659       continue;
660 
661     R.MACCandidates = std::move(MACCandidates);
662 
663     LLVM_DEBUG(dbgs() << "MAC candidates:\n";
664       for (auto &M : R.MACCandidates)
665         M->Root->dump();
666       dbgs() << "\n";);
667   }
668 
669   // Collect all instructions that may read or write memory. Our alias
670   // analysis checks bail out if any of these instructions aliases with an
671   // instruction from the MAC-chain.
672   Instructions Reads, Writes;
673   AliasCandidates(Header, Reads, Writes);
674 
675   for (auto &R : Reductions) {
676     if (AreAliased(AA, Reads, Writes, R.MACCandidates))
677       return false;
678     CreateParallelMACPairs(R);
679     Changed |= InsertParallelMACs(R);
680   }
681 
682   LLVM_DEBUG(if (Changed) dbgs() << "Header block:\n"; Header->dump(););
683   return Changed;
684 }
685 
686 LoadInst* ARMParallelDSP::CreateLoadIns(IRBuilder<NoFolder> &IRB,
687                                         SmallVectorImpl<LoadInst*> &Loads,
688                                         IntegerType *LoadTy) {
689   assert(Loads.size() == 2 && "currently only support widening two loads");
690 
691   const unsigned AddrSpace = Loads[0]->getPointerAddressSpace();
692   Value *VecPtr = IRB.CreateBitCast(Loads[0]->getPointerOperand(),
693                                     LoadTy->getPointerTo(AddrSpace));
694   LoadInst *WideLoad = IRB.CreateAlignedLoad(LoadTy, VecPtr,
695                                              Loads[0]->getAlignment());
696   // Fix up users, Loads[0] needs trunc while Loads[1] needs a lshr and trunc.
697   Instruction *SExt0 = dyn_cast<SExtInst>(Loads[0]->user_back());
698   Instruction *SExt1 = dyn_cast<SExtInst>(Loads[1]->user_back());
699 
700   assert((Loads[0]->hasOneUse() && Loads[1]->hasOneUse() && SExt0 && SExt1) &&
701          "Loads should have a single, extending, user");
702 
703   std::function<void(Instruction*, Instruction*)> MoveAfter =
704     [&](Instruction* Source, Instruction* Sink) -> void {
705     if (DT->dominates(Source, Sink) ||
706         Source->getParent() != Sink->getParent() ||
707         isa<PHINode>(Source) || isa<PHINode>(Sink))
708       return;
709 
710     Sink->moveAfter(Source);
711     for (auto &U : Sink->uses())
712       MoveAfter(Sink, cast<Instruction>(U.getUser()));
713   };
714 
715   // From the wide load, create two values that equal the original two loads.
716   Value *Bottom = IRB.CreateTrunc(WideLoad, Loads[0]->getType());
717   SExt0->setOperand(0, Bottom);
718   if (auto *I = dyn_cast<Instruction>(Bottom)) {
719     I->moveAfter(WideLoad);
720     MoveAfter(I, SExt0);
721   }
722 
723   IntegerType *Ld1Ty = cast<IntegerType>(Loads[1]->getType());
724   Value *ShiftVal = ConstantInt::get(LoadTy, Ld1Ty->getBitWidth());
725   Value *Top = IRB.CreateLShr(WideLoad, ShiftVal);
726   if (auto *I = dyn_cast<Instruction>(Top))
727     MoveAfter(WideLoad, I);
728 
729   Value *Trunc = IRB.CreateTrunc(Top, Ld1Ty);
730   SExt1->setOperand(0, Trunc);
731   if (auto *I = dyn_cast<Instruction>(Trunc))
732     MoveAfter(I, SExt1);
733 
734   WideLoads.emplace(std::make_pair(Loads[0],
735                                    make_unique<WidenedLoad>(Loads, WideLoad)));
736   return WideLoad;
737 }
738 
739 Instruction *ARMParallelDSP::CreateSMLADCall(SmallVectorImpl<LoadInst*> &VecLd0,
740                                              SmallVectorImpl<LoadInst*> &VecLd1,
741                                              Instruction *Acc, bool Exchange,
742                                              Instruction *InsertAfter) {
743   LLVM_DEBUG(dbgs() << "Create SMLAD intrinsic using:\n"
744              << "- " << *VecLd0[0] << "\n"
745              << "- " << *VecLd0[1] << "\n"
746              << "- " << *VecLd1[0] << "\n"
747              << "- " << *VecLd1[1] << "\n"
748              << "- " << *Acc << "\n"
749              << "- Exchange: " << Exchange << "\n");
750 
751   IRBuilder<NoFolder> Builder(InsertAfter->getParent(),
752                               ++BasicBlock::iterator(InsertAfter));
753 
754   // Replace the reduction chain with an intrinsic call
755   IntegerType *Ty = IntegerType::get(M->getContext(), 32);
756   LoadInst *WideLd0 = WideLoads.count(VecLd0[0]) ?
757     WideLoads[VecLd0[0]]->getLoad() : CreateLoadIns(Builder, VecLd0, Ty);
758   LoadInst *WideLd1 = WideLoads.count(VecLd1[0]) ?
759     WideLoads[VecLd1[0]]->getLoad() : CreateLoadIns(Builder, VecLd1, Ty);
760   Value* Args[] = { WideLd0, WideLd1, Acc };
761   Function *SMLAD = nullptr;
762   if (Exchange)
763     SMLAD = Acc->getType()->isIntegerTy(32) ?
764       Intrinsic::getDeclaration(M, Intrinsic::arm_smladx) :
765       Intrinsic::getDeclaration(M, Intrinsic::arm_smlaldx);
766   else
767     SMLAD = Acc->getType()->isIntegerTy(32) ?
768       Intrinsic::getDeclaration(M, Intrinsic::arm_smlad) :
769       Intrinsic::getDeclaration(M, Intrinsic::arm_smlald);
770   CallInst *Call = Builder.CreateCall(SMLAD, Args);
771   NumSMLAD++;
772   return Call;
773 }
774 
775 // Compare the value lists in Other to this chain.
776 bool BinOpChain::AreSymmetrical(BinOpChain *Other) {
777   // Element-by-element comparison of Value lists returning true if they are
778   // instructions with the same opcode or constants with the same value.
779   auto CompareValueList = [](const ValueList &VL0,
780                              const ValueList &VL1) {
781     if (VL0.size() != VL1.size()) {
782       LLVM_DEBUG(dbgs() << "Muls are mismatching operand list lengths: "
783                         << VL0.size() << " != " << VL1.size() << "\n");
784       return false;
785     }
786 
787     const unsigned Pairs = VL0.size();
788 
789     for (unsigned i = 0; i < Pairs; ++i) {
790       const Value *V0 = VL0[i];
791       const Value *V1 = VL1[i];
792       const auto *Inst0 = dyn_cast<Instruction>(V0);
793       const auto *Inst1 = dyn_cast<Instruction>(V1);
794 
795       if (!Inst0 || !Inst1)
796         return false;
797 
798       if (Inst0->isSameOperationAs(Inst1))
799         continue;
800 
801       const APInt *C0, *C1;
802       if (!(match(V0, m_APInt(C0)) && match(V1, m_APInt(C1)) && C0 == C1))
803         return false;
804     }
805 
806     return true;
807   };
808 
809   return CompareValueList(LHS, Other->LHS) &&
810          CompareValueList(RHS, Other->RHS);
811 }
812 
813 Pass *llvm::createARMParallelDSPPass() {
814   return new ARMParallelDSP();
815 }
816 
817 char ARMParallelDSP::ID = 0;
818 
819 INITIALIZE_PASS_BEGIN(ARMParallelDSP, "arm-parallel-dsp",
820                 "Transform loops to use DSP intrinsics", false, false)
821 INITIALIZE_PASS_END(ARMParallelDSP, "arm-parallel-dsp",
822                 "Transform loops to use DSP intrinsics", false, false)
823