1 //===- ParallelDSP.cpp - Parallel DSP Pass --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file
10 /// Armv6 introduced instructions to perform 32-bit SIMD operations. The
11 /// purpose of this pass is do some IR pattern matching to create ACLE
12 /// DSP intrinsics, which map on these 32-bit SIMD operations.
13 /// This pass runs only when unaligned accesses is supported/enabled.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "llvm/ADT/Statistic.h"
18 #include "llvm/ADT/SmallPtrSet.h"
19 #include "llvm/Analysis/AliasAnalysis.h"
20 #include "llvm/Analysis/LoopAccessAnalysis.h"
21 #include "llvm/Analysis/LoopPass.h"
22 #include "llvm/Analysis/LoopInfo.h"
23 #include "llvm/IR/Instructions.h"
24 #include "llvm/IR/NoFolder.h"
25 #include "llvm/Transforms/Scalar.h"
26 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
27 #include "llvm/Transforms/Utils/LoopUtils.h"
28 #include "llvm/Pass.h"
29 #include "llvm/PassRegistry.h"
30 #include "llvm/PassSupport.h"
31 #include "llvm/Support/Debug.h"
32 #include "llvm/IR/PatternMatch.h"
33 #include "llvm/CodeGen/TargetPassConfig.h"
34 #include "ARM.h"
35 #include "ARMSubtarget.h"
36 
37 using namespace llvm;
38 using namespace PatternMatch;
39 
40 #define DEBUG_TYPE "arm-parallel-dsp"
41 
42 STATISTIC(NumSMLAD , "Number of smlad instructions generated");
43 
44 static cl::opt<bool>
45 DisableParallelDSP("disable-arm-parallel-dsp", cl::Hidden, cl::init(false),
46                    cl::desc("Disable the ARM Parallel DSP pass"));
47 
48 namespace {
49   struct OpChain;
50   struct BinOpChain;
51   struct Reduction;
52 
53   using OpChainList     = SmallVector<std::unique_ptr<OpChain>, 8>;
54   using ReductionList   = SmallVector<Reduction, 8>;
55   using ValueList       = SmallVector<Value*, 8>;
56   using MemInstList     = SmallVector<Instruction*, 8>;
57   using PMACPair        = std::pair<BinOpChain*,BinOpChain*>;
58   using PMACPairList    = SmallVector<PMACPair, 8>;
59   using Instructions    = SmallVector<Instruction*,16>;
60   using MemLocList      = SmallVector<MemoryLocation, 4>;
61 
62   struct OpChain {
63     Instruction   *Root;
64     ValueList     AllValues;
65     MemInstList   VecLd;    // List of all load instructions.
66     MemLocList    MemLocs;  // All memory locations read by this tree.
67     bool          ReadOnly = true;
68 
69     OpChain(Instruction *I, ValueList &vl) : Root(I), AllValues(vl) { }
70     virtual ~OpChain() = default;
71 
72     void SetMemoryLocations() {
73       const auto Size = LocationSize::unknown();
74       for (auto *V : AllValues) {
75         if (auto *I = dyn_cast<Instruction>(V)) {
76           if (I->mayWriteToMemory())
77             ReadOnly = false;
78           if (auto *Ld = dyn_cast<LoadInst>(V))
79             MemLocs.push_back(MemoryLocation(Ld->getPointerOperand(), Size));
80         }
81       }
82     }
83 
84     unsigned size() const { return AllValues.size(); }
85   };
86 
87   // 'BinOpChain' and 'Reduction' are just some bookkeeping data structures.
88   // 'Reduction' contains the phi-node and accumulator statement from where we
89   // start pattern matching, and 'BinOpChain' the multiplication
90   // instructions that are candidates for parallel execution.
91   struct BinOpChain : public OpChain {
92     ValueList     LHS;      // List of all (narrow) left hand operands.
93     ValueList     RHS;      // List of all (narrow) right hand operands.
94     bool Exchange = false;
95 
96     BinOpChain(Instruction *I, ValueList &lhs, ValueList &rhs) :
97       OpChain(I, lhs), LHS(lhs), RHS(rhs) {
98         for (auto *V : RHS)
99           AllValues.push_back(V);
100       }
101 
102     bool AreSymmetrical(BinOpChain *Other);
103   };
104 
105   struct Reduction {
106     PHINode         *Phi;             // The Phi-node from where we start
107                                       // pattern matching.
108     Instruction     *AccIntAdd;       // The accumulating integer add statement,
109                                       // i.e, the reduction statement.
110     OpChainList     MACCandidates;    // The MAC candidates associated with
111                                       // this reduction statement.
112     PMACPairList    PMACPairs;
113     Reduction (PHINode *P, Instruction *Acc) : Phi(P), AccIntAdd(Acc) { };
114   };
115 
116   class ARMParallelDSP : public LoopPass {
117     ScalarEvolution   *SE;
118     AliasAnalysis     *AA;
119     TargetLibraryInfo *TLI;
120     DominatorTree     *DT;
121     LoopInfo          *LI;
122     Loop              *L;
123     const DataLayout  *DL;
124     Module            *M;
125     std::map<LoadInst*, LoadInst*> LoadPairs;
126     std::map<LoadInst*, SmallVector<LoadInst*, 4>> SequentialLoads;
127 
128     bool RecordSequentialLoads(BasicBlock *Header);
129     bool InsertParallelMACs(Reduction &Reduction);
130     bool AreSequentialLoads(LoadInst *Ld0, LoadInst *Ld1, MemInstList &VecMem);
131     void CreateParallelMACPairs(Reduction &R);
132     Instruction *CreateSMLADCall(LoadInst *VecLd0, LoadInst *VecLd1,
133                                  Instruction *Acc, bool Exchange,
134                                  Instruction *InsertAfter);
135 
136     /// Try to match and generate: SMLAD, SMLADX - Signed Multiply Accumulate
137     /// Dual performs two signed 16x16-bit multiplications. It adds the
138     /// products to a 32-bit accumulate operand. Optionally, the instruction can
139     /// exchange the halfwords of the second operand before performing the
140     /// arithmetic.
141     bool MatchSMLAD(Function &F);
142 
143   public:
144     static char ID;
145 
146     ARMParallelDSP() : LoopPass(ID) { }
147 
148     void getAnalysisUsage(AnalysisUsage &AU) const override {
149       LoopPass::getAnalysisUsage(AU);
150       AU.addRequired<AssumptionCacheTracker>();
151       AU.addRequired<ScalarEvolutionWrapperPass>();
152       AU.addRequired<AAResultsWrapperPass>();
153       AU.addRequired<TargetLibraryInfoWrapperPass>();
154       AU.addRequired<LoopInfoWrapperPass>();
155       AU.addRequired<DominatorTreeWrapperPass>();
156       AU.addRequired<TargetPassConfig>();
157       AU.addPreserved<LoopInfoWrapperPass>();
158       AU.setPreservesCFG();
159     }
160 
161     bool runOnLoop(Loop *TheLoop, LPPassManager &) override {
162       if (DisableParallelDSP)
163         return false;
164       L = TheLoop;
165       SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
166       AA = &getAnalysis<AAResultsWrapperPass>().getAAResults();
167       TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
168       DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
169       LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
170       auto &TPC = getAnalysis<TargetPassConfig>();
171 
172       BasicBlock *Header = TheLoop->getHeader();
173       if (!Header)
174         return false;
175 
176       // TODO: We assume the loop header and latch to be the same block.
177       // This is not a fundamental restriction, but lifting this would just
178       // require more work to do the transformation and then patch up the CFG.
179       if (Header != TheLoop->getLoopLatch()) {
180         LLVM_DEBUG(dbgs() << "The loop header is not the loop latch: not "
181                              "running pass ARMParallelDSP\n");
182         return false;
183       }
184 
185       Function &F = *Header->getParent();
186       M = F.getParent();
187       DL = &M->getDataLayout();
188 
189       auto &TM = TPC.getTM<TargetMachine>();
190       auto *ST = &TM.getSubtarget<ARMSubtarget>(F);
191 
192       if (!ST->allowsUnalignedMem()) {
193         LLVM_DEBUG(dbgs() << "Unaligned memory access not supported: not "
194                              "running pass ARMParallelDSP\n");
195         return false;
196       }
197 
198       if (!ST->hasDSP()) {
199         LLVM_DEBUG(dbgs() << "DSP extension not enabled: not running pass "
200                              "ARMParallelDSP\n");
201         return false;
202       }
203 
204       LoopAccessInfo LAI(L, SE, TLI, AA, DT, LI);
205       bool Changes = false;
206 
207       LLVM_DEBUG(dbgs() << "\n== Parallel DSP pass ==\n");
208       LLVM_DEBUG(dbgs() << " - " << F.getName() << "\n\n");
209 
210       if (!RecordSequentialLoads(Header)) {
211         LLVM_DEBUG(dbgs() << " - No sequential loads found.\n");
212         return false;
213       }
214 
215       Changes = MatchSMLAD(F);
216       return Changes;
217     }
218   };
219 }
220 
221 // MaxBitwidth: the maximum supported bitwidth of the elements in the DSP
222 // instructions, which is set to 16. So here we should collect all i8 and i16
223 // narrow operations.
224 // TODO: we currently only collect i16, and will support i8 later, so that's
225 // why we check that types are equal to MaxBitWidth, and not <= MaxBitWidth.
226 template<unsigned MaxBitWidth>
227 static bool IsNarrowSequence(Value *V, ValueList &VL) {
228   LLVM_DEBUG(dbgs() << "Is narrow sequence? "; V->dump());
229   ConstantInt *CInt;
230 
231   if (match(V, m_ConstantInt(CInt))) {
232     // TODO: if a constant is used, it needs to fit within the bit width.
233     return false;
234   }
235 
236   auto *I = dyn_cast<Instruction>(V);
237   if (!I)
238    return false;
239 
240   Value *Val, *LHS, *RHS;
241   if (match(V, m_Trunc(m_Value(Val)))) {
242     if (cast<TruncInst>(I)->getDestTy()->getIntegerBitWidth() == MaxBitWidth)
243       return IsNarrowSequence<MaxBitWidth>(Val, VL);
244   } else if (match(V, m_Add(m_Value(LHS), m_Value(RHS)))) {
245     // TODO: we need to implement sadd16/sadd8 for this, which enables to
246     // also do the rewrite for smlad8.ll, but it is unsupported for now.
247     LLVM_DEBUG(dbgs() << "No, unsupported Op:\t"; I->dump());
248     return false;
249   } else if (match(V, m_ZExtOrSExt(m_Value(Val)))) {
250     if (cast<CastInst>(I)->getSrcTy()->getIntegerBitWidth() != MaxBitWidth) {
251       LLVM_DEBUG(dbgs() << "No, wrong SrcTy size: " <<
252         cast<CastInst>(I)->getSrcTy()->getIntegerBitWidth() << "\n");
253       return false;
254     }
255 
256     if (match(Val, m_Load(m_Value()))) {
257       LLVM_DEBUG(dbgs() << "Yes, found narrow Load:\t"; Val->dump());
258       VL.push_back(Val);
259       VL.push_back(I);
260       return true;
261     }
262   }
263   LLVM_DEBUG(dbgs() << "No, unsupported Op:\t"; I->dump());
264   return false;
265 }
266 
267 template<typename MemInst>
268 static bool AreSequentialAccesses(MemInst *MemOp0, MemInst *MemOp1,
269                                   const DataLayout &DL, ScalarEvolution &SE) {
270   if (!MemOp0->isSimple() || !MemOp1->isSimple()) {
271     LLVM_DEBUG(dbgs() << "No, not touching volatile access\n");
272     return false;
273   }
274   if (isConsecutiveAccess(MemOp0, MemOp1, DL, SE)) {
275     LLVM_DEBUG(dbgs() << "OK: accesses are consecutive.\n");
276     return true;
277   }
278   LLVM_DEBUG(dbgs() << "No, accesses aren't consecutive.\n");
279   return false;
280 }
281 
282 bool ARMParallelDSP::AreSequentialLoads(LoadInst *Ld0, LoadInst *Ld1,
283                                         MemInstList &VecMem) {
284   if (!Ld0 || !Ld1)
285     return false;
286 
287   LLVM_DEBUG(dbgs() << "Are consecutive loads:\n";
288     dbgs() << "Ld0:"; Ld0->dump();
289     dbgs() << "Ld1:"; Ld1->dump();
290   );
291 
292   if (!Ld0->hasOneUse() || !Ld1->hasOneUse()) {
293     LLVM_DEBUG(dbgs() << "No, load has more than one use.\n");
294     return false;
295   }
296 
297   if (!LoadPairs.count(Ld0) || LoadPairs[Ld0] != Ld1)
298     return false;
299 
300   VecMem.clear();
301   VecMem.push_back(Ld0);
302   VecMem.push_back(Ld1);
303   return true;
304 }
305 
306 /// Iterate through the block and record base, offset pairs of loads as well as
307 /// maximal sequences of sequential loads.
308 bool ARMParallelDSP::RecordSequentialLoads(BasicBlock *Header) {
309   SmallVector<LoadInst*, 8> Loads;
310   for (auto &I : *Header) {
311     auto *Ld = dyn_cast<LoadInst>(&I);
312     if (!Ld)
313       continue;
314     Loads.push_back(Ld);
315   }
316 
317   std::map<LoadInst*, LoadInst*> BaseLoads;
318 
319   for (auto *Ld0 : Loads) {
320     for (auto *Ld1 : Loads) {
321       if (Ld0 == Ld1)
322         continue;
323 
324       if (AreSequentialAccesses<LoadInst>(Ld0, Ld1, *DL, *SE)) {
325         LoadPairs[Ld0] = Ld1;
326         if (BaseLoads.count(Ld0)) {
327           LoadInst *Base = BaseLoads[Ld0];
328           BaseLoads[Ld1] = Base;
329           SequentialLoads[Base].push_back(Ld1);
330         } else {
331           BaseLoads[Ld1] = Ld0;
332           SequentialLoads[Ld0].push_back(Ld1);
333         }
334       }
335     }
336   }
337   return LoadPairs.size() > 1;
338 }
339 
340 void ARMParallelDSP::CreateParallelMACPairs(Reduction &R) {
341   OpChainList &Candidates = R.MACCandidates;
342   PMACPairList &PMACPairs = R.PMACPairs;
343   const unsigned Elems = Candidates.size();
344 
345   if (Elems < 2)
346     return;
347 
348   auto CanPair = [&](BinOpChain *PMul0, BinOpChain *PMul1) {
349     if (!PMul0->AreSymmetrical(PMul1))
350       return false;
351 
352     // The first elements of each vector should be loads with sexts. If we
353     // find that its two pairs of consecutive loads, then these can be
354     // transformed into two wider loads and the users can be replaced with
355     // DSP intrinsics.
356     for (unsigned x = 0; x < PMul0->LHS.size(); x += 2) {
357       auto *Ld0 = dyn_cast<LoadInst>(PMul0->LHS[x]);
358       auto *Ld1 = dyn_cast<LoadInst>(PMul1->LHS[x]);
359       auto *Ld2 = dyn_cast<LoadInst>(PMul0->RHS[x]);
360       auto *Ld3 = dyn_cast<LoadInst>(PMul1->RHS[x]);
361 
362       if (!Ld0 || !Ld1 || !Ld2 || !Ld3)
363         return false;
364 
365       LLVM_DEBUG(dbgs() << "Looking at operands " << x << ":\n"
366                  << "\t Ld0: " << *Ld0 << "\n"
367                  << "\t Ld1: " << *Ld1 << "\n"
368                  << "and operands " << x + 2 << ":\n"
369                  << "\t Ld2: " << *Ld2 << "\n"
370                  << "\t Ld3: " << *Ld3 << "\n");
371 
372       if (AreSequentialLoads(Ld0, Ld1, PMul0->VecLd)) {
373         if (AreSequentialLoads(Ld2, Ld3, PMul1->VecLd)) {
374           LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n");
375           PMACPairs.push_back(std::make_pair(PMul0, PMul1));
376           return true;
377         } else if (AreSequentialLoads(Ld3, Ld2, PMul1->VecLd)) {
378           LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n");
379           LLVM_DEBUG(dbgs() << "    exchanging Ld2 and Ld3\n");
380           PMul1->Exchange = true;
381           PMACPairs.push_back(std::make_pair(PMul0, PMul1));
382           return true;
383         }
384       } else if (AreSequentialLoads(Ld1, Ld0, PMul0->VecLd) &&
385                  AreSequentialLoads(Ld2, Ld3, PMul1->VecLd)) {
386         LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n");
387         LLVM_DEBUG(dbgs() << "    exchanging Ld0 and Ld1\n");
388         LLVM_DEBUG(dbgs() << "    and swapping muls\n");
389         PMul0->Exchange = true;
390         // Only the second operand can be exchanged, so swap the muls.
391         PMACPairs.push_back(std::make_pair(PMul1, PMul0));
392         return true;
393       }
394     }
395     return false;
396   };
397 
398   SmallPtrSet<const Instruction*, 4> Paired;
399   for (unsigned i = 0; i < Elems; ++i) {
400     BinOpChain *PMul0 = static_cast<BinOpChain*>(Candidates[i].get());
401     if (Paired.count(PMul0->Root))
402       continue;
403 
404     for (unsigned j = 0; j < Elems; ++j) {
405       if (i == j)
406         continue;
407 
408       BinOpChain *PMul1 = static_cast<BinOpChain*>(Candidates[j].get());
409       if (Paired.count(PMul1->Root))
410         continue;
411 
412       const Instruction *Mul0 = PMul0->Root;
413       const Instruction *Mul1 = PMul1->Root;
414       if (Mul0 == Mul1)
415         continue;
416 
417       assert(PMul0 != PMul1 && "expected different chains");
418 
419       LLVM_DEBUG(dbgs() << "\nCheck parallel muls:\n";
420                  dbgs() << "- "; Mul0->dump();
421                  dbgs() << "- "; Mul1->dump());
422 
423       LLVM_DEBUG(dbgs() << "OK: mul operands list match:\n");
424       if (CanPair(PMul0, PMul1)) {
425         Paired.insert(Mul0);
426         Paired.insert(Mul1);
427         break;
428       }
429     }
430   }
431 }
432 
433 bool ARMParallelDSP::InsertParallelMACs(Reduction &Reduction) {
434   Instruction *Acc = Reduction.Phi;
435   Instruction *InsertAfter = Reduction.AccIntAdd;
436 
437   for (auto &Pair : Reduction.PMACPairs) {
438     BinOpChain *PMul0 = Pair.first;
439     BinOpChain *PMul1 = Pair.second;
440     LLVM_DEBUG(dbgs() << "Found parallel MACs!!\n";
441                dbgs() << "- "; PMul0->Root->dump();
442                dbgs() << "- "; PMul1->Root->dump());
443 
444     auto *VecLd0 = cast<LoadInst>(PMul0->VecLd[0]);
445     auto *VecLd1 = cast<LoadInst>(PMul1->VecLd[0]);
446     Acc = CreateSMLADCall(VecLd0, VecLd1, Acc, PMul1->Exchange, InsertAfter);
447     InsertAfter = Acc;
448   }
449 
450   if (Acc != Reduction.Phi) {
451     LLVM_DEBUG(dbgs() << "Replace Accumulate: "; Acc->dump());
452     Reduction.AccIntAdd->replaceAllUsesWith(Acc);
453     return true;
454   }
455   return false;
456 }
457 
458 static void MatchReductions(Function &F, Loop *TheLoop, BasicBlock *Header,
459                             ReductionList &Reductions) {
460   RecurrenceDescriptor RecDesc;
461   const bool HasFnNoNaNAttr =
462     F.getFnAttribute("no-nans-fp-math").getValueAsString() == "true";
463   const BasicBlock *Latch = TheLoop->getLoopLatch();
464 
465   // We need a preheader as getIncomingValueForBlock assumes there is one.
466   if (!TheLoop->getLoopPreheader()) {
467     LLVM_DEBUG(dbgs() << "No preheader found, bailing out\n");
468     return;
469   }
470 
471   for (PHINode &Phi : Header->phis()) {
472     const auto *Ty = Phi.getType();
473     if (!Ty->isIntegerTy(32) && !Ty->isIntegerTy(64))
474       continue;
475 
476     const bool IsReduction =
477       RecurrenceDescriptor::AddReductionVar(&Phi,
478                                             RecurrenceDescriptor::RK_IntegerAdd,
479                                             TheLoop, HasFnNoNaNAttr, RecDesc);
480     if (!IsReduction)
481       continue;
482 
483     Instruction *Acc = dyn_cast<Instruction>(Phi.getIncomingValueForBlock(Latch));
484     if (!Acc)
485       continue;
486 
487     Reductions.push_back(Reduction(&Phi, Acc));
488   }
489 
490   LLVM_DEBUG(
491     dbgs() << "\nAccumulating integer additions (reductions) found:\n";
492     for (auto &R : Reductions) {
493       dbgs() << "-  "; R.Phi->dump();
494       dbgs() << "-> "; R.AccIntAdd->dump();
495     }
496   );
497 }
498 
499 static void AddMACCandidate(OpChainList &Candidates,
500                             Instruction *Mul,
501                             Value *MulOp0, Value *MulOp1) {
502   LLVM_DEBUG(dbgs() << "OK, found acc mul:\t"; Mul->dump());
503   assert(Mul->getOpcode() == Instruction::Mul &&
504          "expected mul instruction");
505   ValueList LHS;
506   ValueList RHS;
507   if (IsNarrowSequence<16>(MulOp0, LHS) &&
508       IsNarrowSequence<16>(MulOp1, RHS)) {
509     LLVM_DEBUG(dbgs() << "OK, found narrow mul: "; Mul->dump());
510     Candidates.push_back(make_unique<BinOpChain>(Mul, LHS, RHS));
511   }
512 }
513 
514 static void MatchParallelMACSequences(Reduction &R,
515                                       OpChainList &Candidates) {
516   Instruction *Acc = R.AccIntAdd;
517   LLVM_DEBUG(dbgs() << "\n- Analysing:\t" << *Acc);
518 
519   // Returns false to signal the search should be stopped.
520   std::function<bool(Value*)> Match =
521     [&Candidates, &Match](Value *V) -> bool {
522 
523     auto *I = dyn_cast<Instruction>(V);
524     if (!I)
525       return false;
526 
527     switch (I->getOpcode()) {
528     case Instruction::Add:
529       if (Match(I->getOperand(0)) || (Match(I->getOperand(1))))
530         return true;
531       break;
532     case Instruction::Mul: {
533       Value *MulOp0 = I->getOperand(0);
534       Value *MulOp1 = I->getOperand(1);
535       if (isa<SExtInst>(MulOp0) && isa<SExtInst>(MulOp1))
536         AddMACCandidate(Candidates, I, MulOp0, MulOp1);
537       return false;
538     }
539     case Instruction::SExt:
540       return Match(I->getOperand(0));
541     }
542     return false;
543   };
544 
545   while (Match (Acc));
546   LLVM_DEBUG(dbgs() << "Finished matching MAC sequences, found "
547              << Candidates.size() << " candidates.\n");
548 }
549 
550 // Collects all instructions that are not part of the MAC chains, which is the
551 // set of instructions that can potentially alias with the MAC operands.
552 static void AliasCandidates(BasicBlock *Header, Instructions &Reads,
553                             Instructions &Writes) {
554   for (auto &I : *Header) {
555     if (I.mayReadFromMemory())
556       Reads.push_back(&I);
557     if (I.mayWriteToMemory())
558       Writes.push_back(&I);
559   }
560 }
561 
562 // Check whether statements in the basic block that write to memory alias with
563 // the memory locations accessed by the MAC-chains.
564 // TODO: we need the read statements when we accept more complicated chains.
565 static bool AreAliased(AliasAnalysis *AA, Instructions &Reads,
566                        Instructions &Writes, OpChainList &MACCandidates) {
567   LLVM_DEBUG(dbgs() << "Alias checks:\n");
568   for (auto &MAC : MACCandidates) {
569     LLVM_DEBUG(dbgs() << "mul: "; MAC->Root->dump());
570 
571     // At the moment, we allow only simple chains that only consist of reads,
572     // accumulate their result with an integer add, and thus that don't write
573     // memory, and simply bail if they do.
574     if (!MAC->ReadOnly)
575       return true;
576 
577     // Now for all writes in the basic block, check that they don't alias with
578     // the memory locations accessed by our MAC-chain:
579     for (auto *I : Writes) {
580       LLVM_DEBUG(dbgs() << "- "; I->dump());
581       assert(MAC->MemLocs.size() >= 2 && "expecting at least 2 memlocs");
582       for (auto &MemLoc : MAC->MemLocs) {
583         if (isModOrRefSet(intersectModRef(AA->getModRefInfo(I, MemLoc),
584                                           ModRefInfo::ModRef))) {
585           LLVM_DEBUG(dbgs() << "Yes, aliases found\n");
586           return true;
587         }
588       }
589     }
590   }
591 
592   LLVM_DEBUG(dbgs() << "OK: no aliases found!\n");
593   return false;
594 }
595 
596 static bool CheckMACMemory(OpChainList &Candidates) {
597   for (auto &C : Candidates) {
598     // A mul has 2 operands, and a narrow op consist of sext and a load; thus
599     // we expect at least 4 items in this operand value list.
600     if (C->size() < 4) {
601       LLVM_DEBUG(dbgs() << "Operand list too short.\n");
602       return false;
603     }
604     C->SetMemoryLocations();
605     ValueList &LHS = static_cast<BinOpChain*>(C.get())->LHS;
606     ValueList &RHS = static_cast<BinOpChain*>(C.get())->RHS;
607 
608     // Use +=2 to skip over the expected extend instructions.
609     for (unsigned i = 0, e = LHS.size(); i < e; i += 2) {
610       if (!isa<LoadInst>(LHS[i]) || !isa<LoadInst>(RHS[i]))
611         return false;
612     }
613   }
614   return true;
615 }
616 
617 // Loop Pass that needs to identify integer add/sub reductions of 16-bit vector
618 // multiplications.
619 // To use SMLAD:
620 // 1) we first need to find integer add reduction PHIs,
621 // 2) then from the PHI, look for this pattern:
622 //
623 // acc0 = phi i32 [0, %entry], [%acc1, %loop.body]
624 // ld0 = load i16
625 // sext0 = sext i16 %ld0 to i32
626 // ld1 = load i16
627 // sext1 = sext i16 %ld1 to i32
628 // mul0 = mul %sext0, %sext1
629 // ld2 = load i16
630 // sext2 = sext i16 %ld2 to i32
631 // ld3 = load i16
632 // sext3 = sext i16 %ld3 to i32
633 // mul1 = mul i32 %sext2, %sext3
634 // add0 = add i32 %mul0, %acc0
635 // acc1 = add i32 %add0, %mul1
636 //
637 // Which can be selected to:
638 //
639 // ldr.h r0
640 // ldr.h r1
641 // smlad r2, r0, r1, r2
642 //
643 // If constants are used instead of loads, these will need to be hoisted
644 // out and into a register.
645 //
646 // If loop invariants are used instead of loads, these need to be packed
647 // before the loop begins.
648 //
649 bool ARMParallelDSP::MatchSMLAD(Function &F) {
650   BasicBlock *Header = L->getHeader();
651   LLVM_DEBUG(dbgs() << "= Matching SMLAD =\n";
652              dbgs() << "Header block:\n"; Header->dump();
653              dbgs() << "Loop info:\n\n"; L->dump());
654 
655   bool Changed = false;
656   ReductionList Reductions;
657   MatchReductions(F, L, Header, Reductions);
658 
659   for (auto &R : Reductions) {
660     OpChainList MACCandidates;
661     MatchParallelMACSequences(R, MACCandidates);
662     if (!CheckMACMemory(MACCandidates))
663       continue;
664 
665     R.MACCandidates = std::move(MACCandidates);
666 
667     LLVM_DEBUG(dbgs() << "MAC candidates:\n";
668       for (auto &M : R.MACCandidates)
669         M->Root->dump();
670       dbgs() << "\n";);
671   }
672 
673   // Collect all instructions that may read or write memory. Our alias
674   // analysis checks bail out if any of these instructions aliases with an
675   // instruction from the MAC-chain.
676   Instructions Reads, Writes;
677   AliasCandidates(Header, Reads, Writes);
678 
679   for (auto &R : Reductions) {
680     if (AreAliased(AA, Reads, Writes, R.MACCandidates))
681       return false;
682     CreateParallelMACPairs(R);
683     Changed |= InsertParallelMACs(R);
684   }
685 
686   LLVM_DEBUG(if (Changed) dbgs() << "Header block:\n"; Header->dump(););
687   return Changed;
688 }
689 
690 static LoadInst *CreateLoadIns(IRBuilder<NoFolder> &IRB, LoadInst &BaseLoad,
691                                Type *LoadTy) {
692   const unsigned AddrSpace = BaseLoad.getPointerAddressSpace();
693 
694   Value *VecPtr = IRB.CreateBitCast(BaseLoad.getPointerOperand(),
695                                     LoadTy->getPointerTo(AddrSpace));
696   return IRB.CreateAlignedLoad(LoadTy, VecPtr, BaseLoad.getAlignment());
697 }
698 
699 Instruction *ARMParallelDSP::CreateSMLADCall(LoadInst *VecLd0, LoadInst *VecLd1,
700                                              Instruction *Acc, bool Exchange,
701                                              Instruction *InsertAfter) {
702   LLVM_DEBUG(dbgs() << "Create SMLAD intrinsic using:\n"
703              << "- " << *VecLd0 << "\n"
704              << "- " << *VecLd1 << "\n"
705              << "- " << *Acc << "\n"
706              << "Exchange: " << Exchange << "\n");
707 
708   IRBuilder<NoFolder> Builder(InsertAfter->getParent(),
709                               ++BasicBlock::iterator(InsertAfter));
710 
711   // Replace the reduction chain with an intrinsic call
712   Type *Ty = IntegerType::get(M->getContext(), 32);
713   LoadInst *NewLd0 = CreateLoadIns(Builder, VecLd0[0], Ty);
714   LoadInst *NewLd1 = CreateLoadIns(Builder, VecLd1[0], Ty);
715   Value* Args[] = { NewLd0, NewLd1, Acc };
716   Function *SMLAD = nullptr;
717   if (Exchange)
718     SMLAD = Acc->getType()->isIntegerTy(32) ?
719       Intrinsic::getDeclaration(M, Intrinsic::arm_smladx) :
720       Intrinsic::getDeclaration(M, Intrinsic::arm_smlaldx);
721   else
722     SMLAD = Acc->getType()->isIntegerTy(32) ?
723       Intrinsic::getDeclaration(M, Intrinsic::arm_smlad) :
724       Intrinsic::getDeclaration(M, Intrinsic::arm_smlald);
725   CallInst *Call = Builder.CreateCall(SMLAD, Args);
726   NumSMLAD++;
727   return Call;
728 }
729 
730 // Compare the value lists in Other to this chain.
731 bool BinOpChain::AreSymmetrical(BinOpChain *Other) {
732   // Element-by-element comparison of Value lists returning true if they are
733   // instructions with the same opcode or constants with the same value.
734   auto CompareValueList = [](const ValueList &VL0,
735                              const ValueList &VL1) {
736     if (VL0.size() != VL1.size()) {
737       LLVM_DEBUG(dbgs() << "Muls are mismatching operand list lengths: "
738                         << VL0.size() << " != " << VL1.size() << "\n");
739       return false;
740     }
741 
742     const unsigned Pairs = VL0.size();
743     LLVM_DEBUG(dbgs() << "Number of operand pairs: " << Pairs << "\n");
744 
745     for (unsigned i = 0; i < Pairs; ++i) {
746       const Value *V0 = VL0[i];
747       const Value *V1 = VL1[i];
748       const auto *Inst0 = dyn_cast<Instruction>(V0);
749       const auto *Inst1 = dyn_cast<Instruction>(V1);
750 
751       LLVM_DEBUG(dbgs() << "Pair " << i << ":\n";
752                 dbgs() << "mul1: "; V0->dump();
753                 dbgs() << "mul2: "; V1->dump());
754 
755       if (!Inst0 || !Inst1)
756         return false;
757 
758       if (Inst0->isSameOperationAs(Inst1)) {
759         LLVM_DEBUG(dbgs() << "OK: same operation found!\n");
760         continue;
761       }
762 
763       const APInt *C0, *C1;
764       if (!(match(V0, m_APInt(C0)) && match(V1, m_APInt(C1)) && C0 == C1))
765         return false;
766     }
767 
768     LLVM_DEBUG(dbgs() << "OK: found symmetrical operand lists.\n");
769     return true;
770   };
771 
772   return CompareValueList(LHS, Other->LHS) &&
773          CompareValueList(RHS, Other->RHS);
774 }
775 
776 Pass *llvm::createARMParallelDSPPass() {
777   return new ARMParallelDSP();
778 }
779 
780 char ARMParallelDSP::ID = 0;
781 
782 INITIALIZE_PASS_BEGIN(ARMParallelDSP, "arm-parallel-dsp",
783                 "Transform loops to use DSP intrinsics", false, false)
784 INITIALIZE_PASS_END(ARMParallelDSP, "arm-parallel-dsp",
785                 "Transform loops to use DSP intrinsics", false, false)
786