1 //===- ARMParallelDSP.cpp - Parallel DSP Pass -----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file
10 /// Armv6 introduced instructions to perform 32-bit SIMD operations. The
11 /// purpose of this pass is do some IR pattern matching to create ACLE
12 /// DSP intrinsics, which map on these 32-bit SIMD operations.
13 /// This pass runs only when unaligned accesses is supported/enabled.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "llvm/ADT/Statistic.h"
18 #include "llvm/ADT/SmallPtrSet.h"
19 #include "llvm/Analysis/AliasAnalysis.h"
20 #include "llvm/Analysis/LoopAccessAnalysis.h"
21 #include "llvm/IR/Instructions.h"
22 #include "llvm/IR/NoFolder.h"
23 #include "llvm/Transforms/Scalar.h"
24 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
25 #include "llvm/Pass.h"
26 #include "llvm/PassRegistry.h"
27 #include "llvm/PassSupport.h"
28 #include "llvm/Support/Debug.h"
29 #include "llvm/IR/PatternMatch.h"
30 #include "llvm/CodeGen/TargetPassConfig.h"
31 #include "ARM.h"
32 #include "ARMSubtarget.h"
33 
34 using namespace llvm;
35 using namespace PatternMatch;
36 
37 #define DEBUG_TYPE "arm-parallel-dsp"
38 
39 STATISTIC(NumSMLAD , "Number of smlad instructions generated");
40 
41 static cl::opt<bool>
42 DisableParallelDSP("disable-arm-parallel-dsp", cl::Hidden, cl::init(false),
43                    cl::desc("Disable the ARM Parallel DSP pass"));
44 
45 namespace {
46   struct OpChain;
47   struct MulCandidate;
48   class Reduction;
49 
50   using MulCandList     = SmallVector<std::unique_ptr<MulCandidate>, 8>;
51   using ReductionList   = SmallVector<Reduction, 8>;
52   using MemInstList     = SmallVector<LoadInst*, 8>;
53   using PMACPair        = std::pair<MulCandidate*,MulCandidate*>;
54   using PMACPairList    = SmallVector<PMACPair, 8>;
55 
56   // 'MulCandidate' holds the multiplication instructions that are candidates
57   // for parallel execution.
58   struct MulCandidate {
59     Instruction   *Root;
60     MemInstList   VecLd;    // Container for loads to widen.
61     Value*        LHS;
62     Value*        RHS;
63     bool          Exchange = false;
64     bool          ReadOnly = true;
65 
66     MulCandidate(Instruction *I, Value *lhs, Value *rhs) :
67       Root(I), LHS(lhs), RHS(rhs) { }
68 
69     bool HasTwoLoadInputs() const {
70       return isa<LoadInst>(LHS) && isa<LoadInst>(RHS);
71     }
72 
73     LoadInst *getBaseLoad() const {
74       return cast<LoadInst>(LHS);
75     }
76   };
77 
78   /// Represent a sequence of multiply-accumulate operations with the aim to
79   /// perform the multiplications in parallel.
80   class Reduction {
81     Instruction     *Root = nullptr;
82     Value           *Acc = nullptr;
83     MulCandList     Muls;
84     PMACPairList        MulPairs;
85     SmallPtrSet<Instruction*, 4> Adds;
86 
87   public:
88     Reduction() = delete;
89 
90     Reduction (Instruction *Add) : Root(Add) { }
91 
92     /// Record an Add instruction that is a part of the this reduction.
93     void InsertAdd(Instruction *I) { Adds.insert(I); }
94 
95     /// Record a MulCandidate, rooted at a Mul instruction, that is a part of
96     /// this reduction.
97     void InsertMul(Instruction *I, Value *LHS, Value *RHS) {
98       Muls.push_back(make_unique<MulCandidate>(I, LHS, RHS));
99     }
100 
101     /// Add the incoming accumulator value, returns true if a value had not
102     /// already been added. Returning false signals to the user that this
103     /// reduction already has a value to initialise the accumulator.
104     bool InsertAcc(Value *V) {
105       if (Acc)
106         return false;
107       Acc = V;
108       return true;
109     }
110 
111     /// Set two MulCandidates, rooted at muls, that can be executed as a single
112     /// parallel operation.
113     void AddMulPair(MulCandidate *Mul0, MulCandidate *Mul1) {
114       MulPairs.push_back(std::make_pair(Mul0, Mul1));
115     }
116 
117     /// Return true if enough mul operations are found that can be executed in
118     /// parallel.
119     bool CreateParallelPairs();
120 
121     /// Return the add instruction which is the root of the reduction.
122     Instruction *getRoot() { return Root; }
123 
124     bool is64Bit() const { return Root->getType()->isIntegerTy(64); }
125 
126     /// Return the incoming value to be accumulated. This maybe null.
127     Value *getAccumulator() { return Acc; }
128 
129     /// Return the set of adds that comprise the reduction.
130     SmallPtrSetImpl<Instruction*> &getAdds() { return Adds; }
131 
132     /// Return the MulCandidate, rooted at mul instruction, that comprise the
133     /// the reduction.
134     MulCandList &getMuls() { return Muls; }
135 
136     /// Return the MulCandidate, rooted at mul instructions, that have been
137     /// paired for parallel execution.
138     PMACPairList &getMulPairs() { return MulPairs; }
139 
140     /// To finalise, replace the uses of the root with the intrinsic call.
141     void UpdateRoot(Instruction *SMLAD) {
142       Root->replaceAllUsesWith(SMLAD);
143     }
144   };
145 
146   class WidenedLoad {
147     LoadInst *NewLd = nullptr;
148     SmallVector<LoadInst*, 4> Loads;
149 
150   public:
151     WidenedLoad(SmallVectorImpl<LoadInst*> &Lds, LoadInst *Wide)
152       : NewLd(Wide) {
153       for (auto *I : Lds)
154         Loads.push_back(I);
155     }
156     LoadInst *getLoad() {
157       return NewLd;
158     }
159   };
160 
161   class ARMParallelDSP : public FunctionPass {
162     ScalarEvolution   *SE;
163     AliasAnalysis     *AA;
164     TargetLibraryInfo *TLI;
165     DominatorTree     *DT;
166     const DataLayout  *DL;
167     Module            *M;
168     std::map<LoadInst*, LoadInst*> LoadPairs;
169     SmallPtrSet<LoadInst*, 4> OffsetLoads;
170     std::map<LoadInst*, std::unique_ptr<WidenedLoad>> WideLoads;
171 
172     template<unsigned>
173     bool IsNarrowSequence(Value *V, Value *&Src);
174 
175     bool RecordMemoryOps(BasicBlock *BB);
176     void InsertParallelMACs(Reduction &Reduction);
177     bool AreSequentialLoads(LoadInst *Ld0, LoadInst *Ld1, MemInstList &VecMem);
178     LoadInst* CreateWideLoad(SmallVectorImpl<LoadInst*> &Loads,
179                              IntegerType *LoadTy);
180     bool CreateParallelPairs(Reduction &R);
181 
182     /// Try to match and generate: SMLAD, SMLADX - Signed Multiply Accumulate
183     /// Dual performs two signed 16x16-bit multiplications. It adds the
184     /// products to a 32-bit accumulate operand. Optionally, the instruction can
185     /// exchange the halfwords of the second operand before performing the
186     /// arithmetic.
187     bool MatchSMLAD(Function &F);
188 
189   public:
190     static char ID;
191 
192     ARMParallelDSP() : FunctionPass(ID) { }
193 
194     void getAnalysisUsage(AnalysisUsage &AU) const override {
195       FunctionPass::getAnalysisUsage(AU);
196       AU.addRequired<AssumptionCacheTracker>();
197       AU.addRequired<ScalarEvolutionWrapperPass>();
198       AU.addRequired<AAResultsWrapperPass>();
199       AU.addRequired<TargetLibraryInfoWrapperPass>();
200       AU.addRequired<DominatorTreeWrapperPass>();
201       AU.addRequired<TargetPassConfig>();
202       AU.addPreserved<ScalarEvolutionWrapperPass>();
203       AU.addPreserved<GlobalsAAWrapperPass>();
204       AU.setPreservesCFG();
205     }
206 
207     bool runOnFunction(Function &F) override {
208       if (DisableParallelDSP)
209         return false;
210       if (skipFunction(F))
211         return false;
212 
213       SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
214       AA = &getAnalysis<AAResultsWrapperPass>().getAAResults();
215       TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
216       DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
217       auto &TPC = getAnalysis<TargetPassConfig>();
218 
219       M = F.getParent();
220       DL = &M->getDataLayout();
221 
222       auto &TM = TPC.getTM<TargetMachine>();
223       auto *ST = &TM.getSubtarget<ARMSubtarget>(F);
224 
225       if (!ST->allowsUnalignedMem()) {
226         LLVM_DEBUG(dbgs() << "Unaligned memory access not supported: not "
227                              "running pass ARMParallelDSP\n");
228         return false;
229       }
230 
231       if (!ST->hasDSP()) {
232         LLVM_DEBUG(dbgs() << "DSP extension not enabled: not running pass "
233                              "ARMParallelDSP\n");
234         return false;
235       }
236 
237       if (!ST->isLittle()) {
238         LLVM_DEBUG(dbgs() << "Only supporting little endian: not running pass "
239                           << "ARMParallelDSP\n");
240         return false;
241       }
242 
243       LLVM_DEBUG(dbgs() << "\n== Parallel DSP pass ==\n");
244       LLVM_DEBUG(dbgs() << " - " << F.getName() << "\n\n");
245 
246       bool Changes = MatchSMLAD(F);
247       return Changes;
248     }
249   };
250 }
251 
252 template<typename MemInst>
253 static bool AreSequentialAccesses(MemInst *MemOp0, MemInst *MemOp1,
254                                   const DataLayout &DL, ScalarEvolution &SE) {
255   if (isConsecutiveAccess(MemOp0, MemOp1, DL, SE))
256     return true;
257   return false;
258 }
259 
260 bool ARMParallelDSP::AreSequentialLoads(LoadInst *Ld0, LoadInst *Ld1,
261                                         MemInstList &VecMem) {
262   if (!Ld0 || !Ld1)
263     return false;
264 
265   if (!LoadPairs.count(Ld0) || LoadPairs[Ld0] != Ld1)
266     return false;
267 
268   LLVM_DEBUG(dbgs() << "Loads are sequential and valid:\n";
269     dbgs() << "Ld0:"; Ld0->dump();
270     dbgs() << "Ld1:"; Ld1->dump();
271   );
272 
273   VecMem.clear();
274   VecMem.push_back(Ld0);
275   VecMem.push_back(Ld1);
276   return true;
277 }
278 
279 // MaxBitwidth: the maximum supported bitwidth of the elements in the DSP
280 // instructions, which is set to 16. So here we should collect all i8 and i16
281 // narrow operations.
282 // TODO: we currently only collect i16, and will support i8 later, so that's
283 // why we check that types are equal to MaxBitWidth, and not <= MaxBitWidth.
284 template<unsigned MaxBitWidth>
285 bool ARMParallelDSP::IsNarrowSequence(Value *V, Value *&Src) {
286   if (auto *SExt = dyn_cast<SExtInst>(V)) {
287     if (SExt->getSrcTy()->getIntegerBitWidth() != MaxBitWidth)
288       return false;
289 
290     if (auto *Ld = dyn_cast<LoadInst>(SExt->getOperand(0))) {
291       // Check that these load could be paired.
292       if (!LoadPairs.count(Ld) && !OffsetLoads.count(Ld))
293         return false;
294 
295       Src = Ld;
296       return true;
297     }
298   }
299   return false;
300 }
301 
302 /// Iterate through the block and record base, offset pairs of loads which can
303 /// be widened into a single load.
304 bool ARMParallelDSP::RecordMemoryOps(BasicBlock *BB) {
305   SmallVector<LoadInst*, 8> Loads;
306   SmallVector<Instruction*, 8> Writes;
307   LoadPairs.clear();
308   WideLoads.clear();
309 
310   // Collect loads and instruction that may write to memory. For now we only
311   // record loads which are simple, sign-extended and have a single user.
312   // TODO: Allow zero-extended loads.
313   for (auto &I : *BB) {
314     if (I.mayWriteToMemory())
315       Writes.push_back(&I);
316     auto *Ld = dyn_cast<LoadInst>(&I);
317     if (!Ld || !Ld->isSimple() ||
318         !Ld->hasOneUse() || !isa<SExtInst>(Ld->user_back()))
319       continue;
320     Loads.push_back(Ld);
321   }
322 
323   using InstSet = std::set<Instruction*>;
324   using DepMap = std::map<Instruction*, InstSet>;
325   DepMap RAWDeps;
326 
327   // Record any writes that may alias a load.
328   const auto Size = LocationSize::unknown();
329   for (auto Read : Loads) {
330     for (auto Write : Writes) {
331       MemoryLocation ReadLoc =
332         MemoryLocation(Read->getPointerOperand(), Size);
333 
334       if (!isModOrRefSet(intersectModRef(AA->getModRefInfo(Write, ReadLoc),
335           ModRefInfo::ModRef)))
336         continue;
337       if (DT->dominates(Write, Read))
338         RAWDeps[Read].insert(Write);
339     }
340   }
341 
342   // Check whether there's not a write between the two loads which would
343   // prevent them from being safely merged.
344   auto SafeToPair = [&](LoadInst *Base, LoadInst *Offset) {
345     LoadInst *Dominator = DT->dominates(Base, Offset) ? Base : Offset;
346     LoadInst *Dominated = DT->dominates(Base, Offset) ? Offset : Base;
347 
348     if (RAWDeps.count(Dominated)) {
349       InstSet &WritesBefore = RAWDeps[Dominated];
350 
351       for (auto Before : WritesBefore) {
352 
353         // We can't move the second load backward, past a write, to merge
354         // with the first load.
355         if (DT->dominates(Dominator, Before))
356           return false;
357       }
358     }
359     return true;
360   };
361 
362   // Record base, offset load pairs.
363   for (auto *Base : Loads) {
364     for (auto *Offset : Loads) {
365       if (Base == Offset)
366         continue;
367 
368       if (AreSequentialAccesses<LoadInst>(Base, Offset, *DL, *SE) &&
369           SafeToPair(Base, Offset)) {
370         LoadPairs[Base] = Offset;
371         OffsetLoads.insert(Offset);
372         break;
373       }
374     }
375   }
376 
377   LLVM_DEBUG(if (!LoadPairs.empty()) {
378                dbgs() << "Consecutive load pairs:\n";
379                for (auto &MapIt : LoadPairs) {
380                  LLVM_DEBUG(dbgs() << *MapIt.first << ", "
381                             << *MapIt.second << "\n");
382                }
383              });
384   return LoadPairs.size() > 1;
385 }
386 
387 // The pass needs to identify integer add/sub reductions of 16-bit vector
388 // multiplications.
389 // To use SMLAD:
390 // 1) we first need to find integer add then look for this pattern:
391 //
392 // acc0 = ...
393 // ld0 = load i16
394 // sext0 = sext i16 %ld0 to i32
395 // ld1 = load i16
396 // sext1 = sext i16 %ld1 to i32
397 // mul0 = mul %sext0, %sext1
398 // ld2 = load i16
399 // sext2 = sext i16 %ld2 to i32
400 // ld3 = load i16
401 // sext3 = sext i16 %ld3 to i32
402 // mul1 = mul i32 %sext2, %sext3
403 // add0 = add i32 %mul0, %acc0
404 // acc1 = add i32 %add0, %mul1
405 //
406 // Which can be selected to:
407 //
408 // ldr r0
409 // ldr r1
410 // smlad r2, r0, r1, r2
411 //
412 // If constants are used instead of loads, these will need to be hoisted
413 // out and into a register.
414 //
415 // If loop invariants are used instead of loads, these need to be packed
416 // before the loop begins.
417 //
418 bool ARMParallelDSP::MatchSMLAD(Function &F) {
419   // Search recursively back through the operands to find a tree of values that
420   // form a multiply-accumulate chain. The search records the Add and Mul
421   // instructions that form the reduction and allows us to find a single value
422   // to be used as the initial input to the accumlator.
423   std::function<bool(Value*, BasicBlock*, Reduction&)> Search = [&]
424     (Value *V, BasicBlock *BB, Reduction &R) -> bool {
425 
426     // If we find a non-instruction, try to use it as the initial accumulator
427     // value. This may have already been found during the search in which case
428     // this function will return false, signaling a search fail.
429     auto *I = dyn_cast<Instruction>(V);
430     if (!I)
431       return R.InsertAcc(V);
432 
433     if (I->getParent() != BB)
434       return false;
435 
436     switch (I->getOpcode()) {
437     default:
438       break;
439     case Instruction::PHI:
440       // Could be the accumulator value.
441       return R.InsertAcc(V);
442     case Instruction::Add: {
443       // Adds should be adding together two muls, or another add and a mul to
444       // be within the mac chain. One of the operands may also be the
445       // accumulator value at which point we should stop searching.
446       bool ValidLHS = Search(I->getOperand(0), BB, R);
447       bool ValidRHS = Search(I->getOperand(1), BB, R);
448       if (!ValidLHS && !ValidLHS)
449         return false;
450       else if (ValidLHS && ValidRHS) {
451         R.InsertAdd(I);
452         return true;
453       } else {
454         R.InsertAdd(I);
455         return R.InsertAcc(I);
456       }
457     }
458     case Instruction::Mul: {
459       Value *MulOp0 = I->getOperand(0);
460       Value *MulOp1 = I->getOperand(1);
461       if (isa<SExtInst>(MulOp0) && isa<SExtInst>(MulOp1)) {
462         Value *LHS = nullptr;
463         Value *RHS = nullptr;
464         if (IsNarrowSequence<16>(MulOp0, LHS) &&
465             IsNarrowSequence<16>(MulOp1, RHS)) {
466           R.InsertMul(I, LHS, RHS);
467           return true;
468         }
469       }
470       return false;
471     }
472     case Instruction::SExt:
473       return Search(I->getOperand(0), BB, R);
474     }
475     return false;
476   };
477 
478   bool Changed = false;
479 
480   for (auto &BB : F) {
481     SmallPtrSet<Instruction*, 4> AllAdds;
482     if (!RecordMemoryOps(&BB))
483       continue;
484 
485     for (Instruction &I : reverse(BB)) {
486       if (I.getOpcode() != Instruction::Add)
487         continue;
488 
489       if (AllAdds.count(&I))
490         continue;
491 
492       const auto *Ty = I.getType();
493       if (!Ty->isIntegerTy(32) && !Ty->isIntegerTy(64))
494         continue;
495 
496       Reduction R(&I);
497       if (!Search(&I, &BB, R))
498         continue;
499 
500       if (!CreateParallelPairs(R))
501         continue;
502 
503       InsertParallelMACs(R);
504       Changed = true;
505       AllAdds.insert(R.getAdds().begin(), R.getAdds().end());
506     }
507   }
508 
509   return Changed;
510 }
511 
512 bool ARMParallelDSP::CreateParallelPairs(Reduction &R) {
513 
514   // Not enough mul operations to make a pair.
515   if (R.getMuls().size() < 2)
516     return false;
517 
518   // Check that the muls operate directly upon sign extended loads.
519   for (auto &MulCand : R.getMuls()) {
520     if (!MulCand->HasTwoLoadInputs())
521       return false;
522   }
523 
524   auto CanPair = [&](Reduction &R, MulCandidate *PMul0, MulCandidate *PMul1) {
525     // The first elements of each vector should be loads with sexts. If we
526     // find that its two pairs of consecutive loads, then these can be
527     // transformed into two wider loads and the users can be replaced with
528     // DSP intrinsics.
529     auto Ld0 = static_cast<LoadInst*>(PMul0->LHS);
530     auto Ld1 = static_cast<LoadInst*>(PMul1->LHS);
531     auto Ld2 = static_cast<LoadInst*>(PMul0->RHS);
532     auto Ld3 = static_cast<LoadInst*>(PMul1->RHS);
533 
534     LLVM_DEBUG(dbgs() << "Loads:\n"
535                << " - " << *Ld0 << "\n"
536                << " - " << *Ld1 << "\n"
537                << " - " << *Ld2 << "\n"
538                << " - " << *Ld3 << "\n");
539 
540     if (AreSequentialLoads(Ld0, Ld1, PMul0->VecLd)) {
541       if (AreSequentialLoads(Ld2, Ld3, PMul1->VecLd)) {
542         LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n");
543         R.AddMulPair(PMul0, PMul1);
544         return true;
545       } else if (AreSequentialLoads(Ld3, Ld2, PMul1->VecLd)) {
546         LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n");
547         LLVM_DEBUG(dbgs() << "    exchanging Ld2 and Ld3\n");
548         PMul1->Exchange = true;
549         R.AddMulPair(PMul0, PMul1);
550         return true;
551       }
552     } else if (AreSequentialLoads(Ld1, Ld0, PMul0->VecLd) &&
553                AreSequentialLoads(Ld2, Ld3, PMul1->VecLd)) {
554       LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n");
555       LLVM_DEBUG(dbgs() << "    exchanging Ld0 and Ld1\n");
556       LLVM_DEBUG(dbgs() << "    and swapping muls\n");
557       PMul0->Exchange = true;
558       // Only the second operand can be exchanged, so swap the muls.
559       R.AddMulPair(PMul1, PMul0);
560       return true;
561     }
562     return false;
563   };
564 
565   MulCandList &Muls = R.getMuls();
566   const unsigned Elems = Muls.size();
567   SmallPtrSet<const Instruction*, 4> Paired;
568   for (unsigned i = 0; i < Elems; ++i) {
569     MulCandidate *PMul0 = static_cast<MulCandidate*>(Muls[i].get());
570     if (Paired.count(PMul0->Root))
571       continue;
572 
573     for (unsigned j = 0; j < Elems; ++j) {
574       if (i == j)
575         continue;
576 
577       MulCandidate *PMul1 = static_cast<MulCandidate*>(Muls[j].get());
578       if (Paired.count(PMul1->Root))
579         continue;
580 
581       const Instruction *Mul0 = PMul0->Root;
582       const Instruction *Mul1 = PMul1->Root;
583       if (Mul0 == Mul1)
584         continue;
585 
586       assert(PMul0 != PMul1 && "expected different chains");
587 
588       if (CanPair(R, PMul0, PMul1)) {
589         Paired.insert(Mul0);
590         Paired.insert(Mul1);
591         break;
592       }
593     }
594   }
595   return !R.getMulPairs().empty();
596 }
597 
598 
599 void ARMParallelDSP::InsertParallelMACs(Reduction &R) {
600 
601   auto CreateSMLAD = [&](LoadInst* WideLd0, LoadInst *WideLd1,
602                          Value *Acc, bool Exchange,
603                          Instruction *InsertAfter) {
604     // Replace the reduction chain with an intrinsic call
605 
606     Value* Args[] = { WideLd0, WideLd1, Acc };
607     Function *SMLAD = nullptr;
608     if (Exchange)
609       SMLAD = Acc->getType()->isIntegerTy(32) ?
610         Intrinsic::getDeclaration(M, Intrinsic::arm_smladx) :
611         Intrinsic::getDeclaration(M, Intrinsic::arm_smlaldx);
612     else
613       SMLAD = Acc->getType()->isIntegerTy(32) ?
614         Intrinsic::getDeclaration(M, Intrinsic::arm_smlad) :
615         Intrinsic::getDeclaration(M, Intrinsic::arm_smlald);
616 
617     IRBuilder<NoFolder> Builder(InsertAfter->getParent(),
618                                 ++BasicBlock::iterator(InsertAfter));
619     Instruction *Call = Builder.CreateCall(SMLAD, Args);
620     NumSMLAD++;
621     return Call;
622   };
623 
624   Instruction *InsertAfter = R.getRoot();
625   Value *Acc = R.getAccumulator();
626   if (!Acc)
627     Acc = ConstantInt::get(IntegerType::get(M->getContext(), 32), 0);
628 
629   IntegerType *Ty = IntegerType::get(M->getContext(), 32);
630   LLVM_DEBUG(dbgs() << "Root: " << *InsertAfter << "\n"
631              << "Acc: " << *Acc << "\n");
632   for (auto &Pair : R.getMulPairs()) {
633     MulCandidate *LHSMul = Pair.first;
634     MulCandidate *RHSMul = Pair.second;
635     LLVM_DEBUG(dbgs() << "Muls:\n"
636                << "- " << *LHSMul->Root << "\n"
637                << "- " << *RHSMul->Root << "\n");
638     LoadInst *BaseLHS = LHSMul->getBaseLoad();
639     LoadInst *BaseRHS = RHSMul->getBaseLoad();
640     LoadInst *WideLHS = WideLoads.count(BaseLHS) ?
641       WideLoads[BaseLHS]->getLoad() : CreateWideLoad(LHSMul->VecLd, Ty);
642     LoadInst *WideRHS = WideLoads.count(BaseRHS) ?
643       WideLoads[BaseRHS]->getLoad() : CreateWideLoad(RHSMul->VecLd, Ty);
644 
645     Acc = CreateSMLAD(WideLHS, WideRHS, Acc, RHSMul->Exchange, InsertAfter);
646     InsertAfter = cast<Instruction>(Acc);
647   }
648   R.UpdateRoot(cast<Instruction>(Acc));
649 }
650 
651 LoadInst* ARMParallelDSP::CreateWideLoad(SmallVectorImpl<LoadInst*> &Loads,
652                                          IntegerType *LoadTy) {
653   assert(Loads.size() == 2 && "currently only support widening two loads");
654 
655   LoadInst *Base = Loads[0];
656   LoadInst *Offset = Loads[1];
657 
658   Instruction *BaseSExt = dyn_cast<SExtInst>(Base->user_back());
659   Instruction *OffsetSExt = dyn_cast<SExtInst>(Offset->user_back());
660 
661   assert((BaseSExt && OffsetSExt)
662          && "Loads should have a single, extending, user");
663 
664   std::function<void(Value*, Value*)> MoveBefore =
665     [&](Value *A, Value *B) -> void {
666       if (!isa<Instruction>(A) || !isa<Instruction>(B))
667         return;
668 
669       auto *Source = cast<Instruction>(A);
670       auto *Sink = cast<Instruction>(B);
671 
672       if (DT->dominates(Source, Sink) ||
673           Source->getParent() != Sink->getParent() ||
674           isa<PHINode>(Source) || isa<PHINode>(Sink))
675         return;
676 
677       Source->moveBefore(Sink);
678       for (auto &Op : Source->operands())
679         MoveBefore(Op, Source);
680     };
681 
682   // Insert the load at the point of the original dominating load.
683   LoadInst *DomLoad = DT->dominates(Base, Offset) ? Base : Offset;
684   IRBuilder<NoFolder> IRB(DomLoad->getParent(),
685                           ++BasicBlock::iterator(DomLoad));
686 
687   // Bitcast the pointer to a wider type and create the wide load, while making
688   // sure to maintain the original alignment as this prevents ldrd from being
689   // generated when it could be illegal due to memory alignment.
690   const unsigned AddrSpace = DomLoad->getPointerAddressSpace();
691   Value *VecPtr = IRB.CreateBitCast(Base->getPointerOperand(),
692                                     LoadTy->getPointerTo(AddrSpace));
693   LoadInst *WideLoad = IRB.CreateAlignedLoad(LoadTy, VecPtr,
694                                              Base->getAlignment());
695 
696   // Make sure everything is in the correct order in the basic block.
697   MoveBefore(Base->getPointerOperand(), VecPtr);
698   MoveBefore(VecPtr, WideLoad);
699 
700   // From the wide load, create two values that equal the original two loads.
701   // Loads[0] needs trunc while Loads[1] needs a lshr and trunc.
702   // TODO: Support big-endian as well.
703   Value *Bottom = IRB.CreateTrunc(WideLoad, Base->getType());
704   BaseSExt->setOperand(0, Bottom);
705 
706   IntegerType *OffsetTy = cast<IntegerType>(Offset->getType());
707   Value *ShiftVal = ConstantInt::get(LoadTy, OffsetTy->getBitWidth());
708   Value *Top = IRB.CreateLShr(WideLoad, ShiftVal);
709   Value *Trunc = IRB.CreateTrunc(Top, OffsetTy);
710   OffsetSExt->setOperand(0, Trunc);
711 
712   WideLoads.emplace(std::make_pair(Base,
713                                    make_unique<WidenedLoad>(Loads, WideLoad)));
714   return WideLoad;
715 }
716 
717 Pass *llvm::createARMParallelDSPPass() {
718   return new ARMParallelDSP();
719 }
720 
721 char ARMParallelDSP::ID = 0;
722 
723 INITIALIZE_PASS_BEGIN(ARMParallelDSP, "arm-parallel-dsp",
724                 "Transform functions to use DSP intrinsics", false, false)
725 INITIALIZE_PASS_END(ARMParallelDSP, "arm-parallel-dsp",
726                 "Transform functions to use DSP intrinsics", false, false)
727