1 //===-- ARMLowOverheadLoops.cpp - CodeGen Low-overhead Loops ---*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 /// \file
9 /// Finalize v8.1-m low-overhead loops by converting the associated pseudo
10 /// instructions into machine operations.
11 /// The expectation is that the loop contains three pseudo instructions:
12 /// - t2*LoopStart - placed in the preheader or pre-preheader. The do-loop
13 ///   form should be in the preheader, whereas the while form should be in the
14 ///   preheaders only predecessor.
15 /// - t2LoopDec - placed within in the loop body.
16 /// - t2LoopEnd - the loop latch terminator.
17 ///
18 //===----------------------------------------------------------------------===//
19 
20 #include "ARM.h"
21 #include "ARMBaseInstrInfo.h"
22 #include "ARMBaseRegisterInfo.h"
23 #include "ARMBasicBlockInfo.h"
24 #include "ARMSubtarget.h"
25 #include "llvm/CodeGen/MachineFunctionPass.h"
26 #include "llvm/CodeGen/MachineLoopInfo.h"
27 #include "llvm/CodeGen/MachineLoopUtils.h"
28 #include "llvm/CodeGen/MachineRegisterInfo.h"
29 #include "llvm/CodeGen/Passes.h"
30 #include "llvm/CodeGen/ReachingDefAnalysis.h"
31 #include "llvm/MC/MCInstrDesc.h"
32 
33 using namespace llvm;
34 
35 #define DEBUG_TYPE "arm-low-overhead-loops"
36 #define ARM_LOW_OVERHEAD_LOOPS_NAME "ARM Low Overhead Loops pass"
37 
38 namespace {
39 
40   struct LowOverheadLoop {
41 
42     MachineLoop *ML = nullptr;
43     MachineFunction *MF = nullptr;
44     MachineInstr *InsertPt = nullptr;
45     MachineInstr *Start = nullptr;
46     MachineInstr *Dec = nullptr;
47     MachineInstr *End = nullptr;
48     MachineInstr *VCTP = nullptr;
49     SmallVector<MachineInstr*, 4> VPTUsers;
50     bool Revert = false;
51     bool FoundOneVCTP = false;
52     bool CannotTailPredicate = false;
53 
54     LowOverheadLoop(MachineLoop *ML) : ML(ML) {
55       MF = ML->getHeader()->getParent();
56     }
57 
58     // For now, only support one vctp instruction. If we find multiple then
59     // we shouldn't perform tail predication.
60     void addVCTP(MachineInstr *MI) {
61       if (!VCTP) {
62         VCTP = MI;
63         FoundOneVCTP = true;
64       } else
65         FoundOneVCTP = false;
66     }
67 
68     // Check that nothing else is writing to VPR and record any insts
69     // reading the VPR.
70     void ScanForVPR(MachineInstr *MI) {
71       for (auto &MO : MI->operands()) {
72         if (!MO.isReg() || MO.getReg() != ARM::VPR)
73           continue;
74         if (MO.isUse())
75           VPTUsers.push_back(MI);
76         if (MO.isDef()) {
77           CannotTailPredicate = true;
78           break;
79         }
80       }
81     }
82 
83     // If this is an MVE instruction, check that we know how to use tail
84     // predication with it.
85     void CheckTPValidity(MachineInstr *MI) {
86       if (CannotTailPredicate)
87         return;
88 
89       const MCInstrDesc &MCID = MI->getDesc();
90       uint64_t Flags = MCID.TSFlags;
91       if ((Flags & ARMII::DomainMask) != ARMII::DomainMVE)
92         return;
93 
94       if ((Flags & ARMII::ValidForTailPredication) == 0) {
95         LLVM_DEBUG(dbgs() << "ARM Loops: Can't tail predicate: " << *MI);
96         CannotTailPredicate = true;
97       }
98     }
99 
100     bool IsTailPredicationLegal() const {
101       // For now, let's keep things really simple and only support a single
102       // block for tail predication.
103       return !Revert && FoundAllComponents() && FoundOneVCTP &&
104              !CannotTailPredicate && ML->getNumBlocks() == 1;
105     }
106 
107     // Is it safe to define LR with DLS/WLS?
108     // LR can be defined if it is the operand to start, because it's the same
109     // value, or if it's going to be equivalent to the operand to Start.
110     MachineInstr *IsSafeToDefineLR(ReachingDefAnalysis *RDA);
111 
112     // Check the branch targets are within range and we satisfy our
113     // restrictions.
114     void CheckLegality(ARMBasicBlockUtils *BBUtils, ReachingDefAnalysis *RDA,
115                        MachineLoopInfo *MLI);
116 
117     bool FoundAllComponents() const {
118       return Start && Dec && End;
119     }
120 
121     // Return the loop iteration count, or the number of elements if we're tail
122     // predicating.
123     MachineOperand &getCount() {
124       return IsTailPredicationLegal() ?
125         VCTP->getOperand(1) : Start->getOperand(0);
126     }
127 
128     unsigned getStartOpcode() const {
129       bool IsDo = Start->getOpcode() == ARM::t2DoLoopStart;
130       if (!IsTailPredicationLegal())
131         return IsDo ? ARM::t2DLS : ARM::t2WLS;
132 
133       switch (VCTP->getOpcode()) {
134       default:
135         llvm_unreachable("unhandled vctp opcode");
136         break;
137       case ARM::MVE_VCTP8:
138         return IsDo ? ARM::MVE_DLSTP_8 : ARM::MVE_WLSTP_8;
139       case ARM::MVE_VCTP16:
140         return IsDo ? ARM::MVE_DLSTP_16 : ARM::MVE_WLSTP_16;
141       case ARM::MVE_VCTP32:
142         return IsDo ? ARM::MVE_DLSTP_32 : ARM::MVE_WLSTP_32;
143       case ARM::MVE_VCTP64:
144         return IsDo ? ARM::MVE_DLSTP_64 : ARM::MVE_WLSTP_64;
145       }
146       return 0;
147     }
148 
149     void dump() const {
150       if (Start) dbgs() << "ARM Loops: Found Loop Start: " << *Start;
151       if (Dec) dbgs() << "ARM Loops: Found Loop Dec: " << *Dec;
152       if (End) dbgs() << "ARM Loops: Found Loop End: " << *End;
153       if (VCTP) dbgs() << "ARM Loops: Found VCTP: " << *VCTP;
154       if (!FoundAllComponents())
155         dbgs() << "ARM Loops: Not a low-overhead loop.\n";
156       else if (!(Start && Dec && End))
157         dbgs() << "ARM Loops: Failed to find all loop components.\n";
158     }
159   };
160 
161   class ARMLowOverheadLoops : public MachineFunctionPass {
162     MachineFunction           *MF = nullptr;
163     MachineLoopInfo           *MLI = nullptr;
164     ReachingDefAnalysis       *RDA = nullptr;
165     const ARMBaseInstrInfo    *TII = nullptr;
166     MachineRegisterInfo       *MRI = nullptr;
167     const TargetRegisterInfo  *TRI = nullptr;
168     std::unique_ptr<ARMBasicBlockUtils> BBUtils = nullptr;
169 
170   public:
171     static char ID;
172 
173     ARMLowOverheadLoops() : MachineFunctionPass(ID) { }
174 
175     void getAnalysisUsage(AnalysisUsage &AU) const override {
176       AU.setPreservesCFG();
177       AU.addRequired<MachineLoopInfo>();
178       AU.addRequired<ReachingDefAnalysis>();
179       MachineFunctionPass::getAnalysisUsage(AU);
180     }
181 
182     bool runOnMachineFunction(MachineFunction &MF) override;
183 
184     MachineFunctionProperties getRequiredProperties() const override {
185       return MachineFunctionProperties().set(
186           MachineFunctionProperties::Property::NoVRegs).set(
187           MachineFunctionProperties::Property::TracksLiveness);
188     }
189 
190     StringRef getPassName() const override {
191       return ARM_LOW_OVERHEAD_LOOPS_NAME;
192     }
193 
194   private:
195     bool ProcessLoop(MachineLoop *ML);
196 
197     bool RevertNonLoops();
198 
199     void RevertWhile(MachineInstr *MI) const;
200 
201     bool RevertLoopDec(MachineInstr *MI, bool AllowFlags = false) const;
202 
203     void RevertLoopEnd(MachineInstr *MI, bool SkipCmp = false) const;
204 
205     void RemoveLoopUpdate(LowOverheadLoop &LoLoop);
206 
207     void RemoveVPTBlocks(LowOverheadLoop &LoLoop);
208 
209     MachineInstr *ExpandLoopStart(LowOverheadLoop &LoLoop);
210 
211     void Expand(LowOverheadLoop &LoLoop);
212 
213   };
214 }
215 
216 char ARMLowOverheadLoops::ID = 0;
217 
218 INITIALIZE_PASS(ARMLowOverheadLoops, DEBUG_TYPE, ARM_LOW_OVERHEAD_LOOPS_NAME,
219                 false, false)
220 
221 static bool IsLoopStart(MachineInstr &MI) {
222   return MI.getOpcode() == ARM::t2DoLoopStart ||
223          MI.getOpcode() == ARM::t2WhileLoopStart;
224 }
225 
226 static bool IsVCTP(MachineInstr *MI) {
227   switch (MI->getOpcode()) {
228   default:
229     break;
230   case ARM::MVE_VCTP8:
231   case ARM::MVE_VCTP16:
232   case ARM::MVE_VCTP32:
233   case ARM::MVE_VCTP64:
234     return true;
235   }
236   return false;
237 }
238 
239 MachineInstr *LowOverheadLoop::IsSafeToDefineLR(ReachingDefAnalysis *RDA) {
240   // We can define LR because LR already contains the same value.
241   if (Start->getOperand(0).getReg() == ARM::LR)
242     return Start;
243 
244   unsigned CountReg = Start->getOperand(0).getReg();
245   auto IsMoveLR = [&CountReg](MachineInstr *MI) {
246     return MI->getOpcode() == ARM::tMOVr &&
247            MI->getOperand(0).getReg() == ARM::LR &&
248            MI->getOperand(1).getReg() == CountReg &&
249            MI->getOperand(2).getImm() == ARMCC::AL;
250    };
251 
252   MachineBasicBlock *MBB = Start->getParent();
253 
254   // Find an insertion point:
255   // - Is there a (mov lr, Count) before Start? If so, and nothing else writes
256   //   to Count before Start, we can insert at that mov.
257   // - Is there a (mov lr, Count) after Start? If so, and nothing else writes
258   //   to Count after Start, we can insert at that mov.
259   if (auto *LRDef = RDA->getReachingMIDef(&MBB->back(), ARM::LR)) {
260     if (IsMoveLR(LRDef) && RDA->hasSameReachingDef(Start, LRDef, CountReg))
261       return LRDef;
262   }
263 
264   // We've found no suitable LR def and Start doesn't use LR directly. Can we
265   // just define LR anyway?
266   if (!RDA->isRegUsedAfter(Start, ARM::LR))
267     return Start;
268 
269   return nullptr;
270 }
271 
272 void LowOverheadLoop::CheckLegality(ARMBasicBlockUtils *BBUtils,
273                                     ReachingDefAnalysis *RDA,
274                                     MachineLoopInfo *MLI) {
275   if (Revert)
276     return;
277 
278   if (!End->getOperand(1).isMBB())
279     report_fatal_error("Expected LoopEnd to target basic block");
280 
281   // TODO Maybe there's cases where the target doesn't have to be the header,
282   // but for now be safe and revert.
283   if (End->getOperand(1).getMBB() != ML->getHeader()) {
284     LLVM_DEBUG(dbgs() << "ARM Loops: LoopEnd is not targetting header.\n");
285     Revert = true;
286     return;
287   }
288 
289   // The WLS and LE instructions have 12-bits for the label offset. WLS
290   // requires a positive offset, while LE uses negative.
291   if (BBUtils->getOffsetOf(End) < BBUtils->getOffsetOf(ML->getHeader()) ||
292       !BBUtils->isBBInRange(End, ML->getHeader(), 4094)) {
293     LLVM_DEBUG(dbgs() << "ARM Loops: LE offset is out-of-range\n");
294     Revert = true;
295     return;
296   }
297 
298   if (Start->getOpcode() == ARM::t2WhileLoopStart &&
299       (BBUtils->getOffsetOf(Start) >
300        BBUtils->getOffsetOf(Start->getOperand(1).getMBB()) ||
301        !BBUtils->isBBInRange(Start, Start->getOperand(1).getMBB(), 4094))) {
302     LLVM_DEBUG(dbgs() << "ARM Loops: WLS offset is out-of-range!\n");
303     Revert = true;
304     return;
305   }
306 
307   InsertPt = Revert ? nullptr : IsSafeToDefineLR(RDA);
308   if (!InsertPt) {
309     LLVM_DEBUG(dbgs() << "ARM Loops: Unable to find safe insertion point.\n");
310     Revert = true;
311     return;
312   } else
313     LLVM_DEBUG(dbgs() << "ARM Loops: Start insertion point: " << *InsertPt);
314 
315   // For tail predication, we need to provide the number of elements, instead
316   // of the iteration count, to the loop start instruction. The number of
317   // elements is provided to the vctp instruction, so we need to check that
318   // we can use this register at InsertPt.
319   if (!IsTailPredicationLegal())
320     return;
321 
322   Register NumElements = VCTP->getOperand(1).getReg();
323 
324   // If the register is defined within loop, then we can't perform TP.
325   // TODO: Check whether this is just a mov of a register that would be
326   // available.
327   if (RDA->getReachingDef(VCTP, NumElements) >= 0) {
328     CannotTailPredicate = true;
329     return;
330   }
331 
332   // We can't perform TP if the register does not hold the same value at
333   // InsertPt as the liveout value.
334   MachineBasicBlock *InsertBB = InsertPt->getParent();
335   if  (!RDA->hasSameReachingDef(InsertPt, &InsertBB->back(),
336                                 NumElements)) {
337     CannotTailPredicate = true;
338     return;
339   }
340 
341   // Especially in the case of while loops, InsertBB may not be the
342   // preheader, so we need to check that the register isn't redefined
343   // before entering the loop.
344   auto CannotProvideElements = [&RDA](MachineBasicBlock *MBB,
345                                       Register NumElements) {
346     // NumElements is redefined in this block.
347     if (RDA->getReachingDef(&MBB->back(), NumElements) >= 0)
348       return true;
349 
350     // Don't continue searching up through multiple predecessors.
351     if (MBB->pred_size() > 1)
352       return true;
353 
354     return false;
355   };
356 
357   // First, find the block that looks like the preheader.
358   MachineBasicBlock *MBB = MLI->findLoopPreheader(ML, true);
359   if (!MBB) {
360     CannotTailPredicate = true;
361     return;
362   }
363 
364   // Then search backwards for a def, until we get to InsertBB.
365   while (MBB != InsertBB) {
366     CannotTailPredicate = CannotProvideElements(MBB, NumElements);
367     if (CannotTailPredicate)
368       return;
369     MBB = *MBB->pred_begin();
370   }
371 
372   LLVM_DEBUG(dbgs() << "ARM Loops: Will use tail predication to convert:\n";
373                for (auto *MI : VPTUsers)
374                  dbgs() << " - " << *MI;);
375 }
376 
377 bool ARMLowOverheadLoops::runOnMachineFunction(MachineFunction &mf) {
378   const ARMSubtarget &ST = static_cast<const ARMSubtarget&>(mf.getSubtarget());
379   if (!ST.hasLOB())
380     return false;
381 
382   MF = &mf;
383   LLVM_DEBUG(dbgs() << "ARM Loops on " << MF->getName() << " ------------- \n");
384 
385   MLI = &getAnalysis<MachineLoopInfo>();
386   RDA = &getAnalysis<ReachingDefAnalysis>();
387   MF->getProperties().set(MachineFunctionProperties::Property::TracksLiveness);
388   MRI = &MF->getRegInfo();
389   TII = static_cast<const ARMBaseInstrInfo*>(ST.getInstrInfo());
390   TRI = ST.getRegisterInfo();
391   BBUtils = std::unique_ptr<ARMBasicBlockUtils>(new ARMBasicBlockUtils(*MF));
392   BBUtils->computeAllBlockSizes();
393   BBUtils->adjustBBOffsetsAfter(&MF->front());
394 
395   bool Changed = false;
396   for (auto ML : *MLI) {
397     if (!ML->getParentLoop())
398       Changed |= ProcessLoop(ML);
399   }
400   Changed |= RevertNonLoops();
401   return Changed;
402 }
403 
404 bool ARMLowOverheadLoops::ProcessLoop(MachineLoop *ML) {
405 
406   bool Changed = false;
407 
408   // Process inner loops first.
409   for (auto I = ML->begin(), E = ML->end(); I != E; ++I)
410     Changed |= ProcessLoop(*I);
411 
412   LLVM_DEBUG(dbgs() << "ARM Loops: Processing loop containing:\n";
413              if (auto *Preheader = ML->getLoopPreheader())
414                dbgs() << " - " << Preheader->getName() << "\n";
415              else if (auto *Preheader = MLI->findLoopPreheader(ML))
416                dbgs() << " - " << Preheader->getName() << "\n";
417              for (auto *MBB : ML->getBlocks())
418                dbgs() << " - " << MBB->getName() << "\n";
419             );
420 
421   // Search the given block for a loop start instruction. If one isn't found,
422   // and there's only one predecessor block, search that one too.
423   std::function<MachineInstr*(MachineBasicBlock*)> SearchForStart =
424     [&SearchForStart](MachineBasicBlock *MBB) -> MachineInstr* {
425     for (auto &MI : *MBB) {
426       if (IsLoopStart(MI))
427         return &MI;
428     }
429     if (MBB->pred_size() == 1)
430       return SearchForStart(*MBB->pred_begin());
431     return nullptr;
432   };
433 
434   LowOverheadLoop LoLoop(ML);
435   // Search the preheader for the start intrinsic.
436   // FIXME: I don't see why we shouldn't be supporting multiple predecessors
437   // with potentially multiple set.loop.iterations, so we need to enable this.
438   if (auto *Preheader = ML->getLoopPreheader())
439     LoLoop.Start = SearchForStart(Preheader);
440   else if (auto *Preheader = MLI->findLoopPreheader(ML, true))
441     LoLoop.Start = SearchForStart(Preheader);
442   else
443     return false;
444 
445   // Find the low-overhead loop components and decide whether or not to fall
446   // back to a normal loop. Also look for a vctp instructions and decide
447   // whether we can convert that predicate using tail predication.
448   for (auto *MBB : reverse(ML->getBlocks())) {
449     for (auto &MI : *MBB) {
450       if (MI.getOpcode() == ARM::t2LoopDec)
451         LoLoop.Dec = &MI;
452       else if (MI.getOpcode() == ARM::t2LoopEnd)
453         LoLoop.End = &MI;
454       else if (IsLoopStart(MI))
455         LoLoop.Start = &MI;
456       else if (IsVCTP(&MI))
457         LoLoop.addVCTP(&MI);
458       else if (MI.getDesc().isCall()) {
459         // TODO: Though the call will require LE to execute again, does this
460         // mean we should revert? Always executing LE hopefully should be
461         // faster than performing a sub,cmp,br or even subs,br.
462         LoLoop.Revert = true;
463         LLVM_DEBUG(dbgs() << "ARM Loops: Found call.\n");
464       } else {
465         // Once we've found a vctp, record the users of vpr and check there's
466         // no more vpr defs.
467         if (LoLoop.FoundOneVCTP)
468           LoLoop.ScanForVPR(&MI);
469         // Check we know how to tail predicate any mve instructions.
470         LoLoop.CheckTPValidity(&MI);
471       }
472 
473       // We need to ensure that LR is not used or defined inbetween LoopDec and
474       // LoopEnd.
475       if (!LoLoop.Dec || LoLoop.End || LoLoop.Revert)
476         continue;
477 
478       // If we find that LR has been written or read between LoopDec and
479       // LoopEnd, expect that the decremented value is being used else where.
480       // Because this value isn't actually going to be produced until the
481       // latch, by LE, we would need to generate a real sub. The value is also
482       // likely to be copied/reloaded for use of LoopEnd - in which in case
483       // we'd need to perform an add because it gets subtracted again by LE!
484       // The other option is to then generate the other form of LE which doesn't
485       // perform the sub.
486       for (auto &MO : MI.operands()) {
487         if (MI.getOpcode() != ARM::t2LoopDec && MO.isReg() &&
488             MO.getReg() == ARM::LR) {
489           LLVM_DEBUG(dbgs() << "ARM Loops: Found LR Use/Def: " << MI);
490           LoLoop.Revert = true;
491           break;
492         }
493       }
494     }
495   }
496 
497   LLVM_DEBUG(LoLoop.dump());
498   if (!LoLoop.FoundAllComponents())
499     return false;
500 
501   LoLoop.CheckLegality(BBUtils.get(), RDA, MLI);
502   Expand(LoLoop);
503   return true;
504 }
505 
506 // WhileLoopStart holds the exit block, so produce a cmp lr, 0 and then a
507 // beq that branches to the exit branch.
508 // TODO: We could also try to generate a cbz if the value in LR is also in
509 // another low register.
510 void ARMLowOverheadLoops::RevertWhile(MachineInstr *MI) const {
511   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to cmp: " << *MI);
512   MachineBasicBlock *MBB = MI->getParent();
513   MachineInstrBuilder MIB = BuildMI(*MBB, MI, MI->getDebugLoc(),
514                                     TII->get(ARM::t2CMPri));
515   MIB.add(MI->getOperand(0));
516   MIB.addImm(0);
517   MIB.addImm(ARMCC::AL);
518   MIB.addReg(ARM::NoRegister);
519 
520   MachineBasicBlock *DestBB = MI->getOperand(1).getMBB();
521   unsigned BrOpc = BBUtils->isBBInRange(MI, DestBB, 254) ?
522     ARM::tBcc : ARM::t2Bcc;
523 
524   MIB = BuildMI(*MBB, MI, MI->getDebugLoc(), TII->get(BrOpc));
525   MIB.add(MI->getOperand(1));   // branch target
526   MIB.addImm(ARMCC::EQ);        // condition code
527   MIB.addReg(ARM::CPSR);
528   MI->eraseFromParent();
529 }
530 
531 bool ARMLowOverheadLoops::RevertLoopDec(MachineInstr *MI,
532                                         bool SetFlags) const {
533   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to sub: " << *MI);
534   MachineBasicBlock *MBB = MI->getParent();
535 
536   // If nothing defines CPSR between LoopDec and LoopEnd, use a t2SUBS.
537   if (SetFlags &&
538       (RDA->isRegUsedAfter(MI, ARM::CPSR) ||
539        !RDA->hasSameReachingDef(MI, &MBB->back(), ARM::CPSR)))
540       SetFlags = false;
541 
542   MachineInstrBuilder MIB = BuildMI(*MBB, MI, MI->getDebugLoc(),
543                                     TII->get(ARM::t2SUBri));
544   MIB.addDef(ARM::LR);
545   MIB.add(MI->getOperand(1));
546   MIB.add(MI->getOperand(2));
547   MIB.addImm(ARMCC::AL);
548   MIB.addReg(0);
549 
550   if (SetFlags) {
551     MIB.addReg(ARM::CPSR);
552     MIB->getOperand(5).setIsDef(true);
553   } else
554     MIB.addReg(0);
555 
556   MI->eraseFromParent();
557   return SetFlags;
558 }
559 
560 // Generate a subs, or sub and cmp, and a branch instead of an LE.
561 void ARMLowOverheadLoops::RevertLoopEnd(MachineInstr *MI, bool SkipCmp) const {
562   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to cmp, br: " << *MI);
563 
564   MachineBasicBlock *MBB = MI->getParent();
565   // Create cmp
566   if (!SkipCmp) {
567     MachineInstrBuilder MIB = BuildMI(*MBB, MI, MI->getDebugLoc(),
568                                       TII->get(ARM::t2CMPri));
569     MIB.addReg(ARM::LR);
570     MIB.addImm(0);
571     MIB.addImm(ARMCC::AL);
572     MIB.addReg(ARM::NoRegister);
573   }
574 
575   MachineBasicBlock *DestBB = MI->getOperand(1).getMBB();
576   unsigned BrOpc = BBUtils->isBBInRange(MI, DestBB, 254) ?
577     ARM::tBcc : ARM::t2Bcc;
578 
579   // Create bne
580   MachineInstrBuilder MIB =
581     BuildMI(*MBB, MI, MI->getDebugLoc(), TII->get(BrOpc));
582   MIB.add(MI->getOperand(1));   // branch target
583   MIB.addImm(ARMCC::NE);        // condition code
584   MIB.addReg(ARM::CPSR);
585   MI->eraseFromParent();
586 }
587 
588 MachineInstr* ARMLowOverheadLoops::ExpandLoopStart(LowOverheadLoop &LoLoop) {
589   MachineInstr *InsertPt = LoLoop.InsertPt;
590   MachineInstr *Start = LoLoop.Start;
591   MachineBasicBlock *MBB = InsertPt->getParent();
592   bool IsDo = Start->getOpcode() == ARM::t2DoLoopStart;
593   unsigned Opc = LoLoop.getStartOpcode();
594   MachineOperand &Count = LoLoop.getCount();
595 
596   MachineInstrBuilder MIB =
597     BuildMI(*MBB, InsertPt, InsertPt->getDebugLoc(), TII->get(Opc));
598 
599   MIB.addDef(ARM::LR);
600   MIB.add(Count);
601   if (!IsDo)
602     MIB.add(Start->getOperand(1));
603 
604   // When using tail-predication, try to delete the dead code that was used to
605   // calculate the number of loop iterations.
606   if (LoLoop.IsTailPredicationLegal()) {
607     SmallVector<MachineInstr*, 4> Killed;
608     SmallVector<MachineInstr*, 4> Dead;
609     if (auto *Def = RDA->getReachingMIDef(Start,
610                                           Start->getOperand(0).getReg())) {
611       Killed.push_back(Def);
612 
613       while (!Killed.empty()) {
614         MachineInstr *Def = Killed.back();
615         Killed.pop_back();
616         Dead.push_back(Def);
617         for (auto &MO : Def->operands()) {
618           if (!MO.isReg() || !MO.isKill())
619             continue;
620 
621           MachineInstr *Kill = RDA->getReachingMIDef(Def, MO.getReg());
622           if (Kill && RDA->getNumUses(Kill, MO.getReg()) == 1)
623             Killed.push_back(Kill);
624         }
625       }
626       for (auto *MI : Dead)
627         MI->eraseFromParent();
628     }
629   }
630 
631   // If we're inserting at a mov lr, then remove it as it's redundant.
632   if (InsertPt != Start)
633     InsertPt->eraseFromParent();
634   Start->eraseFromParent();
635   LLVM_DEBUG(dbgs() << "ARM Loops: Inserted start: " << *MIB);
636   return &*MIB;
637 }
638 
639 // Goal is to optimise and clean-up these loops:
640 //
641 //   vector.body:
642 //     renamable $vpr = MVE_VCTP32 renamable $r3, 0, $noreg
643 //     renamable $r3, dead $cpsr = tSUBi8 killed renamable $r3(tied-def 0), 4
644 //     ..
645 //     $lr = MVE_DLSTP_32 renamable $r3
646 //
647 // The SUB is the old update of the loop iteration count expression, which
648 // is no longer needed. This sub is removed when the element count, which is in
649 // r3 in this example, is defined by an instruction in the loop, and it has
650 // no uses.
651 //
652 void ARMLowOverheadLoops::RemoveLoopUpdate(LowOverheadLoop &LoLoop) {
653   Register ElemCount = LoLoop.VCTP->getOperand(1).getReg();
654   MachineInstr *LastInstrInBlock = &LoLoop.VCTP->getParent()->back();
655 
656   LLVM_DEBUG(dbgs() << "ARM Loops: Trying to remove loop update stmt\n");
657 
658   if (LoLoop.ML->getNumBlocks() != 1) {
659     LLVM_DEBUG(dbgs() << "ARM Loops: single block loop expected\n");
660     return;
661   }
662 
663   LLVM_DEBUG(dbgs() << "ARM Loops: Analyzing MO: ";
664              LoLoop.VCTP->getOperand(1).dump());
665 
666   // Find the definition we are interested in removing, if there is one.
667   MachineInstr *Def = RDA->getReachingMIDef(LastInstrInBlock, ElemCount);
668   if (!Def)
669     return;
670 
671   // Bail if we define CPSR and it is not dead
672   if (!Def->registerDefIsDead(ARM::CPSR, TRI)) {
673     LLVM_DEBUG(dbgs() << "ARM Loops: CPSR is not dead\n");
674     return;
675   }
676 
677   // Bail if elemcount is used in exit blocks, i.e. if it is live-in.
678   if (isRegLiveInExitBlocks(LoLoop.ML, ElemCount)) {
679     LLVM_DEBUG(dbgs() << "ARM Loops: Elemcount is live-out, can't remove stmt\n");
680     return;
681   }
682 
683   // Bail if there are uses after this Def in the block.
684   SmallVector<MachineInstr*, 4> Uses;
685   RDA->getReachingLocalUses(Def, ElemCount, Uses);
686   if (Uses.size()) {
687     LLVM_DEBUG(dbgs() << "ARM Loops: Local uses in block, can't remove stmt\n");
688     return;
689   }
690 
691   Uses.clear();
692   RDA->getAllInstWithUseBefore(Def, ElemCount, Uses);
693 
694   // Remove Def if there are no uses, or if the only use is the VCTP
695   // instruction.
696   if (!Uses.size() || (Uses.size() == 1 && Uses[0] == LoLoop.VCTP)) {
697     LLVM_DEBUG(dbgs() << "ARM Loops: Removing loop update instruction: ";
698                Def->dump());
699     Def->eraseFromParent();
700   }
701 }
702 
703 void ARMLowOverheadLoops::RemoveVPTBlocks(LowOverheadLoop &LoLoop) {
704   LLVM_DEBUG(dbgs() << "ARM Loops: Removing VCTP: " << *LoLoop.VCTP);
705   LoLoop.VCTP->eraseFromParent();
706 
707   for (auto *MI : LoLoop.VPTUsers) {
708     if (MI->getOpcode() == ARM::MVE_VPST) {
709       LLVM_DEBUG(dbgs() << "ARM Loops: Removing VPST: " << *MI);
710       MI->eraseFromParent();
711     } else {
712       unsigned OpNum = MI->getNumOperands() - 1;
713       assert((MI->getOperand(OpNum).isReg() &&
714               MI->getOperand(OpNum).getReg() == ARM::VPR) &&
715              "Expected VPR");
716       assert((MI->getOperand(OpNum-1).isImm() &&
717               MI->getOperand(OpNum-1).getImm() == ARMVCC::Then) &&
718              "Expected Then predicate");
719       MI->getOperand(OpNum-1).setImm(ARMVCC::None);
720       MI->getOperand(OpNum).setReg(0);
721       LLVM_DEBUG(dbgs() << "ARM Loops: Removed predicate from: " << *MI);
722     }
723   }
724 }
725 
726 void ARMLowOverheadLoops::Expand(LowOverheadLoop &LoLoop) {
727 
728   // Combine the LoopDec and LoopEnd instructions into LE(TP).
729   auto ExpandLoopEnd = [this](LowOverheadLoop &LoLoop) {
730     MachineInstr *End = LoLoop.End;
731     MachineBasicBlock *MBB = End->getParent();
732     unsigned Opc = LoLoop.IsTailPredicationLegal() ?
733       ARM::MVE_LETP : ARM::t2LEUpdate;
734     MachineInstrBuilder MIB = BuildMI(*MBB, End, End->getDebugLoc(),
735                                       TII->get(Opc));
736     MIB.addDef(ARM::LR);
737     MIB.add(End->getOperand(0));
738     MIB.add(End->getOperand(1));
739     LLVM_DEBUG(dbgs() << "ARM Loops: Inserted LE: " << *MIB);
740 
741     LoLoop.End->eraseFromParent();
742     LoLoop.Dec->eraseFromParent();
743     return &*MIB;
744   };
745 
746   // TODO: We should be able to automatically remove these branches before we
747   // get here - probably by teaching analyzeBranch about the pseudo
748   // instructions.
749   // If there is an unconditional branch, after I, that just branches to the
750   // next block, remove it.
751   auto RemoveDeadBranch = [](MachineInstr *I) {
752     MachineBasicBlock *BB = I->getParent();
753     MachineInstr *Terminator = &BB->instr_back();
754     if (Terminator->isUnconditionalBranch() && I != Terminator) {
755       MachineBasicBlock *Succ = Terminator->getOperand(0).getMBB();
756       if (BB->isLayoutSuccessor(Succ)) {
757         LLVM_DEBUG(dbgs() << "ARM Loops: Removing branch: " << *Terminator);
758         Terminator->eraseFromParent();
759       }
760     }
761   };
762 
763   if (LoLoop.Revert) {
764     if (LoLoop.Start->getOpcode() == ARM::t2WhileLoopStart)
765       RevertWhile(LoLoop.Start);
766     else
767       LoLoop.Start->eraseFromParent();
768     bool FlagsAlreadySet = RevertLoopDec(LoLoop.Dec, true);
769     RevertLoopEnd(LoLoop.End, FlagsAlreadySet);
770   } else {
771     LoLoop.Start = ExpandLoopStart(LoLoop);
772     RemoveDeadBranch(LoLoop.Start);
773     LoLoop.End = ExpandLoopEnd(LoLoop);
774     RemoveDeadBranch(LoLoop.End);
775     if (LoLoop.IsTailPredicationLegal()) {
776       RemoveLoopUpdate(LoLoop);
777       RemoveVPTBlocks(LoLoop);
778     }
779   }
780 }
781 
782 bool ARMLowOverheadLoops::RevertNonLoops() {
783   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting any remaining pseudos...\n");
784   bool Changed = false;
785 
786   for (auto &MBB : *MF) {
787     SmallVector<MachineInstr*, 4> Starts;
788     SmallVector<MachineInstr*, 4> Decs;
789     SmallVector<MachineInstr*, 4> Ends;
790 
791     for (auto &I : MBB) {
792       if (IsLoopStart(I))
793         Starts.push_back(&I);
794       else if (I.getOpcode() == ARM::t2LoopDec)
795         Decs.push_back(&I);
796       else if (I.getOpcode() == ARM::t2LoopEnd)
797         Ends.push_back(&I);
798     }
799 
800     if (Starts.empty() && Decs.empty() && Ends.empty())
801       continue;
802 
803     Changed = true;
804 
805     for (auto *Start : Starts) {
806       if (Start->getOpcode() == ARM::t2WhileLoopStart)
807         RevertWhile(Start);
808       else
809         Start->eraseFromParent();
810     }
811     for (auto *Dec : Decs)
812       RevertLoopDec(Dec);
813 
814     for (auto *End : Ends)
815       RevertLoopEnd(End);
816   }
817   return Changed;
818 }
819 
820 FunctionPass *llvm::createARMLowOverheadLoopsPass() {
821   return new ARMLowOverheadLoops();
822 }
823