1 //===- ParallelDSP.cpp - Parallel DSP Pass --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file
10 /// Armv6 introduced instructions to perform 32-bit SIMD operations. The
11 /// purpose of this pass is do some IR pattern matching to create ACLE
12 /// DSP intrinsics, which map on these 32-bit SIMD operations.
13 /// This pass runs only when unaligned accesses is supported/enabled.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "llvm/ADT/Statistic.h"
18 #include "llvm/ADT/SmallPtrSet.h"
19 #include "llvm/Analysis/AliasAnalysis.h"
20 #include "llvm/Analysis/LoopAccessAnalysis.h"
21 #include "llvm/Analysis/LoopPass.h"
22 #include "llvm/Analysis/LoopInfo.h"
23 #include "llvm/IR/Instructions.h"
24 #include "llvm/IR/NoFolder.h"
25 #include "llvm/Transforms/Scalar.h"
26 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
27 #include "llvm/Transforms/Utils/LoopUtils.h"
28 #include "llvm/Pass.h"
29 #include "llvm/PassRegistry.h"
30 #include "llvm/PassSupport.h"
31 #include "llvm/Support/Debug.h"
32 #include "llvm/IR/PatternMatch.h"
33 #include "llvm/CodeGen/TargetPassConfig.h"
34 #include "ARM.h"
35 #include "ARMSubtarget.h"
36 
37 using namespace llvm;
38 using namespace PatternMatch;
39 
40 #define DEBUG_TYPE "arm-parallel-dsp"
41 
42 STATISTIC(NumSMLAD , "Number of smlad instructions generated");
43 
44 static cl::opt<bool>
45 DisableParallelDSP("disable-arm-parallel-dsp", cl::Hidden, cl::init(false),
46                    cl::desc("Disable the ARM Parallel DSP pass"));
47 
48 namespace {
49   struct OpChain;
50   struct BinOpChain;
51   class Reduction;
52 
53   using OpChainList     = SmallVector<std::unique_ptr<BinOpChain>, 8>;
54   using ReductionList   = SmallVector<Reduction, 8>;
55   using ValueList       = SmallVector<Value*, 8>;
56   using MemInstList     = SmallVector<LoadInst*, 8>;
57   using PMACPair        = std::pair<BinOpChain*,BinOpChain*>;
58   using PMACPairList    = SmallVector<PMACPair, 8>;
59   using Instructions    = SmallVector<Instruction*,16>;
60   using MemLocList      = SmallVector<MemoryLocation, 4>;
61 
62   // 'BinOpChain' holds the multiplication instructions that are candidates
63   // for parallel execution.
64   struct BinOpChain {
65     Instruction   *Root;
66     ValueList     AllValues;
67     MemInstList   VecLd;    // List of all load instructions.
68     ValueList     LHS;      // List of all (narrow) left hand operands.
69     ValueList     RHS;      // List of all (narrow) right hand operands.
70     bool          Exchange = false;
71     bool          ReadOnly = true;
72 
73     BinOpChain(Instruction *I, ValueList &lhs, ValueList &rhs) :
74       Root(I), LHS(lhs), RHS(rhs) {
75         for (auto *V : LHS)
76           AllValues.push_back(V);
77         for (auto *V : RHS)
78           AllValues.push_back(V);
79     }
80 
81     unsigned size() const { return AllValues.size(); }
82   };
83 
84   /// Represent a sequence of multiply-accumulate operations with the aim to
85   /// perform the multiplications in parallel.
86   class Reduction {
87     Instruction     *Root = nullptr;
88     Value           *Acc = nullptr;
89     OpChainList     Muls;
90     PMACPairList        MulPairs;
91     SmallPtrSet<Instruction*, 4> Adds;
92 
93   public:
94     Reduction() = delete;
95 
96     Reduction (Instruction *Add) : Root(Add) { }
97 
98     /// Record an Add instruction that is a part of the this reduction.
99     void InsertAdd(Instruction *I) { Adds.insert(I); }
100 
101     /// Record a BinOpChain, rooted at a Mul instruction, that is a part of
102     /// this reduction.
103     void InsertMul(Instruction *I, ValueList &LHS, ValueList &RHS) {
104       Muls.push_back(make_unique<BinOpChain>(I, LHS, RHS));
105     }
106 
107     /// Add the incoming accumulator value, returns true if a value had not
108     /// already been added. Returning false signals to the user that this
109     /// reduction already has a value to initialise the accumulator.
110     bool InsertAcc(Value *V) {
111       if (Acc)
112         return false;
113       Acc = V;
114       return true;
115     }
116 
117     /// Set two BinOpChains, rooted at muls, that can be executed as a single
118     /// parallel operation.
119     void AddMulPair(BinOpChain *Mul0, BinOpChain *Mul1) {
120       MulPairs.push_back(std::make_pair(Mul0, Mul1));
121     }
122 
123     /// Return true if enough mul operations are found that can be executed in
124     /// parallel.
125     bool CreateParallelPairs();
126 
127     /// Return the add instruction which is the root of the reduction.
128     Instruction *getRoot() { return Root; }
129 
130     /// Return the incoming value to be accumulated. This maybe null.
131     Value *getAccumulator() { return Acc; }
132 
133     /// Return the set of adds that comprise the reduction.
134     SmallPtrSetImpl<Instruction*> &getAdds() { return Adds; }
135 
136     /// Return the BinOpChain, rooted at mul instruction, that comprise the
137     /// the reduction.
138     OpChainList &getMuls() { return Muls; }
139 
140     /// Return the BinOpChain, rooted at mul instructions, that have been
141     /// paired for parallel execution.
142     PMACPairList &getMulPairs() { return MulPairs; }
143 
144     /// To finalise, replace the uses of the root with the intrinsic call.
145     void UpdateRoot(Instruction *SMLAD) {
146       Root->replaceAllUsesWith(SMLAD);
147     }
148   };
149 
150   class WidenedLoad {
151     LoadInst *NewLd = nullptr;
152     SmallVector<LoadInst*, 4> Loads;
153 
154   public:
155     WidenedLoad(SmallVectorImpl<LoadInst*> &Lds, LoadInst *Wide)
156       : NewLd(Wide) {
157       for (auto *I : Lds)
158         Loads.push_back(I);
159     }
160     LoadInst *getLoad() {
161       return NewLd;
162     }
163   };
164 
165   class ARMParallelDSP : public LoopPass {
166     ScalarEvolution   *SE;
167     AliasAnalysis     *AA;
168     TargetLibraryInfo *TLI;
169     DominatorTree     *DT;
170     LoopInfo          *LI;
171     Loop              *L;
172     const DataLayout  *DL;
173     Module            *M;
174     std::map<LoadInst*, LoadInst*> LoadPairs;
175     SmallPtrSet<LoadInst*, 4> OffsetLoads;
176     std::map<LoadInst*, std::unique_ptr<WidenedLoad>> WideLoads;
177 
178     template<unsigned>
179     bool IsNarrowSequence(Value *V, ValueList &VL);
180 
181     bool RecordMemoryOps(BasicBlock *BB);
182     void InsertParallelMACs(Reduction &Reduction);
183     bool AreSequentialLoads(LoadInst *Ld0, LoadInst *Ld1, MemInstList &VecMem);
184     LoadInst* CreateWideLoad(SmallVectorImpl<LoadInst*> &Loads,
185                              IntegerType *LoadTy);
186     bool CreateParallelPairs(Reduction &R);
187 
188     /// Try to match and generate: SMLAD, SMLADX - Signed Multiply Accumulate
189     /// Dual performs two signed 16x16-bit multiplications. It adds the
190     /// products to a 32-bit accumulate operand. Optionally, the instruction can
191     /// exchange the halfwords of the second operand before performing the
192     /// arithmetic.
193     bool MatchSMLAD(Loop *L);
194 
195   public:
196     static char ID;
197 
198     ARMParallelDSP() : LoopPass(ID) { }
199 
200     bool doInitialization(Loop *L, LPPassManager &LPM) override {
201       LoadPairs.clear();
202       WideLoads.clear();
203       return true;
204     }
205 
206     void getAnalysisUsage(AnalysisUsage &AU) const override {
207       LoopPass::getAnalysisUsage(AU);
208       AU.addRequired<AssumptionCacheTracker>();
209       AU.addRequired<ScalarEvolutionWrapperPass>();
210       AU.addRequired<AAResultsWrapperPass>();
211       AU.addRequired<TargetLibraryInfoWrapperPass>();
212       AU.addRequired<LoopInfoWrapperPass>();
213       AU.addRequired<DominatorTreeWrapperPass>();
214       AU.addRequired<TargetPassConfig>();
215       AU.addPreserved<LoopInfoWrapperPass>();
216       AU.setPreservesCFG();
217     }
218 
219     bool runOnLoop(Loop *TheLoop, LPPassManager &) override {
220       if (DisableParallelDSP)
221         return false;
222       if (skipLoop(TheLoop))
223         return false;
224 
225       L = TheLoop;
226       SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
227       AA = &getAnalysis<AAResultsWrapperPass>().getAAResults();
228       TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
229       DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
230       LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
231       auto &TPC = getAnalysis<TargetPassConfig>();
232 
233       BasicBlock *Header = TheLoop->getHeader();
234       if (!Header)
235         return false;
236 
237       // TODO: We assume the loop header and latch to be the same block.
238       // This is not a fundamental restriction, but lifting this would just
239       // require more work to do the transformation and then patch up the CFG.
240       if (Header != TheLoop->getLoopLatch()) {
241         LLVM_DEBUG(dbgs() << "The loop header is not the loop latch: not "
242                              "running pass ARMParallelDSP\n");
243         return false;
244       }
245 
246       if (!TheLoop->getLoopPreheader())
247         InsertPreheaderForLoop(L, DT, LI, nullptr, true);
248 
249       Function &F = *Header->getParent();
250       M = F.getParent();
251       DL = &M->getDataLayout();
252 
253       auto &TM = TPC.getTM<TargetMachine>();
254       auto *ST = &TM.getSubtarget<ARMSubtarget>(F);
255 
256       if (!ST->allowsUnalignedMem()) {
257         LLVM_DEBUG(dbgs() << "Unaligned memory access not supported: not "
258                              "running pass ARMParallelDSP\n");
259         return false;
260       }
261 
262       if (!ST->hasDSP()) {
263         LLVM_DEBUG(dbgs() << "DSP extension not enabled: not running pass "
264                              "ARMParallelDSP\n");
265         return false;
266       }
267 
268       if (!ST->isLittle()) {
269         LLVM_DEBUG(dbgs() << "Only supporting little endian: not running pass "
270                           << "ARMParallelDSP\n");
271         return false;
272       }
273 
274       LoopAccessInfo LAI(L, SE, TLI, AA, DT, LI);
275 
276       LLVM_DEBUG(dbgs() << "\n== Parallel DSP pass ==\n");
277       LLVM_DEBUG(dbgs() << " - " << F.getName() << "\n\n");
278 
279       if (!RecordMemoryOps(Header)) {
280         LLVM_DEBUG(dbgs() << " - No sequential loads found.\n");
281         return false;
282       }
283 
284       bool Changes = MatchSMLAD(L);
285       return Changes;
286     }
287   };
288 }
289 
290 template<typename MemInst>
291 static bool AreSequentialAccesses(MemInst *MemOp0, MemInst *MemOp1,
292                                   const DataLayout &DL, ScalarEvolution &SE) {
293   if (isConsecutiveAccess(MemOp0, MemOp1, DL, SE))
294     return true;
295   return false;
296 }
297 
298 bool ARMParallelDSP::AreSequentialLoads(LoadInst *Ld0, LoadInst *Ld1,
299                                         MemInstList &VecMem) {
300   if (!Ld0 || !Ld1)
301     return false;
302 
303   if (!LoadPairs.count(Ld0) || LoadPairs[Ld0] != Ld1)
304     return false;
305 
306   LLVM_DEBUG(dbgs() << "Loads are sequential and valid:\n";
307     dbgs() << "Ld0:"; Ld0->dump();
308     dbgs() << "Ld1:"; Ld1->dump();
309   );
310 
311   VecMem.clear();
312   VecMem.push_back(Ld0);
313   VecMem.push_back(Ld1);
314   return true;
315 }
316 
317 // MaxBitwidth: the maximum supported bitwidth of the elements in the DSP
318 // instructions, which is set to 16. So here we should collect all i8 and i16
319 // narrow operations.
320 // TODO: we currently only collect i16, and will support i8 later, so that's
321 // why we check that types are equal to MaxBitWidth, and not <= MaxBitWidth.
322 template<unsigned MaxBitWidth>
323 bool ARMParallelDSP::IsNarrowSequence(Value *V, ValueList &VL) {
324   if (auto *SExt = dyn_cast<SExtInst>(V)) {
325     if (SExt->getSrcTy()->getIntegerBitWidth() != MaxBitWidth)
326       return false;
327 
328     if (auto *Ld = dyn_cast<LoadInst>(SExt->getOperand(0))) {
329       // Check that these load could be paired.
330       if (!LoadPairs.count(Ld) && !OffsetLoads.count(Ld))
331         return false;
332 
333       VL.push_back(Ld);
334       VL.push_back(SExt);
335       return true;
336     }
337   }
338   return false;
339 }
340 
341 /// Iterate through the block and record base, offset pairs of loads which can
342 /// be widened into a single load.
343 bool ARMParallelDSP::RecordMemoryOps(BasicBlock *BB) {
344   SmallVector<LoadInst*, 8> Loads;
345   SmallVector<Instruction*, 8> Writes;
346 
347   // Collect loads and instruction that may write to memory. For now we only
348   // record loads which are simple, sign-extended and have a single user.
349   // TODO: Allow zero-extended loads.
350   for (auto &I : *BB) {
351     if (I.mayWriteToMemory())
352       Writes.push_back(&I);
353     auto *Ld = dyn_cast<LoadInst>(&I);
354     if (!Ld || !Ld->isSimple() ||
355         !Ld->hasOneUse() || !isa<SExtInst>(Ld->user_back()))
356       continue;
357     Loads.push_back(Ld);
358   }
359 
360   using InstSet = std::set<Instruction*>;
361   using DepMap = std::map<Instruction*, InstSet>;
362   DepMap RAWDeps;
363 
364   // Record any writes that may alias a load.
365   const auto Size = LocationSize::unknown();
366   for (auto Read : Loads) {
367     for (auto Write : Writes) {
368       MemoryLocation ReadLoc =
369         MemoryLocation(Read->getPointerOperand(), Size);
370 
371       if (!isModOrRefSet(intersectModRef(AA->getModRefInfo(Write, ReadLoc),
372           ModRefInfo::ModRef)))
373         continue;
374       if (DT->dominates(Write, Read))
375         RAWDeps[Read].insert(Write);
376     }
377   }
378 
379   // Check whether there's not a write between the two loads which would
380   // prevent them from being safely merged.
381   auto SafeToPair = [&](LoadInst *Base, LoadInst *Offset) {
382     LoadInst *Dominator = DT->dominates(Base, Offset) ? Base : Offset;
383     LoadInst *Dominated = DT->dominates(Base, Offset) ? Offset : Base;
384 
385     if (RAWDeps.count(Dominated)) {
386       InstSet &WritesBefore = RAWDeps[Dominated];
387 
388       for (auto Before : WritesBefore) {
389 
390         // We can't move the second load backward, past a write, to merge
391         // with the first load.
392         if (DT->dominates(Dominator, Before))
393           return false;
394       }
395     }
396     return true;
397   };
398 
399   // Record base, offset load pairs.
400   for (auto *Base : Loads) {
401     for (auto *Offset : Loads) {
402       if (Base == Offset)
403         continue;
404 
405       if (AreSequentialAccesses<LoadInst>(Base, Offset, *DL, *SE) &&
406           SafeToPair(Base, Offset)) {
407         LoadPairs[Base] = Offset;
408         OffsetLoads.insert(Offset);
409         break;
410       }
411     }
412   }
413 
414   LLVM_DEBUG(if (!LoadPairs.empty()) {
415                dbgs() << "Consecutive load pairs:\n";
416                for (auto &MapIt : LoadPairs) {
417                  LLVM_DEBUG(dbgs() << *MapIt.first << ", "
418                             << *MapIt.second << "\n");
419                }
420              });
421   return LoadPairs.size() > 1;
422 }
423 
424 // Loop Pass that needs to identify integer add/sub reductions of 16-bit vector
425 // multiplications.
426 // To use SMLAD:
427 // 1) we first need to find integer add then look for this pattern:
428 //
429 // acc0 = ...
430 // ld0 = load i16
431 // sext0 = sext i16 %ld0 to i32
432 // ld1 = load i16
433 // sext1 = sext i16 %ld1 to i32
434 // mul0 = mul %sext0, %sext1
435 // ld2 = load i16
436 // sext2 = sext i16 %ld2 to i32
437 // ld3 = load i16
438 // sext3 = sext i16 %ld3 to i32
439 // mul1 = mul i32 %sext2, %sext3
440 // add0 = add i32 %mul0, %acc0
441 // acc1 = add i32 %add0, %mul1
442 //
443 // Which can be selected to:
444 //
445 // ldr r0
446 // ldr r1
447 // smlad r2, r0, r1, r2
448 //
449 // If constants are used instead of loads, these will need to be hoisted
450 // out and into a register.
451 //
452 // If loop invariants are used instead of loads, these need to be packed
453 // before the loop begins.
454 //
455 bool ARMParallelDSP::MatchSMLAD(Loop *L) {
456   // Search recursively back through the operands to find a tree of values that
457   // form a multiply-accumulate chain. The search records the Add and Mul
458   // instructions that form the reduction and allows us to find a single value
459   // to be used as the initial input to the accumlator.
460   std::function<bool(Value*, Reduction&)> Search = [&]
461     (Value *V, Reduction &R) -> bool {
462 
463     // If we find a non-instruction, try to use it as the initial accumulator
464     // value. This may have already been found during the search in which case
465     // this function will return false, signaling a search fail.
466     auto *I = dyn_cast<Instruction>(V);
467     if (!I)
468       return R.InsertAcc(V);
469 
470     switch (I->getOpcode()) {
471     default:
472       break;
473     case Instruction::PHI:
474       // Could be the accumulator value.
475       return R.InsertAcc(V);
476     case Instruction::Add: {
477       // Adds should be adding together two muls, or another add and a mul to
478       // be within the mac chain. One of the operands may also be the
479       // accumulator value at which point we should stop searching.
480       bool ValidLHS = Search(I->getOperand(0), R);
481       bool ValidRHS = Search(I->getOperand(1), R);
482       if (!ValidLHS && !ValidLHS)
483         return false;
484       else if (ValidLHS && ValidRHS) {
485         R.InsertAdd(I);
486         return true;
487       } else {
488         R.InsertAdd(I);
489         return R.InsertAcc(I);
490       }
491     }
492     case Instruction::Mul: {
493       Value *MulOp0 = I->getOperand(0);
494       Value *MulOp1 = I->getOperand(1);
495       if (isa<SExtInst>(MulOp0) && isa<SExtInst>(MulOp1)) {
496         ValueList LHS;
497         ValueList RHS;
498         if (IsNarrowSequence<16>(MulOp0, LHS) &&
499             IsNarrowSequence<16>(MulOp1, RHS)) {
500           R.InsertMul(I, LHS, RHS);
501           return true;
502         }
503       }
504       return false;
505     }
506     case Instruction::SExt:
507       return Search(I->getOperand(0), R);
508     }
509     return false;
510   };
511 
512   bool Changed = false;
513   SmallPtrSet<Instruction*, 4> AllAdds;
514   BasicBlock *Latch = L->getLoopLatch();
515 
516   for (Instruction &I : reverse(*Latch)) {
517     if (I.getOpcode() != Instruction::Add)
518       continue;
519 
520     if (AllAdds.count(&I))
521       continue;
522 
523     const auto *Ty = I.getType();
524     if (!Ty->isIntegerTy(32) && !Ty->isIntegerTy(64))
525       continue;
526 
527     Reduction R(&I);
528     if (!Search(&I, R))
529       continue;
530 
531     if (!CreateParallelPairs(R))
532       continue;
533 
534     InsertParallelMACs(R);
535     Changed = true;
536     AllAdds.insert(R.getAdds().begin(), R.getAdds().end());
537   }
538 
539   return Changed;
540 }
541 
542 bool ARMParallelDSP::CreateParallelPairs(Reduction &R) {
543 
544   // Not enough mul operations to make a pair.
545   if (R.getMuls().size() < 2)
546     return false;
547 
548   // Check that the muls operate directly upon sign extended loads.
549   for (auto &MulChain : R.getMuls()) {
550     // A mul has 2 operands, and a narrow op consist of sext and a load; thus
551     // we expect at least 4 items in this operand value list.
552     if (MulChain->size() < 4) {
553       LLVM_DEBUG(dbgs() << "Operand list too short.\n");
554       return false;
555     }
556     ValueList &LHS = static_cast<BinOpChain*>(MulChain.get())->LHS;
557     ValueList &RHS = static_cast<BinOpChain*>(MulChain.get())->RHS;
558 
559     // Use +=2 to skip over the expected extend instructions.
560     for (unsigned i = 0, e = LHS.size(); i < e; i += 2) {
561       if (!isa<LoadInst>(LHS[i]) || !isa<LoadInst>(RHS[i]))
562         return false;
563     }
564   }
565 
566   auto CanPair = [&](Reduction &R, BinOpChain *PMul0, BinOpChain *PMul1) {
567     // The first elements of each vector should be loads with sexts. If we
568     // find that its two pairs of consecutive loads, then these can be
569     // transformed into two wider loads and the users can be replaced with
570     // DSP intrinsics.
571     for (unsigned x = 0; x < PMul0->LHS.size(); x += 2) {
572       auto *Ld0 = dyn_cast<LoadInst>(PMul0->LHS[x]);
573       auto *Ld1 = dyn_cast<LoadInst>(PMul1->LHS[x]);
574       auto *Ld2 = dyn_cast<LoadInst>(PMul0->RHS[x]);
575       auto *Ld3 = dyn_cast<LoadInst>(PMul1->RHS[x]);
576 
577       if (!Ld0 || !Ld1 || !Ld2 || !Ld3)
578         return false;
579 
580       LLVM_DEBUG(dbgs() << "Loads:\n"
581                  << " - " << *Ld0 << "\n"
582                  << " - " << *Ld1 << "\n"
583                  << " - " << *Ld2 << "\n"
584                  << " - " << *Ld3 << "\n");
585 
586       if (AreSequentialLoads(Ld0, Ld1, PMul0->VecLd)) {
587         if (AreSequentialLoads(Ld2, Ld3, PMul1->VecLd)) {
588           LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n");
589           R.AddMulPair(PMul0, PMul1);
590           return true;
591         } else if (AreSequentialLoads(Ld3, Ld2, PMul1->VecLd)) {
592           LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n");
593           LLVM_DEBUG(dbgs() << "    exchanging Ld2 and Ld3\n");
594           PMul1->Exchange = true;
595           R.AddMulPair(PMul0, PMul1);
596           return true;
597         }
598       } else if (AreSequentialLoads(Ld1, Ld0, PMul0->VecLd) &&
599                  AreSequentialLoads(Ld2, Ld3, PMul1->VecLd)) {
600         LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n");
601         LLVM_DEBUG(dbgs() << "    exchanging Ld0 and Ld1\n");
602         LLVM_DEBUG(dbgs() << "    and swapping muls\n");
603         PMul0->Exchange = true;
604         // Only the second operand can be exchanged, so swap the muls.
605         R.AddMulPair(PMul1, PMul0);
606         return true;
607       }
608     }
609     return false;
610   };
611 
612   OpChainList &Muls = R.getMuls();
613   const unsigned Elems = Muls.size();
614   SmallPtrSet<const Instruction*, 4> Paired;
615   for (unsigned i = 0; i < Elems; ++i) {
616     BinOpChain *PMul0 = static_cast<BinOpChain*>(Muls[i].get());
617     if (Paired.count(PMul0->Root))
618       continue;
619 
620     for (unsigned j = 0; j < Elems; ++j) {
621       if (i == j)
622         continue;
623 
624       BinOpChain *PMul1 = static_cast<BinOpChain*>(Muls[j].get());
625       if (Paired.count(PMul1->Root))
626         continue;
627 
628       const Instruction *Mul0 = PMul0->Root;
629       const Instruction *Mul1 = PMul1->Root;
630       if (Mul0 == Mul1)
631         continue;
632 
633       assert(PMul0 != PMul1 && "expected different chains");
634 
635       if (CanPair(R, PMul0, PMul1)) {
636         Paired.insert(Mul0);
637         Paired.insert(Mul1);
638         break;
639       }
640     }
641   }
642   return !R.getMulPairs().empty();
643 }
644 
645 
646 void ARMParallelDSP::InsertParallelMACs(Reduction &R) {
647 
648   auto CreateSMLADCall = [&](SmallVectorImpl<LoadInst*> &VecLd0,
649                              SmallVectorImpl<LoadInst*> &VecLd1,
650                              Value *Acc, bool Exchange,
651                              Instruction *InsertAfter) {
652     // Replace the reduction chain with an intrinsic call
653     IntegerType *Ty = IntegerType::get(M->getContext(), 32);
654     LoadInst *WideLd0 = WideLoads.count(VecLd0[0]) ?
655       WideLoads[VecLd0[0]]->getLoad() : CreateWideLoad(VecLd0, Ty);
656     LoadInst *WideLd1 = WideLoads.count(VecLd1[0]) ?
657       WideLoads[VecLd1[0]]->getLoad() : CreateWideLoad(VecLd1, Ty);
658 
659     Value* Args[] = { WideLd0, WideLd1, Acc };
660     Function *SMLAD = nullptr;
661     if (Exchange)
662       SMLAD = Acc->getType()->isIntegerTy(32) ?
663         Intrinsic::getDeclaration(M, Intrinsic::arm_smladx) :
664         Intrinsic::getDeclaration(M, Intrinsic::arm_smlaldx);
665     else
666       SMLAD = Acc->getType()->isIntegerTy(32) ?
667         Intrinsic::getDeclaration(M, Intrinsic::arm_smlad) :
668         Intrinsic::getDeclaration(M, Intrinsic::arm_smlald);
669 
670     IRBuilder<NoFolder> Builder(InsertAfter->getParent(),
671                                 ++BasicBlock::iterator(InsertAfter));
672     Instruction *Call = Builder.CreateCall(SMLAD, Args);
673     NumSMLAD++;
674     return Call;
675   };
676 
677   Instruction *InsertAfter = R.getRoot();
678   Value *Acc = R.getAccumulator();
679   if (!Acc)
680     Acc = ConstantInt::get(IntegerType::get(M->getContext(), 32), 0);
681 
682   LLVM_DEBUG(dbgs() << "Root: " << *InsertAfter << "\n"
683              << "Acc: " << *Acc << "\n");
684   for (auto &Pair : R.getMulPairs()) {
685     BinOpChain *PMul0 = Pair.first;
686     BinOpChain *PMul1 = Pair.second;
687     LLVM_DEBUG(dbgs() << "Muls:\n"
688                << "- " << *PMul0->Root << "\n"
689                << "- " << *PMul1->Root << "\n");
690 
691     Acc = CreateSMLADCall(PMul0->VecLd, PMul1->VecLd, Acc, PMul1->Exchange,
692                           InsertAfter);
693     InsertAfter = cast<Instruction>(Acc);
694   }
695   R.UpdateRoot(cast<Instruction>(Acc));
696 }
697 
698 LoadInst* ARMParallelDSP::CreateWideLoad(SmallVectorImpl<LoadInst*> &Loads,
699                                          IntegerType *LoadTy) {
700   assert(Loads.size() == 2 && "currently only support widening two loads");
701 
702   LoadInst *Base = Loads[0];
703   LoadInst *Offset = Loads[1];
704 
705   Instruction *BaseSExt = dyn_cast<SExtInst>(Base->user_back());
706   Instruction *OffsetSExt = dyn_cast<SExtInst>(Offset->user_back());
707 
708   assert((BaseSExt && OffsetSExt)
709          && "Loads should have a single, extending, user");
710 
711   std::function<void(Value*, Value*)> MoveBefore =
712     [&](Value *A, Value *B) -> void {
713       if (!isa<Instruction>(A) || !isa<Instruction>(B))
714         return;
715 
716       auto *Source = cast<Instruction>(A);
717       auto *Sink = cast<Instruction>(B);
718 
719       if (DT->dominates(Source, Sink) ||
720           Source->getParent() != Sink->getParent() ||
721           isa<PHINode>(Source) || isa<PHINode>(Sink))
722         return;
723 
724       Source->moveBefore(Sink);
725       for (auto &Op : Source->operands())
726         MoveBefore(Op, Source);
727     };
728 
729   // Insert the load at the point of the original dominating load.
730   LoadInst *DomLoad = DT->dominates(Base, Offset) ? Base : Offset;
731   IRBuilder<NoFolder> IRB(DomLoad->getParent(),
732                           ++BasicBlock::iterator(DomLoad));
733 
734   // Bitcast the pointer to a wider type and create the wide load, while making
735   // sure to maintain the original alignment as this prevents ldrd from being
736   // generated when it could be illegal due to memory alignment.
737   const unsigned AddrSpace = DomLoad->getPointerAddressSpace();
738   Value *VecPtr = IRB.CreateBitCast(Base->getPointerOperand(),
739                                     LoadTy->getPointerTo(AddrSpace));
740   LoadInst *WideLoad = IRB.CreateAlignedLoad(LoadTy, VecPtr,
741                                              Base->getAlignment());
742 
743   // Make sure everything is in the correct order in the basic block.
744   MoveBefore(Base->getPointerOperand(), VecPtr);
745   MoveBefore(VecPtr, WideLoad);
746 
747   // From the wide load, create two values that equal the original two loads.
748   // Loads[0] needs trunc while Loads[1] needs a lshr and trunc.
749   // TODO: Support big-endian as well.
750   Value *Bottom = IRB.CreateTrunc(WideLoad, Base->getType());
751   BaseSExt->setOperand(0, Bottom);
752 
753   IntegerType *OffsetTy = cast<IntegerType>(Offset->getType());
754   Value *ShiftVal = ConstantInt::get(LoadTy, OffsetTy->getBitWidth());
755   Value *Top = IRB.CreateLShr(WideLoad, ShiftVal);
756   Value *Trunc = IRB.CreateTrunc(Top, OffsetTy);
757   OffsetSExt->setOperand(0, Trunc);
758 
759   WideLoads.emplace(std::make_pair(Base,
760                                    make_unique<WidenedLoad>(Loads, WideLoad)));
761   return WideLoad;
762 }
763 
764 Pass *llvm::createARMParallelDSPPass() {
765   return new ARMParallelDSP();
766 }
767 
768 char ARMParallelDSP::ID = 0;
769 
770 INITIALIZE_PASS_BEGIN(ARMParallelDSP, "arm-parallel-dsp",
771                 "Transform loops to use DSP intrinsics", false, false)
772 INITIALIZE_PASS_END(ARMParallelDSP, "arm-parallel-dsp",
773                 "Transform loops to use DSP intrinsics", false, false)
774