14ba319b5SDimitry Andric //===- ParallelDSP.cpp - Parallel DSP Pass --------------------------------===//
24ba319b5SDimitry Andric //
34ba319b5SDimitry Andric //                     The LLVM Compiler Infrastructure
44ba319b5SDimitry Andric //
54ba319b5SDimitry Andric // This file is distributed under the University of Illinois Open Source
64ba319b5SDimitry Andric // License. See LICENSE.TXT for details.
74ba319b5SDimitry Andric //
84ba319b5SDimitry Andric //===----------------------------------------------------------------------===//
94ba319b5SDimitry Andric //
104ba319b5SDimitry Andric /// \file
114ba319b5SDimitry Andric /// Armv6 introduced instructions to perform 32-bit SIMD operations. The
124ba319b5SDimitry Andric /// purpose of this pass is do some IR pattern matching to create ACLE
134ba319b5SDimitry Andric /// DSP intrinsics, which map on these 32-bit SIMD operations.
144ba319b5SDimitry Andric /// This pass runs only when unaligned accesses is supported/enabled.
154ba319b5SDimitry Andric //
164ba319b5SDimitry Andric //===----------------------------------------------------------------------===//
174ba319b5SDimitry Andric 
184ba319b5SDimitry Andric #include "llvm/ADT/Statistic.h"
194ba319b5SDimitry Andric #include "llvm/ADT/SmallPtrSet.h"
204ba319b5SDimitry Andric #include "llvm/Analysis/AliasAnalysis.h"
214ba319b5SDimitry Andric #include "llvm/Analysis/LoopAccessAnalysis.h"
224ba319b5SDimitry Andric #include "llvm/Analysis/LoopPass.h"
234ba319b5SDimitry Andric #include "llvm/Analysis/LoopInfo.h"
244ba319b5SDimitry Andric #include "llvm/IR/Instructions.h"
254ba319b5SDimitry Andric #include "llvm/IR/NoFolder.h"
264ba319b5SDimitry Andric #include "llvm/Transforms/Scalar.h"
274ba319b5SDimitry Andric #include "llvm/Transforms/Utils/BasicBlockUtils.h"
284ba319b5SDimitry Andric #include "llvm/Transforms/Utils/LoopUtils.h"
294ba319b5SDimitry Andric #include "llvm/Pass.h"
304ba319b5SDimitry Andric #include "llvm/PassRegistry.h"
314ba319b5SDimitry Andric #include "llvm/PassSupport.h"
324ba319b5SDimitry Andric #include "llvm/Support/Debug.h"
334ba319b5SDimitry Andric #include "llvm/IR/PatternMatch.h"
344ba319b5SDimitry Andric #include "llvm/CodeGen/TargetPassConfig.h"
354ba319b5SDimitry Andric #include "ARM.h"
364ba319b5SDimitry Andric #include "ARMSubtarget.h"
374ba319b5SDimitry Andric 
384ba319b5SDimitry Andric using namespace llvm;
394ba319b5SDimitry Andric using namespace PatternMatch;
404ba319b5SDimitry Andric 
414ba319b5SDimitry Andric #define DEBUG_TYPE "arm-parallel-dsp"
424ba319b5SDimitry Andric 
434ba319b5SDimitry Andric STATISTIC(NumSMLAD , "Number of smlad instructions generated");
444ba319b5SDimitry Andric 
45*b5893f02SDimitry Andric static cl::opt<bool>
46*b5893f02SDimitry Andric DisableParallelDSP("disable-arm-parallel-dsp", cl::Hidden, cl::init(false),
47*b5893f02SDimitry Andric                    cl::desc("Disable the ARM Parallel DSP pass"));
48*b5893f02SDimitry Andric 
494ba319b5SDimitry Andric namespace {
504ba319b5SDimitry Andric   struct OpChain;
514ba319b5SDimitry Andric   struct BinOpChain;
524ba319b5SDimitry Andric   struct Reduction;
534ba319b5SDimitry Andric 
544ba319b5SDimitry Andric   using OpChainList     = SmallVector<std::unique_ptr<OpChain>, 8>;
554ba319b5SDimitry Andric   using ReductionList   = SmallVector<Reduction, 8>;
564ba319b5SDimitry Andric   using ValueList       = SmallVector<Value*, 8>;
574ba319b5SDimitry Andric   using MemInstList     = SmallVector<Instruction*, 8>;
584ba319b5SDimitry Andric   using PMACPair        = std::pair<BinOpChain*,BinOpChain*>;
594ba319b5SDimitry Andric   using PMACPairList    = SmallVector<PMACPair, 8>;
604ba319b5SDimitry Andric   using Instructions    = SmallVector<Instruction*,16>;
614ba319b5SDimitry Andric   using MemLocList      = SmallVector<MemoryLocation, 4>;
624ba319b5SDimitry Andric 
634ba319b5SDimitry Andric   struct OpChain {
644ba319b5SDimitry Andric     Instruction   *Root;
654ba319b5SDimitry Andric     ValueList     AllValues;
664ba319b5SDimitry Andric     MemInstList   VecLd;    // List of all load instructions.
674ba319b5SDimitry Andric     MemLocList    MemLocs;  // All memory locations read by this tree.
684ba319b5SDimitry Andric     bool          ReadOnly = true;
694ba319b5SDimitry Andric 
OpChain__anon04a4004e0111::OpChain704ba319b5SDimitry Andric     OpChain(Instruction *I, ValueList &vl) : Root(I), AllValues(vl) { }
714ba319b5SDimitry Andric     virtual ~OpChain() = default;
724ba319b5SDimitry Andric 
SetMemoryLocations__anon04a4004e0111::OpChain734ba319b5SDimitry Andric     void SetMemoryLocations() {
74*b5893f02SDimitry Andric       const auto Size = LocationSize::unknown();
754ba319b5SDimitry Andric       for (auto *V : AllValues) {
764ba319b5SDimitry Andric         if (auto *I = dyn_cast<Instruction>(V)) {
774ba319b5SDimitry Andric           if (I->mayWriteToMemory())
784ba319b5SDimitry Andric             ReadOnly = false;
794ba319b5SDimitry Andric           if (auto *Ld = dyn_cast<LoadInst>(V))
804ba319b5SDimitry Andric             MemLocs.push_back(MemoryLocation(Ld->getPointerOperand(), Size));
814ba319b5SDimitry Andric         }
824ba319b5SDimitry Andric       }
834ba319b5SDimitry Andric     }
844ba319b5SDimitry Andric 
size__anon04a4004e0111::OpChain854ba319b5SDimitry Andric     unsigned size() const { return AllValues.size(); }
864ba319b5SDimitry Andric   };
874ba319b5SDimitry Andric 
884ba319b5SDimitry Andric   // 'BinOpChain' and 'Reduction' are just some bookkeeping data structures.
894ba319b5SDimitry Andric   // 'Reduction' contains the phi-node and accumulator statement from where we
904ba319b5SDimitry Andric   // start pattern matching, and 'BinOpChain' the multiplication
914ba319b5SDimitry Andric   // instructions that are candidates for parallel execution.
924ba319b5SDimitry Andric   struct BinOpChain : public OpChain {
934ba319b5SDimitry Andric     ValueList     LHS;      // List of all (narrow) left hand operands.
944ba319b5SDimitry Andric     ValueList     RHS;      // List of all (narrow) right hand operands.
95*b5893f02SDimitry Andric     bool Exchange = false;
964ba319b5SDimitry Andric 
BinOpChain__anon04a4004e0111::BinOpChain974ba319b5SDimitry Andric     BinOpChain(Instruction *I, ValueList &lhs, ValueList &rhs) :
984ba319b5SDimitry Andric       OpChain(I, lhs), LHS(lhs), RHS(rhs) {
994ba319b5SDimitry Andric         for (auto *V : RHS)
1004ba319b5SDimitry Andric           AllValues.push_back(V);
1014ba319b5SDimitry Andric       }
102*b5893f02SDimitry Andric 
103*b5893f02SDimitry Andric     bool AreSymmetrical(BinOpChain *Other);
1044ba319b5SDimitry Andric   };
1054ba319b5SDimitry Andric 
1064ba319b5SDimitry Andric   struct Reduction {
1074ba319b5SDimitry Andric     PHINode         *Phi;             // The Phi-node from where we start
1084ba319b5SDimitry Andric                                       // pattern matching.
1094ba319b5SDimitry Andric     Instruction     *AccIntAdd;       // The accumulating integer add statement,
1104ba319b5SDimitry Andric                                       // i.e, the reduction statement.
1114ba319b5SDimitry Andric     OpChainList     MACCandidates;    // The MAC candidates associated with
1124ba319b5SDimitry Andric                                       // this reduction statement.
113*b5893f02SDimitry Andric     PMACPairList    PMACPairs;
Reduction__anon04a4004e0111::Reduction1144ba319b5SDimitry Andric     Reduction (PHINode *P, Instruction *Acc) : Phi(P), AccIntAdd(Acc) { };
1154ba319b5SDimitry Andric   };
1164ba319b5SDimitry Andric 
1174ba319b5SDimitry Andric   class ARMParallelDSP : public LoopPass {
1184ba319b5SDimitry Andric     ScalarEvolution   *SE;
1194ba319b5SDimitry Andric     AliasAnalysis     *AA;
1204ba319b5SDimitry Andric     TargetLibraryInfo *TLI;
1214ba319b5SDimitry Andric     DominatorTree     *DT;
1224ba319b5SDimitry Andric     LoopInfo          *LI;
1234ba319b5SDimitry Andric     Loop              *L;
1244ba319b5SDimitry Andric     const DataLayout  *DL;
1254ba319b5SDimitry Andric     Module            *M;
126*b5893f02SDimitry Andric     std::map<LoadInst*, LoadInst*> LoadPairs;
127*b5893f02SDimitry Andric     std::map<LoadInst*, SmallVector<LoadInst*, 4>> SequentialLoads;
1284ba319b5SDimitry Andric 
129*b5893f02SDimitry Andric     bool RecordSequentialLoads(BasicBlock *Header);
130*b5893f02SDimitry Andric     bool InsertParallelMACs(Reduction &Reduction);
1314ba319b5SDimitry Andric     bool AreSequentialLoads(LoadInst *Ld0, LoadInst *Ld1, MemInstList &VecMem);
132*b5893f02SDimitry Andric     void CreateParallelMACPairs(Reduction &R);
1334ba319b5SDimitry Andric     Instruction *CreateSMLADCall(LoadInst *VecLd0, LoadInst *VecLd1,
134*b5893f02SDimitry Andric                                  Instruction *Acc, bool Exchange,
135*b5893f02SDimitry Andric                                  Instruction *InsertAfter);
1364ba319b5SDimitry Andric 
1374ba319b5SDimitry Andric     /// Try to match and generate: SMLAD, SMLADX - Signed Multiply Accumulate
1384ba319b5SDimitry Andric     /// Dual performs two signed 16x16-bit multiplications. It adds the
1394ba319b5SDimitry Andric     /// products to a 32-bit accumulate operand. Optionally, the instruction can
1404ba319b5SDimitry Andric     /// exchange the halfwords of the second operand before performing the
1414ba319b5SDimitry Andric     /// arithmetic.
1424ba319b5SDimitry Andric     bool MatchSMLAD(Function &F);
1434ba319b5SDimitry Andric 
1444ba319b5SDimitry Andric   public:
1454ba319b5SDimitry Andric     static char ID;
1464ba319b5SDimitry Andric 
ARMParallelDSP()1474ba319b5SDimitry Andric     ARMParallelDSP() : LoopPass(ID) { }
1484ba319b5SDimitry Andric 
getAnalysisUsage(AnalysisUsage & AU) const1494ba319b5SDimitry Andric     void getAnalysisUsage(AnalysisUsage &AU) const override {
1504ba319b5SDimitry Andric       LoopPass::getAnalysisUsage(AU);
1514ba319b5SDimitry Andric       AU.addRequired<AssumptionCacheTracker>();
1524ba319b5SDimitry Andric       AU.addRequired<ScalarEvolutionWrapperPass>();
1534ba319b5SDimitry Andric       AU.addRequired<AAResultsWrapperPass>();
1544ba319b5SDimitry Andric       AU.addRequired<TargetLibraryInfoWrapperPass>();
1554ba319b5SDimitry Andric       AU.addRequired<LoopInfoWrapperPass>();
1564ba319b5SDimitry Andric       AU.addRequired<DominatorTreeWrapperPass>();
1574ba319b5SDimitry Andric       AU.addRequired<TargetPassConfig>();
1584ba319b5SDimitry Andric       AU.addPreserved<LoopInfoWrapperPass>();
1594ba319b5SDimitry Andric       AU.setPreservesCFG();
1604ba319b5SDimitry Andric     }
1614ba319b5SDimitry Andric 
runOnLoop(Loop * TheLoop,LPPassManager &)1624ba319b5SDimitry Andric     bool runOnLoop(Loop *TheLoop, LPPassManager &) override {
163*b5893f02SDimitry Andric       if (DisableParallelDSP)
164*b5893f02SDimitry Andric         return false;
1654ba319b5SDimitry Andric       L = TheLoop;
1664ba319b5SDimitry Andric       SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
1674ba319b5SDimitry Andric       AA = &getAnalysis<AAResultsWrapperPass>().getAAResults();
1684ba319b5SDimitry Andric       TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
1694ba319b5SDimitry Andric       DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
1704ba319b5SDimitry Andric       LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
1714ba319b5SDimitry Andric       auto &TPC = getAnalysis<TargetPassConfig>();
1724ba319b5SDimitry Andric 
1734ba319b5SDimitry Andric       BasicBlock *Header = TheLoop->getHeader();
1744ba319b5SDimitry Andric       if (!Header)
1754ba319b5SDimitry Andric         return false;
1764ba319b5SDimitry Andric 
1774ba319b5SDimitry Andric       // TODO: We assume the loop header and latch to be the same block.
1784ba319b5SDimitry Andric       // This is not a fundamental restriction, but lifting this would just
1794ba319b5SDimitry Andric       // require more work to do the transformation and then patch up the CFG.
1804ba319b5SDimitry Andric       if (Header != TheLoop->getLoopLatch()) {
1814ba319b5SDimitry Andric         LLVM_DEBUG(dbgs() << "The loop header is not the loop latch: not "
1824ba319b5SDimitry Andric                              "running pass ARMParallelDSP\n");
1834ba319b5SDimitry Andric         return false;
1844ba319b5SDimitry Andric       }
1854ba319b5SDimitry Andric 
1864ba319b5SDimitry Andric       Function &F = *Header->getParent();
1874ba319b5SDimitry Andric       M = F.getParent();
1884ba319b5SDimitry Andric       DL = &M->getDataLayout();
1894ba319b5SDimitry Andric 
1904ba319b5SDimitry Andric       auto &TM = TPC.getTM<TargetMachine>();
1914ba319b5SDimitry Andric       auto *ST = &TM.getSubtarget<ARMSubtarget>(F);
1924ba319b5SDimitry Andric 
1934ba319b5SDimitry Andric       if (!ST->allowsUnalignedMem()) {
1944ba319b5SDimitry Andric         LLVM_DEBUG(dbgs() << "Unaligned memory access not supported: not "
1954ba319b5SDimitry Andric                              "running pass ARMParallelDSP\n");
1964ba319b5SDimitry Andric         return false;
1974ba319b5SDimitry Andric       }
1984ba319b5SDimitry Andric 
1994ba319b5SDimitry Andric       if (!ST->hasDSP()) {
2004ba319b5SDimitry Andric         LLVM_DEBUG(dbgs() << "DSP extension not enabled: not running pass "
2014ba319b5SDimitry Andric                              "ARMParallelDSP\n");
2024ba319b5SDimitry Andric         return false;
2034ba319b5SDimitry Andric       }
2044ba319b5SDimitry Andric 
2054ba319b5SDimitry Andric       LoopAccessInfo LAI(L, SE, TLI, AA, DT, LI);
2064ba319b5SDimitry Andric       bool Changes = false;
2074ba319b5SDimitry Andric 
208*b5893f02SDimitry Andric       LLVM_DEBUG(dbgs() << "\n== Parallel DSP pass ==\n");
209*b5893f02SDimitry Andric       LLVM_DEBUG(dbgs() << " - " << F.getName() << "\n\n");
210*b5893f02SDimitry Andric 
211*b5893f02SDimitry Andric       if (!RecordSequentialLoads(Header)) {
212*b5893f02SDimitry Andric         LLVM_DEBUG(dbgs() << " - No sequential loads found.\n");
213*b5893f02SDimitry Andric         return false;
214*b5893f02SDimitry Andric       }
215*b5893f02SDimitry Andric 
2164ba319b5SDimitry Andric       Changes = MatchSMLAD(F);
2174ba319b5SDimitry Andric       return Changes;
2184ba319b5SDimitry Andric     }
2194ba319b5SDimitry Andric   };
2204ba319b5SDimitry Andric }
2214ba319b5SDimitry Andric 
2224ba319b5SDimitry Andric // MaxBitwidth: the maximum supported bitwidth of the elements in the DSP
2234ba319b5SDimitry Andric // instructions, which is set to 16. So here we should collect all i8 and i16
2244ba319b5SDimitry Andric // narrow operations.
2254ba319b5SDimitry Andric // TODO: we currently only collect i16, and will support i8 later, so that's
2264ba319b5SDimitry Andric // why we check that types are equal to MaxBitWidth, and not <= MaxBitWidth.
2274ba319b5SDimitry Andric template<unsigned MaxBitWidth>
IsNarrowSequence(Value * V,ValueList & VL)2284ba319b5SDimitry Andric static bool IsNarrowSequence(Value *V, ValueList &VL) {
2294ba319b5SDimitry Andric   LLVM_DEBUG(dbgs() << "Is narrow sequence? "; V->dump());
2304ba319b5SDimitry Andric   ConstantInt *CInt;
2314ba319b5SDimitry Andric 
2324ba319b5SDimitry Andric   if (match(V, m_ConstantInt(CInt))) {
2334ba319b5SDimitry Andric     // TODO: if a constant is used, it needs to fit within the bit width.
2344ba319b5SDimitry Andric     return false;
2354ba319b5SDimitry Andric   }
2364ba319b5SDimitry Andric 
2374ba319b5SDimitry Andric   auto *I = dyn_cast<Instruction>(V);
2384ba319b5SDimitry Andric   if (!I)
2394ba319b5SDimitry Andric    return false;
2404ba319b5SDimitry Andric 
2414ba319b5SDimitry Andric   Value *Val, *LHS, *RHS;
2424ba319b5SDimitry Andric   if (match(V, m_Trunc(m_Value(Val)))) {
2434ba319b5SDimitry Andric     if (cast<TruncInst>(I)->getDestTy()->getIntegerBitWidth() == MaxBitWidth)
2444ba319b5SDimitry Andric       return IsNarrowSequence<MaxBitWidth>(Val, VL);
2454ba319b5SDimitry Andric   } else if (match(V, m_Add(m_Value(LHS), m_Value(RHS)))) {
2464ba319b5SDimitry Andric     // TODO: we need to implement sadd16/sadd8 for this, which enables to
2474ba319b5SDimitry Andric     // also do the rewrite for smlad8.ll, but it is unsupported for now.
2484ba319b5SDimitry Andric     LLVM_DEBUG(dbgs() << "No, unsupported Op:\t"; I->dump());
2494ba319b5SDimitry Andric     return false;
2504ba319b5SDimitry Andric   } else if (match(V, m_ZExtOrSExt(m_Value(Val)))) {
2514ba319b5SDimitry Andric     if (cast<CastInst>(I)->getSrcTy()->getIntegerBitWidth() != MaxBitWidth) {
2524ba319b5SDimitry Andric       LLVM_DEBUG(dbgs() << "No, wrong SrcTy size: " <<
2534ba319b5SDimitry Andric         cast<CastInst>(I)->getSrcTy()->getIntegerBitWidth() << "\n");
2544ba319b5SDimitry Andric       return false;
2554ba319b5SDimitry Andric     }
2564ba319b5SDimitry Andric 
2574ba319b5SDimitry Andric     if (match(Val, m_Load(m_Value()))) {
2584ba319b5SDimitry Andric       LLVM_DEBUG(dbgs() << "Yes, found narrow Load:\t"; Val->dump());
2594ba319b5SDimitry Andric       VL.push_back(Val);
2604ba319b5SDimitry Andric       VL.push_back(I);
2614ba319b5SDimitry Andric       return true;
2624ba319b5SDimitry Andric     }
2634ba319b5SDimitry Andric   }
2644ba319b5SDimitry Andric   LLVM_DEBUG(dbgs() << "No, unsupported Op:\t"; I->dump());
2654ba319b5SDimitry Andric   return false;
2664ba319b5SDimitry Andric }
2674ba319b5SDimitry Andric 
2684ba319b5SDimitry Andric template<typename MemInst>
AreSequentialAccesses(MemInst * MemOp0,MemInst * MemOp1,const DataLayout & DL,ScalarEvolution & SE)2694ba319b5SDimitry Andric static bool AreSequentialAccesses(MemInst *MemOp0, MemInst *MemOp1,
270*b5893f02SDimitry Andric                                   const DataLayout &DL, ScalarEvolution &SE) {
2714ba319b5SDimitry Andric   if (!MemOp0->isSimple() || !MemOp1->isSimple()) {
2724ba319b5SDimitry Andric     LLVM_DEBUG(dbgs() << "No, not touching volatile access\n");
2734ba319b5SDimitry Andric     return false;
2744ba319b5SDimitry Andric   }
2754ba319b5SDimitry Andric   if (isConsecutiveAccess(MemOp0, MemOp1, DL, SE)) {
2764ba319b5SDimitry Andric     LLVM_DEBUG(dbgs() << "OK: accesses are consecutive.\n");
2774ba319b5SDimitry Andric     return true;
2784ba319b5SDimitry Andric   }
2794ba319b5SDimitry Andric   LLVM_DEBUG(dbgs() << "No, accesses aren't consecutive.\n");
2804ba319b5SDimitry Andric   return false;
2814ba319b5SDimitry Andric }
2824ba319b5SDimitry Andric 
AreSequentialLoads(LoadInst * Ld0,LoadInst * Ld1,MemInstList & VecMem)2834ba319b5SDimitry Andric bool ARMParallelDSP::AreSequentialLoads(LoadInst *Ld0, LoadInst *Ld1,
2844ba319b5SDimitry Andric                                         MemInstList &VecMem) {
2854ba319b5SDimitry Andric   if (!Ld0 || !Ld1)
2864ba319b5SDimitry Andric     return false;
2874ba319b5SDimitry Andric 
2884ba319b5SDimitry Andric   LLVM_DEBUG(dbgs() << "Are consecutive loads:\n";
2894ba319b5SDimitry Andric     dbgs() << "Ld0:"; Ld0->dump();
2904ba319b5SDimitry Andric     dbgs() << "Ld1:"; Ld1->dump();
2914ba319b5SDimitry Andric   );
2924ba319b5SDimitry Andric 
2934ba319b5SDimitry Andric   if (!Ld0->hasOneUse() || !Ld1->hasOneUse()) {
2944ba319b5SDimitry Andric     LLVM_DEBUG(dbgs() << "No, load has more than one use.\n");
2954ba319b5SDimitry Andric     return false;
2964ba319b5SDimitry Andric   }
2974ba319b5SDimitry Andric 
298*b5893f02SDimitry Andric   if (!LoadPairs.count(Ld0) || LoadPairs[Ld0] != Ld1)
299*b5893f02SDimitry Andric     return false;
300*b5893f02SDimitry Andric 
301*b5893f02SDimitry Andric   VecMem.clear();
302*b5893f02SDimitry Andric   VecMem.push_back(Ld0);
303*b5893f02SDimitry Andric   VecMem.push_back(Ld1);
304*b5893f02SDimitry Andric   return true;
3054ba319b5SDimitry Andric }
3064ba319b5SDimitry Andric 
307*b5893f02SDimitry Andric /// Iterate through the block and record base, offset pairs of loads as well as
308*b5893f02SDimitry Andric /// maximal sequences of sequential loads.
RecordSequentialLoads(BasicBlock * Header)309*b5893f02SDimitry Andric bool ARMParallelDSP::RecordSequentialLoads(BasicBlock *Header) {
310*b5893f02SDimitry Andric   SmallVector<LoadInst*, 8> Loads;
311*b5893f02SDimitry Andric   for (auto &I : *Header) {
312*b5893f02SDimitry Andric     auto *Ld = dyn_cast<LoadInst>(&I);
313*b5893f02SDimitry Andric     if (!Ld)
314*b5893f02SDimitry Andric       continue;
315*b5893f02SDimitry Andric     Loads.push_back(Ld);
316*b5893f02SDimitry Andric   }
317*b5893f02SDimitry Andric 
318*b5893f02SDimitry Andric   std::map<LoadInst*, LoadInst*> BaseLoads;
319*b5893f02SDimitry Andric 
320*b5893f02SDimitry Andric   for (auto *Ld0 : Loads) {
321*b5893f02SDimitry Andric     for (auto *Ld1 : Loads) {
322*b5893f02SDimitry Andric       if (Ld0 == Ld1)
323*b5893f02SDimitry Andric         continue;
324*b5893f02SDimitry Andric 
325*b5893f02SDimitry Andric       if (AreSequentialAccesses<LoadInst>(Ld0, Ld1, *DL, *SE)) {
326*b5893f02SDimitry Andric         LoadPairs[Ld0] = Ld1;
327*b5893f02SDimitry Andric         if (BaseLoads.count(Ld0)) {
328*b5893f02SDimitry Andric           LoadInst *Base = BaseLoads[Ld0];
329*b5893f02SDimitry Andric           BaseLoads[Ld1] = Base;
330*b5893f02SDimitry Andric           SequentialLoads[Base].push_back(Ld1);
331*b5893f02SDimitry Andric         } else {
332*b5893f02SDimitry Andric           BaseLoads[Ld1] = Ld0;
333*b5893f02SDimitry Andric           SequentialLoads[Ld0].push_back(Ld1);
334*b5893f02SDimitry Andric         }
335*b5893f02SDimitry Andric       }
336*b5893f02SDimitry Andric     }
337*b5893f02SDimitry Andric   }
338*b5893f02SDimitry Andric   return LoadPairs.size() > 1;
339*b5893f02SDimitry Andric }
340*b5893f02SDimitry Andric 
CreateParallelMACPairs(Reduction & R)341*b5893f02SDimitry Andric void ARMParallelDSP::CreateParallelMACPairs(Reduction &R) {
342*b5893f02SDimitry Andric   OpChainList &Candidates = R.MACCandidates;
343*b5893f02SDimitry Andric   PMACPairList &PMACPairs = R.PMACPairs;
3444ba319b5SDimitry Andric   const unsigned Elems = Candidates.size();
3454ba319b5SDimitry Andric 
3464ba319b5SDimitry Andric   if (Elems < 2)
347*b5893f02SDimitry Andric     return;
3484ba319b5SDimitry Andric 
349*b5893f02SDimitry Andric   auto CanPair = [&](BinOpChain *PMul0, BinOpChain *PMul1) {
350*b5893f02SDimitry Andric     if (!PMul0->AreSymmetrical(PMul1))
351*b5893f02SDimitry Andric       return false;
352*b5893f02SDimitry Andric 
353*b5893f02SDimitry Andric     // The first elements of each vector should be loads with sexts. If we
354*b5893f02SDimitry Andric     // find that its two pairs of consecutive loads, then these can be
355*b5893f02SDimitry Andric     // transformed into two wider loads and the users can be replaced with
356*b5893f02SDimitry Andric     // DSP intrinsics.
357*b5893f02SDimitry Andric     for (unsigned x = 0; x < PMul0->LHS.size(); x += 2) {
358*b5893f02SDimitry Andric       auto *Ld0 = dyn_cast<LoadInst>(PMul0->LHS[x]);
359*b5893f02SDimitry Andric       auto *Ld1 = dyn_cast<LoadInst>(PMul1->LHS[x]);
360*b5893f02SDimitry Andric       auto *Ld2 = dyn_cast<LoadInst>(PMul0->RHS[x]);
361*b5893f02SDimitry Andric       auto *Ld3 = dyn_cast<LoadInst>(PMul1->RHS[x]);
362*b5893f02SDimitry Andric 
363*b5893f02SDimitry Andric       if (!Ld0 || !Ld1 || !Ld2 || !Ld3)
364*b5893f02SDimitry Andric         return false;
365*b5893f02SDimitry Andric 
366*b5893f02SDimitry Andric       LLVM_DEBUG(dbgs() << "Looking at operands " << x << ":\n"
367*b5893f02SDimitry Andric                  << "\t Ld0: " << *Ld0 << "\n"
368*b5893f02SDimitry Andric                  << "\t Ld1: " << *Ld1 << "\n"
369*b5893f02SDimitry Andric                  << "and operands " << x + 2 << ":\n"
370*b5893f02SDimitry Andric                  << "\t Ld2: " << *Ld2 << "\n"
371*b5893f02SDimitry Andric                  << "\t Ld3: " << *Ld3 << "\n");
372*b5893f02SDimitry Andric 
373*b5893f02SDimitry Andric       if (AreSequentialLoads(Ld0, Ld1, PMul0->VecLd)) {
374*b5893f02SDimitry Andric         if (AreSequentialLoads(Ld2, Ld3, PMul1->VecLd)) {
375*b5893f02SDimitry Andric           LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n");
376*b5893f02SDimitry Andric           PMACPairs.push_back(std::make_pair(PMul0, PMul1));
377*b5893f02SDimitry Andric           return true;
378*b5893f02SDimitry Andric         } else if (AreSequentialLoads(Ld3, Ld2, PMul1->VecLd)) {
379*b5893f02SDimitry Andric           LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n");
380*b5893f02SDimitry Andric           LLVM_DEBUG(dbgs() << "    exchanging Ld2 and Ld3\n");
381*b5893f02SDimitry Andric           PMul1->Exchange = true;
382*b5893f02SDimitry Andric           PMACPairs.push_back(std::make_pair(PMul0, PMul1));
383*b5893f02SDimitry Andric           return true;
384*b5893f02SDimitry Andric         }
385*b5893f02SDimitry Andric       } else if (AreSequentialLoads(Ld1, Ld0, PMul0->VecLd) &&
386*b5893f02SDimitry Andric                  AreSequentialLoads(Ld2, Ld3, PMul1->VecLd)) {
387*b5893f02SDimitry Andric         LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n");
388*b5893f02SDimitry Andric         LLVM_DEBUG(dbgs() << "    exchanging Ld0 and Ld1\n");
389*b5893f02SDimitry Andric         LLVM_DEBUG(dbgs() << "    and swapping muls\n");
390*b5893f02SDimitry Andric         PMul0->Exchange = true;
391*b5893f02SDimitry Andric         // Only the second operand can be exchanged, so swap the muls.
392*b5893f02SDimitry Andric         PMACPairs.push_back(std::make_pair(PMul1, PMul0));
393*b5893f02SDimitry Andric         return true;
394*b5893f02SDimitry Andric       }
395*b5893f02SDimitry Andric     }
396*b5893f02SDimitry Andric     return false;
397*b5893f02SDimitry Andric   };
398*b5893f02SDimitry Andric 
399*b5893f02SDimitry Andric   SmallPtrSet<const Instruction*, 4> Paired;
400*b5893f02SDimitry Andric   for (unsigned i = 0; i < Elems; ++i) {
4014ba319b5SDimitry Andric     BinOpChain *PMul0 = static_cast<BinOpChain*>(Candidates[i].get());
402*b5893f02SDimitry Andric     if (Paired.count(PMul0->Root))
403*b5893f02SDimitry Andric       continue;
404*b5893f02SDimitry Andric 
405*b5893f02SDimitry Andric     for (unsigned j = 0; j < Elems; ++j) {
406*b5893f02SDimitry Andric       if (i == j)
407*b5893f02SDimitry Andric         continue;
408*b5893f02SDimitry Andric 
409*b5893f02SDimitry Andric       BinOpChain *PMul1 = static_cast<BinOpChain*>(Candidates[j].get());
410*b5893f02SDimitry Andric       if (Paired.count(PMul1->Root))
411*b5893f02SDimitry Andric         continue;
412*b5893f02SDimitry Andric 
4134ba319b5SDimitry Andric       const Instruction *Mul0 = PMul0->Root;
4144ba319b5SDimitry Andric       const Instruction *Mul1 = PMul1->Root;
4154ba319b5SDimitry Andric       if (Mul0 == Mul1)
4164ba319b5SDimitry Andric         continue;
4174ba319b5SDimitry Andric 
418*b5893f02SDimitry Andric       assert(PMul0 != PMul1 && "expected different chains");
419*b5893f02SDimitry Andric 
4204ba319b5SDimitry Andric       LLVM_DEBUG(dbgs() << "\nCheck parallel muls:\n";
4214ba319b5SDimitry Andric                  dbgs() << "- "; Mul0->dump();
4224ba319b5SDimitry Andric                  dbgs() << "- "; Mul1->dump());
4234ba319b5SDimitry Andric 
4244ba319b5SDimitry Andric       LLVM_DEBUG(dbgs() << "OK: mul operands list match:\n");
425*b5893f02SDimitry Andric       if (CanPair(PMul0, PMul1)) {
426*b5893f02SDimitry Andric         Paired.insert(Mul0);
427*b5893f02SDimitry Andric         Paired.insert(Mul1);
428*b5893f02SDimitry Andric         break;
4294ba319b5SDimitry Andric       }
4304ba319b5SDimitry Andric     }
4314ba319b5SDimitry Andric   }
4324ba319b5SDimitry Andric }
4334ba319b5SDimitry Andric 
InsertParallelMACs(Reduction & Reduction)434*b5893f02SDimitry Andric bool ARMParallelDSP::InsertParallelMACs(Reduction &Reduction) {
4354ba319b5SDimitry Andric   Instruction *Acc = Reduction.Phi;
4364ba319b5SDimitry Andric   Instruction *InsertAfter = Reduction.AccIntAdd;
4374ba319b5SDimitry Andric 
438*b5893f02SDimitry Andric   for (auto &Pair : Reduction.PMACPairs) {
439*b5893f02SDimitry Andric     BinOpChain *PMul0 = Pair.first;
440*b5893f02SDimitry Andric     BinOpChain *PMul1 = Pair.second;
4414ba319b5SDimitry Andric     LLVM_DEBUG(dbgs() << "Found parallel MACs!!\n";
442*b5893f02SDimitry Andric                dbgs() << "- "; PMul0->Root->dump();
443*b5893f02SDimitry Andric                dbgs() << "- "; PMul1->Root->dump());
444*b5893f02SDimitry Andric 
445*b5893f02SDimitry Andric     auto *VecLd0 = cast<LoadInst>(PMul0->VecLd[0]);
446*b5893f02SDimitry Andric     auto *VecLd1 = cast<LoadInst>(PMul1->VecLd[0]);
447*b5893f02SDimitry Andric     Acc = CreateSMLADCall(VecLd0, VecLd1, Acc, PMul1->Exchange, InsertAfter);
4484ba319b5SDimitry Andric     InsertAfter = Acc;
4494ba319b5SDimitry Andric   }
4504ba319b5SDimitry Andric 
4514ba319b5SDimitry Andric   if (Acc != Reduction.Phi) {
4524ba319b5SDimitry Andric     LLVM_DEBUG(dbgs() << "Replace Accumulate: "; Acc->dump());
4534ba319b5SDimitry Andric     Reduction.AccIntAdd->replaceAllUsesWith(Acc);
4544ba319b5SDimitry Andric     return true;
4554ba319b5SDimitry Andric   }
4564ba319b5SDimitry Andric   return false;
4574ba319b5SDimitry Andric }
4584ba319b5SDimitry Andric 
MatchReductions(Function & F,Loop * TheLoop,BasicBlock * Header,ReductionList & Reductions)4594ba319b5SDimitry Andric static void MatchReductions(Function &F, Loop *TheLoop, BasicBlock *Header,
4604ba319b5SDimitry Andric                             ReductionList &Reductions) {
4614ba319b5SDimitry Andric   RecurrenceDescriptor RecDesc;
4624ba319b5SDimitry Andric   const bool HasFnNoNaNAttr =
4634ba319b5SDimitry Andric     F.getFnAttribute("no-nans-fp-math").getValueAsString() == "true";
4644ba319b5SDimitry Andric   const BasicBlock *Latch = TheLoop->getLoopLatch();
4654ba319b5SDimitry Andric 
4664ba319b5SDimitry Andric   // We need a preheader as getIncomingValueForBlock assumes there is one.
4674ba319b5SDimitry Andric   if (!TheLoop->getLoopPreheader()) {
4684ba319b5SDimitry Andric     LLVM_DEBUG(dbgs() << "No preheader found, bailing out\n");
4694ba319b5SDimitry Andric     return;
4704ba319b5SDimitry Andric   }
4714ba319b5SDimitry Andric 
4724ba319b5SDimitry Andric   for (PHINode &Phi : Header->phis()) {
4734ba319b5SDimitry Andric     const auto *Ty = Phi.getType();
474*b5893f02SDimitry Andric     if (!Ty->isIntegerTy(32) && !Ty->isIntegerTy(64))
4754ba319b5SDimitry Andric       continue;
4764ba319b5SDimitry Andric 
4774ba319b5SDimitry Andric     const bool IsReduction =
4784ba319b5SDimitry Andric       RecurrenceDescriptor::AddReductionVar(&Phi,
4794ba319b5SDimitry Andric                                             RecurrenceDescriptor::RK_IntegerAdd,
4804ba319b5SDimitry Andric                                             TheLoop, HasFnNoNaNAttr, RecDesc);
4814ba319b5SDimitry Andric     if (!IsReduction)
4824ba319b5SDimitry Andric       continue;
4834ba319b5SDimitry Andric 
4844ba319b5SDimitry Andric     Instruction *Acc = dyn_cast<Instruction>(Phi.getIncomingValueForBlock(Latch));
4854ba319b5SDimitry Andric     if (!Acc)
4864ba319b5SDimitry Andric       continue;
4874ba319b5SDimitry Andric 
4884ba319b5SDimitry Andric     Reductions.push_back(Reduction(&Phi, Acc));
4894ba319b5SDimitry Andric   }
4904ba319b5SDimitry Andric 
4914ba319b5SDimitry Andric   LLVM_DEBUG(
4924ba319b5SDimitry Andric     dbgs() << "\nAccumulating integer additions (reductions) found:\n";
4934ba319b5SDimitry Andric     for (auto &R : Reductions) {
4944ba319b5SDimitry Andric       dbgs() << "-  "; R.Phi->dump();
4954ba319b5SDimitry Andric       dbgs() << "-> "; R.AccIntAdd->dump();
4964ba319b5SDimitry Andric     }
4974ba319b5SDimitry Andric   );
4984ba319b5SDimitry Andric }
4994ba319b5SDimitry Andric 
AddMACCandidate(OpChainList & Candidates,Instruction * Mul,Value * MulOp0,Value * MulOp1)5004ba319b5SDimitry Andric static void AddMACCandidate(OpChainList &Candidates,
501*b5893f02SDimitry Andric                             Instruction *Mul,
502*b5893f02SDimitry Andric                             Value *MulOp0, Value *MulOp1) {
5034ba319b5SDimitry Andric   LLVM_DEBUG(dbgs() << "OK, found acc mul:\t"; Mul->dump());
504*b5893f02SDimitry Andric   assert(Mul->getOpcode() == Instruction::Mul &&
505*b5893f02SDimitry Andric          "expected mul instruction");
5064ba319b5SDimitry Andric   ValueList LHS;
5074ba319b5SDimitry Andric   ValueList RHS;
5084ba319b5SDimitry Andric   if (IsNarrowSequence<16>(MulOp0, LHS) &&
5094ba319b5SDimitry Andric       IsNarrowSequence<16>(MulOp1, RHS)) {
5104ba319b5SDimitry Andric     LLVM_DEBUG(dbgs() << "OK, found narrow mul: "; Mul->dump());
5114ba319b5SDimitry Andric     Candidates.push_back(make_unique<BinOpChain>(Mul, LHS, RHS));
5124ba319b5SDimitry Andric   }
5134ba319b5SDimitry Andric }
5144ba319b5SDimitry Andric 
MatchParallelMACSequences(Reduction & R,OpChainList & Candidates)5154ba319b5SDimitry Andric static void MatchParallelMACSequences(Reduction &R,
5164ba319b5SDimitry Andric                                       OpChainList &Candidates) {
517*b5893f02SDimitry Andric   Instruction *Acc = R.AccIntAdd;
518*b5893f02SDimitry Andric   LLVM_DEBUG(dbgs() << "\n- Analysing:\t" << *Acc);
5194ba319b5SDimitry Andric 
520*b5893f02SDimitry Andric   // Returns false to signal the search should be stopped.
521*b5893f02SDimitry Andric   std::function<bool(Value*)> Match =
522*b5893f02SDimitry Andric     [&Candidates, &Match](Value *V) -> bool {
523*b5893f02SDimitry Andric 
524*b5893f02SDimitry Andric     auto *I = dyn_cast<Instruction>(V);
525*b5893f02SDimitry Andric     if (!I)
526*b5893f02SDimitry Andric       return false;
527*b5893f02SDimitry Andric 
528*b5893f02SDimitry Andric     switch (I->getOpcode()) {
529*b5893f02SDimitry Andric     case Instruction::Add:
530*b5893f02SDimitry Andric       if (Match(I->getOperand(0)) || (Match(I->getOperand(1))))
531*b5893f02SDimitry Andric         return true;
532*b5893f02SDimitry Andric       break;
533*b5893f02SDimitry Andric     case Instruction::Mul: {
534*b5893f02SDimitry Andric       Value *MulOp0 = I->getOperand(0);
535*b5893f02SDimitry Andric       Value *MulOp1 = I->getOperand(1);
536*b5893f02SDimitry Andric       if (isa<SExtInst>(MulOp0) && isa<SExtInst>(MulOp1))
537*b5893f02SDimitry Andric         AddMACCandidate(Candidates, I, MulOp0, MulOp1);
538*b5893f02SDimitry Andric       return false;
5394ba319b5SDimitry Andric     }
540*b5893f02SDimitry Andric     case Instruction::SExt:
541*b5893f02SDimitry Andric       return Match(I->getOperand(0));
5424ba319b5SDimitry Andric     }
543*b5893f02SDimitry Andric     return false;
544*b5893f02SDimitry Andric   };
5454ba319b5SDimitry Andric 
546*b5893f02SDimitry Andric   while (Match (Acc));
547*b5893f02SDimitry Andric   LLVM_DEBUG(dbgs() << "Finished matching MAC sequences, found "
548*b5893f02SDimitry Andric              << Candidates.size() << " candidates.\n");
5494ba319b5SDimitry Andric }
5504ba319b5SDimitry Andric 
5514ba319b5SDimitry Andric // Collects all instructions that are not part of the MAC chains, which is the
5524ba319b5SDimitry Andric // set of instructions that can potentially alias with the MAC operands.
AliasCandidates(BasicBlock * Header,Instructions & Reads,Instructions & Writes)5534ba319b5SDimitry Andric static void AliasCandidates(BasicBlock *Header, Instructions &Reads,
5544ba319b5SDimitry Andric                             Instructions &Writes) {
5554ba319b5SDimitry Andric   for (auto &I : *Header) {
5564ba319b5SDimitry Andric     if (I.mayReadFromMemory())
5574ba319b5SDimitry Andric       Reads.push_back(&I);
5584ba319b5SDimitry Andric     if (I.mayWriteToMemory())
5594ba319b5SDimitry Andric       Writes.push_back(&I);
5604ba319b5SDimitry Andric   }
5614ba319b5SDimitry Andric }
5624ba319b5SDimitry Andric 
5634ba319b5SDimitry Andric // Check whether statements in the basic block that write to memory alias with
5644ba319b5SDimitry Andric // the memory locations accessed by the MAC-chains.
5654ba319b5SDimitry Andric // TODO: we need the read statements when we accept more complicated chains.
AreAliased(AliasAnalysis * AA,Instructions & Reads,Instructions & Writes,OpChainList & MACCandidates)5664ba319b5SDimitry Andric static bool AreAliased(AliasAnalysis *AA, Instructions &Reads,
5674ba319b5SDimitry Andric                        Instructions &Writes, OpChainList &MACCandidates) {
5684ba319b5SDimitry Andric   LLVM_DEBUG(dbgs() << "Alias checks:\n");
5694ba319b5SDimitry Andric   for (auto &MAC : MACCandidates) {
5704ba319b5SDimitry Andric     LLVM_DEBUG(dbgs() << "mul: "; MAC->Root->dump());
5714ba319b5SDimitry Andric 
5724ba319b5SDimitry Andric     // At the moment, we allow only simple chains that only consist of reads,
5734ba319b5SDimitry Andric     // accumulate their result with an integer add, and thus that don't write
5744ba319b5SDimitry Andric     // memory, and simply bail if they do.
5754ba319b5SDimitry Andric     if (!MAC->ReadOnly)
5764ba319b5SDimitry Andric       return true;
5774ba319b5SDimitry Andric 
5784ba319b5SDimitry Andric     // Now for all writes in the basic block, check that they don't alias with
5794ba319b5SDimitry Andric     // the memory locations accessed by our MAC-chain:
5804ba319b5SDimitry Andric     for (auto *I : Writes) {
5814ba319b5SDimitry Andric       LLVM_DEBUG(dbgs() << "- "; I->dump());
5824ba319b5SDimitry Andric       assert(MAC->MemLocs.size() >= 2 && "expecting at least 2 memlocs");
5834ba319b5SDimitry Andric       for (auto &MemLoc : MAC->MemLocs) {
5844ba319b5SDimitry Andric         if (isModOrRefSet(intersectModRef(AA->getModRefInfo(I, MemLoc),
5854ba319b5SDimitry Andric                                           ModRefInfo::ModRef))) {
5864ba319b5SDimitry Andric           LLVM_DEBUG(dbgs() << "Yes, aliases found\n");
5874ba319b5SDimitry Andric           return true;
5884ba319b5SDimitry Andric         }
5894ba319b5SDimitry Andric       }
5904ba319b5SDimitry Andric     }
5914ba319b5SDimitry Andric   }
5924ba319b5SDimitry Andric 
5934ba319b5SDimitry Andric   LLVM_DEBUG(dbgs() << "OK: no aliases found!\n");
5944ba319b5SDimitry Andric   return false;
5954ba319b5SDimitry Andric }
5964ba319b5SDimitry Andric 
CheckMACMemory(OpChainList & Candidates)5974ba319b5SDimitry Andric static bool CheckMACMemory(OpChainList &Candidates) {
5984ba319b5SDimitry Andric   for (auto &C : Candidates) {
5994ba319b5SDimitry Andric     // A mul has 2 operands, and a narrow op consist of sext and a load; thus
6004ba319b5SDimitry Andric     // we expect at least 4 items in this operand value list.
6014ba319b5SDimitry Andric     if (C->size() < 4) {
6024ba319b5SDimitry Andric       LLVM_DEBUG(dbgs() << "Operand list too short.\n");
6034ba319b5SDimitry Andric       return false;
6044ba319b5SDimitry Andric     }
6054ba319b5SDimitry Andric     C->SetMemoryLocations();
6064ba319b5SDimitry Andric     ValueList &LHS = static_cast<BinOpChain*>(C.get())->LHS;
6074ba319b5SDimitry Andric     ValueList &RHS = static_cast<BinOpChain*>(C.get())->RHS;
6084ba319b5SDimitry Andric 
6094ba319b5SDimitry Andric     // Use +=2 to skip over the expected extend instructions.
6104ba319b5SDimitry Andric     for (unsigned i = 0, e = LHS.size(); i < e; i += 2) {
6114ba319b5SDimitry Andric       if (!isa<LoadInst>(LHS[i]) || !isa<LoadInst>(RHS[i]))
6124ba319b5SDimitry Andric         return false;
6134ba319b5SDimitry Andric     }
6144ba319b5SDimitry Andric   }
6154ba319b5SDimitry Andric   return true;
6164ba319b5SDimitry Andric }
6174ba319b5SDimitry Andric 
6184ba319b5SDimitry Andric // Loop Pass that needs to identify integer add/sub reductions of 16-bit vector
6194ba319b5SDimitry Andric // multiplications.
6204ba319b5SDimitry Andric // To use SMLAD:
6214ba319b5SDimitry Andric // 1) we first need to find integer add reduction PHIs,
6224ba319b5SDimitry Andric // 2) then from the PHI, look for this pattern:
6234ba319b5SDimitry Andric //
6244ba319b5SDimitry Andric // acc0 = phi i32 [0, %entry], [%acc1, %loop.body]
6254ba319b5SDimitry Andric // ld0 = load i16
6264ba319b5SDimitry Andric // sext0 = sext i16 %ld0 to i32
6274ba319b5SDimitry Andric // ld1 = load i16
6284ba319b5SDimitry Andric // sext1 = sext i16 %ld1 to i32
6294ba319b5SDimitry Andric // mul0 = mul %sext0, %sext1
6304ba319b5SDimitry Andric // ld2 = load i16
6314ba319b5SDimitry Andric // sext2 = sext i16 %ld2 to i32
6324ba319b5SDimitry Andric // ld3 = load i16
6334ba319b5SDimitry Andric // sext3 = sext i16 %ld3 to i32
6344ba319b5SDimitry Andric // mul1 = mul i32 %sext2, %sext3
6354ba319b5SDimitry Andric // add0 = add i32 %mul0, %acc0
6364ba319b5SDimitry Andric // acc1 = add i32 %add0, %mul1
6374ba319b5SDimitry Andric //
6384ba319b5SDimitry Andric // Which can be selected to:
6394ba319b5SDimitry Andric //
6404ba319b5SDimitry Andric // ldr.h r0
6414ba319b5SDimitry Andric // ldr.h r1
6424ba319b5SDimitry Andric // smlad r2, r0, r1, r2
6434ba319b5SDimitry Andric //
6444ba319b5SDimitry Andric // If constants are used instead of loads, these will need to be hoisted
6454ba319b5SDimitry Andric // out and into a register.
6464ba319b5SDimitry Andric //
6474ba319b5SDimitry Andric // If loop invariants are used instead of loads, these need to be packed
6484ba319b5SDimitry Andric // before the loop begins.
6494ba319b5SDimitry Andric //
MatchSMLAD(Function & F)6504ba319b5SDimitry Andric bool ARMParallelDSP::MatchSMLAD(Function &F) {
6514ba319b5SDimitry Andric   BasicBlock *Header = L->getHeader();
6524ba319b5SDimitry Andric   LLVM_DEBUG(dbgs() << "= Matching SMLAD =\n";
6534ba319b5SDimitry Andric              dbgs() << "Header block:\n"; Header->dump();
6544ba319b5SDimitry Andric              dbgs() << "Loop info:\n\n"; L->dump());
6554ba319b5SDimitry Andric 
6564ba319b5SDimitry Andric   bool Changed = false;
6574ba319b5SDimitry Andric   ReductionList Reductions;
6584ba319b5SDimitry Andric   MatchReductions(F, L, Header, Reductions);
6594ba319b5SDimitry Andric 
6604ba319b5SDimitry Andric   for (auto &R : Reductions) {
6614ba319b5SDimitry Andric     OpChainList MACCandidates;
6624ba319b5SDimitry Andric     MatchParallelMACSequences(R, MACCandidates);
6634ba319b5SDimitry Andric     if (!CheckMACMemory(MACCandidates))
6644ba319b5SDimitry Andric       continue;
6654ba319b5SDimitry Andric 
6664ba319b5SDimitry Andric     R.MACCandidates = std::move(MACCandidates);
6674ba319b5SDimitry Andric 
6684ba319b5SDimitry Andric     LLVM_DEBUG(dbgs() << "MAC candidates:\n";
6694ba319b5SDimitry Andric       for (auto &M : R.MACCandidates)
6704ba319b5SDimitry Andric         M->Root->dump();
6714ba319b5SDimitry Andric       dbgs() << "\n";);
6724ba319b5SDimitry Andric   }
6734ba319b5SDimitry Andric 
6744ba319b5SDimitry Andric   // Collect all instructions that may read or write memory. Our alias
6754ba319b5SDimitry Andric   // analysis checks bail out if any of these instructions aliases with an
6764ba319b5SDimitry Andric   // instruction from the MAC-chain.
6774ba319b5SDimitry Andric   Instructions Reads, Writes;
6784ba319b5SDimitry Andric   AliasCandidates(Header, Reads, Writes);
6794ba319b5SDimitry Andric 
6804ba319b5SDimitry Andric   for (auto &R : Reductions) {
6814ba319b5SDimitry Andric     if (AreAliased(AA, Reads, Writes, R.MACCandidates))
6824ba319b5SDimitry Andric       return false;
683*b5893f02SDimitry Andric     CreateParallelMACPairs(R);
684*b5893f02SDimitry Andric     Changed |= InsertParallelMACs(R);
6854ba319b5SDimitry Andric   }
6864ba319b5SDimitry Andric 
6874ba319b5SDimitry Andric   LLVM_DEBUG(if (Changed) dbgs() << "Header block:\n"; Header->dump(););
6884ba319b5SDimitry Andric   return Changed;
6894ba319b5SDimitry Andric }
6904ba319b5SDimitry Andric 
CreateLoadIns(IRBuilder<NoFolder> & IRB,LoadInst & BaseLoad,const Type * LoadTy)691*b5893f02SDimitry Andric static LoadInst *CreateLoadIns(IRBuilder<NoFolder> &IRB, LoadInst &BaseLoad,
692*b5893f02SDimitry Andric                                const Type *LoadTy) {
693*b5893f02SDimitry Andric   const unsigned AddrSpace = BaseLoad.getPointerAddressSpace();
6944ba319b5SDimitry Andric 
695*b5893f02SDimitry Andric   Value *VecPtr = IRB.CreateBitCast(BaseLoad.getPointerOperand(),
696*b5893f02SDimitry Andric                                     LoadTy->getPointerTo(AddrSpace));
697*b5893f02SDimitry Andric   return IRB.CreateAlignedLoad(VecPtr, BaseLoad.getAlignment());
6984ba319b5SDimitry Andric }
6994ba319b5SDimitry Andric 
CreateSMLADCall(LoadInst * VecLd0,LoadInst * VecLd1,Instruction * Acc,bool Exchange,Instruction * InsertAfter)7004ba319b5SDimitry Andric Instruction *ARMParallelDSP::CreateSMLADCall(LoadInst *VecLd0, LoadInst *VecLd1,
701*b5893f02SDimitry Andric                                              Instruction *Acc, bool Exchange,
7024ba319b5SDimitry Andric                                              Instruction *InsertAfter) {
703*b5893f02SDimitry Andric   LLVM_DEBUG(dbgs() << "Create SMLAD intrinsic using:\n"
704*b5893f02SDimitry Andric              << "- " << *VecLd0 << "\n"
705*b5893f02SDimitry Andric              << "- " << *VecLd1 << "\n"
706*b5893f02SDimitry Andric              << "- " << *Acc << "\n"
707*b5893f02SDimitry Andric              << "Exchange: " << Exchange << "\n");
7084ba319b5SDimitry Andric 
7094ba319b5SDimitry Andric   IRBuilder<NoFolder> Builder(InsertAfter->getParent(),
7104ba319b5SDimitry Andric                               ++BasicBlock::iterator(InsertAfter));
7114ba319b5SDimitry Andric 
7124ba319b5SDimitry Andric   // Replace the reduction chain with an intrinsic call
713*b5893f02SDimitry Andric   const Type *Ty = IntegerType::get(M->getContext(), 32);
714*b5893f02SDimitry Andric   LoadInst *NewLd0 = CreateLoadIns(Builder, VecLd0[0], Ty);
715*b5893f02SDimitry Andric   LoadInst *NewLd1 = CreateLoadIns(Builder, VecLd1[0], Ty);
716*b5893f02SDimitry Andric   Value* Args[] = { NewLd0, NewLd1, Acc };
717*b5893f02SDimitry Andric   Function *SMLAD = nullptr;
718*b5893f02SDimitry Andric   if (Exchange)
719*b5893f02SDimitry Andric     SMLAD = Acc->getType()->isIntegerTy(32) ?
720*b5893f02SDimitry Andric       Intrinsic::getDeclaration(M, Intrinsic::arm_smladx) :
721*b5893f02SDimitry Andric       Intrinsic::getDeclaration(M, Intrinsic::arm_smlaldx);
722*b5893f02SDimitry Andric   else
723*b5893f02SDimitry Andric     SMLAD = Acc->getType()->isIntegerTy(32) ?
724*b5893f02SDimitry Andric       Intrinsic::getDeclaration(M, Intrinsic::arm_smlad) :
725*b5893f02SDimitry Andric       Intrinsic::getDeclaration(M, Intrinsic::arm_smlald);
7264ba319b5SDimitry Andric   CallInst *Call = Builder.CreateCall(SMLAD, Args);
7274ba319b5SDimitry Andric   NumSMLAD++;
7284ba319b5SDimitry Andric   return Call;
7294ba319b5SDimitry Andric }
7304ba319b5SDimitry Andric 
731*b5893f02SDimitry Andric // Compare the value lists in Other to this chain.
AreSymmetrical(BinOpChain * Other)732*b5893f02SDimitry Andric bool BinOpChain::AreSymmetrical(BinOpChain *Other) {
733*b5893f02SDimitry Andric   // Element-by-element comparison of Value lists returning true if they are
734*b5893f02SDimitry Andric   // instructions with the same opcode or constants with the same value.
735*b5893f02SDimitry Andric   auto CompareValueList = [](const ValueList &VL0,
736*b5893f02SDimitry Andric                              const ValueList &VL1) {
737*b5893f02SDimitry Andric     if (VL0.size() != VL1.size()) {
738*b5893f02SDimitry Andric       LLVM_DEBUG(dbgs() << "Muls are mismatching operand list lengths: "
739*b5893f02SDimitry Andric                         << VL0.size() << " != " << VL1.size() << "\n");
740*b5893f02SDimitry Andric       return false;
741*b5893f02SDimitry Andric     }
742*b5893f02SDimitry Andric 
743*b5893f02SDimitry Andric     const unsigned Pairs = VL0.size();
744*b5893f02SDimitry Andric     LLVM_DEBUG(dbgs() << "Number of operand pairs: " << Pairs << "\n");
745*b5893f02SDimitry Andric 
746*b5893f02SDimitry Andric     for (unsigned i = 0; i < Pairs; ++i) {
747*b5893f02SDimitry Andric       const Value *V0 = VL0[i];
748*b5893f02SDimitry Andric       const Value *V1 = VL1[i];
749*b5893f02SDimitry Andric       const auto *Inst0 = dyn_cast<Instruction>(V0);
750*b5893f02SDimitry Andric       const auto *Inst1 = dyn_cast<Instruction>(V1);
751*b5893f02SDimitry Andric 
752*b5893f02SDimitry Andric       LLVM_DEBUG(dbgs() << "Pair " << i << ":\n";
753*b5893f02SDimitry Andric                 dbgs() << "mul1: "; V0->dump();
754*b5893f02SDimitry Andric                 dbgs() << "mul2: "; V1->dump());
755*b5893f02SDimitry Andric 
756*b5893f02SDimitry Andric       if (!Inst0 || !Inst1)
757*b5893f02SDimitry Andric         return false;
758*b5893f02SDimitry Andric 
759*b5893f02SDimitry Andric       if (Inst0->isSameOperationAs(Inst1)) {
760*b5893f02SDimitry Andric         LLVM_DEBUG(dbgs() << "OK: same operation found!\n");
761*b5893f02SDimitry Andric         continue;
762*b5893f02SDimitry Andric       }
763*b5893f02SDimitry Andric 
764*b5893f02SDimitry Andric       const APInt *C0, *C1;
765*b5893f02SDimitry Andric       if (!(match(V0, m_APInt(C0)) && match(V1, m_APInt(C1)) && C0 == C1))
766*b5893f02SDimitry Andric         return false;
767*b5893f02SDimitry Andric     }
768*b5893f02SDimitry Andric 
769*b5893f02SDimitry Andric     LLVM_DEBUG(dbgs() << "OK: found symmetrical operand lists.\n");
770*b5893f02SDimitry Andric     return true;
771*b5893f02SDimitry Andric   };
772*b5893f02SDimitry Andric 
773*b5893f02SDimitry Andric   return CompareValueList(LHS, Other->LHS) &&
774*b5893f02SDimitry Andric          CompareValueList(RHS, Other->RHS);
775*b5893f02SDimitry Andric }
776*b5893f02SDimitry Andric 
createARMParallelDSPPass()7774ba319b5SDimitry Andric Pass *llvm::createARMParallelDSPPass() {
7784ba319b5SDimitry Andric   return new ARMParallelDSP();
7794ba319b5SDimitry Andric }
7804ba319b5SDimitry Andric 
7814ba319b5SDimitry Andric char ARMParallelDSP::ID = 0;
7824ba319b5SDimitry Andric 
7834ba319b5SDimitry Andric INITIALIZE_PASS_BEGIN(ARMParallelDSP, "arm-parallel-dsp",
7844ba319b5SDimitry Andric                 "Transform loops to use DSP intrinsics", false, false)
7854ba319b5SDimitry Andric INITIALIZE_PASS_END(ARMParallelDSP, "arm-parallel-dsp",
7864ba319b5SDimitry Andric                 "Transform loops to use DSP intrinsics", false, false)
787