1 //===- ParallelDSP.cpp - Parallel DSP Pass --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file
10 /// Armv6 introduced instructions to perform 32-bit SIMD operations. The
11 /// purpose of this pass is do some IR pattern matching to create ACLE
12 /// DSP intrinsics, which map on these 32-bit SIMD operations.
13 /// This pass runs only when unaligned accesses is supported/enabled.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "llvm/ADT/Statistic.h"
18 #include "llvm/ADT/SmallPtrSet.h"
19 #include "llvm/Analysis/AliasAnalysis.h"
20 #include "llvm/Analysis/LoopAccessAnalysis.h"
21 #include "llvm/Analysis/LoopPass.h"
22 #include "llvm/Analysis/LoopInfo.h"
23 #include "llvm/IR/Instructions.h"
24 #include "llvm/IR/NoFolder.h"
25 #include "llvm/Transforms/Scalar.h"
26 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
27 #include "llvm/Transforms/Utils/LoopUtils.h"
28 #include "llvm/Pass.h"
29 #include "llvm/PassRegistry.h"
30 #include "llvm/PassSupport.h"
31 #include "llvm/Support/Debug.h"
32 #include "llvm/IR/PatternMatch.h"
33 #include "llvm/CodeGen/TargetPassConfig.h"
34 #include "ARM.h"
35 #include "ARMSubtarget.h"
36 
37 using namespace llvm;
38 using namespace PatternMatch;
39 
40 #define DEBUG_TYPE "arm-parallel-dsp"
41 
42 STATISTIC(NumSMLAD , "Number of smlad instructions generated");
43 
44 static cl::opt<bool>
45 DisableParallelDSP("disable-arm-parallel-dsp", cl::Hidden, cl::init(false),
46                    cl::desc("Disable the ARM Parallel DSP pass"));
47 
48 namespace {
49   struct MulCandidate;
50   class Reduction;
51 
52   using MulCandList = SmallVector<std::unique_ptr<MulCandidate>, 8>;
53   using MemInstList = SmallVectorImpl<LoadInst*>;
54   using MulPairList = SmallVector<std::pair<MulCandidate*, MulCandidate*>, 8>;
55 
56   // 'MulCandidate' holds the multiplication instructions that are candidates
57   // for parallel execution.
58   struct MulCandidate {
59     Instruction   *Root;
60     Value*        LHS;
61     Value*        RHS;
62     bool          Exchange = false;
63     bool          ReadOnly = true;
64     SmallVector<LoadInst*, 2> VecLd;    // Container for loads to widen.
65 
66     MulCandidate(Instruction *I, Value *lhs, Value *rhs) :
67       Root(I), LHS(lhs), RHS(rhs) { }
68 
69     bool HasTwoLoadInputs() const {
70       return isa<LoadInst>(LHS) && isa<LoadInst>(RHS);
71     }
72 
73     LoadInst *getBaseLoad() const {
74       return cast<LoadInst>(LHS);
75     }
76   };
77 
78   /// Represent a sequence of multiply-accumulate operations with the aim to
79   /// perform the multiplications in parallel.
80   class Reduction {
81     Instruction     *Root = nullptr;
82     Value           *Acc = nullptr;
83     MulCandList     Muls;
84     MulPairList        MulPairs;
85     SmallPtrSet<Instruction*, 4> Adds;
86 
87   public:
88     Reduction() = delete;
89 
90     Reduction (Instruction *Add) : Root(Add) { }
91 
92     /// Record an Add instruction that is a part of the this reduction.
93     void InsertAdd(Instruction *I) { Adds.insert(I); }
94 
95     /// Record a MulCandidate, rooted at a Mul instruction, that is a part of
96     /// this reduction.
97     void InsertMul(Instruction *I, Value *LHS, Value *RHS) {
98       Muls.push_back(std::make_unique<MulCandidate>(I, LHS, RHS));
99     }
100 
101     /// Add the incoming accumulator value, returns true if a value had not
102     /// already been added. Returning false signals to the user that this
103     /// reduction already has a value to initialise the accumulator.
104     bool InsertAcc(Value *V) {
105       if (Acc)
106         return false;
107       Acc = V;
108       return true;
109     }
110 
111     /// Set two MulCandidates, rooted at muls, that can be executed as a single
112     /// parallel operation.
113     void AddMulPair(MulCandidate *Mul0, MulCandidate *Mul1) {
114       MulPairs.push_back(std::make_pair(Mul0, Mul1));
115     }
116 
117     /// Return true if enough mul operations are found that can be executed in
118     /// parallel.
119     bool CreateParallelPairs();
120 
121     /// Return the add instruction which is the root of the reduction.
122     Instruction *getRoot() { return Root; }
123 
124     bool is64Bit() const { return Root->getType()->isIntegerTy(64); }
125 
126     /// Return the incoming value to be accumulated. This maybe null.
127     Value *getAccumulator() { return Acc; }
128 
129     /// Return the set of adds that comprise the reduction.
130     SmallPtrSetImpl<Instruction*> &getAdds() { return Adds; }
131 
132     /// Return the MulCandidate, rooted at mul instruction, that comprise the
133     /// the reduction.
134     MulCandList &getMuls() { return Muls; }
135 
136     /// Return the MulCandidate, rooted at mul instructions, that have been
137     /// paired for parallel execution.
138     MulPairList &getMulPairs() { return MulPairs; }
139 
140     /// To finalise, replace the uses of the root with the intrinsic call.
141     void UpdateRoot(Instruction *SMLAD) {
142       Root->replaceAllUsesWith(SMLAD);
143     }
144   };
145 
146   class WidenedLoad {
147     LoadInst *NewLd = nullptr;
148     SmallVector<LoadInst*, 4> Loads;
149 
150   public:
151     WidenedLoad(SmallVectorImpl<LoadInst*> &Lds, LoadInst *Wide)
152       : NewLd(Wide) {
153       for (auto *I : Lds)
154         Loads.push_back(I);
155     }
156     LoadInst *getLoad() {
157       return NewLd;
158     }
159   };
160 
161   class ARMParallelDSP : public LoopPass {
162     ScalarEvolution   *SE;
163     AliasAnalysis     *AA;
164     TargetLibraryInfo *TLI;
165     DominatorTree     *DT;
166     LoopInfo          *LI;
167     Loop              *L;
168     const DataLayout  *DL;
169     Module            *M;
170     std::map<LoadInst*, LoadInst*> LoadPairs;
171     SmallPtrSet<LoadInst*, 4> OffsetLoads;
172     std::map<LoadInst*, std::unique_ptr<WidenedLoad>> WideLoads;
173 
174     template<unsigned>
175     bool IsNarrowSequence(Value *V, Value *&Src);
176 
177     bool RecordMemoryOps(BasicBlock *BB);
178     void InsertParallelMACs(Reduction &Reduction);
179     bool AreSequentialLoads(LoadInst *Ld0, LoadInst *Ld1, MemInstList &VecMem);
180     LoadInst* CreateWideLoad(MemInstList &Loads, IntegerType *LoadTy);
181     bool CreateParallelPairs(Reduction &R);
182 
183     /// Try to match and generate: SMLAD, SMLADX - Signed Multiply Accumulate
184     /// Dual performs two signed 16x16-bit multiplications. It adds the
185     /// products to a 32-bit accumulate operand. Optionally, the instruction can
186     /// exchange the halfwords of the second operand before performing the
187     /// arithmetic.
188     bool MatchSMLAD(Loop *L);
189 
190   public:
191     static char ID;
192 
193     ARMParallelDSP() : LoopPass(ID) { }
194 
195     bool doInitialization(Loop *L, LPPassManager &LPM) override {
196       LoadPairs.clear();
197       WideLoads.clear();
198       return true;
199     }
200 
201     void getAnalysisUsage(AnalysisUsage &AU) const override {
202       LoopPass::getAnalysisUsage(AU);
203       AU.addRequired<AssumptionCacheTracker>();
204       AU.addRequired<ScalarEvolutionWrapperPass>();
205       AU.addRequired<AAResultsWrapperPass>();
206       AU.addRequired<TargetLibraryInfoWrapperPass>();
207       AU.addRequired<LoopInfoWrapperPass>();
208       AU.addRequired<DominatorTreeWrapperPass>();
209       AU.addRequired<TargetPassConfig>();
210       AU.addPreserved<LoopInfoWrapperPass>();
211       AU.setPreservesCFG();
212     }
213 
214     bool runOnLoop(Loop *TheLoop, LPPassManager &) override {
215       if (DisableParallelDSP)
216         return false;
217       if (skipLoop(TheLoop))
218         return false;
219 
220       L = TheLoop;
221       SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
222       AA = &getAnalysis<AAResultsWrapperPass>().getAAResults();
223       TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
224       DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
225       LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
226       auto &TPC = getAnalysis<TargetPassConfig>();
227 
228       BasicBlock *Header = TheLoop->getHeader();
229       if (!Header)
230         return false;
231 
232       // TODO: We assume the loop header and latch to be the same block.
233       // This is not a fundamental restriction, but lifting this would just
234       // require more work to do the transformation and then patch up the CFG.
235       if (Header != TheLoop->getLoopLatch()) {
236         LLVM_DEBUG(dbgs() << "The loop header is not the loop latch: not "
237                              "running pass ARMParallelDSP\n");
238         return false;
239       }
240 
241       if (!TheLoop->getLoopPreheader())
242         InsertPreheaderForLoop(L, DT, LI, nullptr, true);
243 
244       Function &F = *Header->getParent();
245       M = F.getParent();
246       DL = &M->getDataLayout();
247 
248       auto &TM = TPC.getTM<TargetMachine>();
249       auto *ST = &TM.getSubtarget<ARMSubtarget>(F);
250 
251       if (!ST->allowsUnalignedMem()) {
252         LLVM_DEBUG(dbgs() << "Unaligned memory access not supported: not "
253                              "running pass ARMParallelDSP\n");
254         return false;
255       }
256 
257       if (!ST->hasDSP()) {
258         LLVM_DEBUG(dbgs() << "DSP extension not enabled: not running pass "
259                              "ARMParallelDSP\n");
260         return false;
261       }
262 
263       if (!ST->isLittle()) {
264         LLVM_DEBUG(dbgs() << "Only supporting little endian: not running pass "
265                           << "ARMParallelDSP\n");
266         return false;
267       }
268 
269       LoopAccessInfo LAI(L, SE, TLI, AA, DT, LI);
270 
271       LLVM_DEBUG(dbgs() << "\n== Parallel DSP pass ==\n");
272       LLVM_DEBUG(dbgs() << " - " << F.getName() << "\n\n");
273 
274       if (!RecordMemoryOps(Header)) {
275         LLVM_DEBUG(dbgs() << " - No sequential loads found.\n");
276         return false;
277       }
278 
279       bool Changes = MatchSMLAD(L);
280       return Changes;
281     }
282   };
283 }
284 
285 template<typename MemInst>
286 static bool AreSequentialAccesses(MemInst *MemOp0, MemInst *MemOp1,
287                                   const DataLayout &DL, ScalarEvolution &SE) {
288   if (isConsecutiveAccess(MemOp0, MemOp1, DL, SE))
289     return true;
290   return false;
291 }
292 
293 bool ARMParallelDSP::AreSequentialLoads(LoadInst *Ld0, LoadInst *Ld1,
294                                         MemInstList &VecMem) {
295   if (!Ld0 || !Ld1)
296     return false;
297 
298   if (!LoadPairs.count(Ld0) || LoadPairs[Ld0] != Ld1)
299     return false;
300 
301   LLVM_DEBUG(dbgs() << "Loads are sequential and valid:\n";
302     dbgs() << "Ld0:"; Ld0->dump();
303     dbgs() << "Ld1:"; Ld1->dump();
304   );
305 
306   VecMem.clear();
307   VecMem.push_back(Ld0);
308   VecMem.push_back(Ld1);
309   return true;
310 }
311 
312 // MaxBitwidth: the maximum supported bitwidth of the elements in the DSP
313 // instructions, which is set to 16. So here we should collect all i8 and i16
314 // narrow operations.
315 // TODO: we currently only collect i16, and will support i8 later, so that's
316 // why we check that types are equal to MaxBitWidth, and not <= MaxBitWidth.
317 template<unsigned MaxBitWidth>
318 bool ARMParallelDSP::IsNarrowSequence(Value *V, Value *&Src) {
319   if (auto *SExt = dyn_cast<SExtInst>(V)) {
320     if (SExt->getSrcTy()->getIntegerBitWidth() != MaxBitWidth)
321       return false;
322 
323     if (auto *Ld = dyn_cast<LoadInst>(SExt->getOperand(0))) {
324       // Check that these load could be paired.
325       if (!LoadPairs.count(Ld) && !OffsetLoads.count(Ld))
326         return false;
327 
328       Src = Ld;
329       return true;
330     }
331   }
332   return false;
333 }
334 
335 /// Iterate through the block and record base, offset pairs of loads which can
336 /// be widened into a single load.
337 bool ARMParallelDSP::RecordMemoryOps(BasicBlock *BB) {
338   SmallVector<LoadInst*, 8> Loads;
339   SmallVector<Instruction*, 8> Writes;
340 
341   // Collect loads and instruction that may write to memory. For now we only
342   // record loads which are simple, sign-extended and have a single user.
343   // TODO: Allow zero-extended loads.
344   for (auto &I : *BB) {
345     if (I.mayWriteToMemory())
346       Writes.push_back(&I);
347     auto *Ld = dyn_cast<LoadInst>(&I);
348     if (!Ld || !Ld->isSimple() ||
349         !Ld->hasOneUse() || !isa<SExtInst>(Ld->user_back()))
350       continue;
351     Loads.push_back(Ld);
352   }
353 
354   using InstSet = std::set<Instruction*>;
355   using DepMap = std::map<Instruction*, InstSet>;
356   DepMap RAWDeps;
357 
358   // Record any writes that may alias a load.
359   const auto Size = LocationSize::unknown();
360   for (auto Read : Loads) {
361     for (auto Write : Writes) {
362       MemoryLocation ReadLoc =
363         MemoryLocation(Read->getPointerOperand(), Size);
364 
365       if (!isModOrRefSet(intersectModRef(AA->getModRefInfo(Write, ReadLoc),
366           ModRefInfo::ModRef)))
367         continue;
368       if (DT->dominates(Write, Read))
369         RAWDeps[Read].insert(Write);
370     }
371   }
372 
373   // Check whether there's not a write between the two loads which would
374   // prevent them from being safely merged.
375   auto SafeToPair = [&](LoadInst *Base, LoadInst *Offset) {
376     LoadInst *Dominator = DT->dominates(Base, Offset) ? Base : Offset;
377     LoadInst *Dominated = DT->dominates(Base, Offset) ? Offset : Base;
378 
379     if (RAWDeps.count(Dominated)) {
380       InstSet &WritesBefore = RAWDeps[Dominated];
381 
382       for (auto Before : WritesBefore) {
383         // We can't move the second load backward, past a write, to merge
384         // with the first load.
385         if (DT->dominates(Dominator, Before))
386           return false;
387       }
388     }
389     return true;
390   };
391 
392   // Record base, offset load pairs.
393   for (auto *Base : Loads) {
394     for (auto *Offset : Loads) {
395       if (Base == Offset)
396         continue;
397 
398       if (AreSequentialAccesses<LoadInst>(Base, Offset, *DL, *SE) &&
399           SafeToPair(Base, Offset)) {
400         LoadPairs[Base] = Offset;
401         OffsetLoads.insert(Offset);
402         break;
403       }
404     }
405   }
406 
407   LLVM_DEBUG(if (!LoadPairs.empty()) {
408                dbgs() << "Consecutive load pairs:\n";
409                for (auto &MapIt : LoadPairs) {
410                  LLVM_DEBUG(dbgs() << *MapIt.first << ", "
411                             << *MapIt.second << "\n");
412                }
413              });
414   return LoadPairs.size() > 1;
415 }
416 
417 // Loop Pass that needs to identify integer add/sub reductions of 16-bit vector
418 // multiplications.
419 // To use SMLAD:
420 // 1) we first need to find integer add then look for this pattern:
421 //
422 // acc0 = ...
423 // ld0 = load i16
424 // sext0 = sext i16 %ld0 to i32
425 // ld1 = load i16
426 // sext1 = sext i16 %ld1 to i32
427 // mul0 = mul %sext0, %sext1
428 // ld2 = load i16
429 // sext2 = sext i16 %ld2 to i32
430 // ld3 = load i16
431 // sext3 = sext i16 %ld3 to i32
432 // mul1 = mul i32 %sext2, %sext3
433 // add0 = add i32 %mul0, %acc0
434 // acc1 = add i32 %add0, %mul1
435 //
436 // Which can be selected to:
437 //
438 // ldr r0
439 // ldr r1
440 // smlad r2, r0, r1, r2
441 //
442 // If constants are used instead of loads, these will need to be hoisted
443 // out and into a register.
444 //
445 // If loop invariants are used instead of loads, these need to be packed
446 // before the loop begins.
447 //
448 bool ARMParallelDSP::MatchSMLAD(Loop *L) {
449   // Search recursively back through the operands to find a tree of values that
450   // form a multiply-accumulate chain. The search records the Add and Mul
451   // instructions that form the reduction and allows us to find a single value
452   // to be used as the initial input to the accumlator.
453   std::function<bool(Value*, Reduction&)> Search = [&]
454     (Value *V, Reduction &R) -> bool {
455 
456     // If we find a non-instruction, try to use it as the initial accumulator
457     // value. This may have already been found during the search in which case
458     // this function will return false, signaling a search fail.
459     auto *I = dyn_cast<Instruction>(V);
460     if (!I)
461       return R.InsertAcc(V);
462 
463     switch (I->getOpcode()) {
464     default:
465       break;
466     case Instruction::PHI:
467       // Could be the accumulator value.
468       return R.InsertAcc(V);
469     case Instruction::Add: {
470       // Adds should be adding together two muls, or another add and a mul to
471       // be within the mac chain. One of the operands may also be the
472       // accumulator value at which point we should stop searching.
473       bool ValidLHS = Search(I->getOperand(0), R);
474       bool ValidRHS = Search(I->getOperand(1), R);
475       if (!ValidLHS && !ValidLHS)
476         return false;
477       else if (ValidLHS && ValidRHS) {
478         R.InsertAdd(I);
479         return true;
480       } else {
481         R.InsertAdd(I);
482         return R.InsertAcc(I);
483       }
484     }
485     case Instruction::Mul: {
486       Value *MulOp0 = I->getOperand(0);
487       Value *MulOp1 = I->getOperand(1);
488       if (isa<SExtInst>(MulOp0) && isa<SExtInst>(MulOp1)) {
489         Value *LHS = nullptr;
490         Value *RHS = nullptr;
491         if (IsNarrowSequence<16>(MulOp0, LHS) &&
492             IsNarrowSequence<16>(MulOp1, RHS)) {
493           R.InsertMul(I, LHS, RHS);
494           return true;
495         }
496       }
497       return false;
498     }
499     case Instruction::SExt:
500       return Search(I->getOperand(0), R);
501     }
502     return false;
503   };
504 
505   bool Changed = false;
506   SmallPtrSet<Instruction*, 4> AllAdds;
507   BasicBlock *Latch = L->getLoopLatch();
508 
509   for (Instruction &I : reverse(*Latch)) {
510     if (I.getOpcode() != Instruction::Add)
511       continue;
512 
513     if (AllAdds.count(&I))
514       continue;
515 
516     const auto *Ty = I.getType();
517     if (!Ty->isIntegerTy(32) && !Ty->isIntegerTy(64))
518       continue;
519 
520     Reduction R(&I);
521     if (!Search(&I, R))
522       continue;
523 
524     if (!CreateParallelPairs(R))
525       continue;
526 
527     InsertParallelMACs(R);
528     Changed = true;
529     AllAdds.insert(R.getAdds().begin(), R.getAdds().end());
530   }
531 
532   return Changed;
533 }
534 
535 bool ARMParallelDSP::CreateParallelPairs(Reduction &R) {
536 
537   // Not enough mul operations to make a pair.
538   if (R.getMuls().size() < 2)
539     return false;
540 
541   // Check that the muls operate directly upon sign extended loads.
542   for (auto &MulCand : R.getMuls()) {
543     if (!MulCand->HasTwoLoadInputs())
544       return false;
545   }
546 
547   auto CanPair = [&](Reduction &R, MulCandidate *PMul0, MulCandidate *PMul1) {
548     // The first elements of each vector should be loads with sexts. If we
549     // find that its two pairs of consecutive loads, then these can be
550     // transformed into two wider loads and the users can be replaced with
551     // DSP intrinsics.
552     auto Ld0 = static_cast<LoadInst*>(PMul0->LHS);
553     auto Ld1 = static_cast<LoadInst*>(PMul1->LHS);
554     auto Ld2 = static_cast<LoadInst*>(PMul0->RHS);
555     auto Ld3 = static_cast<LoadInst*>(PMul1->RHS);
556 
557     LLVM_DEBUG(dbgs() << "Loads:\n"
558                << " - " << *Ld0 << "\n"
559                << " - " << *Ld1 << "\n"
560                << " - " << *Ld2 << "\n"
561                << " - " << *Ld3 << "\n");
562 
563     if (AreSequentialLoads(Ld0, Ld1, PMul0->VecLd)) {
564       if (AreSequentialLoads(Ld2, Ld3, PMul1->VecLd)) {
565         LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n");
566         R.AddMulPair(PMul0, PMul1);
567         return true;
568       } else if (AreSequentialLoads(Ld3, Ld2, PMul1->VecLd)) {
569         LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n");
570         LLVM_DEBUG(dbgs() << "    exchanging Ld2 and Ld3\n");
571         PMul1->Exchange = true;
572         R.AddMulPair(PMul0, PMul1);
573         return true;
574       }
575     } else if (AreSequentialLoads(Ld1, Ld0, PMul0->VecLd) &&
576                AreSequentialLoads(Ld2, Ld3, PMul1->VecLd)) {
577       LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n");
578       LLVM_DEBUG(dbgs() << "    exchanging Ld0 and Ld1\n");
579       LLVM_DEBUG(dbgs() << "    and swapping muls\n");
580       PMul0->Exchange = true;
581       // Only the second operand can be exchanged, so swap the muls.
582       R.AddMulPair(PMul1, PMul0);
583       return true;
584     }
585     return false;
586   };
587 
588   MulCandList &Muls = R.getMuls();
589   const unsigned Elems = Muls.size();
590   SmallPtrSet<const Instruction*, 4> Paired;
591   for (unsigned i = 0; i < Elems; ++i) {
592     MulCandidate *PMul0 = static_cast<MulCandidate*>(Muls[i].get());
593     if (Paired.count(PMul0->Root))
594       continue;
595 
596     for (unsigned j = 0; j < Elems; ++j) {
597       if (i == j)
598         continue;
599 
600       MulCandidate *PMul1 = static_cast<MulCandidate*>(Muls[j].get());
601       if (Paired.count(PMul1->Root))
602         continue;
603 
604       const Instruction *Mul0 = PMul0->Root;
605       const Instruction *Mul1 = PMul1->Root;
606       if (Mul0 == Mul1)
607         continue;
608 
609       assert(PMul0 != PMul1 && "expected different chains");
610 
611       if (CanPair(R, PMul0, PMul1)) {
612         Paired.insert(Mul0);
613         Paired.insert(Mul1);
614         break;
615       }
616     }
617   }
618   return !R.getMulPairs().empty();
619 }
620 
621 
622 void ARMParallelDSP::InsertParallelMACs(Reduction &R) {
623 
624   auto CreateSMLAD = [&](LoadInst* WideLd0, LoadInst *WideLd1,
625                          Value *Acc, bool Exchange,
626                          Instruction *InsertAfter) {
627     // Replace the reduction chain with an intrinsic call
628 
629     Value* Args[] = { WideLd0, WideLd1, Acc };
630     Function *SMLAD = nullptr;
631     if (Exchange)
632       SMLAD = Acc->getType()->isIntegerTy(32) ?
633         Intrinsic::getDeclaration(M, Intrinsic::arm_smladx) :
634         Intrinsic::getDeclaration(M, Intrinsic::arm_smlaldx);
635     else
636       SMLAD = Acc->getType()->isIntegerTy(32) ?
637         Intrinsic::getDeclaration(M, Intrinsic::arm_smlad) :
638         Intrinsic::getDeclaration(M, Intrinsic::arm_smlald);
639 
640     IRBuilder<NoFolder> Builder(InsertAfter->getParent(),
641                                 ++BasicBlock::iterator(InsertAfter));
642     Instruction *Call = Builder.CreateCall(SMLAD, Args);
643     NumSMLAD++;
644     return Call;
645   };
646 
647   Instruction *InsertAfter = R.getRoot();
648   Value *Acc = R.getAccumulator();
649   if (!Acc)
650     Acc = ConstantInt::get(IntegerType::get(M->getContext(), 32), 0);
651 
652   IntegerType *Ty = IntegerType::get(M->getContext(), 32);
653   LLVM_DEBUG(dbgs() << "Root: " << *InsertAfter << "\n"
654              << "Acc: " << *Acc << "\n");
655   for (auto &Pair : R.getMulPairs()) {
656     MulCandidate *LHSMul = Pair.first;
657     MulCandidate *RHSMul = Pair.second;
658     LLVM_DEBUG(dbgs() << "Muls:\n"
659                << "- " << *LHSMul->Root << "\n"
660                << "- " << *RHSMul->Root << "\n");
661     LoadInst *BaseLHS = LHSMul->getBaseLoad();
662     LoadInst *BaseRHS = RHSMul->getBaseLoad();
663     LoadInst *WideLHS = WideLoads.count(BaseLHS) ?
664       WideLoads[BaseLHS]->getLoad() : CreateWideLoad(LHSMul->VecLd, Ty);
665     LoadInst *WideRHS = WideLoads.count(BaseRHS) ?
666       WideLoads[BaseRHS]->getLoad() : CreateWideLoad(RHSMul->VecLd, Ty);
667 
668     Acc = CreateSMLAD(WideLHS, WideRHS, Acc, RHSMul->Exchange, InsertAfter);
669     InsertAfter = cast<Instruction>(Acc);
670   }
671   R.UpdateRoot(cast<Instruction>(Acc));
672 }
673 
674 LoadInst* ARMParallelDSP::CreateWideLoad(MemInstList &Loads,
675                                          IntegerType *LoadTy) {
676   assert(Loads.size() == 2 && "currently only support widening two loads");
677 
678   LoadInst *Base = Loads[0];
679   LoadInst *Offset = Loads[1];
680 
681   Instruction *BaseSExt = dyn_cast<SExtInst>(Base->user_back());
682   Instruction *OffsetSExt = dyn_cast<SExtInst>(Offset->user_back());
683 
684   assert((BaseSExt && OffsetSExt)
685          && "Loads should have a single, extending, user");
686 
687   std::function<void(Value*, Value*)> MoveBefore =
688     [&](Value *A, Value *B) -> void {
689       if (!isa<Instruction>(A) || !isa<Instruction>(B))
690         return;
691 
692       auto *Source = cast<Instruction>(A);
693       auto *Sink = cast<Instruction>(B);
694 
695       if (DT->dominates(Source, Sink) ||
696           Source->getParent() != Sink->getParent() ||
697           isa<PHINode>(Source) || isa<PHINode>(Sink))
698         return;
699 
700       Source->moveBefore(Sink);
701       for (auto &Op : Source->operands())
702         MoveBefore(Op, Source);
703     };
704 
705   // Insert the load at the point of the original dominating load.
706   LoadInst *DomLoad = DT->dominates(Base, Offset) ? Base : Offset;
707   IRBuilder<NoFolder> IRB(DomLoad->getParent(),
708                           ++BasicBlock::iterator(DomLoad));
709 
710   // Bitcast the pointer to a wider type and create the wide load, while making
711   // sure to maintain the original alignment as this prevents ldrd from being
712   // generated when it could be illegal due to memory alignment.
713   const unsigned AddrSpace = DomLoad->getPointerAddressSpace();
714   Value *VecPtr = IRB.CreateBitCast(Base->getPointerOperand(),
715                                     LoadTy->getPointerTo(AddrSpace));
716   LoadInst *WideLoad = IRB.CreateAlignedLoad(LoadTy, VecPtr,
717                                              Base->getAlignment());
718 
719   // Make sure everything is in the correct order in the basic block.
720   MoveBefore(Base->getPointerOperand(), VecPtr);
721   MoveBefore(VecPtr, WideLoad);
722 
723   // From the wide load, create two values that equal the original two loads.
724   // Loads[0] needs trunc while Loads[1] needs a lshr and trunc.
725   // TODO: Support big-endian as well.
726   Value *Bottom = IRB.CreateTrunc(WideLoad, Base->getType());
727   BaseSExt->setOperand(0, Bottom);
728 
729   IntegerType *OffsetTy = cast<IntegerType>(Offset->getType());
730   Value *ShiftVal = ConstantInt::get(LoadTy, OffsetTy->getBitWidth());
731   Value *Top = IRB.CreateLShr(WideLoad, ShiftVal);
732   Value *Trunc = IRB.CreateTrunc(Top, OffsetTy);
733   OffsetSExt->setOperand(0, Trunc);
734 
735   WideLoads.emplace(std::make_pair(Base,
736                                    std::make_unique<WidenedLoad>(Loads, WideLoad)));
737   return WideLoad;
738 }
739 
740 Pass *llvm::createARMParallelDSPPass() {
741   return new ARMParallelDSP();
742 }
743 
744 char ARMParallelDSP::ID = 0;
745 
746 INITIALIZE_PASS_BEGIN(ARMParallelDSP, "arm-parallel-dsp",
747                 "Transform loops to use DSP intrinsics", false, false)
748 INITIALIZE_PASS_END(ARMParallelDSP, "arm-parallel-dsp",
749                 "Transform loops to use DSP intrinsics", false, false)
750