1 //===- ParallelDSP.cpp - Parallel DSP Pass --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file
10 /// Armv6 introduced instructions to perform 32-bit SIMD operations. The
11 /// purpose of this pass is do some IR pattern matching to create ACLE
12 /// DSP intrinsics, which map on these 32-bit SIMD operations.
13 /// This pass runs only when unaligned accesses is supported/enabled.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "llvm/ADT/Statistic.h"
18 #include "llvm/ADT/SmallPtrSet.h"
19 #include "llvm/Analysis/AliasAnalysis.h"
20 #include "llvm/Analysis/LoopAccessAnalysis.h"
21 #include "llvm/Analysis/LoopPass.h"
22 #include "llvm/Analysis/LoopInfo.h"
23 #include "llvm/IR/Instructions.h"
24 #include "llvm/IR/NoFolder.h"
25 #include "llvm/Transforms/Scalar.h"
26 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
27 #include "llvm/Transforms/Utils/LoopUtils.h"
28 #include "llvm/Pass.h"
29 #include "llvm/PassRegistry.h"
30 #include "llvm/PassSupport.h"
31 #include "llvm/Support/Debug.h"
32 #include "llvm/IR/PatternMatch.h"
33 #include "llvm/CodeGen/TargetPassConfig.h"
34 #include "ARM.h"
35 #include "ARMSubtarget.h"
36 
37 using namespace llvm;
38 using namespace PatternMatch;
39 
40 #define DEBUG_TYPE "arm-parallel-dsp"
41 
42 STATISTIC(NumSMLAD , "Number of smlad instructions generated");
43 
44 static cl::opt<bool>
45 DisableParallelDSP("disable-arm-parallel-dsp", cl::Hidden, cl::init(false),
46                    cl::desc("Disable the ARM Parallel DSP pass"));
47 
48 namespace {
49   struct OpChain;
50   struct BinOpChain;
51   struct Reduction;
52 
53   using OpChainList     = SmallVector<std::unique_ptr<OpChain>, 8>;
54   using ReductionList   = SmallVector<Reduction, 8>;
55   using ValueList       = SmallVector<Value*, 8>;
56   using MemInstList     = SmallVector<LoadInst*, 8>;
57   using PMACPair        = std::pair<BinOpChain*,BinOpChain*>;
58   using PMACPairList    = SmallVector<PMACPair, 8>;
59   using Instructions    = SmallVector<Instruction*,16>;
60   using MemLocList      = SmallVector<MemoryLocation, 4>;
61 
62   struct OpChain {
63     Instruction   *Root;
64     ValueList     AllValues;
65     MemInstList   VecLd;    // List of all load instructions.
66     MemInstList   Loads;
67     bool          ReadOnly = true;
68 
69     OpChain(Instruction *I, ValueList &vl) : Root(I), AllValues(vl) { }
70     virtual ~OpChain() = default;
71 
72     void PopulateLoads() {
73       for (auto *V : AllValues) {
74         if (auto *Ld = dyn_cast<LoadInst>(V))
75           Loads.push_back(Ld);
76       }
77     }
78 
79     unsigned size() const { return AllValues.size(); }
80   };
81 
82   // 'BinOpChain' and 'Reduction' are just some bookkeeping data structures.
83   // 'Reduction' contains the phi-node and accumulator statement from where we
84   // start pattern matching, and 'BinOpChain' the multiplication
85   // instructions that are candidates for parallel execution.
86   struct BinOpChain : public OpChain {
87     ValueList     LHS;      // List of all (narrow) left hand operands.
88     ValueList     RHS;      // List of all (narrow) right hand operands.
89     bool Exchange = false;
90 
91     BinOpChain(Instruction *I, ValueList &lhs, ValueList &rhs) :
92       OpChain(I, lhs), LHS(lhs), RHS(rhs) {
93         for (auto *V : RHS)
94           AllValues.push_back(V);
95       }
96 
97     bool AreSymmetrical(BinOpChain *Other);
98   };
99 
100   struct Reduction {
101     PHINode         *Phi;             // The Phi-node from where we start
102                                       // pattern matching.
103     Instruction     *AccIntAdd;       // The accumulating integer add statement,
104                                       // i.e, the reduction statement.
105     OpChainList     MACCandidates;    // The MAC candidates associated with
106                                       // this reduction statement.
107     PMACPairList    PMACPairs;
108     Reduction (PHINode *P, Instruction *Acc) : Phi(P), AccIntAdd(Acc) { };
109   };
110 
111   class WidenedLoad {
112     LoadInst *NewLd = nullptr;
113     SmallVector<LoadInst*, 4> Loads;
114 
115   public:
116     WidenedLoad(SmallVectorImpl<LoadInst*> &Lds, LoadInst *Wide)
117       : NewLd(Wide) {
118       for (auto *I : Lds)
119         Loads.push_back(I);
120     }
121     LoadInst *getLoad() {
122       return NewLd;
123     }
124   };
125 
126   class ARMParallelDSP : public LoopPass {
127     ScalarEvolution   *SE;
128     AliasAnalysis     *AA;
129     TargetLibraryInfo *TLI;
130     DominatorTree     *DT;
131     LoopInfo          *LI;
132     Loop              *L;
133     const DataLayout  *DL;
134     Module            *M;
135     std::map<LoadInst*, LoadInst*> LoadPairs;
136     std::map<LoadInst*, std::unique_ptr<WidenedLoad>> WideLoads;
137 
138     bool RecordMemoryOps(BasicBlock *BB);
139     bool InsertParallelMACs(Reduction &Reduction);
140     bool AreSequentialLoads(LoadInst *Ld0, LoadInst *Ld1, MemInstList &VecMem);
141     LoadInst* CreateWideLoad(SmallVectorImpl<LoadInst*> &Loads,
142                              IntegerType *LoadTy);
143     void CreateParallelMACPairs(Reduction &R);
144     Instruction *CreateSMLADCall(SmallVectorImpl<LoadInst*> &VecLd0,
145                                  SmallVectorImpl<LoadInst*> &VecLd1,
146                                  Instruction *Acc, bool Exchange,
147                                  Instruction *InsertAfter);
148 
149     /// Try to match and generate: SMLAD, SMLADX - Signed Multiply Accumulate
150     /// Dual performs two signed 16x16-bit multiplications. It adds the
151     /// products to a 32-bit accumulate operand. Optionally, the instruction can
152     /// exchange the halfwords of the second operand before performing the
153     /// arithmetic.
154     bool MatchSMLAD(Function &F);
155 
156   public:
157     static char ID;
158 
159     ARMParallelDSP() : LoopPass(ID) { }
160 
161     bool doInitialization(Loop *L, LPPassManager &LPM) override {
162       LoadPairs.clear();
163       WideLoads.clear();
164       return true;
165     }
166 
167     void getAnalysisUsage(AnalysisUsage &AU) const override {
168       LoopPass::getAnalysisUsage(AU);
169       AU.addRequired<AssumptionCacheTracker>();
170       AU.addRequired<ScalarEvolutionWrapperPass>();
171       AU.addRequired<AAResultsWrapperPass>();
172       AU.addRequired<TargetLibraryInfoWrapperPass>();
173       AU.addRequired<LoopInfoWrapperPass>();
174       AU.addRequired<DominatorTreeWrapperPass>();
175       AU.addRequired<TargetPassConfig>();
176       AU.addPreserved<LoopInfoWrapperPass>();
177       AU.setPreservesCFG();
178     }
179 
180     bool runOnLoop(Loop *TheLoop, LPPassManager &) override {
181       if (DisableParallelDSP)
182         return false;
183       L = TheLoop;
184       SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
185       AA = &getAnalysis<AAResultsWrapperPass>().getAAResults();
186       TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
187       DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
188       LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
189       auto &TPC = getAnalysis<TargetPassConfig>();
190 
191       BasicBlock *Header = TheLoop->getHeader();
192       if (!Header)
193         return false;
194 
195       // TODO: We assume the loop header and latch to be the same block.
196       // This is not a fundamental restriction, but lifting this would just
197       // require more work to do the transformation and then patch up the CFG.
198       if (Header != TheLoop->getLoopLatch()) {
199         LLVM_DEBUG(dbgs() << "The loop header is not the loop latch: not "
200                              "running pass ARMParallelDSP\n");
201         return false;
202       }
203 
204       // We need a preheader as getIncomingValueForBlock assumes there is one.
205       if (!TheLoop->getLoopPreheader()) {
206         LLVM_DEBUG(dbgs() << "No preheader found, bailing out\n");
207         return false;
208       }
209 
210       Function &F = *Header->getParent();
211       M = F.getParent();
212       DL = &M->getDataLayout();
213 
214       auto &TM = TPC.getTM<TargetMachine>();
215       auto *ST = &TM.getSubtarget<ARMSubtarget>(F);
216 
217       if (!ST->allowsUnalignedMem()) {
218         LLVM_DEBUG(dbgs() << "Unaligned memory access not supported: not "
219                              "running pass ARMParallelDSP\n");
220         return false;
221       }
222 
223       if (!ST->hasDSP()) {
224         LLVM_DEBUG(dbgs() << "DSP extension not enabled: not running pass "
225                              "ARMParallelDSP\n");
226         return false;
227       }
228 
229       if (!ST->isLittle()) {
230         LLVM_DEBUG(dbgs() << "Only supporting little endian: not running pass "
231                           << "ARMParallelDSP\n");
232         return false;
233       }
234 
235       LoopAccessInfo LAI(L, SE, TLI, AA, DT, LI);
236 
237       LLVM_DEBUG(dbgs() << "\n== Parallel DSP pass ==\n");
238       LLVM_DEBUG(dbgs() << " - " << F.getName() << "\n\n");
239 
240       if (!RecordMemoryOps(Header)) {
241         LLVM_DEBUG(dbgs() << " - No sequential loads found.\n");
242         return false;
243       }
244 
245       bool Changes = MatchSMLAD(F);
246       return Changes;
247     }
248   };
249 }
250 
251 // MaxBitwidth: the maximum supported bitwidth of the elements in the DSP
252 // instructions, which is set to 16. So here we should collect all i8 and i16
253 // narrow operations.
254 // TODO: we currently only collect i16, and will support i8 later, so that's
255 // why we check that types are equal to MaxBitWidth, and not <= MaxBitWidth.
256 template<unsigned MaxBitWidth>
257 static bool IsNarrowSequence(Value *V, ValueList &VL) {
258   ConstantInt *CInt;
259 
260   if (match(V, m_ConstantInt(CInt))) {
261     // TODO: if a constant is used, it needs to fit within the bit width.
262     return false;
263   }
264 
265   auto *I = dyn_cast<Instruction>(V);
266   if (!I)
267    return false;
268 
269   Value *Val, *LHS, *RHS;
270   if (match(V, m_Trunc(m_Value(Val)))) {
271     if (cast<TruncInst>(I)->getDestTy()->getIntegerBitWidth() == MaxBitWidth)
272       return IsNarrowSequence<MaxBitWidth>(Val, VL);
273   } else if (match(V, m_Add(m_Value(LHS), m_Value(RHS)))) {
274     // TODO: we need to implement sadd16/sadd8 for this, which enables to
275     // also do the rewrite for smlad8.ll, but it is unsupported for now.
276     return false;
277   } else if (match(V, m_ZExtOrSExt(m_Value(Val)))) {
278     if (cast<CastInst>(I)->getSrcTy()->getIntegerBitWidth() != MaxBitWidth)
279       return false;
280 
281     if (match(Val, m_Load(m_Value()))) {
282       VL.push_back(Val);
283       VL.push_back(I);
284       return true;
285     }
286   }
287   return false;
288 }
289 
290 template<typename MemInst>
291 static bool AreSequentialAccesses(MemInst *MemOp0, MemInst *MemOp1,
292                                   const DataLayout &DL, ScalarEvolution &SE) {
293   if (isConsecutiveAccess(MemOp0, MemOp1, DL, SE))
294     return true;
295   return false;
296 }
297 
298 bool ARMParallelDSP::AreSequentialLoads(LoadInst *Ld0, LoadInst *Ld1,
299                                         MemInstList &VecMem) {
300   if (!Ld0 || !Ld1)
301     return false;
302 
303   if (!LoadPairs.count(Ld0) || LoadPairs[Ld0] != Ld1)
304     return false;
305 
306   LLVM_DEBUG(dbgs() << "Loads are sequential and valid:\n";
307     dbgs() << "Ld0:"; Ld0->dump();
308     dbgs() << "Ld1:"; Ld1->dump();
309   );
310 
311   VecMem.clear();
312   VecMem.push_back(Ld0);
313   VecMem.push_back(Ld1);
314   return true;
315 }
316 
317 /// Iterate through the block and record base, offset pairs of loads which can
318 /// be widened into a single load.
319 bool ARMParallelDSP::RecordMemoryOps(BasicBlock *BB) {
320   SmallVector<LoadInst*, 8> Loads;
321   SmallVector<Instruction*, 8> Writes;
322 
323   // Collect loads and instruction that may write to memory. For now we only
324   // record loads which are simple, sign-extended and have a single user.
325   // TODO: Allow zero-extended loads.
326   for (auto &I : *BB) {
327     if (I.mayWriteToMemory())
328       Writes.push_back(&I);
329     auto *Ld = dyn_cast<LoadInst>(&I);
330     if (!Ld || !Ld->isSimple() ||
331         !Ld->hasOneUse() || !isa<SExtInst>(Ld->user_back()))
332       continue;
333     Loads.push_back(Ld);
334   }
335 
336   using InstSet = std::set<Instruction*>;
337   using DepMap = std::map<Instruction*, InstSet>;
338   DepMap RAWDeps;
339 
340   // Record any writes that may alias a load.
341   const auto Size = LocationSize::unknown();
342   for (auto Read : Loads) {
343     for (auto Write : Writes) {
344       MemoryLocation ReadLoc =
345         MemoryLocation(Read->getPointerOperand(), Size);
346 
347       if (!isModOrRefSet(intersectModRef(AA->getModRefInfo(Write, ReadLoc),
348           ModRefInfo::ModRef)))
349         continue;
350       if (DT->dominates(Write, Read))
351         RAWDeps[Read].insert(Write);
352     }
353   }
354 
355   // Check whether there's not a write between the two loads which would
356   // prevent them from being safely merged.
357   auto SafeToPair = [&](LoadInst *Base, LoadInst *Offset) {
358     LoadInst *Dominator = DT->dominates(Base, Offset) ? Base : Offset;
359     LoadInst *Dominated = DT->dominates(Base, Offset) ? Offset : Base;
360 
361     if (RAWDeps.count(Dominated)) {
362       InstSet &WritesBefore = RAWDeps[Dominated];
363 
364       for (auto Before : WritesBefore) {
365 
366         // We can't move the second load backward, past a write, to merge
367         // with the first load.
368         if (DT->dominates(Dominator, Before))
369           return false;
370       }
371     }
372     return true;
373   };
374 
375   // Record base, offset load pairs.
376   for (auto *Base : Loads) {
377     for (auto *Offset : Loads) {
378       if (Base == Offset)
379         continue;
380 
381       if (AreSequentialAccesses<LoadInst>(Base, Offset, *DL, *SE) &&
382           SafeToPair(Base, Offset)) {
383         LoadPairs[Base] = Offset;
384         break;
385       }
386     }
387   }
388 
389   LLVM_DEBUG(if (!LoadPairs.empty()) {
390                dbgs() << "Consecutive load pairs:\n";
391                for (auto &MapIt : LoadPairs) {
392                  LLVM_DEBUG(dbgs() << *MapIt.first << ", "
393                             << *MapIt.second << "\n");
394                }
395              });
396   return LoadPairs.size() > 1;
397 }
398 
399 void ARMParallelDSP::CreateParallelMACPairs(Reduction &R) {
400   OpChainList &Candidates = R.MACCandidates;
401   PMACPairList &PMACPairs = R.PMACPairs;
402   const unsigned Elems = Candidates.size();
403 
404   if (Elems < 2)
405     return;
406 
407   auto CanPair = [&](BinOpChain *PMul0, BinOpChain *PMul1) {
408     if (!PMul0->AreSymmetrical(PMul1))
409       return false;
410 
411     // The first elements of each vector should be loads with sexts. If we
412     // find that its two pairs of consecutive loads, then these can be
413     // transformed into two wider loads and the users can be replaced with
414     // DSP intrinsics.
415     for (unsigned x = 0; x < PMul0->LHS.size(); x += 2) {
416       auto *Ld0 = dyn_cast<LoadInst>(PMul0->LHS[x]);
417       auto *Ld1 = dyn_cast<LoadInst>(PMul1->LHS[x]);
418       auto *Ld2 = dyn_cast<LoadInst>(PMul0->RHS[x]);
419       auto *Ld3 = dyn_cast<LoadInst>(PMul1->RHS[x]);
420 
421       if (!Ld0 || !Ld1 || !Ld2 || !Ld3)
422         return false;
423 
424       LLVM_DEBUG(dbgs() << "Loads:\n"
425                  << " - " << *Ld0 << "\n"
426                  << " - " << *Ld1 << "\n"
427                  << " - " << *Ld2 << "\n"
428                  << " - " << *Ld3 << "\n");
429 
430       if (AreSequentialLoads(Ld0, Ld1, PMul0->VecLd)) {
431         if (AreSequentialLoads(Ld2, Ld3, PMul1->VecLd)) {
432           LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n");
433           PMACPairs.push_back(std::make_pair(PMul0, PMul1));
434           return true;
435         } else if (AreSequentialLoads(Ld3, Ld2, PMul1->VecLd)) {
436           LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n");
437           LLVM_DEBUG(dbgs() << "    exchanging Ld2 and Ld3\n");
438           PMul1->Exchange = true;
439           PMACPairs.push_back(std::make_pair(PMul0, PMul1));
440           return true;
441         }
442       } else if (AreSequentialLoads(Ld1, Ld0, PMul0->VecLd) &&
443                  AreSequentialLoads(Ld2, Ld3, PMul1->VecLd)) {
444         LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n");
445         LLVM_DEBUG(dbgs() << "    exchanging Ld0 and Ld1\n");
446         LLVM_DEBUG(dbgs() << "    and swapping muls\n");
447         PMul0->Exchange = true;
448         // Only the second operand can be exchanged, so swap the muls.
449         PMACPairs.push_back(std::make_pair(PMul1, PMul0));
450         return true;
451       }
452     }
453     return false;
454   };
455 
456   SmallPtrSet<const Instruction*, 4> Paired;
457   for (unsigned i = 0; i < Elems; ++i) {
458     BinOpChain *PMul0 = static_cast<BinOpChain*>(Candidates[i].get());
459     if (Paired.count(PMul0->Root))
460       continue;
461 
462     for (unsigned j = 0; j < Elems; ++j) {
463       if (i == j)
464         continue;
465 
466       BinOpChain *PMul1 = static_cast<BinOpChain*>(Candidates[j].get());
467       if (Paired.count(PMul1->Root))
468         continue;
469 
470       const Instruction *Mul0 = PMul0->Root;
471       const Instruction *Mul1 = PMul1->Root;
472       if (Mul0 == Mul1)
473         continue;
474 
475       assert(PMul0 != PMul1 && "expected different chains");
476 
477       if (CanPair(PMul0, PMul1)) {
478         Paired.insert(Mul0);
479         Paired.insert(Mul1);
480         break;
481       }
482     }
483   }
484 }
485 
486 bool ARMParallelDSP::InsertParallelMACs(Reduction &Reduction) {
487   Instruction *Acc = Reduction.Phi;
488   Instruction *InsertAfter = Reduction.AccIntAdd;
489 
490   for (auto &Pair : Reduction.PMACPairs) {
491     BinOpChain *PMul0 = Pair.first;
492     BinOpChain *PMul1 = Pair.second;
493     LLVM_DEBUG(dbgs() << "Found parallel MACs:\n"
494                << "- " << *PMul0->Root << "\n"
495                << "- " << *PMul1->Root << "\n");
496 
497     Acc = CreateSMLADCall(PMul0->VecLd, PMul1->VecLd, Acc, PMul1->Exchange,
498                           InsertAfter);
499     InsertAfter = Acc;
500   }
501 
502   if (Acc != Reduction.Phi) {
503     LLVM_DEBUG(dbgs() << "Replace Accumulate: "; Acc->dump());
504     Reduction.AccIntAdd->replaceAllUsesWith(Acc);
505     return true;
506   }
507   return false;
508 }
509 
510 static void MatchParallelMACSequences(Reduction &R,
511                                       OpChainList &Candidates) {
512   Instruction *Acc = R.AccIntAdd;
513   LLVM_DEBUG(dbgs() << "\n- Analysing:\t" << *Acc << "\n");
514 
515   // Returns false to signal the search should be stopped.
516   std::function<bool(Value*)> Match =
517     [&Candidates, &Match](Value *V) -> bool {
518 
519     auto *I = dyn_cast<Instruction>(V);
520     if (!I)
521       return false;
522 
523     switch (I->getOpcode()) {
524     case Instruction::Add:
525       if (Match(I->getOperand(0)) || (Match(I->getOperand(1))))
526         return true;
527       break;
528     case Instruction::Mul: {
529       Value *MulOp0 = I->getOperand(0);
530       Value *MulOp1 = I->getOperand(1);
531       if (isa<SExtInst>(MulOp0) && isa<SExtInst>(MulOp1)) {
532         ValueList LHS;
533         ValueList RHS;
534         if (IsNarrowSequence<16>(MulOp0, LHS) &&
535             IsNarrowSequence<16>(MulOp1, RHS)) {
536           Candidates.push_back(make_unique<BinOpChain>(I, LHS, RHS));
537         }
538       }
539       return false;
540     }
541     case Instruction::SExt:
542       return Match(I->getOperand(0));
543     }
544     return false;
545   };
546 
547   while (Match (Acc));
548   LLVM_DEBUG(dbgs() << "Finished matching MAC sequences, found "
549              << Candidates.size() << " candidates.\n");
550 }
551 
552 static bool CheckMACMemory(OpChainList &Candidates) {
553   for (auto &C : Candidates) {
554     // A mul has 2 operands, and a narrow op consist of sext and a load; thus
555     // we expect at least 4 items in this operand value list.
556     if (C->size() < 4) {
557       LLVM_DEBUG(dbgs() << "Operand list too short.\n");
558       return false;
559     }
560     C->PopulateLoads();
561     ValueList &LHS = static_cast<BinOpChain*>(C.get())->LHS;
562     ValueList &RHS = static_cast<BinOpChain*>(C.get())->RHS;
563 
564     // Use +=2 to skip over the expected extend instructions.
565     for (unsigned i = 0, e = LHS.size(); i < e; i += 2) {
566       if (!isa<LoadInst>(LHS[i]) || !isa<LoadInst>(RHS[i]))
567         return false;
568     }
569   }
570   return true;
571 }
572 
573 // Loop Pass that needs to identify integer add/sub reductions of 16-bit vector
574 // multiplications.
575 // To use SMLAD:
576 // 1) we first need to find integer add reduction PHIs,
577 // 2) then from the PHI, look for this pattern:
578 //
579 // acc0 = phi i32 [0, %entry], [%acc1, %loop.body]
580 // ld0 = load i16
581 // sext0 = sext i16 %ld0 to i32
582 // ld1 = load i16
583 // sext1 = sext i16 %ld1 to i32
584 // mul0 = mul %sext0, %sext1
585 // ld2 = load i16
586 // sext2 = sext i16 %ld2 to i32
587 // ld3 = load i16
588 // sext3 = sext i16 %ld3 to i32
589 // mul1 = mul i32 %sext2, %sext3
590 // add0 = add i32 %mul0, %acc0
591 // acc1 = add i32 %add0, %mul1
592 //
593 // Which can be selected to:
594 //
595 // ldr.h r0
596 // ldr.h r1
597 // smlad r2, r0, r1, r2
598 //
599 // If constants are used instead of loads, these will need to be hoisted
600 // out and into a register.
601 //
602 // If loop invariants are used instead of loads, these need to be packed
603 // before the loop begins.
604 //
605 bool ARMParallelDSP::MatchSMLAD(Function &F) {
606 
607   auto FindReductions = [&](ReductionList &Reductions) {
608     RecurrenceDescriptor RecDesc;
609     const bool HasFnNoNaNAttr =
610       F.getFnAttribute("no-nans-fp-math").getValueAsString() == "true";
611     BasicBlock *Latch = L->getLoopLatch();
612 
613     for (PHINode &Phi : Latch->phis()) {
614       const auto *Ty = Phi.getType();
615       if (!Ty->isIntegerTy(32) && !Ty->isIntegerTy(64))
616         continue;
617 
618       const bool IsReduction = RecurrenceDescriptor::AddReductionVar(
619         &Phi, RecurrenceDescriptor::RK_IntegerAdd, L, HasFnNoNaNAttr, RecDesc);
620 
621       if (!IsReduction)
622         continue;
623 
624       Instruction *Acc = dyn_cast<Instruction>(Phi.getIncomingValueForBlock(Latch));
625       if (!Acc)
626         continue;
627 
628       Reductions.push_back(Reduction(&Phi, Acc));
629     }
630     return !Reductions.empty();
631   };
632 
633   ReductionList Reductions;
634   if (!FindReductions(Reductions))
635     return false;
636 
637   for (auto &R : Reductions) {
638     OpChainList MACCandidates;
639     MatchParallelMACSequences(R, MACCandidates);
640     if (!CheckMACMemory(MACCandidates))
641       continue;
642 
643     R.MACCandidates = std::move(MACCandidates);
644 
645     LLVM_DEBUG(dbgs() << "MAC candidates:\n";
646       for (auto &M : R.MACCandidates)
647         M->Root->dump();
648       dbgs() << "\n";);
649   }
650 
651   bool Changed = false;
652   // Check whether statements in the basic block that write to memory alias
653   // with the memory locations accessed by the MAC-chains.
654   for (auto &R : Reductions) {
655     CreateParallelMACPairs(R);
656     Changed |= InsertParallelMACs(R);
657   }
658 
659   return Changed;
660 }
661 
662 LoadInst* ARMParallelDSP::CreateWideLoad(SmallVectorImpl<LoadInst*> &Loads,
663                                          IntegerType *LoadTy) {
664   assert(Loads.size() == 2 && "currently only support widening two loads");
665 
666   LoadInst *Base = Loads[0];
667   LoadInst *Offset = Loads[1];
668 
669   Instruction *BaseSExt = dyn_cast<SExtInst>(Base->user_back());
670   Instruction *OffsetSExt = dyn_cast<SExtInst>(Offset->user_back());
671 
672   assert((BaseSExt && OffsetSExt)
673          && "Loads should have a single, extending, user");
674 
675   std::function<void(Value*, Value*)> MoveBefore =
676     [&](Value *A, Value *B) -> void {
677       if (!isa<Instruction>(A) || !isa<Instruction>(B))
678         return;
679 
680       auto *Source = cast<Instruction>(A);
681       auto *Sink = cast<Instruction>(B);
682 
683       if (DT->dominates(Source, Sink) ||
684           Source->getParent() != Sink->getParent() ||
685           isa<PHINode>(Source) || isa<PHINode>(Sink))
686         return;
687 
688       Source->moveBefore(Sink);
689       for (auto &U : Source->uses())
690         MoveBefore(Source, U.getUser());
691     };
692 
693   // Insert the load at the point of the original dominating load.
694   LoadInst *DomLoad = DT->dominates(Base, Offset) ? Base : Offset;
695   IRBuilder<NoFolder> IRB(DomLoad->getParent(),
696                           ++BasicBlock::iterator(DomLoad));
697 
698   // Bitcast the pointer to a wider type and create the wide load, while making
699   // sure to maintain the original alignment as this prevents ldrd from being
700   // generated when it could be illegal due to memory alignment.
701   const unsigned AddrSpace = DomLoad->getPointerAddressSpace();
702   Value *VecPtr = IRB.CreateBitCast(Base->getPointerOperand(),
703                                     LoadTy->getPointerTo(AddrSpace));
704   LoadInst *WideLoad = IRB.CreateAlignedLoad(LoadTy, VecPtr,
705                                              Base->getAlignment());
706 
707   // Make sure everything is in the correct order in the basic block.
708   MoveBefore(Base->getPointerOperand(), VecPtr);
709   MoveBefore(VecPtr, WideLoad);
710 
711   // From the wide load, create two values that equal the original two loads.
712   // Loads[0] needs trunc while Loads[1] needs a lshr and trunc.
713   // TODO: Support big-endian as well.
714   Value *Bottom = IRB.CreateTrunc(WideLoad, Base->getType());
715   BaseSExt->setOperand(0, Bottom);
716 
717   IntegerType *OffsetTy = cast<IntegerType>(Offset->getType());
718   Value *ShiftVal = ConstantInt::get(LoadTy, OffsetTy->getBitWidth());
719   Value *Top = IRB.CreateLShr(WideLoad, ShiftVal);
720   Value *Trunc = IRB.CreateTrunc(Top, OffsetTy);
721   OffsetSExt->setOperand(0, Trunc);
722 
723   WideLoads.emplace(std::make_pair(Base,
724                                    make_unique<WidenedLoad>(Loads, WideLoad)));
725   return WideLoad;
726 }
727 
728 Instruction *ARMParallelDSP::CreateSMLADCall(SmallVectorImpl<LoadInst*> &VecLd0,
729                                              SmallVectorImpl<LoadInst*> &VecLd1,
730                                              Instruction *Acc, bool Exchange,
731                                              Instruction *InsertAfter) {
732   LLVM_DEBUG(dbgs() << "Create SMLAD intrinsic using:\n"
733              << "- " << *VecLd0[0] << "\n"
734              << "- " << *VecLd0[1] << "\n"
735              << "- " << *VecLd1[0] << "\n"
736              << "- " << *VecLd1[1] << "\n"
737              << "- " << *Acc << "\n"
738              << "- Exchange: " << Exchange << "\n");
739 
740   // Replace the reduction chain with an intrinsic call
741   IntegerType *Ty = IntegerType::get(M->getContext(), 32);
742   LoadInst *WideLd0 = WideLoads.count(VecLd0[0]) ?
743     WideLoads[VecLd0[0]]->getLoad() : CreateWideLoad(VecLd0, Ty);
744   LoadInst *WideLd1 = WideLoads.count(VecLd1[0]) ?
745     WideLoads[VecLd1[0]]->getLoad() : CreateWideLoad(VecLd1, Ty);
746 
747   Value* Args[] = { WideLd0, WideLd1, Acc };
748   Function *SMLAD = nullptr;
749   if (Exchange)
750     SMLAD = Acc->getType()->isIntegerTy(32) ?
751       Intrinsic::getDeclaration(M, Intrinsic::arm_smladx) :
752       Intrinsic::getDeclaration(M, Intrinsic::arm_smlaldx);
753   else
754     SMLAD = Acc->getType()->isIntegerTy(32) ?
755       Intrinsic::getDeclaration(M, Intrinsic::arm_smlad) :
756       Intrinsic::getDeclaration(M, Intrinsic::arm_smlald);
757 
758   IRBuilder<NoFolder> Builder(InsertAfter->getParent(),
759                               ++BasicBlock::iterator(InsertAfter));
760   CallInst *Call = Builder.CreateCall(SMLAD, Args);
761   NumSMLAD++;
762   return Call;
763 }
764 
765 // Compare the value lists in Other to this chain.
766 bool BinOpChain::AreSymmetrical(BinOpChain *Other) {
767   // Element-by-element comparison of Value lists returning true if they are
768   // instructions with the same opcode or constants with the same value.
769   auto CompareValueList = [](const ValueList &VL0,
770                              const ValueList &VL1) {
771     if (VL0.size() != VL1.size()) {
772       LLVM_DEBUG(dbgs() << "Muls are mismatching operand list lengths: "
773                         << VL0.size() << " != " << VL1.size() << "\n");
774       return false;
775     }
776 
777     const unsigned Pairs = VL0.size();
778 
779     for (unsigned i = 0; i < Pairs; ++i) {
780       const Value *V0 = VL0[i];
781       const Value *V1 = VL1[i];
782       const auto *Inst0 = dyn_cast<Instruction>(V0);
783       const auto *Inst1 = dyn_cast<Instruction>(V1);
784 
785       if (!Inst0 || !Inst1)
786         return false;
787 
788       if (Inst0->isSameOperationAs(Inst1))
789         continue;
790 
791       const APInt *C0, *C1;
792       if (!(match(V0, m_APInt(C0)) && match(V1, m_APInt(C1)) && C0 == C1))
793         return false;
794     }
795 
796     return true;
797   };
798 
799   return CompareValueList(LHS, Other->LHS) &&
800          CompareValueList(RHS, Other->RHS);
801 }
802 
803 Pass *llvm::createARMParallelDSPPass() {
804   return new ARMParallelDSP();
805 }
806 
807 char ARMParallelDSP::ID = 0;
808 
809 INITIALIZE_PASS_BEGIN(ARMParallelDSP, "arm-parallel-dsp",
810                 "Transform loops to use DSP intrinsics", false, false)
811 INITIALIZE_PASS_END(ARMParallelDSP, "arm-parallel-dsp",
812                 "Transform loops to use DSP intrinsics", false, false)
813