1 //===- ParallelDSP.cpp - Parallel DSP Pass --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file
10 /// Armv6 introduced instructions to perform 32-bit SIMD operations. The
11 /// purpose of this pass is do some IR pattern matching to create ACLE
12 /// DSP intrinsics, which map on these 32-bit SIMD operations.
13 /// This pass runs only when unaligned accesses is supported/enabled.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "llvm/ADT/Statistic.h"
18 #include "llvm/ADT/SmallPtrSet.h"
19 #include "llvm/Analysis/AliasAnalysis.h"
20 #include "llvm/Analysis/LoopAccessAnalysis.h"
21 #include "llvm/Analysis/LoopPass.h"
22 #include "llvm/Analysis/LoopInfo.h"
23 #include "llvm/IR/Instructions.h"
24 #include "llvm/IR/NoFolder.h"
25 #include "llvm/Transforms/Scalar.h"
26 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
27 #include "llvm/Transforms/Utils/LoopUtils.h"
28 #include "llvm/Pass.h"
29 #include "llvm/PassRegistry.h"
30 #include "llvm/PassSupport.h"
31 #include "llvm/Support/Debug.h"
32 #include "llvm/IR/PatternMatch.h"
33 #include "llvm/CodeGen/TargetPassConfig.h"
34 #include "ARM.h"
35 #include "ARMSubtarget.h"
36 
37 using namespace llvm;
38 using namespace PatternMatch;
39 
40 #define DEBUG_TYPE "arm-parallel-dsp"
41 
42 STATISTIC(NumSMLAD , "Number of smlad instructions generated");
43 
44 static cl::opt<bool>
45 DisableParallelDSP("disable-arm-parallel-dsp", cl::Hidden, cl::init(false),
46                    cl::desc("Disable the ARM Parallel DSP pass"));
47 
48 namespace {
49   struct OpChain;
50   struct BinOpChain;
51   struct Reduction;
52 
53   using OpChainList     = SmallVector<std::unique_ptr<OpChain>, 8>;
54   using ReductionList   = SmallVector<Reduction, 8>;
55   using ValueList       = SmallVector<Value*, 8>;
56   using MemInstList     = SmallVector<LoadInst*, 8>;
57   using PMACPair        = std::pair<BinOpChain*,BinOpChain*>;
58   using PMACPairList    = SmallVector<PMACPair, 8>;
59   using Instructions    = SmallVector<Instruction*,16>;
60   using MemLocList      = SmallVector<MemoryLocation, 4>;
61 
62   struct OpChain {
63     Instruction   *Root;
64     ValueList     AllValues;
65     MemInstList   VecLd;    // List of all load instructions.
66     MemInstList   Loads;
67     bool          ReadOnly = true;
68 
69     OpChain(Instruction *I, ValueList &vl) : Root(I), AllValues(vl) { }
70     virtual ~OpChain() = default;
71 
72     void PopulateLoads() {
73       for (auto *V : AllValues) {
74         if (auto *Ld = dyn_cast<LoadInst>(V))
75           Loads.push_back(Ld);
76       }
77     }
78 
79     unsigned size() const { return AllValues.size(); }
80   };
81 
82   // 'BinOpChain' and 'Reduction' are just some bookkeeping data structures.
83   // 'Reduction' contains the phi-node and accumulator statement from where we
84   // start pattern matching, and 'BinOpChain' the multiplication
85   // instructions that are candidates for parallel execution.
86   struct BinOpChain : public OpChain {
87     ValueList     LHS;      // List of all (narrow) left hand operands.
88     ValueList     RHS;      // List of all (narrow) right hand operands.
89     bool Exchange = false;
90 
91     BinOpChain(Instruction *I, ValueList &lhs, ValueList &rhs) :
92       OpChain(I, lhs), LHS(lhs), RHS(rhs) {
93         for (auto *V : RHS)
94           AllValues.push_back(V);
95       }
96 
97     bool AreSymmetrical(BinOpChain *Other);
98   };
99 
100   struct Reduction {
101     PHINode         *Phi;             // The Phi-node from where we start
102                                       // pattern matching.
103     Instruction     *AccIntAdd;       // The accumulating integer add statement,
104                                       // i.e, the reduction statement.
105     OpChainList     MACCandidates;    // The MAC candidates associated with
106                                       // this reduction statement.
107     PMACPairList    PMACPairs;
108     Reduction (PHINode *P, Instruction *Acc) : Phi(P), AccIntAdd(Acc) { };
109   };
110 
111   class WidenedLoad {
112     LoadInst *NewLd = nullptr;
113     SmallVector<LoadInst*, 4> Loads;
114 
115   public:
116     WidenedLoad(SmallVectorImpl<LoadInst*> &Lds, LoadInst *Wide)
117       : NewLd(Wide) {
118       for (auto *I : Lds)
119         Loads.push_back(I);
120     }
121     LoadInst *getLoad() {
122       return NewLd;
123     }
124   };
125 
126   class ARMParallelDSP : public LoopPass {
127     ScalarEvolution   *SE;
128     AliasAnalysis     *AA;
129     TargetLibraryInfo *TLI;
130     DominatorTree     *DT;
131     LoopInfo          *LI;
132     Loop              *L;
133     const DataLayout  *DL;
134     Module            *M;
135     std::map<LoadInst*, LoadInst*> LoadPairs;
136     std::map<LoadInst*, std::unique_ptr<WidenedLoad>> WideLoads;
137 
138     bool RecordMemoryOps(BasicBlock *BB);
139     bool InsertParallelMACs(Reduction &Reduction);
140     bool AreSequentialLoads(LoadInst *Ld0, LoadInst *Ld1, MemInstList &VecMem);
141     LoadInst* CreateWideLoad(SmallVectorImpl<LoadInst*> &Loads,
142                              IntegerType *LoadTy);
143     void CreateParallelMACPairs(Reduction &R);
144     Instruction *CreateSMLADCall(SmallVectorImpl<LoadInst*> &VecLd0,
145                                  SmallVectorImpl<LoadInst*> &VecLd1,
146                                  Instruction *Acc, bool Exchange,
147                                  Instruction *InsertAfter);
148 
149     /// Try to match and generate: SMLAD, SMLADX - Signed Multiply Accumulate
150     /// Dual performs two signed 16x16-bit multiplications. It adds the
151     /// products to a 32-bit accumulate operand. Optionally, the instruction can
152     /// exchange the halfwords of the second operand before performing the
153     /// arithmetic.
154     bool MatchSMLAD(Function &F);
155 
156   public:
157     static char ID;
158 
159     ARMParallelDSP() : LoopPass(ID) { }
160 
161     bool doInitialization(Loop *L, LPPassManager &LPM) override {
162       LoadPairs.clear();
163       WideLoads.clear();
164       return true;
165     }
166 
167     void getAnalysisUsage(AnalysisUsage &AU) const override {
168       LoopPass::getAnalysisUsage(AU);
169       AU.addRequired<AssumptionCacheTracker>();
170       AU.addRequired<ScalarEvolutionWrapperPass>();
171       AU.addRequired<AAResultsWrapperPass>();
172       AU.addRequired<TargetLibraryInfoWrapperPass>();
173       AU.addRequired<LoopInfoWrapperPass>();
174       AU.addRequired<DominatorTreeWrapperPass>();
175       AU.addRequired<TargetPassConfig>();
176       AU.addPreserved<LoopInfoWrapperPass>();
177       AU.setPreservesCFG();
178     }
179 
180     bool runOnLoop(Loop *TheLoop, LPPassManager &) override {
181       if (DisableParallelDSP)
182         return false;
183       L = TheLoop;
184       SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
185       AA = &getAnalysis<AAResultsWrapperPass>().getAAResults();
186       TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
187       DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
188       LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
189       auto &TPC = getAnalysis<TargetPassConfig>();
190 
191       BasicBlock *Header = TheLoop->getHeader();
192       if (!Header)
193         return false;
194 
195       // TODO: We assume the loop header and latch to be the same block.
196       // This is not a fundamental restriction, but lifting this would just
197       // require more work to do the transformation and then patch up the CFG.
198       if (Header != TheLoop->getLoopLatch()) {
199         LLVM_DEBUG(dbgs() << "The loop header is not the loop latch: not "
200                              "running pass ARMParallelDSP\n");
201         return false;
202       }
203 
204       // We need a preheader as getIncomingValueForBlock assumes there is one.
205       if (!TheLoop->getLoopPreheader()) {
206         LLVM_DEBUG(dbgs() << "No preheader found, bailing out\n");
207         return false;
208       }
209 
210       Function &F = *Header->getParent();
211       M = F.getParent();
212       DL = &M->getDataLayout();
213 
214       auto &TM = TPC.getTM<TargetMachine>();
215       auto *ST = &TM.getSubtarget<ARMSubtarget>(F);
216 
217       if (!ST->allowsUnalignedMem()) {
218         LLVM_DEBUG(dbgs() << "Unaligned memory access not supported: not "
219                              "running pass ARMParallelDSP\n");
220         return false;
221       }
222 
223       if (!ST->hasDSP()) {
224         LLVM_DEBUG(dbgs() << "DSP extension not enabled: not running pass "
225                              "ARMParallelDSP\n");
226         return false;
227       }
228 
229       if (!ST->isLittle()) {
230         LLVM_DEBUG(dbgs() << "Only supporting little endian: not running pass "
231                           << "ARMParallelDSP\n");
232         return false;
233       }
234 
235       LoopAccessInfo LAI(L, SE, TLI, AA, DT, LI);
236 
237       LLVM_DEBUG(dbgs() << "\n== Parallel DSP pass ==\n");
238       LLVM_DEBUG(dbgs() << " - " << F.getName() << "\n\n");
239 
240       if (!RecordMemoryOps(Header)) {
241         LLVM_DEBUG(dbgs() << " - No sequential loads found.\n");
242         return false;
243       }
244 
245       bool Changes = MatchSMLAD(F);
246       return Changes;
247     }
248   };
249 }
250 
251 template<typename MemInst>
252 static bool AreSequentialAccesses(MemInst *MemOp0, MemInst *MemOp1,
253                                   const DataLayout &DL, ScalarEvolution &SE) {
254   if (isConsecutiveAccess(MemOp0, MemOp1, DL, SE))
255     return true;
256   return false;
257 }
258 
259 bool ARMParallelDSP::AreSequentialLoads(LoadInst *Ld0, LoadInst *Ld1,
260                                         MemInstList &VecMem) {
261   if (!Ld0 || !Ld1)
262     return false;
263 
264   if (!LoadPairs.count(Ld0) || LoadPairs[Ld0] != Ld1)
265     return false;
266 
267   LLVM_DEBUG(dbgs() << "Loads are sequential and valid:\n";
268     dbgs() << "Ld0:"; Ld0->dump();
269     dbgs() << "Ld1:"; Ld1->dump();
270   );
271 
272   VecMem.clear();
273   VecMem.push_back(Ld0);
274   VecMem.push_back(Ld1);
275   return true;
276 }
277 
278 /// Iterate through the block and record base, offset pairs of loads which can
279 /// be widened into a single load.
280 bool ARMParallelDSP::RecordMemoryOps(BasicBlock *BB) {
281   SmallVector<LoadInst*, 8> Loads;
282   SmallVector<Instruction*, 8> Writes;
283 
284   // Collect loads and instruction that may write to memory. For now we only
285   // record loads which are simple, sign-extended and have a single user.
286   // TODO: Allow zero-extended loads.
287   for (auto &I : *BB) {
288     if (I.mayWriteToMemory())
289       Writes.push_back(&I);
290     auto *Ld = dyn_cast<LoadInst>(&I);
291     if (!Ld || !Ld->isSimple() ||
292         !Ld->hasOneUse() || !isa<SExtInst>(Ld->user_back()))
293       continue;
294     Loads.push_back(Ld);
295   }
296 
297   using InstSet = std::set<Instruction*>;
298   using DepMap = std::map<Instruction*, InstSet>;
299   DepMap RAWDeps;
300 
301   // Record any writes that may alias a load.
302   const auto Size = LocationSize::unknown();
303   for (auto Read : Loads) {
304     for (auto Write : Writes) {
305       MemoryLocation ReadLoc =
306         MemoryLocation(Read->getPointerOperand(), Size);
307 
308       if (!isModOrRefSet(intersectModRef(AA->getModRefInfo(Write, ReadLoc),
309           ModRefInfo::ModRef)))
310         continue;
311       if (DT->dominates(Write, Read))
312         RAWDeps[Read].insert(Write);
313     }
314   }
315 
316   // Check whether there's not a write between the two loads which would
317   // prevent them from being safely merged.
318   auto SafeToPair = [&](LoadInst *Base, LoadInst *Offset) {
319     LoadInst *Dominator = DT->dominates(Base, Offset) ? Base : Offset;
320     LoadInst *Dominated = DT->dominates(Base, Offset) ? Offset : Base;
321 
322     if (RAWDeps.count(Dominated)) {
323       InstSet &WritesBefore = RAWDeps[Dominated];
324 
325       for (auto Before : WritesBefore) {
326 
327         // We can't move the second load backward, past a write, to merge
328         // with the first load.
329         if (DT->dominates(Dominator, Before))
330           return false;
331       }
332     }
333     return true;
334   };
335 
336   // Record base, offset load pairs.
337   for (auto *Base : Loads) {
338     for (auto *Offset : Loads) {
339       if (Base == Offset)
340         continue;
341 
342       if (AreSequentialAccesses<LoadInst>(Base, Offset, *DL, *SE) &&
343           SafeToPair(Base, Offset)) {
344         LoadPairs[Base] = Offset;
345         break;
346       }
347     }
348   }
349 
350   LLVM_DEBUG(if (!LoadPairs.empty()) {
351                dbgs() << "Consecutive load pairs:\n";
352                for (auto &MapIt : LoadPairs) {
353                  LLVM_DEBUG(dbgs() << *MapIt.first << ", "
354                             << *MapIt.second << "\n");
355                }
356              });
357   return LoadPairs.size() > 1;
358 }
359 
360 void ARMParallelDSP::CreateParallelMACPairs(Reduction &R) {
361   OpChainList &Candidates = R.MACCandidates;
362   PMACPairList &PMACPairs = R.PMACPairs;
363   const unsigned Elems = Candidates.size();
364 
365   if (Elems < 2)
366     return;
367 
368   auto CanPair = [&](BinOpChain *PMul0, BinOpChain *PMul1) {
369     if (!PMul0->AreSymmetrical(PMul1))
370       return false;
371 
372     // The first elements of each vector should be loads with sexts. If we
373     // find that its two pairs of consecutive loads, then these can be
374     // transformed into two wider loads and the users can be replaced with
375     // DSP intrinsics.
376     for (unsigned x = 0; x < PMul0->LHS.size(); x += 2) {
377       auto *Ld0 = dyn_cast<LoadInst>(PMul0->LHS[x]);
378       auto *Ld1 = dyn_cast<LoadInst>(PMul1->LHS[x]);
379       auto *Ld2 = dyn_cast<LoadInst>(PMul0->RHS[x]);
380       auto *Ld3 = dyn_cast<LoadInst>(PMul1->RHS[x]);
381 
382       if (!Ld0 || !Ld1 || !Ld2 || !Ld3)
383         return false;
384 
385       LLVM_DEBUG(dbgs() << "Loads:\n"
386                  << " - " << *Ld0 << "\n"
387                  << " - " << *Ld1 << "\n"
388                  << " - " << *Ld2 << "\n"
389                  << " - " << *Ld3 << "\n");
390 
391       if (AreSequentialLoads(Ld0, Ld1, PMul0->VecLd)) {
392         if (AreSequentialLoads(Ld2, Ld3, PMul1->VecLd)) {
393           LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n");
394           PMACPairs.push_back(std::make_pair(PMul0, PMul1));
395           return true;
396         } else if (AreSequentialLoads(Ld3, Ld2, PMul1->VecLd)) {
397           LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n");
398           LLVM_DEBUG(dbgs() << "    exchanging Ld2 and Ld3\n");
399           PMul1->Exchange = true;
400           PMACPairs.push_back(std::make_pair(PMul0, PMul1));
401           return true;
402         }
403       } else if (AreSequentialLoads(Ld1, Ld0, PMul0->VecLd) &&
404                  AreSequentialLoads(Ld2, Ld3, PMul1->VecLd)) {
405         LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n");
406         LLVM_DEBUG(dbgs() << "    exchanging Ld0 and Ld1\n");
407         LLVM_DEBUG(dbgs() << "    and swapping muls\n");
408         PMul0->Exchange = true;
409         // Only the second operand can be exchanged, so swap the muls.
410         PMACPairs.push_back(std::make_pair(PMul1, PMul0));
411         return true;
412       }
413     }
414     return false;
415   };
416 
417   SmallPtrSet<const Instruction*, 4> Paired;
418   for (unsigned i = 0; i < Elems; ++i) {
419     BinOpChain *PMul0 = static_cast<BinOpChain*>(Candidates[i].get());
420     if (Paired.count(PMul0->Root))
421       continue;
422 
423     for (unsigned j = 0; j < Elems; ++j) {
424       if (i == j)
425         continue;
426 
427       BinOpChain *PMul1 = static_cast<BinOpChain*>(Candidates[j].get());
428       if (Paired.count(PMul1->Root))
429         continue;
430 
431       const Instruction *Mul0 = PMul0->Root;
432       const Instruction *Mul1 = PMul1->Root;
433       if (Mul0 == Mul1)
434         continue;
435 
436       assert(PMul0 != PMul1 && "expected different chains");
437 
438       if (CanPair(PMul0, PMul1)) {
439         Paired.insert(Mul0);
440         Paired.insert(Mul1);
441         break;
442       }
443     }
444   }
445 }
446 
447 bool ARMParallelDSP::InsertParallelMACs(Reduction &Reduction) {
448   Instruction *Acc = Reduction.Phi;
449   Instruction *InsertAfter = Reduction.AccIntAdd;
450 
451   for (auto &Pair : Reduction.PMACPairs) {
452     BinOpChain *PMul0 = Pair.first;
453     BinOpChain *PMul1 = Pair.second;
454     LLVM_DEBUG(dbgs() << "Found parallel MACs:\n"
455                << "- " << *PMul0->Root << "\n"
456                << "- " << *PMul1->Root << "\n");
457 
458     Acc = CreateSMLADCall(PMul0->VecLd, PMul1->VecLd, Acc, PMul1->Exchange,
459                           InsertAfter);
460     InsertAfter = Acc;
461   }
462 
463   if (Acc != Reduction.Phi) {
464     LLVM_DEBUG(dbgs() << "Replace Accumulate: "; Acc->dump());
465     Reduction.AccIntAdd->replaceAllUsesWith(Acc);
466     return true;
467   }
468   return false;
469 }
470 
471 template<typename InstType, unsigned BitWidth>
472 bool IsExtendingLoad(Value *V) {
473   auto *I = dyn_cast<InstType>(V);
474   if (!I)
475     return false;
476 
477   if (I->getSrcTy()->getIntegerBitWidth() != BitWidth)
478     return false;
479 
480   return isa<LoadInst>(I->getOperand(0));
481 }
482 
483 static void MatchParallelMACSequences(Reduction &R,
484                                       OpChainList &Candidates) {
485   Instruction *Acc = R.AccIntAdd;
486   LLVM_DEBUG(dbgs() << "\n- Analysing:\t" << *Acc << "\n");
487 
488   // Returns false to signal the search should be stopped.
489   std::function<bool(Value*)> Match =
490     [&Candidates, &Match](Value *V) -> bool {
491 
492     auto *I = dyn_cast<Instruction>(V);
493     if (!I)
494       return false;
495 
496     switch (I->getOpcode()) {
497     case Instruction::Add:
498       if (Match(I->getOperand(0)) || (Match(I->getOperand(1))))
499         return true;
500       break;
501     case Instruction::Mul: {
502       Value *Op0 = I->getOperand(0);
503       Value *Op1 = I->getOperand(1);
504       if (IsExtendingLoad<SExtInst, 16>(Op0) &&
505           IsExtendingLoad<SExtInst, 16>(Op1)) {
506         ValueList LHS = { cast<SExtInst>(Op0)->getOperand(0), Op0 };
507         ValueList RHS = { cast<SExtInst>(Op1)->getOperand(0), Op1 };
508         Candidates.push_back(make_unique<BinOpChain>(I, LHS, RHS));
509       }
510       return false;
511     }
512     case Instruction::SExt:
513       return Match(I->getOperand(0));
514     }
515     return false;
516   };
517 
518   while (Match (Acc));
519   LLVM_DEBUG(dbgs() << "Finished matching MAC sequences, found "
520              << Candidates.size() << " candidates.\n");
521 }
522 
523 static bool CheckMACMemory(OpChainList &Candidates) {
524   for (auto &C : Candidates) {
525     // A mul has 2 operands, and a narrow op consist of sext and a load; thus
526     // we expect at least 4 items in this operand value list.
527     if (C->size() < 4) {
528       LLVM_DEBUG(dbgs() << "Operand list too short.\n");
529       return false;
530     }
531     C->PopulateLoads();
532     ValueList &LHS = static_cast<BinOpChain*>(C.get())->LHS;
533     ValueList &RHS = static_cast<BinOpChain*>(C.get())->RHS;
534 
535     // Use +=2 to skip over the expected extend instructions.
536     for (unsigned i = 0, e = LHS.size(); i < e; i += 2) {
537       if (!isa<LoadInst>(LHS[i]) || !isa<LoadInst>(RHS[i]))
538         return false;
539     }
540   }
541   return true;
542 }
543 
544 // Loop Pass that needs to identify integer add/sub reductions of 16-bit vector
545 // multiplications.
546 // To use SMLAD:
547 // 1) we first need to find integer add reduction PHIs,
548 // 2) then from the PHI, look for this pattern:
549 //
550 // acc0 = phi i32 [0, %entry], [%acc1, %loop.body]
551 // ld0 = load i16
552 // sext0 = sext i16 %ld0 to i32
553 // ld1 = load i16
554 // sext1 = sext i16 %ld1 to i32
555 // mul0 = mul %sext0, %sext1
556 // ld2 = load i16
557 // sext2 = sext i16 %ld2 to i32
558 // ld3 = load i16
559 // sext3 = sext i16 %ld3 to i32
560 // mul1 = mul i32 %sext2, %sext3
561 // add0 = add i32 %mul0, %acc0
562 // acc1 = add i32 %add0, %mul1
563 //
564 // Which can be selected to:
565 //
566 // ldr.h r0
567 // ldr.h r1
568 // smlad r2, r0, r1, r2
569 //
570 // If constants are used instead of loads, these will need to be hoisted
571 // out and into a register.
572 //
573 // If loop invariants are used instead of loads, these need to be packed
574 // before the loop begins.
575 //
576 bool ARMParallelDSP::MatchSMLAD(Function &F) {
577 
578   auto FindReductions = [&](ReductionList &Reductions) {
579     RecurrenceDescriptor RecDesc;
580     const bool HasFnNoNaNAttr =
581       F.getFnAttribute("no-nans-fp-math").getValueAsString() == "true";
582     BasicBlock *Latch = L->getLoopLatch();
583 
584     for (PHINode &Phi : Latch->phis()) {
585       const auto *Ty = Phi.getType();
586       if (!Ty->isIntegerTy(32) && !Ty->isIntegerTy(64))
587         continue;
588 
589       const bool IsReduction = RecurrenceDescriptor::AddReductionVar(
590         &Phi, RecurrenceDescriptor::RK_IntegerAdd, L, HasFnNoNaNAttr, RecDesc);
591 
592       if (!IsReduction)
593         continue;
594 
595       Instruction *Acc = dyn_cast<Instruction>(Phi.getIncomingValueForBlock(Latch));
596       if (!Acc)
597         continue;
598 
599       Reductions.push_back(Reduction(&Phi, Acc));
600     }
601     return !Reductions.empty();
602   };
603 
604   ReductionList Reductions;
605   if (!FindReductions(Reductions))
606     return false;
607 
608   for (auto &R : Reductions) {
609     OpChainList MACCandidates;
610     MatchParallelMACSequences(R, MACCandidates);
611     if (!CheckMACMemory(MACCandidates))
612       continue;
613 
614     R.MACCandidates = std::move(MACCandidates);
615 
616     LLVM_DEBUG(dbgs() << "MAC candidates:\n";
617       for (auto &M : R.MACCandidates)
618         M->Root->dump();
619       dbgs() << "\n";);
620   }
621 
622   bool Changed = false;
623   // Check whether statements in the basic block that write to memory alias
624   // with the memory locations accessed by the MAC-chains.
625   for (auto &R : Reductions) {
626     CreateParallelMACPairs(R);
627     Changed |= InsertParallelMACs(R);
628   }
629 
630   return Changed;
631 }
632 
633 LoadInst* ARMParallelDSP::CreateWideLoad(SmallVectorImpl<LoadInst*> &Loads,
634                                          IntegerType *LoadTy) {
635   assert(Loads.size() == 2 && "currently only support widening two loads");
636 
637   LoadInst *Base = Loads[0];
638   LoadInst *Offset = Loads[1];
639 
640   Instruction *BaseSExt = dyn_cast<SExtInst>(Base->user_back());
641   Instruction *OffsetSExt = dyn_cast<SExtInst>(Offset->user_back());
642 
643   assert((BaseSExt && OffsetSExt)
644          && "Loads should have a single, extending, user");
645 
646   std::function<void(Value*, Value*)> MoveBefore =
647     [&](Value *A, Value *B) -> void {
648       if (!isa<Instruction>(A) || !isa<Instruction>(B))
649         return;
650 
651       auto *Source = cast<Instruction>(A);
652       auto *Sink = cast<Instruction>(B);
653 
654       if (DT->dominates(Source, Sink) ||
655           Source->getParent() != Sink->getParent() ||
656           isa<PHINode>(Source) || isa<PHINode>(Sink))
657         return;
658 
659       Source->moveBefore(Sink);
660       for (auto &U : Source->uses())
661         MoveBefore(Source, U.getUser());
662     };
663 
664   // Insert the load at the point of the original dominating load.
665   LoadInst *DomLoad = DT->dominates(Base, Offset) ? Base : Offset;
666   IRBuilder<NoFolder> IRB(DomLoad->getParent(),
667                           ++BasicBlock::iterator(DomLoad));
668 
669   // Bitcast the pointer to a wider type and create the wide load, while making
670   // sure to maintain the original alignment as this prevents ldrd from being
671   // generated when it could be illegal due to memory alignment.
672   const unsigned AddrSpace = DomLoad->getPointerAddressSpace();
673   Value *VecPtr = IRB.CreateBitCast(Base->getPointerOperand(),
674                                     LoadTy->getPointerTo(AddrSpace));
675   LoadInst *WideLoad = IRB.CreateAlignedLoad(LoadTy, VecPtr,
676                                              Base->getAlignment());
677 
678   // Make sure everything is in the correct order in the basic block.
679   MoveBefore(Base->getPointerOperand(), VecPtr);
680   MoveBefore(VecPtr, WideLoad);
681 
682   // From the wide load, create two values that equal the original two loads.
683   // Loads[0] needs trunc while Loads[1] needs a lshr and trunc.
684   // TODO: Support big-endian as well.
685   Value *Bottom = IRB.CreateTrunc(WideLoad, Base->getType());
686   BaseSExt->setOperand(0, Bottom);
687 
688   IntegerType *OffsetTy = cast<IntegerType>(Offset->getType());
689   Value *ShiftVal = ConstantInt::get(LoadTy, OffsetTy->getBitWidth());
690   Value *Top = IRB.CreateLShr(WideLoad, ShiftVal);
691   Value *Trunc = IRB.CreateTrunc(Top, OffsetTy);
692   OffsetSExt->setOperand(0, Trunc);
693 
694   WideLoads.emplace(std::make_pair(Base,
695                                    make_unique<WidenedLoad>(Loads, WideLoad)));
696   return WideLoad;
697 }
698 
699 Instruction *ARMParallelDSP::CreateSMLADCall(SmallVectorImpl<LoadInst*> &VecLd0,
700                                              SmallVectorImpl<LoadInst*> &VecLd1,
701                                              Instruction *Acc, bool Exchange,
702                                              Instruction *InsertAfter) {
703   LLVM_DEBUG(dbgs() << "Create SMLAD intrinsic using:\n"
704              << "- " << *VecLd0[0] << "\n"
705              << "- " << *VecLd0[1] << "\n"
706              << "- " << *VecLd1[0] << "\n"
707              << "- " << *VecLd1[1] << "\n"
708              << "- " << *Acc << "\n"
709              << "- Exchange: " << Exchange << "\n");
710 
711   // Replace the reduction chain with an intrinsic call
712   IntegerType *Ty = IntegerType::get(M->getContext(), 32);
713   LoadInst *WideLd0 = WideLoads.count(VecLd0[0]) ?
714     WideLoads[VecLd0[0]]->getLoad() : CreateWideLoad(VecLd0, Ty);
715   LoadInst *WideLd1 = WideLoads.count(VecLd1[0]) ?
716     WideLoads[VecLd1[0]]->getLoad() : CreateWideLoad(VecLd1, Ty);
717 
718   Value* Args[] = { WideLd0, WideLd1, Acc };
719   Function *SMLAD = nullptr;
720   if (Exchange)
721     SMLAD = Acc->getType()->isIntegerTy(32) ?
722       Intrinsic::getDeclaration(M, Intrinsic::arm_smladx) :
723       Intrinsic::getDeclaration(M, Intrinsic::arm_smlaldx);
724   else
725     SMLAD = Acc->getType()->isIntegerTy(32) ?
726       Intrinsic::getDeclaration(M, Intrinsic::arm_smlad) :
727       Intrinsic::getDeclaration(M, Intrinsic::arm_smlald);
728 
729   IRBuilder<NoFolder> Builder(InsertAfter->getParent(),
730                               ++BasicBlock::iterator(InsertAfter));
731   CallInst *Call = Builder.CreateCall(SMLAD, Args);
732   NumSMLAD++;
733   return Call;
734 }
735 
736 // Compare the value lists in Other to this chain.
737 bool BinOpChain::AreSymmetrical(BinOpChain *Other) {
738   // Element-by-element comparison of Value lists returning true if they are
739   // instructions with the same opcode or constants with the same value.
740   auto CompareValueList = [](const ValueList &VL0,
741                              const ValueList &VL1) {
742     if (VL0.size() != VL1.size()) {
743       LLVM_DEBUG(dbgs() << "Muls are mismatching operand list lengths: "
744                         << VL0.size() << " != " << VL1.size() << "\n");
745       return false;
746     }
747 
748     const unsigned Pairs = VL0.size();
749 
750     for (unsigned i = 0; i < Pairs; ++i) {
751       const Value *V0 = VL0[i];
752       const Value *V1 = VL1[i];
753       const auto *Inst0 = dyn_cast<Instruction>(V0);
754       const auto *Inst1 = dyn_cast<Instruction>(V1);
755 
756       if (!Inst0 || !Inst1)
757         return false;
758 
759       if (Inst0->isSameOperationAs(Inst1))
760         continue;
761 
762       const APInt *C0, *C1;
763       if (!(match(V0, m_APInt(C0)) && match(V1, m_APInt(C1)) && C0 == C1))
764         return false;
765     }
766 
767     return true;
768   };
769 
770   return CompareValueList(LHS, Other->LHS) &&
771          CompareValueList(RHS, Other->RHS);
772 }
773 
774 Pass *llvm::createARMParallelDSPPass() {
775   return new ARMParallelDSP();
776 }
777 
778 char ARMParallelDSP::ID = 0;
779 
780 INITIALIZE_PASS_BEGIN(ARMParallelDSP, "arm-parallel-dsp",
781                 "Transform loops to use DSP intrinsics", false, false)
782 INITIALIZE_PASS_END(ARMParallelDSP, "arm-parallel-dsp",
783                 "Transform loops to use DSP intrinsics", false, false)
784