1 //===- ARMParallelDSP.cpp - Parallel DSP Pass -----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file
10 /// Armv6 introduced instructions to perform 32-bit SIMD operations. The
11 /// purpose of this pass is do some IR pattern matching to create ACLE
12 /// DSP intrinsics, which map on these 32-bit SIMD operations.
13 /// This pass runs only when unaligned accesses is supported/enabled.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "llvm/ADT/Statistic.h"
18 #include "llvm/ADT/SmallPtrSet.h"
19 #include "llvm/Analysis/AliasAnalysis.h"
20 #include "llvm/Analysis/LoopAccessAnalysis.h"
21 #include "llvm/IR/Instructions.h"
22 #include "llvm/IR/NoFolder.h"
23 #include "llvm/Transforms/Scalar.h"
24 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
25 #include "llvm/Pass.h"
26 #include "llvm/PassRegistry.h"
27 #include "llvm/PassSupport.h"
28 #include "llvm/Support/Debug.h"
29 #include "llvm/IR/PatternMatch.h"
30 #include "llvm/CodeGen/TargetPassConfig.h"
31 #include "ARM.h"
32 #include "ARMSubtarget.h"
33 
34 using namespace llvm;
35 using namespace PatternMatch;
36 
37 #define DEBUG_TYPE "arm-parallel-dsp"
38 
39 STATISTIC(NumSMLAD , "Number of smlad instructions generated");
40 
41 static cl::opt<bool>
42 DisableParallelDSP("disable-arm-parallel-dsp", cl::Hidden, cl::init(false),
43                    cl::desc("Disable the ARM Parallel DSP pass"));
44 
45 namespace {
46   struct MulCandidate;
47   class Reduction;
48 
49   using MulCandList = SmallVector<std::unique_ptr<MulCandidate>, 8>;
50   using MemInstList = SmallVectorImpl<LoadInst*>;
51   using MulPairList = SmallVector<std::pair<MulCandidate*, MulCandidate*>, 8>;
52 
53   // 'MulCandidate' holds the multiplication instructions that are candidates
54   // for parallel execution.
55   struct MulCandidate {
56     Instruction   *Root;
57     Value*        LHS;
58     Value*        RHS;
59     bool          Exchange = false;
60     bool          ReadOnly = true;
61     bool          Paired = false;
62     SmallVector<LoadInst*, 2> VecLd;    // Container for loads to widen.
63 
64     MulCandidate(Instruction *I, Value *lhs, Value *rhs) :
65       Root(I), LHS(lhs), RHS(rhs) { }
66 
67     bool HasTwoLoadInputs() const {
68       return isa<LoadInst>(LHS) && isa<LoadInst>(RHS);
69     }
70 
71     LoadInst *getBaseLoad() const {
72       return VecLd.front();
73     }
74   };
75 
76   /// Represent a sequence of multiply-accumulate operations with the aim to
77   /// perform the multiplications in parallel.
78   class Reduction {
79     Instruction     *Root = nullptr;
80     Value           *Acc = nullptr;
81     MulCandList     Muls;
82     MulPairList        MulPairs;
83     SetVector<Instruction*> Adds;
84 
85   public:
86     Reduction() = delete;
87 
88     Reduction (Instruction *Add) : Root(Add) { }
89 
90     /// Record an Add instruction that is a part of the this reduction.
91     void InsertAdd(Instruction *I) { Adds.insert(I); }
92 
93     /// Create MulCandidates, each rooted at a Mul instruction, that is a part
94     /// of this reduction.
95     void InsertMuls() {
96       auto GetMulOperand = [](Value *V) -> Instruction* {
97         if (auto *SExt = dyn_cast<SExtInst>(V)) {
98           if (auto *I = dyn_cast<Instruction>(SExt->getOperand(0)))
99             if (I->getOpcode() == Instruction::Mul)
100               return I;
101         } else if (auto *I = dyn_cast<Instruction>(V)) {
102           if (I->getOpcode() == Instruction::Mul)
103             return I;
104         }
105         return nullptr;
106       };
107 
108       auto InsertMul = [this](Instruction *I) {
109         Value *LHS = cast<Instruction>(I->getOperand(0))->getOperand(0);
110         Value *RHS = cast<Instruction>(I->getOperand(1))->getOperand(0);
111         Muls.push_back(std::make_unique<MulCandidate>(I, LHS, RHS));
112       };
113 
114       for (auto *Add : Adds) {
115         if (Add == Acc)
116           continue;
117         if (auto *Mul = GetMulOperand(Add->getOperand(0)))
118           InsertMul(Mul);
119         if (auto *Mul = GetMulOperand(Add->getOperand(1)))
120           InsertMul(Mul);
121       }
122     }
123 
124     /// Add the incoming accumulator value, returns true if a value had not
125     /// already been added. Returning false signals to the user that this
126     /// reduction already has a value to initialise the accumulator.
127     bool InsertAcc(Value *V) {
128       if (Acc)
129         return false;
130       Acc = V;
131       return true;
132     }
133 
134     /// Set two MulCandidates, rooted at muls, that can be executed as a single
135     /// parallel operation.
136     void AddMulPair(MulCandidate *Mul0, MulCandidate *Mul1,
137                     bool Exchange = false) {
138       LLVM_DEBUG(dbgs() << "Pairing:\n"
139                  << *Mul0->Root << "\n"
140                  << *Mul1->Root << "\n");
141       Mul0->Paired = true;
142       Mul1->Paired = true;
143       if (Exchange)
144         Mul1->Exchange = true;
145       MulPairs.push_back(std::make_pair(Mul0, Mul1));
146     }
147 
148     /// Return true if enough mul operations are found that can be executed in
149     /// parallel.
150     bool CreateParallelPairs();
151 
152     /// Return the add instruction which is the root of the reduction.
153     Instruction *getRoot() { return Root; }
154 
155     bool is64Bit() const { return Root->getType()->isIntegerTy(64); }
156 
157     Type *getType() const { return Root->getType(); }
158 
159     /// Return the incoming value to be accumulated. This maybe null.
160     Value *getAccumulator() { return Acc; }
161 
162     /// Return the set of adds that comprise the reduction.
163     SetVector<Instruction*> &getAdds() { return Adds; }
164 
165     /// Return the MulCandidate, rooted at mul instruction, that comprise the
166     /// the reduction.
167     MulCandList &getMuls() { return Muls; }
168 
169     /// Return the MulCandidate, rooted at mul instructions, that have been
170     /// paired for parallel execution.
171     MulPairList &getMulPairs() { return MulPairs; }
172 
173     /// To finalise, replace the uses of the root with the intrinsic call.
174     void UpdateRoot(Instruction *SMLAD) {
175       Root->replaceAllUsesWith(SMLAD);
176     }
177 
178     void dump() {
179       LLVM_DEBUG(dbgs() << "Reduction:\n";
180         for (auto *Add : Adds)
181           LLVM_DEBUG(dbgs() << *Add << "\n");
182         for (auto &Mul : Muls)
183           LLVM_DEBUG(dbgs() << *Mul->Root << "\n"
184                      << "  " << *Mul->LHS << "\n"
185                      << "  " << *Mul->RHS << "\n");
186         LLVM_DEBUG(if (Acc) dbgs() << "Acc in: " << *Acc << "\n")
187       );
188     }
189   };
190 
191   class WidenedLoad {
192     LoadInst *NewLd = nullptr;
193     SmallVector<LoadInst*, 4> Loads;
194 
195   public:
196     WidenedLoad(SmallVectorImpl<LoadInst*> &Lds, LoadInst *Wide)
197       : NewLd(Wide) {
198       for (auto *I : Lds)
199         Loads.push_back(I);
200     }
201     LoadInst *getLoad() {
202       return NewLd;
203     }
204   };
205 
206   class ARMParallelDSP : public FunctionPass {
207     ScalarEvolution   *SE;
208     AliasAnalysis     *AA;
209     TargetLibraryInfo *TLI;
210     DominatorTree     *DT;
211     const DataLayout  *DL;
212     Module            *M;
213     std::map<LoadInst*, LoadInst*> LoadPairs;
214     SmallPtrSet<LoadInst*, 4> OffsetLoads;
215     std::map<LoadInst*, std::unique_ptr<WidenedLoad>> WideLoads;
216 
217     template<unsigned>
218     bool IsNarrowSequence(Value *V);
219     bool Search(Value *V, BasicBlock *BB, Reduction &R);
220     bool RecordMemoryOps(BasicBlock *BB);
221     void InsertParallelMACs(Reduction &Reduction);
222     bool AreSequentialLoads(LoadInst *Ld0, LoadInst *Ld1, MemInstList &VecMem);
223     LoadInst* CreateWideLoad(MemInstList &Loads, IntegerType *LoadTy);
224     bool CreateParallelPairs(Reduction &R);
225 
226     /// Try to match and generate: SMLAD, SMLADX - Signed Multiply Accumulate
227     /// Dual performs two signed 16x16-bit multiplications. It adds the
228     /// products to a 32-bit accumulate operand. Optionally, the instruction can
229     /// exchange the halfwords of the second operand before performing the
230     /// arithmetic.
231     bool MatchSMLAD(Function &F);
232 
233   public:
234     static char ID;
235 
236     ARMParallelDSP() : FunctionPass(ID) { }
237 
238     void getAnalysisUsage(AnalysisUsage &AU) const override {
239       FunctionPass::getAnalysisUsage(AU);
240       AU.addRequired<AssumptionCacheTracker>();
241       AU.addRequired<ScalarEvolutionWrapperPass>();
242       AU.addRequired<AAResultsWrapperPass>();
243       AU.addRequired<TargetLibraryInfoWrapperPass>();
244       AU.addRequired<DominatorTreeWrapperPass>();
245       AU.addRequired<TargetPassConfig>();
246       AU.addPreserved<ScalarEvolutionWrapperPass>();
247       AU.addPreserved<GlobalsAAWrapperPass>();
248       AU.setPreservesCFG();
249     }
250 
251     bool runOnFunction(Function &F) override {
252       if (DisableParallelDSP)
253         return false;
254       if (skipFunction(F))
255         return false;
256 
257       SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
258       AA = &getAnalysis<AAResultsWrapperPass>().getAAResults();
259       TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
260       DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
261       auto &TPC = getAnalysis<TargetPassConfig>();
262 
263       M = F.getParent();
264       DL = &M->getDataLayout();
265 
266       auto &TM = TPC.getTM<TargetMachine>();
267       auto *ST = &TM.getSubtarget<ARMSubtarget>(F);
268 
269       if (!ST->allowsUnalignedMem()) {
270         LLVM_DEBUG(dbgs() << "Unaligned memory access not supported: not "
271                              "running pass ARMParallelDSP\n");
272         return false;
273       }
274 
275       if (!ST->hasDSP()) {
276         LLVM_DEBUG(dbgs() << "DSP extension not enabled: not running pass "
277                              "ARMParallelDSP\n");
278         return false;
279       }
280 
281       if (!ST->isLittle()) {
282         LLVM_DEBUG(dbgs() << "Only supporting little endian: not running pass "
283                           << "ARMParallelDSP\n");
284         return false;
285       }
286 
287       LLVM_DEBUG(dbgs() << "\n== Parallel DSP pass ==\n");
288       LLVM_DEBUG(dbgs() << " - " << F.getName() << "\n\n");
289 
290       bool Changes = MatchSMLAD(F);
291       return Changes;
292     }
293   };
294 }
295 
296 template<typename MemInst>
297 static bool AreSequentialAccesses(MemInst *MemOp0, MemInst *MemOp1,
298                                   const DataLayout &DL, ScalarEvolution &SE) {
299   if (isConsecutiveAccess(MemOp0, MemOp1, DL, SE))
300     return true;
301   return false;
302 }
303 
304 bool ARMParallelDSP::AreSequentialLoads(LoadInst *Ld0, LoadInst *Ld1,
305                                         MemInstList &VecMem) {
306   if (!Ld0 || !Ld1)
307     return false;
308 
309   if (!LoadPairs.count(Ld0) || LoadPairs[Ld0] != Ld1)
310     return false;
311 
312   LLVM_DEBUG(dbgs() << "Loads are sequential and valid:\n";
313     dbgs() << "Ld0:"; Ld0->dump();
314     dbgs() << "Ld1:"; Ld1->dump();
315   );
316 
317   VecMem.clear();
318   VecMem.push_back(Ld0);
319   VecMem.push_back(Ld1);
320   return true;
321 }
322 
323 // MaxBitwidth: the maximum supported bitwidth of the elements in the DSP
324 // instructions, which is set to 16. So here we should collect all i8 and i16
325 // narrow operations.
326 // TODO: we currently only collect i16, and will support i8 later, so that's
327 // why we check that types are equal to MaxBitWidth, and not <= MaxBitWidth.
328 template<unsigned MaxBitWidth>
329 bool ARMParallelDSP::IsNarrowSequence(Value *V) {
330   if (auto *SExt = dyn_cast<SExtInst>(V)) {
331     if (SExt->getSrcTy()->getIntegerBitWidth() != MaxBitWidth)
332       return false;
333 
334     if (auto *Ld = dyn_cast<LoadInst>(SExt->getOperand(0))) {
335       // Check that this load could be paired.
336       return LoadPairs.count(Ld) || OffsetLoads.count(Ld);
337     }
338   }
339   return false;
340 }
341 
342 /// Iterate through the block and record base, offset pairs of loads which can
343 /// be widened into a single load.
344 bool ARMParallelDSP::RecordMemoryOps(BasicBlock *BB) {
345   SmallVector<LoadInst*, 8> Loads;
346   SmallVector<Instruction*, 8> Writes;
347   LoadPairs.clear();
348   WideLoads.clear();
349 
350   // Collect loads and instruction that may write to memory. For now we only
351   // record loads which are simple, sign-extended and have a single user.
352   // TODO: Allow zero-extended loads.
353   for (auto &I : *BB) {
354     if (I.mayWriteToMemory())
355       Writes.push_back(&I);
356     auto *Ld = dyn_cast<LoadInst>(&I);
357     if (!Ld || !Ld->isSimple() ||
358         !Ld->hasOneUse() || !isa<SExtInst>(Ld->user_back()))
359       continue;
360     Loads.push_back(Ld);
361   }
362 
363   using InstSet = std::set<Instruction*>;
364   using DepMap = std::map<Instruction*, InstSet>;
365   DepMap RAWDeps;
366 
367   // Record any writes that may alias a load.
368   const auto Size = LocationSize::unknown();
369   for (auto Read : Loads) {
370     for (auto Write : Writes) {
371       MemoryLocation ReadLoc =
372         MemoryLocation(Read->getPointerOperand(), Size);
373 
374       if (!isModOrRefSet(intersectModRef(AA->getModRefInfo(Write, ReadLoc),
375           ModRefInfo::ModRef)))
376         continue;
377       if (DT->dominates(Write, Read))
378         RAWDeps[Read].insert(Write);
379     }
380   }
381 
382   // Check whether there's not a write between the two loads which would
383   // prevent them from being safely merged.
384   auto SafeToPair = [&](LoadInst *Base, LoadInst *Offset) {
385     LoadInst *Dominator = DT->dominates(Base, Offset) ? Base : Offset;
386     LoadInst *Dominated = DT->dominates(Base, Offset) ? Offset : Base;
387 
388     if (RAWDeps.count(Dominated)) {
389       InstSet &WritesBefore = RAWDeps[Dominated];
390 
391       for (auto Before : WritesBefore) {
392         // We can't move the second load backward, past a write, to merge
393         // with the first load.
394         if (DT->dominates(Dominator, Before))
395           return false;
396       }
397     }
398     return true;
399   };
400 
401   // Record base, offset load pairs.
402   for (auto *Base : Loads) {
403     for (auto *Offset : Loads) {
404       if (Base == Offset)
405         continue;
406 
407       if (AreSequentialAccesses<LoadInst>(Base, Offset, *DL, *SE) &&
408           SafeToPair(Base, Offset)) {
409         LoadPairs[Base] = Offset;
410         OffsetLoads.insert(Offset);
411         break;
412       }
413     }
414   }
415 
416   LLVM_DEBUG(if (!LoadPairs.empty()) {
417                dbgs() << "Consecutive load pairs:\n";
418                for (auto &MapIt : LoadPairs) {
419                  LLVM_DEBUG(dbgs() << *MapIt.first << ", "
420                             << *MapIt.second << "\n");
421                }
422              });
423   return LoadPairs.size() > 1;
424 }
425 
426 // Search recursively back through the operands to find a tree of values that
427 // form a multiply-accumulate chain. The search records the Add and Mul
428 // instructions that form the reduction and allows us to find a single value
429 // to be used as the initial input to the accumlator.
430 bool ARMParallelDSP::Search(Value *V, BasicBlock *BB, Reduction &R) {
431   // If we find a non-instruction, try to use it as the initial accumulator
432   // value. This may have already been found during the search in which case
433   // this function will return false, signaling a search fail.
434   auto *I = dyn_cast<Instruction>(V);
435   if (!I)
436     return R.InsertAcc(V);
437 
438   if (I->getParent() != BB)
439     return false;
440 
441   switch (I->getOpcode()) {
442   default:
443     break;
444   case Instruction::PHI:
445     // Could be the accumulator value.
446     return R.InsertAcc(V);
447   case Instruction::Add: {
448     // Adds should be adding together two muls, or another add and a mul to
449     // be within the mac chain. One of the operands may also be the
450     // accumulator value at which point we should stop searching.
451     R.InsertAdd(I);
452     Value *LHS = I->getOperand(0);
453     Value *RHS = I->getOperand(1);
454     bool ValidLHS = Search(LHS, BB, R);
455     bool ValidRHS = Search(RHS, BB, R);
456 
457     if (ValidLHS && ValidRHS)
458       return true;
459 
460     return R.InsertAcc(I);
461   }
462   case Instruction::Mul: {
463     Value *MulOp0 = I->getOperand(0);
464     Value *MulOp1 = I->getOperand(1);
465     return IsNarrowSequence<16>(MulOp0) && IsNarrowSequence<16>(MulOp1);
466   }
467   case Instruction::SExt:
468     return Search(I->getOperand(0), BB, R);
469   }
470   return false;
471 }
472 
473 // The pass needs to identify integer add/sub reductions of 16-bit vector
474 // multiplications.
475 // To use SMLAD:
476 // 1) we first need to find integer add then look for this pattern:
477 //
478 // acc0 = ...
479 // ld0 = load i16
480 // sext0 = sext i16 %ld0 to i32
481 // ld1 = load i16
482 // sext1 = sext i16 %ld1 to i32
483 // mul0 = mul %sext0, %sext1
484 // ld2 = load i16
485 // sext2 = sext i16 %ld2 to i32
486 // ld3 = load i16
487 // sext3 = sext i16 %ld3 to i32
488 // mul1 = mul i32 %sext2, %sext3
489 // add0 = add i32 %mul0, %acc0
490 // acc1 = add i32 %add0, %mul1
491 //
492 // Which can be selected to:
493 //
494 // ldr r0
495 // ldr r1
496 // smlad r2, r0, r1, r2
497 //
498 // If constants are used instead of loads, these will need to be hoisted
499 // out and into a register.
500 //
501 // If loop invariants are used instead of loads, these need to be packed
502 // before the loop begins.
503 //
504 bool ARMParallelDSP::MatchSMLAD(Function &F) {
505   bool Changed = false;
506 
507   for (auto &BB : F) {
508     SmallPtrSet<Instruction*, 4> AllAdds;
509     if (!RecordMemoryOps(&BB))
510       continue;
511 
512     for (Instruction &I : reverse(BB)) {
513       if (I.getOpcode() != Instruction::Add)
514         continue;
515 
516       if (AllAdds.count(&I))
517         continue;
518 
519       const auto *Ty = I.getType();
520       if (!Ty->isIntegerTy(32) && !Ty->isIntegerTy(64))
521         continue;
522 
523       Reduction R(&I);
524       if (!Search(&I, &BB, R))
525         continue;
526 
527       R.InsertMuls();
528       LLVM_DEBUG(dbgs() << "After search, Reduction:\n"; R.dump());
529 
530       if (!CreateParallelPairs(R))
531         continue;
532 
533       InsertParallelMACs(R);
534       Changed = true;
535       AllAdds.insert(R.getAdds().begin(), R.getAdds().end());
536     }
537   }
538 
539   return Changed;
540 }
541 
542 bool ARMParallelDSP::CreateParallelPairs(Reduction &R) {
543 
544   // Not enough mul operations to make a pair.
545   if (R.getMuls().size() < 2)
546     return false;
547 
548   // Check that the muls operate directly upon sign extended loads.
549   for (auto &MulCand : R.getMuls()) {
550     if (!MulCand->HasTwoLoadInputs())
551       return false;
552   }
553 
554   auto CanPair = [&](Reduction &R, MulCandidate *PMul0, MulCandidate *PMul1) {
555     // The first elements of each vector should be loads with sexts. If we
556     // find that its two pairs of consecutive loads, then these can be
557     // transformed into two wider loads and the users can be replaced with
558     // DSP intrinsics.
559     auto Ld0 = static_cast<LoadInst*>(PMul0->LHS);
560     auto Ld1 = static_cast<LoadInst*>(PMul1->LHS);
561     auto Ld2 = static_cast<LoadInst*>(PMul0->RHS);
562     auto Ld3 = static_cast<LoadInst*>(PMul1->RHS);
563 
564     if (AreSequentialLoads(Ld0, Ld1, PMul0->VecLd)) {
565       if (AreSequentialLoads(Ld2, Ld3, PMul1->VecLd)) {
566         LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n");
567         R.AddMulPair(PMul0, PMul1);
568         return true;
569       } else if (AreSequentialLoads(Ld3, Ld2, PMul1->VecLd)) {
570         LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n");
571         LLVM_DEBUG(dbgs() << "    exchanging Ld2 and Ld3\n");
572         R.AddMulPair(PMul0, PMul1, true);
573         return true;
574       }
575     } else if (AreSequentialLoads(Ld1, Ld0, PMul0->VecLd) &&
576                AreSequentialLoads(Ld2, Ld3, PMul1->VecLd)) {
577       LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n");
578       LLVM_DEBUG(dbgs() << "    exchanging Ld0 and Ld1\n");
579       LLVM_DEBUG(dbgs() << "    and swapping muls\n");
580       // Only the second operand can be exchanged, so swap the muls.
581       R.AddMulPair(PMul1, PMul0, true);
582       return true;
583     }
584     return false;
585   };
586 
587   MulCandList &Muls = R.getMuls();
588   const unsigned Elems = Muls.size();
589   for (unsigned i = 0; i < Elems; ++i) {
590     MulCandidate *PMul0 = static_cast<MulCandidate*>(Muls[i].get());
591     if (PMul0->Paired)
592       continue;
593 
594     for (unsigned j = 0; j < Elems; ++j) {
595       if (i == j)
596         continue;
597 
598       MulCandidate *PMul1 = static_cast<MulCandidate*>(Muls[j].get());
599       if (PMul1->Paired)
600         continue;
601 
602       const Instruction *Mul0 = PMul0->Root;
603       const Instruction *Mul1 = PMul1->Root;
604       if (Mul0 == Mul1)
605         continue;
606 
607       assert(PMul0 != PMul1 && "expected different chains");
608 
609       if (CanPair(R, PMul0, PMul1))
610         break;
611     }
612   }
613   return !R.getMulPairs().empty();
614 }
615 
616 
617 void ARMParallelDSP::InsertParallelMACs(Reduction &R) {
618 
619   auto CreateSMLAD = [&](LoadInst* WideLd0, LoadInst *WideLd1,
620                          Value *Acc, bool Exchange,
621                          Instruction *InsertAfter) {
622     // Replace the reduction chain with an intrinsic call
623 
624     Value* Args[] = { WideLd0, WideLd1, Acc };
625     Function *SMLAD = nullptr;
626     if (Exchange)
627       SMLAD = Acc->getType()->isIntegerTy(32) ?
628         Intrinsic::getDeclaration(M, Intrinsic::arm_smladx) :
629         Intrinsic::getDeclaration(M, Intrinsic::arm_smlaldx);
630     else
631       SMLAD = Acc->getType()->isIntegerTy(32) ?
632         Intrinsic::getDeclaration(M, Intrinsic::arm_smlad) :
633         Intrinsic::getDeclaration(M, Intrinsic::arm_smlald);
634 
635     IRBuilder<NoFolder> Builder(InsertAfter->getParent(),
636                                 ++BasicBlock::iterator(InsertAfter));
637     Instruction *Call = Builder.CreateCall(SMLAD, Args);
638     NumSMLAD++;
639     return Call;
640   };
641 
642   Instruction *InsertAfter = R.getRoot();
643   Value *Acc = R.getAccumulator();
644 
645   // For any muls that were discovered but not paired, accumulate their values
646   // as before.
647   IRBuilder<NoFolder> Builder(InsertAfter->getParent(),
648                               ++BasicBlock::iterator(InsertAfter));
649   MulCandList &MulCands = R.getMuls();
650   for (auto &MulCand : MulCands) {
651     if (MulCand->Paired)
652       continue;
653 
654     Value *Mul = MulCand->Root;
655     LLVM_DEBUG(dbgs() << "Accumulating unpaired mul: " << *Mul << "\n");
656 
657     if (R.getType() != Mul->getType()) {
658       assert(R.is64Bit() && "expected 64-bit result");
659       Mul = Builder.CreateSExt(Mul, R.getType());
660     }
661 
662     if (!Acc) {
663       Acc = Mul;
664       continue;
665     }
666 
667     Acc = Builder.CreateAdd(Mul, Acc);
668     InsertAfter = cast<Instruction>(Acc);
669   }
670 
671   if (!Acc) {
672     Acc = R.is64Bit() ?
673       ConstantInt::get(IntegerType::get(M->getContext(), 64), 0) :
674       ConstantInt::get(IntegerType::get(M->getContext(), 32), 0);
675   } else if (Acc->getType() != R.getType()) {
676     Builder.SetInsertPoint(R.getRoot());
677     Acc = Builder.CreateSExt(Acc, R.getType());
678   }
679 
680   IntegerType *Ty = IntegerType::get(M->getContext(), 32);
681   for (auto &Pair : R.getMulPairs()) {
682     MulCandidate *LHSMul = Pair.first;
683     MulCandidate *RHSMul = Pair.second;
684     LoadInst *BaseLHS = LHSMul->getBaseLoad();
685     LoadInst *BaseRHS = RHSMul->getBaseLoad();
686     LoadInst *WideLHS = WideLoads.count(BaseLHS) ?
687       WideLoads[BaseLHS]->getLoad() : CreateWideLoad(LHSMul->VecLd, Ty);
688     LoadInst *WideRHS = WideLoads.count(BaseRHS) ?
689       WideLoads[BaseRHS]->getLoad() : CreateWideLoad(RHSMul->VecLd, Ty);
690 
691     Acc = CreateSMLAD(WideLHS, WideRHS, Acc, RHSMul->Exchange, InsertAfter);
692     InsertAfter = cast<Instruction>(Acc);
693   }
694   R.UpdateRoot(cast<Instruction>(Acc));
695 }
696 
697 LoadInst* ARMParallelDSP::CreateWideLoad(MemInstList &Loads,
698                                          IntegerType *LoadTy) {
699   assert(Loads.size() == 2 && "currently only support widening two loads");
700 
701   LoadInst *Base = Loads[0];
702   LoadInst *Offset = Loads[1];
703 
704   Instruction *BaseSExt = dyn_cast<SExtInst>(Base->user_back());
705   Instruction *OffsetSExt = dyn_cast<SExtInst>(Offset->user_back());
706 
707   assert((BaseSExt && OffsetSExt)
708          && "Loads should have a single, extending, user");
709 
710   std::function<void(Value*, Value*)> MoveBefore =
711     [&](Value *A, Value *B) -> void {
712       if (!isa<Instruction>(A) || !isa<Instruction>(B))
713         return;
714 
715       auto *Source = cast<Instruction>(A);
716       auto *Sink = cast<Instruction>(B);
717 
718       if (DT->dominates(Source, Sink) ||
719           Source->getParent() != Sink->getParent() ||
720           isa<PHINode>(Source) || isa<PHINode>(Sink))
721         return;
722 
723       Source->moveBefore(Sink);
724       for (auto &Op : Source->operands())
725         MoveBefore(Op, Source);
726     };
727 
728   // Insert the load at the point of the original dominating load.
729   LoadInst *DomLoad = DT->dominates(Base, Offset) ? Base : Offset;
730   IRBuilder<NoFolder> IRB(DomLoad->getParent(),
731                           ++BasicBlock::iterator(DomLoad));
732 
733   // Bitcast the pointer to a wider type and create the wide load, while making
734   // sure to maintain the original alignment as this prevents ldrd from being
735   // generated when it could be illegal due to memory alignment.
736   const unsigned AddrSpace = DomLoad->getPointerAddressSpace();
737   Value *VecPtr = IRB.CreateBitCast(Base->getPointerOperand(),
738                                     LoadTy->getPointerTo(AddrSpace));
739   LoadInst *WideLoad = IRB.CreateAlignedLoad(LoadTy, VecPtr,
740                                              Base->getAlignment());
741 
742   // Make sure everything is in the correct order in the basic block.
743   MoveBefore(Base->getPointerOperand(), VecPtr);
744   MoveBefore(VecPtr, WideLoad);
745 
746   // From the wide load, create two values that equal the original two loads.
747   // Loads[0] needs trunc while Loads[1] needs a lshr and trunc.
748   // TODO: Support big-endian as well.
749   Value *Bottom = IRB.CreateTrunc(WideLoad, Base->getType());
750   Value *NewBaseSExt = IRB.CreateSExt(Bottom, BaseSExt->getType());
751   BaseSExt->replaceAllUsesWith(NewBaseSExt);
752 
753   IntegerType *OffsetTy = cast<IntegerType>(Offset->getType());
754   Value *ShiftVal = ConstantInt::get(LoadTy, OffsetTy->getBitWidth());
755   Value *Top = IRB.CreateLShr(WideLoad, ShiftVal);
756   Value *Trunc = IRB.CreateTrunc(Top, OffsetTy);
757   Value *NewOffsetSExt = IRB.CreateSExt(Trunc, OffsetSExt->getType());
758   OffsetSExt->replaceAllUsesWith(NewOffsetSExt);
759 
760   LLVM_DEBUG(dbgs() << "From Base and Offset:\n"
761              << *Base << "\n" << *Offset << "\n"
762              << "Created Wide Load:\n"
763              << *WideLoad << "\n"
764              << *Bottom << "\n"
765              << *NewBaseSExt << "\n"
766              << *Top << "\n"
767              << *Trunc << "\n"
768              << *NewOffsetSExt << "\n");
769   WideLoads.emplace(std::make_pair(Base,
770                                    std::make_unique<WidenedLoad>(Loads, WideLoad)));
771   return WideLoad;
772 }
773 
774 Pass *llvm::createARMParallelDSPPass() {
775   return new ARMParallelDSP();
776 }
777 
778 char ARMParallelDSP::ID = 0;
779 
780 INITIALIZE_PASS_BEGIN(ARMParallelDSP, "arm-parallel-dsp",
781                 "Transform functions to use DSP intrinsics", false, false)
782 INITIALIZE_PASS_END(ARMParallelDSP, "arm-parallel-dsp",
783                 "Transform functions to use DSP intrinsics", false, false)
784