1 //===-- ARMLowOverheadLoops.cpp - CodeGen Low-overhead Loops ---*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 /// \file
9 /// Finalize v8.1-m low-overhead loops by converting the associated pseudo
10 /// instructions into machine operations.
11 /// The expectation is that the loop contains three pseudo instructions:
12 /// - t2*LoopStart - placed in the preheader or pre-preheader. The do-loop
13 ///   form should be in the preheader, whereas the while form should be in the
14 ///   preheaders only predecessor.
15 /// - t2LoopDec - placed within in the loop body.
16 /// - t2LoopEnd - the loop latch terminator.
17 ///
18 /// In addition to this, we also look for the presence of the VCTP instruction,
19 /// which determines whether we can generated the tail-predicated low-overhead
20 /// loop form.
21 ///
22 /// Assumptions and Dependencies:
23 /// Low-overhead loops are constructed and executed using a setup instruction:
24 /// DLS, WLS, DLSTP or WLSTP and an instruction that loops back: LE or LETP.
25 /// WLS(TP) and LE(TP) are branching instructions with a (large) limited range
26 /// but fixed polarity: WLS can only branch forwards and LE can only branch
27 /// backwards. These restrictions mean that this pass is dependent upon block
28 /// layout and block sizes, which is why it's the last pass to run. The same is
29 /// true for ConstantIslands, but this pass does not increase the size of the
30 /// basic blocks, nor does it change the CFG. Instructions are mainly removed
31 /// during the transform and pseudo instructions are replaced by real ones. In
32 /// some cases, when we have to revert to a 'normal' loop, we have to introduce
33 /// multiple instructions for a single pseudo (see RevertWhile and
34 /// RevertLoopEnd). To handle this situation, t2WhileLoopStart and t2LoopEnd
35 /// are defined to be as large as this maximum sequence of replacement
36 /// instructions.
37 ///
38 //===----------------------------------------------------------------------===//
39 
40 #include "ARM.h"
41 #include "ARMBaseInstrInfo.h"
42 #include "ARMBaseRegisterInfo.h"
43 #include "ARMBasicBlockInfo.h"
44 #include "ARMSubtarget.h"
45 #include "Thumb2InstrInfo.h"
46 #include "llvm/ADT/SetOperations.h"
47 #include "llvm/ADT/SmallSet.h"
48 #include "llvm/CodeGen/LivePhysRegs.h"
49 #include "llvm/CodeGen/MachineFunctionPass.h"
50 #include "llvm/CodeGen/MachineLoopInfo.h"
51 #include "llvm/CodeGen/MachineLoopUtils.h"
52 #include "llvm/CodeGen/MachineRegisterInfo.h"
53 #include "llvm/CodeGen/Passes.h"
54 #include "llvm/CodeGen/ReachingDefAnalysis.h"
55 #include "llvm/MC/MCInstrDesc.h"
56 
57 using namespace llvm;
58 
59 #define DEBUG_TYPE "arm-low-overhead-loops"
60 #define ARM_LOW_OVERHEAD_LOOPS_NAME "ARM Low Overhead Loops pass"
61 
62 namespace {
63 
64   using InstSet = SmallPtrSetImpl<MachineInstr *>;
65 
66   class PostOrderLoopTraversal {
67     MachineLoop &ML;
68     MachineLoopInfo &MLI;
69     SmallPtrSet<MachineBasicBlock*, 4> Visited;
70     SmallVector<MachineBasicBlock*, 4> Order;
71 
72   public:
73     PostOrderLoopTraversal(MachineLoop &ML, MachineLoopInfo &MLI)
74       : ML(ML), MLI(MLI) { }
75 
76     const SmallVectorImpl<MachineBasicBlock*> &getOrder() const {
77       return Order;
78     }
79 
80     // Visit all the blocks within the loop, as well as exit blocks and any
81     // blocks properly dominating the header.
82     void ProcessLoop() {
83       std::function<void(MachineBasicBlock*)> Search = [this, &Search]
84         (MachineBasicBlock *MBB) -> void {
85         if (Visited.count(MBB))
86           return;
87 
88         Visited.insert(MBB);
89         for (auto *Succ : MBB->successors()) {
90           if (!ML.contains(Succ))
91             continue;
92           Search(Succ);
93         }
94         Order.push_back(MBB);
95       };
96 
97       // Insert exit blocks.
98       SmallVector<MachineBasicBlock*, 2> ExitBlocks;
99       ML.getExitBlocks(ExitBlocks);
100       for (auto *MBB : ExitBlocks)
101         Order.push_back(MBB);
102 
103       // Then add the loop body.
104       Search(ML.getHeader());
105 
106       // Then try the preheader and its predecessors.
107       std::function<void(MachineBasicBlock*)> GetPredecessor =
108         [this, &GetPredecessor] (MachineBasicBlock *MBB) -> void {
109         Order.push_back(MBB);
110         if (MBB->pred_size() == 1)
111           GetPredecessor(*MBB->pred_begin());
112       };
113 
114       if (auto *Preheader = ML.getLoopPreheader())
115         GetPredecessor(Preheader);
116       else if (auto *Preheader = MLI.findLoopPreheader(&ML, true))
117         GetPredecessor(Preheader);
118     }
119   };
120 
121   struct PredicatedMI {
122     MachineInstr *MI = nullptr;
123     SetVector<MachineInstr*> Predicates;
124 
125   public:
126     PredicatedMI(MachineInstr *I, SetVector<MachineInstr*> &Preds) :
127       MI(I) { Predicates.insert(Preds.begin(), Preds.end()); }
128   };
129 
130   // Represent a VPT block, a list of instructions that begins with a VPST and
131   // has a maximum of four proceeding instructions. All instructions within the
132   // block are predicated upon the vpr and we allow instructions to define the
133   // vpr within in the block too.
134   class VPTBlock {
135     std::unique_ptr<PredicatedMI> VPST;
136     PredicatedMI *Divergent = nullptr;
137     SmallVector<PredicatedMI, 4> Insts;
138 
139   public:
140     VPTBlock(MachineInstr *MI, SetVector<MachineInstr*> &Preds) {
141       VPST = std::make_unique<PredicatedMI>(MI, Preds);
142     }
143 
144     void addInst(MachineInstr *MI, SetVector<MachineInstr*> &Preds) {
145       LLVM_DEBUG(dbgs() << "ARM Loops: Adding predicated MI: " << *MI);
146       if (!Divergent && !set_difference(Preds, VPST->Predicates).empty()) {
147         Divergent = &Insts.back();
148         LLVM_DEBUG(dbgs() << " - has divergent predicate: " << *Divergent->MI);
149       }
150       Insts.emplace_back(MI, Preds);
151       assert(Insts.size() <= 4 && "Too many instructions in VPT block!");
152     }
153 
154     // Have we found an instruction within the block which defines the vpr? If
155     // so, not all the instructions in the block will have the same predicate.
156     bool HasNonUniformPredicate() const {
157       return Divergent != nullptr;
158     }
159 
160     // Is the given instruction part of the predicate set controlling the entry
161     // to the block.
162     bool IsPredicatedOn(MachineInstr *MI) const {
163       return VPST->Predicates.count(MI);
164     }
165 
166     // Is the given instruction the only predicate which controls the entry to
167     // the block.
168     bool IsOnlyPredicatedOn(MachineInstr *MI) const {
169       return IsPredicatedOn(MI) && VPST->Predicates.size() == 1;
170     }
171 
172     unsigned size() const { return Insts.size(); }
173     SmallVectorImpl<PredicatedMI> &getInsts() { return Insts; }
174     MachineInstr *getVPST() const { return VPST->MI; }
175     PredicatedMI *getDivergent() const { return Divergent; }
176   };
177 
178   struct LowOverheadLoop {
179 
180     MachineLoop &ML;
181     MachineLoopInfo &MLI;
182     ReachingDefAnalysis &RDA;
183     const TargetRegisterInfo &TRI;
184     MachineFunction *MF = nullptr;
185     MachineInstr *InsertPt = nullptr;
186     MachineInstr *Start = nullptr;
187     MachineInstr *Dec = nullptr;
188     MachineInstr *End = nullptr;
189     MachineInstr *VCTP = nullptr;
190     VPTBlock *CurrentBlock = nullptr;
191     SetVector<MachineInstr*> CurrentPredicate;
192     SmallVector<VPTBlock, 4> VPTBlocks;
193     SmallPtrSet<MachineInstr*, 4> ToRemove;
194     bool Revert = false;
195     bool CannotTailPredicate = false;
196 
197     LowOverheadLoop(MachineLoop &ML, MachineLoopInfo &MLI,
198                     ReachingDefAnalysis &RDA, const TargetRegisterInfo &TRI)
199       : ML(ML), MLI(MLI), RDA(RDA), TRI(TRI) {
200       MF = ML.getHeader()->getParent();
201     }
202 
203     // If this is an MVE instruction, check that we know how to use tail
204     // predication with it. Record VPT blocks and return whether the
205     // instruction is valid for tail predication.
206     bool ValidateMVEInst(MachineInstr *MI);
207 
208     void AnalyseMVEInst(MachineInstr *MI) {
209       CannotTailPredicate = !ValidateMVEInst(MI);
210     }
211 
212     bool IsTailPredicationLegal() const {
213       // For now, let's keep things really simple and only support a single
214       // block for tail predication.
215       return !Revert && FoundAllComponents() && VCTP &&
216              !CannotTailPredicate && ML.getNumBlocks() == 1;
217     }
218 
219     // Check that the predication in the loop will be equivalent once we
220     // perform the conversion. Also ensure that we can provide the number
221     // of elements to the loop start instruction.
222     bool ValidateTailPredicate(MachineInstr *StartInsertPt);
223 
224     // Check that any values available outside of the loop will be the same
225     // after tail predication conversion.
226     bool ValidateLiveOuts() const;
227 
228     // Is it safe to define LR with DLS/WLS?
229     // LR can be defined if it is the operand to start, because it's the same
230     // value, or if it's going to be equivalent to the operand to Start.
231     MachineInstr *isSafeToDefineLR();
232 
233     // Check the branch targets are within range and we satisfy our
234     // restrictions.
235     void CheckLegality(ARMBasicBlockUtils *BBUtils);
236 
237     bool FoundAllComponents() const {
238       return Start && Dec && End;
239     }
240 
241     SmallVectorImpl<VPTBlock> &getVPTBlocks() { return VPTBlocks; }
242 
243     // Return the loop iteration count, or the number of elements if we're tail
244     // predicating.
245     MachineOperand &getCount() {
246       return IsTailPredicationLegal() ?
247         VCTP->getOperand(1) : Start->getOperand(0);
248     }
249 
250     unsigned getStartOpcode() const {
251       bool IsDo = Start->getOpcode() == ARM::t2DoLoopStart;
252       if (!IsTailPredicationLegal())
253         return IsDo ? ARM::t2DLS : ARM::t2WLS;
254 
255       return VCTPOpcodeToLSTP(VCTP->getOpcode(), IsDo);
256     }
257 
258     void dump() const {
259       if (Start) dbgs() << "ARM Loops: Found Loop Start: " << *Start;
260       if (Dec) dbgs() << "ARM Loops: Found Loop Dec: " << *Dec;
261       if (End) dbgs() << "ARM Loops: Found Loop End: " << *End;
262       if (VCTP) dbgs() << "ARM Loops: Found VCTP: " << *VCTP;
263       if (!FoundAllComponents())
264         dbgs() << "ARM Loops: Not a low-overhead loop.\n";
265       else if (!(Start && Dec && End))
266         dbgs() << "ARM Loops: Failed to find all loop components.\n";
267     }
268   };
269 
270   class ARMLowOverheadLoops : public MachineFunctionPass {
271     MachineFunction           *MF = nullptr;
272     MachineLoopInfo           *MLI = nullptr;
273     ReachingDefAnalysis       *RDA = nullptr;
274     const ARMBaseInstrInfo    *TII = nullptr;
275     MachineRegisterInfo       *MRI = nullptr;
276     const TargetRegisterInfo  *TRI = nullptr;
277     std::unique_ptr<ARMBasicBlockUtils> BBUtils = nullptr;
278 
279   public:
280     static char ID;
281 
282     ARMLowOverheadLoops() : MachineFunctionPass(ID) { }
283 
284     void getAnalysisUsage(AnalysisUsage &AU) const override {
285       AU.setPreservesCFG();
286       AU.addRequired<MachineLoopInfo>();
287       AU.addRequired<ReachingDefAnalysis>();
288       MachineFunctionPass::getAnalysisUsage(AU);
289     }
290 
291     bool runOnMachineFunction(MachineFunction &MF) override;
292 
293     MachineFunctionProperties getRequiredProperties() const override {
294       return MachineFunctionProperties().set(
295           MachineFunctionProperties::Property::NoVRegs).set(
296           MachineFunctionProperties::Property::TracksLiveness);
297     }
298 
299     StringRef getPassName() const override {
300       return ARM_LOW_OVERHEAD_LOOPS_NAME;
301     }
302 
303   private:
304     bool ProcessLoop(MachineLoop *ML);
305 
306     bool RevertNonLoops();
307 
308     void RevertWhile(MachineInstr *MI) const;
309 
310     bool RevertLoopDec(MachineInstr *MI) const;
311 
312     void RevertLoopEnd(MachineInstr *MI, bool SkipCmp = false) const;
313 
314     void ConvertVPTBlocks(LowOverheadLoop &LoLoop);
315 
316     MachineInstr *ExpandLoopStart(LowOverheadLoop &LoLoop);
317 
318     void Expand(LowOverheadLoop &LoLoop);
319 
320     void IterationCountDCE(LowOverheadLoop &LoLoop);
321   };
322 }
323 
324 char ARMLowOverheadLoops::ID = 0;
325 
326 INITIALIZE_PASS(ARMLowOverheadLoops, DEBUG_TYPE, ARM_LOW_OVERHEAD_LOOPS_NAME,
327                 false, false)
328 
329 MachineInstr *LowOverheadLoop::isSafeToDefineLR() {
330   // We can define LR because LR already contains the same value.
331   if (Start->getOperand(0).getReg() == ARM::LR)
332     return Start;
333 
334   unsigned CountReg = Start->getOperand(0).getReg();
335   auto IsMoveLR = [&CountReg](MachineInstr *MI) {
336     return MI->getOpcode() == ARM::tMOVr &&
337            MI->getOperand(0).getReg() == ARM::LR &&
338            MI->getOperand(1).getReg() == CountReg &&
339            MI->getOperand(2).getImm() == ARMCC::AL;
340    };
341 
342   MachineBasicBlock *MBB = Start->getParent();
343 
344   // Find an insertion point:
345   // - Is there a (mov lr, Count) before Start? If so, and nothing else writes
346   //   to Count before Start, we can insert at that mov.
347   if (auto *LRDef = RDA.getUniqueReachingMIDef(Start, ARM::LR))
348     if (IsMoveLR(LRDef) && RDA.hasSameReachingDef(Start, LRDef, CountReg))
349       return LRDef;
350 
351   // - Is there a (mov lr, Count) after Start? If so, and nothing else writes
352   //   to Count after Start, we can insert at that mov.
353   if (auto *LRDef = RDA.getLocalLiveOutMIDef(MBB, ARM::LR))
354     if (IsMoveLR(LRDef) && RDA.hasSameReachingDef(Start, LRDef, CountReg))
355       return LRDef;
356 
357   // We've found no suitable LR def and Start doesn't use LR directly. Can we
358   // just define LR anyway?
359   return RDA.isSafeToDefRegAt(Start, ARM::LR) ? Start : nullptr;
360 }
361 
362 bool LowOverheadLoop::ValidateTailPredicate(MachineInstr *StartInsertPt) {
363   assert(VCTP && "VCTP instruction expected but is not set");
364   // All predication within the loop should be based on vctp. If the block
365   // isn't predicated on entry, check whether the vctp is within the block
366   // and that all other instructions are then predicated on it.
367   for (auto &Block : VPTBlocks) {
368     if (Block.IsPredicatedOn(VCTP))
369       continue;
370     if (!Block.HasNonUniformPredicate() || !isVCTP(Block.getDivergent()->MI)) {
371       LLVM_DEBUG(dbgs() << "ARM Loops: Found unsupported diverging predicate: "
372                  << *Block.getDivergent()->MI);
373       return false;
374     }
375     SmallVectorImpl<PredicatedMI> &Insts = Block.getInsts();
376     for (auto &PredMI : Insts) {
377       if (PredMI.Predicates.count(VCTP) || isVCTP(PredMI.MI))
378         continue;
379       LLVM_DEBUG(dbgs() << "ARM Loops: Can't convert: " << *PredMI.MI
380                  << " - which is predicated on:\n";
381                  for (auto *MI : PredMI.Predicates)
382                    dbgs() << "   - " << *MI);
383       return false;
384     }
385   }
386 
387   if (!ValidateLiveOuts())
388     return false;
389 
390   // For tail predication, we need to provide the number of elements, instead
391   // of the iteration count, to the loop start instruction. The number of
392   // elements is provided to the vctp instruction, so we need to check that
393   // we can use this register at InsertPt.
394   Register NumElements = VCTP->getOperand(1).getReg();
395 
396   // If the register is defined within loop, then we can't perform TP.
397   // TODO: Check whether this is just a mov of a register that would be
398   // available.
399   if (RDA.hasLocalDefBefore(VCTP, NumElements)) {
400     LLVM_DEBUG(dbgs() << "ARM Loops: VCTP operand is defined in the loop.\n");
401     return false;
402   }
403 
404   // The element count register maybe defined after InsertPt, in which case we
405   // need to try to move either InsertPt or the def so that the [w|d]lstp can
406   // use the value.
407   // TODO: On failing to move an instruction, check if the count is provided by
408   // a mov and whether we can use the mov operand directly.
409   MachineBasicBlock *InsertBB = StartInsertPt->getParent();
410   if (!RDA.isReachingDefLiveOut(StartInsertPt, NumElements)) {
411     if (auto *ElemDef = RDA.getLocalLiveOutMIDef(InsertBB, NumElements)) {
412       if (RDA.isSafeToMoveForwards(ElemDef, StartInsertPt)) {
413         ElemDef->removeFromParent();
414         InsertBB->insert(MachineBasicBlock::iterator(StartInsertPt), ElemDef);
415         LLVM_DEBUG(dbgs() << "ARM Loops: Moved element count def: "
416                    << *ElemDef);
417       } else if (RDA.isSafeToMoveBackwards(StartInsertPt, ElemDef)) {
418         StartInsertPt->removeFromParent();
419         InsertBB->insertAfter(MachineBasicBlock::iterator(ElemDef),
420                               StartInsertPt);
421         LLVM_DEBUG(dbgs() << "ARM Loops: Moved start past: " << *ElemDef);
422       } else {
423         LLVM_DEBUG(dbgs() << "ARM Loops: Unable to move element count to loop "
424                    << "start instruction.\n");
425         return false;
426       }
427     }
428   }
429 
430   // Especially in the case of while loops, InsertBB may not be the
431   // preheader, so we need to check that the register isn't redefined
432   // before entering the loop.
433   auto CannotProvideElements = [this](MachineBasicBlock *MBB,
434                                       Register NumElements) {
435     // NumElements is redefined in this block.
436     if (RDA.hasLocalDefBefore(&MBB->back(), NumElements))
437       return true;
438 
439     // Don't continue searching up through multiple predecessors.
440     if (MBB->pred_size() > 1)
441       return true;
442 
443     return false;
444   };
445 
446   // First, find the block that looks like the preheader.
447   MachineBasicBlock *MBB = MLI.findLoopPreheader(&ML, true);
448   if (!MBB) {
449     LLVM_DEBUG(dbgs() << "ARM Loops: Didn't find preheader.\n");
450     return false;
451   }
452 
453   // Then search backwards for a def, until we get to InsertBB.
454   while (MBB != InsertBB) {
455     if (CannotProvideElements(MBB, NumElements)) {
456       LLVM_DEBUG(dbgs() << "ARM Loops: Unable to provide element count.\n");
457       return false;
458     }
459     MBB = *MBB->pred_begin();
460   }
461 
462   // Check that the value change of the element count is what we expect and
463   // that the predication will be equivalent. For this we need:
464   // NumElements = NumElements - VectorWidth. The sub will be a sub immediate
465   // and we can also allow register copies within the chain too.
466   auto IsValidSub = [](MachineInstr *MI, int ExpectedVecWidth) {
467     return -getAddSubImmediate(*MI) == ExpectedVecWidth;
468   };
469 
470   MBB = VCTP->getParent();
471   if (auto *Def = RDA.getUniqueReachingMIDef(&MBB->back(), NumElements)) {
472     SmallPtrSet<MachineInstr*, 2> ElementChain;
473     SmallPtrSet<MachineInstr*, 2> Ignore = { VCTP };
474     unsigned ExpectedVectorWidth = getTailPredVectorWidth(VCTP->getOpcode());
475 
476     if (RDA.isSafeToRemove(Def, ElementChain, Ignore)) {
477       bool FoundSub = false;
478 
479       for (auto *MI : ElementChain) {
480         if (isMovRegOpcode(MI->getOpcode()))
481           continue;
482 
483         if (isSubImmOpcode(MI->getOpcode())) {
484           if (FoundSub || !IsValidSub(MI, ExpectedVectorWidth))
485             return false;
486           FoundSub = true;
487         } else
488           return false;
489       }
490 
491       LLVM_DEBUG(dbgs() << "ARM Loops: Will remove element count chain:\n";
492                  for (auto *MI : ElementChain)
493                    dbgs() << " - " << *MI);
494       ToRemove.insert(ElementChain.begin(), ElementChain.end());
495     }
496   }
497   return true;
498 }
499 
500 static bool isVectorPredicated(MachineInstr *MI) {
501   int PIdx = llvm::findFirstVPTPredOperandIdx(*MI);
502   return PIdx != -1 && MI->getOperand(PIdx + 1).getReg() == ARM::VPR;
503 }
504 
505 static bool isRegInClass(const MachineOperand &MO,
506                          const TargetRegisterClass *Class) {
507   return MO.isReg() && MO.getReg() && Class->contains(MO.getReg());
508 }
509 
510 // MVE 'narrowing' operate on half a lane, reading from half and writing
511 // to half, which are referred to has the top and bottom half. The other
512 // half retains its previous value.
513 static bool retainsPreviousHalfElement(const MachineInstr &MI) {
514   const MCInstrDesc &MCID = MI.getDesc();
515   uint64_t Flags = MCID.TSFlags;
516   return (Flags & ARMII::RetainsPreviousHalfElement) != 0;
517 }
518 
519 // Some MVE instructions read from the top/bottom halves of their operand(s)
520 // and generate a vector result with result elements that are double the
521 // width of the input.
522 static bool producesDoubleWidthResult(const MachineInstr &MI) {
523   const MCInstrDesc &MCID = MI.getDesc();
524   uint64_t Flags = MCID.TSFlags;
525   return (Flags & ARMII::DoubleWidthResult) != 0;
526 }
527 
528 static bool isHorizontalReduction(const MachineInstr &MI) {
529   const MCInstrDesc &MCID = MI.getDesc();
530   uint64_t Flags = MCID.TSFlags;
531   return (Flags & ARMII::HorizontalReduction) != 0;
532 }
533 
534 // Can this instruction generate a non-zero result when given only zeroed
535 // operands? This allows us to know that, given operands with false bytes
536 // zeroed by masked loads, that the result will also contain zeros in those
537 // bytes.
538 static bool canGenerateNonZeros(const MachineInstr &MI) {
539 
540   // Check for instructions which can write into a larger element size,
541   // possibly writing into a previous zero'd lane.
542   if (producesDoubleWidthResult(MI))
543     return true;
544 
545   switch (MI.getOpcode()) {
546   default:
547     break;
548   // FIXME: VNEG FP and -0? I think we'll need to handle this once we allow
549   // fp16 -> fp32 vector conversions.
550   // Instructions that perform a NOT will generate 1s from 0s.
551   case ARM::MVE_VMVN:
552   case ARM::MVE_VORN:
553   // Count leading zeros will do just that!
554   case ARM::MVE_VCLZs8:
555   case ARM::MVE_VCLZs16:
556   case ARM::MVE_VCLZs32:
557     return true;
558   }
559   return false;
560 }
561 
562 
563 // Look at its register uses to see if it only can only receive zeros
564 // into its false lanes which would then produce zeros. Also check that
565 // the output register is also defined by an FalseLanesZero instruction
566 // so that if tail-predication happens, the lanes that aren't updated will
567 // still be zeros.
568 static bool producesFalseLanesZero(MachineInstr &MI,
569                                    const TargetRegisterClass *QPRs,
570                                    const ReachingDefAnalysis &RDA,
571                                    InstSet &FalseLanesZero) {
572   if (canGenerateNonZeros(MI))
573     return false;
574 
575   bool AllowScalars = isHorizontalReduction(MI);
576   for (auto &MO : MI.operands()) {
577     if (!MO.isReg() || !MO.getReg())
578       continue;
579     if (!isRegInClass(MO, QPRs) && AllowScalars)
580       continue;
581     if (auto *OpDef = RDA.getMIOperand(&MI, MO))
582       if (FalseLanesZero.count(OpDef))
583        continue;
584     return false;
585   }
586   LLVM_DEBUG(dbgs() << "ARM Loops: Always False Zeros: " << MI);
587   return true;
588 }
589 
590 bool LowOverheadLoop::ValidateLiveOuts() const {
591   // We want to find out if the tail-predicated version of this loop will
592   // produce the same values as the loop in its original form. For this to
593   // be true, the newly inserted implicit predication must not change the
594   // the (observable) results.
595   // We're doing this because many instructions in the loop will not be
596   // predicated and so the conversion from VPT predication to tail-predication
597   // can result in different values being produced; due to the tail-predication
598   // preventing many instructions from updating their falsely predicated
599   // lanes. This analysis assumes that all the instructions perform lane-wise
600   // operations and don't perform any exchanges.
601   // A masked load, whether through VPT or tail predication, will write zeros
602   // to any of the falsely predicated bytes. So, from the loads, we know that
603   // the false lanes are zeroed and here we're trying to track that those false
604   // lanes remain zero, or where they change, the differences are masked away
605   // by their user(s).
606   // All MVE loads and stores have to be predicated, so we know that any load
607   // operands, or stored results are equivalent already. Other explicitly
608   // predicated instructions will perform the same operation in the original
609   // loop and the tail-predicated form too. Because of this, we can insert
610   // loads, stores and other predicated instructions into our Predicated
611   // set and build from there.
612   const TargetRegisterClass *QPRs = TRI.getRegClass(ARM::MQPRRegClassID);
613   SetVector<MachineInstr *> FalseLanesUnknown;
614   SmallPtrSet<MachineInstr *, 4> FalseLanesZero;
615   SmallPtrSet<MachineInstr *, 4> Predicated;
616   MachineBasicBlock *MBB = ML.getHeader();
617 
618   for (auto &MI : *MBB) {
619     const MCInstrDesc &MCID = MI.getDesc();
620     uint64_t Flags = MCID.TSFlags;
621     if ((Flags & ARMII::DomainMask) != ARMII::DomainMVE)
622       continue;
623 
624     if (isVCTP(&MI) || MI.getOpcode() == ARM::MVE_VPST)
625       continue;
626 
627     // Predicated loads will write zeros to the falsely predicated bytes of the
628     // destination register.
629     if (isVectorPredicated(&MI)) {
630       if (MI.mayLoad())
631         FalseLanesZero.insert(&MI);
632       Predicated.insert(&MI);
633       continue;
634     }
635 
636     if (MI.getNumDefs() == 0)
637       continue;
638 
639     if (!producesFalseLanesZero(MI, QPRs, RDA, FalseLanesZero)) {
640       // We require retaining and horizontal operations to operate upon zero'd
641       // false lanes to ensure the conversion doesn't change the output.
642       if (retainsPreviousHalfElement(MI) || isHorizontalReduction(MI))
643         return false;
644       // Otherwise we need to evaluate this instruction later to see whether
645       // unknown false lanes will get masked away by their user(s).
646       FalseLanesUnknown.insert(&MI);
647     } else if (!isHorizontalReduction(MI))
648       FalseLanesZero.insert(&MI);
649   }
650 
651   auto HasPredicatedUsers = [this](MachineInstr *MI, const MachineOperand &MO,
652                               SmallPtrSetImpl<MachineInstr *> &Predicated) {
653     SmallPtrSet<MachineInstr *, 2> Uses;
654     RDA.getGlobalUses(MI, MO.getReg(), Uses);
655     for (auto *Use : Uses) {
656       if (Use != MI && !Predicated.count(Use))
657         return false;
658     }
659     return true;
660   };
661 
662   // Visit the unknowns in reverse so that we can start at the values being
663   // stored and then we can work towards the leaves, hopefully adding more
664   // instructions to Predicated. Successfully terminating the loop means that
665   // all the unknown values have to found to be masked by predicated user(s).
666   for (auto *MI : reverse(FalseLanesUnknown)) {
667     for (auto &MO : MI->operands()) {
668       if (!isRegInClass(MO, QPRs) || !MO.isDef())
669         continue;
670       if (!HasPredicatedUsers(MI, MO, Predicated)) {
671         LLVM_DEBUG(dbgs() << "ARM Loops: Found an unknown def of : "
672                           << TRI.getRegAsmName(MO.getReg()) << " at " << *MI);
673         return false;
674       }
675     }
676     // Any unknown false lanes have been masked away by the user(s).
677     Predicated.insert(MI);
678   }
679 
680   // Collect Q-regs that are live in the exit blocks. We don't collect scalars
681   // because they won't be affected by lane predication.
682   SmallSet<Register, 2> LiveOuts;
683   SmallVector<MachineBasicBlock *, 2> ExitBlocks;
684   ML.getExitBlocks(ExitBlocks);
685   for (auto *MBB : ExitBlocks)
686     for (const MachineBasicBlock::RegisterMaskPair &RegMask : MBB->liveins())
687       if (QPRs->contains(RegMask.PhysReg))
688         LiveOuts.insert(RegMask.PhysReg);
689 
690   // Collect the instructions in the loop body that define the live-out values.
691   SmallPtrSet<MachineInstr *, 2> LiveMIs;
692   assert(ML.getNumBlocks() == 1 && "Expected single block loop!");
693   for (auto Reg : LiveOuts)
694     if (auto *MI = RDA.getLocalLiveOutMIDef(MBB, Reg))
695       LiveMIs.insert(MI);
696 
697   LLVM_DEBUG(dbgs() << "ARM Loops: Found loop live-outs:\n";
698              for (auto *MI : LiveMIs)
699                dbgs() << " - " << *MI);
700   // We've already validated that any VPT predication within the loop will be
701   // equivalent when we perform the predication transformation; so we know that
702   // any VPT predicated instruction is predicated upon VCTP. Any live-out
703   // instruction needs to be predicated, so check this here.
704   for (auto *MI : LiveMIs)
705     if (!isVectorPredicated(MI))
706       return false;
707 
708   return true;
709 }
710 
711 void LowOverheadLoop::CheckLegality(ARMBasicBlockUtils *BBUtils) {
712   if (Revert)
713     return;
714 
715   if (!End->getOperand(1).isMBB())
716     report_fatal_error("Expected LoopEnd to target basic block");
717 
718   // TODO Maybe there's cases where the target doesn't have to be the header,
719   // but for now be safe and revert.
720   if (End->getOperand(1).getMBB() != ML.getHeader()) {
721     LLVM_DEBUG(dbgs() << "ARM Loops: LoopEnd is not targetting header.\n");
722     Revert = true;
723     return;
724   }
725 
726   // The WLS and LE instructions have 12-bits for the label offset. WLS
727   // requires a positive offset, while LE uses negative.
728   if (BBUtils->getOffsetOf(End) < BBUtils->getOffsetOf(ML.getHeader()) ||
729       !BBUtils->isBBInRange(End, ML.getHeader(), 4094)) {
730     LLVM_DEBUG(dbgs() << "ARM Loops: LE offset is out-of-range\n");
731     Revert = true;
732     return;
733   }
734 
735   if (Start->getOpcode() == ARM::t2WhileLoopStart &&
736       (BBUtils->getOffsetOf(Start) >
737        BBUtils->getOffsetOf(Start->getOperand(1).getMBB()) ||
738        !BBUtils->isBBInRange(Start, Start->getOperand(1).getMBB(), 4094))) {
739     LLVM_DEBUG(dbgs() << "ARM Loops: WLS offset is out-of-range!\n");
740     Revert = true;
741     return;
742   }
743 
744   InsertPt = Revert ? nullptr : isSafeToDefineLR();
745   if (!InsertPt) {
746     LLVM_DEBUG(dbgs() << "ARM Loops: Unable to find safe insertion point.\n");
747     Revert = true;
748     return;
749   } else
750     LLVM_DEBUG(dbgs() << "ARM Loops: Start insertion point: " << *InsertPt);
751 
752   if (!IsTailPredicationLegal()) {
753     LLVM_DEBUG(if (!VCTP)
754                  dbgs() << "ARM Loops: Didn't find a VCTP instruction.\n";
755                dbgs() << "ARM Loops: Tail-predication is not valid.\n");
756     return;
757   }
758 
759   assert(ML.getBlocks().size() == 1 &&
760          "Shouldn't be processing a loop with more than one block");
761   CannotTailPredicate = !ValidateTailPredicate(InsertPt);
762   LLVM_DEBUG(if (CannotTailPredicate)
763              dbgs() << "ARM Loops: Couldn't validate tail predicate.\n");
764 }
765 
766 bool LowOverheadLoop::ValidateMVEInst(MachineInstr* MI) {
767   if (CannotTailPredicate)
768     return false;
769 
770   // Only support a single vctp.
771   if (isVCTP(MI) && VCTP)
772     return false;
773 
774   // Start a new vpt block when we discover a vpt.
775   if (MI->getOpcode() == ARM::MVE_VPST) {
776     VPTBlocks.emplace_back(MI, CurrentPredicate);
777     CurrentBlock = &VPTBlocks.back();
778     return true;
779   } else if (isVCTP(MI))
780     VCTP = MI;
781   else if (MI->getOpcode() == ARM::MVE_VPSEL ||
782            MI->getOpcode() == ARM::MVE_VPNOT)
783     return false;
784 
785   // TODO: Allow VPSEL and VPNOT, we currently cannot because:
786   // 1) It will use the VPR as a predicate operand, but doesn't have to be
787   //    instead a VPT block, which means we can assert while building up
788   //    the VPT block because we don't find another VPST to being a new
789   //    one.
790   // 2) VPSEL still requires a VPR operand even after tail predicating,
791   //    which means we can't remove it unless there is another
792   //    instruction, such as vcmp, that can provide the VPR def.
793 
794   bool IsUse = false;
795   bool IsDef = false;
796   const MCInstrDesc &MCID = MI->getDesc();
797   for (int i = MI->getNumOperands() - 1; i >= 0; --i) {
798     const MachineOperand &MO = MI->getOperand(i);
799     if (!MO.isReg() || MO.getReg() != ARM::VPR)
800       continue;
801 
802     if (MO.isDef()) {
803       CurrentPredicate.insert(MI);
804       IsDef = true;
805     } else if (ARM::isVpred(MCID.OpInfo[i].OperandType)) {
806       CurrentBlock->addInst(MI, CurrentPredicate);
807       IsUse = true;
808     } else {
809       LLVM_DEBUG(dbgs() << "ARM Loops: Found instruction using vpr: " << *MI);
810       return false;
811     }
812   }
813 
814   // If we find a vpr def that is not already predicated on the vctp, we've
815   // got disjoint predicates that may not be equivalent when we do the
816   // conversion.
817   if (IsDef && !IsUse && VCTP && !isVCTP(MI)) {
818     LLVM_DEBUG(dbgs() << "ARM Loops: Found disjoint vpr def: " << *MI);
819     return false;
820   }
821 
822   uint64_t Flags = MCID.TSFlags;
823   if ((Flags & ARMII::DomainMask) != ARMII::DomainMVE)
824     return true;
825 
826   // If we find an instruction that has been marked as not valid for tail
827   // predication, only allow the instruction if it's contained within a valid
828   // VPT block.
829   if ((Flags & ARMII::ValidForTailPredication) == 0 && !IsUse) {
830     LLVM_DEBUG(dbgs() << "ARM Loops: Can't tail predicate: " << *MI);
831     return false;
832   }
833 
834   // If the instruction is already explicitly predicated, then the conversion
835   // will be fine, but ensure that all memory operations are predicated.
836   return !IsUse && MI->mayLoadOrStore() ? false : true;
837 }
838 
839 bool ARMLowOverheadLoops::runOnMachineFunction(MachineFunction &mf) {
840   const ARMSubtarget &ST = static_cast<const ARMSubtarget&>(mf.getSubtarget());
841   if (!ST.hasLOB())
842     return false;
843 
844   MF = &mf;
845   LLVM_DEBUG(dbgs() << "ARM Loops on " << MF->getName() << " ------------- \n");
846 
847   MLI = &getAnalysis<MachineLoopInfo>();
848   RDA = &getAnalysis<ReachingDefAnalysis>();
849   MF->getProperties().set(MachineFunctionProperties::Property::TracksLiveness);
850   MRI = &MF->getRegInfo();
851   TII = static_cast<const ARMBaseInstrInfo*>(ST.getInstrInfo());
852   TRI = ST.getRegisterInfo();
853   BBUtils = std::unique_ptr<ARMBasicBlockUtils>(new ARMBasicBlockUtils(*MF));
854   BBUtils->computeAllBlockSizes();
855   BBUtils->adjustBBOffsetsAfter(&MF->front());
856 
857   bool Changed = false;
858   for (auto ML : *MLI) {
859     if (!ML->getParentLoop())
860       Changed |= ProcessLoop(ML);
861   }
862   Changed |= RevertNonLoops();
863   return Changed;
864 }
865 
866 bool ARMLowOverheadLoops::ProcessLoop(MachineLoop *ML) {
867 
868   bool Changed = false;
869 
870   // Process inner loops first.
871   for (auto I = ML->begin(), E = ML->end(); I != E; ++I)
872     Changed |= ProcessLoop(*I);
873 
874   LLVM_DEBUG(dbgs() << "ARM Loops: Processing loop containing:\n";
875              if (auto *Preheader = ML->getLoopPreheader())
876                dbgs() << " - " << Preheader->getName() << "\n";
877              else if (auto *Preheader = MLI->findLoopPreheader(ML))
878                dbgs() << " - " << Preheader->getName() << "\n";
879              else if (auto *Preheader = MLI->findLoopPreheader(ML, true))
880                dbgs() << " - " << Preheader->getName() << "\n";
881              for (auto *MBB : ML->getBlocks())
882                dbgs() << " - " << MBB->getName() << "\n";
883             );
884 
885   // Search the given block for a loop start instruction. If one isn't found,
886   // and there's only one predecessor block, search that one too.
887   std::function<MachineInstr*(MachineBasicBlock*)> SearchForStart =
888     [&SearchForStart](MachineBasicBlock *MBB) -> MachineInstr* {
889     for (auto &MI : *MBB) {
890       if (isLoopStart(MI))
891         return &MI;
892     }
893     if (MBB->pred_size() == 1)
894       return SearchForStart(*MBB->pred_begin());
895     return nullptr;
896   };
897 
898   LowOverheadLoop LoLoop(*ML, *MLI, *RDA, *TRI);
899   // Search the preheader for the start intrinsic.
900   // FIXME: I don't see why we shouldn't be supporting multiple predecessors
901   // with potentially multiple set.loop.iterations, so we need to enable this.
902   if (auto *Preheader = ML->getLoopPreheader())
903     LoLoop.Start = SearchForStart(Preheader);
904   else if (auto *Preheader = MLI->findLoopPreheader(ML, true))
905     LoLoop.Start = SearchForStart(Preheader);
906   else
907     return false;
908 
909   // Find the low-overhead loop components and decide whether or not to fall
910   // back to a normal loop. Also look for a vctp instructions and decide
911   // whether we can convert that predicate using tail predication.
912   for (auto *MBB : reverse(ML->getBlocks())) {
913     for (auto &MI : *MBB) {
914       if (MI.isDebugValue())
915         continue;
916       else if (MI.getOpcode() == ARM::t2LoopDec)
917         LoLoop.Dec = &MI;
918       else if (MI.getOpcode() == ARM::t2LoopEnd)
919         LoLoop.End = &MI;
920       else if (isLoopStart(MI))
921         LoLoop.Start = &MI;
922       else if (MI.getDesc().isCall()) {
923         // TODO: Though the call will require LE to execute again, does this
924         // mean we should revert? Always executing LE hopefully should be
925         // faster than performing a sub,cmp,br or even subs,br.
926         LoLoop.Revert = true;
927         LLVM_DEBUG(dbgs() << "ARM Loops: Found call.\n");
928       } else {
929         // Record VPR defs and build up their corresponding vpt blocks.
930         // Check we know how to tail predicate any mve instructions.
931         LoLoop.AnalyseMVEInst(&MI);
932       }
933     }
934   }
935 
936   LLVM_DEBUG(LoLoop.dump());
937   if (!LoLoop.FoundAllComponents()) {
938     LLVM_DEBUG(dbgs() << "ARM Loops: Didn't find loop start, update, end\n");
939     return false;
940   }
941 
942   // Check that the only instruction using LoopDec is LoopEnd.
943   // TODO: Check for copy chains that really have no effect.
944   SmallPtrSet<MachineInstr*, 2> Uses;
945   RDA->getReachingLocalUses(LoLoop.Dec, ARM::LR, Uses);
946   if (Uses.size() > 1 || !Uses.count(LoLoop.End)) {
947     LLVM_DEBUG(dbgs() << "ARM Loops: Unable to remove LoopDec.\n");
948     LoLoop.Revert = true;
949   }
950   LoLoop.CheckLegality(BBUtils.get());
951   Expand(LoLoop);
952   return true;
953 }
954 
955 // WhileLoopStart holds the exit block, so produce a cmp lr, 0 and then a
956 // beq that branches to the exit branch.
957 // TODO: We could also try to generate a cbz if the value in LR is also in
958 // another low register.
959 void ARMLowOverheadLoops::RevertWhile(MachineInstr *MI) const {
960   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to cmp: " << *MI);
961   MachineBasicBlock *MBB = MI->getParent();
962   MachineInstrBuilder MIB = BuildMI(*MBB, MI, MI->getDebugLoc(),
963                                     TII->get(ARM::t2CMPri));
964   MIB.add(MI->getOperand(0));
965   MIB.addImm(0);
966   MIB.addImm(ARMCC::AL);
967   MIB.addReg(ARM::NoRegister);
968 
969   MachineBasicBlock *DestBB = MI->getOperand(1).getMBB();
970   unsigned BrOpc = BBUtils->isBBInRange(MI, DestBB, 254) ?
971     ARM::tBcc : ARM::t2Bcc;
972 
973   MIB = BuildMI(*MBB, MI, MI->getDebugLoc(), TII->get(BrOpc));
974   MIB.add(MI->getOperand(1));   // branch target
975   MIB.addImm(ARMCC::EQ);        // condition code
976   MIB.addReg(ARM::CPSR);
977   MI->eraseFromParent();
978 }
979 
980 bool ARMLowOverheadLoops::RevertLoopDec(MachineInstr *MI) const {
981   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to sub: " << *MI);
982   MachineBasicBlock *MBB = MI->getParent();
983   SmallPtrSet<MachineInstr*, 1> Ignore;
984   for (auto I = MachineBasicBlock::iterator(MI), E = MBB->end(); I != E; ++I) {
985     if (I->getOpcode() == ARM::t2LoopEnd) {
986       Ignore.insert(&*I);
987       break;
988     }
989   }
990 
991   // If nothing defines CPSR between LoopDec and LoopEnd, use a t2SUBS.
992   bool SetFlags = RDA->isSafeToDefRegAt(MI, ARM::CPSR, Ignore);
993 
994   MachineInstrBuilder MIB = BuildMI(*MBB, MI, MI->getDebugLoc(),
995                                     TII->get(ARM::t2SUBri));
996   MIB.addDef(ARM::LR);
997   MIB.add(MI->getOperand(1));
998   MIB.add(MI->getOperand(2));
999   MIB.addImm(ARMCC::AL);
1000   MIB.addReg(0);
1001 
1002   if (SetFlags) {
1003     MIB.addReg(ARM::CPSR);
1004     MIB->getOperand(5).setIsDef(true);
1005   } else
1006     MIB.addReg(0);
1007 
1008   MI->eraseFromParent();
1009   return SetFlags;
1010 }
1011 
1012 // Generate a subs, or sub and cmp, and a branch instead of an LE.
1013 void ARMLowOverheadLoops::RevertLoopEnd(MachineInstr *MI, bool SkipCmp) const {
1014   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to cmp, br: " << *MI);
1015 
1016   MachineBasicBlock *MBB = MI->getParent();
1017   // Create cmp
1018   if (!SkipCmp) {
1019     MachineInstrBuilder MIB = BuildMI(*MBB, MI, MI->getDebugLoc(),
1020                                       TII->get(ARM::t2CMPri));
1021     MIB.addReg(ARM::LR);
1022     MIB.addImm(0);
1023     MIB.addImm(ARMCC::AL);
1024     MIB.addReg(ARM::NoRegister);
1025   }
1026 
1027   MachineBasicBlock *DestBB = MI->getOperand(1).getMBB();
1028   unsigned BrOpc = BBUtils->isBBInRange(MI, DestBB, 254) ?
1029     ARM::tBcc : ARM::t2Bcc;
1030 
1031   // Create bne
1032   MachineInstrBuilder MIB =
1033     BuildMI(*MBB, MI, MI->getDebugLoc(), TII->get(BrOpc));
1034   MIB.add(MI->getOperand(1));   // branch target
1035   MIB.addImm(ARMCC::NE);        // condition code
1036   MIB.addReg(ARM::CPSR);
1037   MI->eraseFromParent();
1038 }
1039 
1040 // Perform dead code elimation on the loop iteration count setup expression.
1041 // If we are tail-predicating, the number of elements to be processed is the
1042 // operand of the VCTP instruction in the vector body, see getCount(), which is
1043 // register $r3 in this example:
1044 //
1045 //   $lr = big-itercount-expression
1046 //   ..
1047 //   t2DoLoopStart renamable $lr
1048 //   vector.body:
1049 //     ..
1050 //     $vpr = MVE_VCTP32 renamable $r3
1051 //     renamable $lr = t2LoopDec killed renamable $lr, 1
1052 //     t2LoopEnd renamable $lr, %vector.body
1053 //     tB %end
1054 //
1055 // What we would like achieve here is to replace the do-loop start pseudo
1056 // instruction t2DoLoopStart with:
1057 //
1058 //    $lr = MVE_DLSTP_32 killed renamable $r3
1059 //
1060 // Thus, $r3 which defines the number of elements, is written to $lr,
1061 // and then we want to delete the whole chain that used to define $lr,
1062 // see the comment below how this chain could look like.
1063 //
1064 void ARMLowOverheadLoops::IterationCountDCE(LowOverheadLoop &LoLoop) {
1065   if (!LoLoop.IsTailPredicationLegal())
1066     return;
1067 
1068   LLVM_DEBUG(dbgs() << "ARM Loops: Trying DCE on loop iteration count.\n");
1069 
1070   MachineInstr *Def = RDA->getMIOperand(LoLoop.Start, 0);
1071   if (!Def) {
1072     LLVM_DEBUG(dbgs() << "ARM Loops: Couldn't find iteration count.\n");
1073     return;
1074   }
1075 
1076   // Collect and remove the users of iteration count.
1077   SmallPtrSet<MachineInstr*, 4> Killed  = { LoLoop.Start, LoLoop.Dec,
1078                                             LoLoop.End, LoLoop.InsertPt };
1079   SmallPtrSet<MachineInstr*, 2> Remove;
1080   if (RDA->isSafeToRemove(Def, Remove, Killed))
1081     LoLoop.ToRemove.insert(Remove.begin(), Remove.end());
1082   else {
1083     LLVM_DEBUG(dbgs() << "ARM Loops: Unsafe to remove loop iteration count.\n");
1084     return;
1085   }
1086 
1087   // Collect the dead code and the MBBs in which they reside.
1088   RDA->collectKilledOperands(Def, Killed);
1089   SmallPtrSet<MachineBasicBlock*, 2> BasicBlocks;
1090   for (auto *MI : Killed)
1091     BasicBlocks.insert(MI->getParent());
1092 
1093   // Collect IT blocks in all affected basic blocks.
1094   std::map<MachineInstr *, SmallPtrSet<MachineInstr *, 2>> ITBlocks;
1095   for (auto *MBB : BasicBlocks) {
1096     for (auto &MI : *MBB) {
1097       if (MI.getOpcode() != ARM::t2IT)
1098         continue;
1099       RDA->getReachingLocalUses(&MI, ARM::ITSTATE, ITBlocks[&MI]);
1100     }
1101   }
1102 
1103   // If we're removing all of the instructions within an IT block, then
1104   // also remove the IT instruction.
1105   SmallPtrSet<MachineInstr*, 2> ModifiedITs;
1106   for (auto *MI : Killed) {
1107     if (MachineOperand *MO = MI->findRegisterUseOperand(ARM::ITSTATE)) {
1108       MachineInstr *IT = RDA->getMIOperand(MI, *MO);
1109       auto &CurrentBlock = ITBlocks[IT];
1110       CurrentBlock.erase(MI);
1111       if (CurrentBlock.empty())
1112         ModifiedITs.erase(IT);
1113       else
1114         ModifiedITs.insert(IT);
1115     }
1116   }
1117 
1118   // Delete the killed instructions only if we don't have any IT blocks that
1119   // need to be modified because we need to fixup the mask.
1120   // TODO: Handle cases where IT blocks are modified.
1121   if (ModifiedITs.empty()) {
1122     LLVM_DEBUG(dbgs() << "ARM Loops: Will remove iteration count:\n";
1123                for (auto *MI : Killed)
1124                  dbgs() << " - " << *MI);
1125     LoLoop.ToRemove.insert(Killed.begin(), Killed.end());
1126   } else
1127     LLVM_DEBUG(dbgs() << "ARM Loops: Would need to modify IT block(s).\n");
1128 }
1129 
1130 MachineInstr* ARMLowOverheadLoops::ExpandLoopStart(LowOverheadLoop &LoLoop) {
1131   LLVM_DEBUG(dbgs() << "ARM Loops: Expanding LoopStart.\n");
1132   // When using tail-predication, try to delete the dead code that was used to
1133   // calculate the number of loop iterations.
1134   IterationCountDCE(LoLoop);
1135 
1136   MachineInstr *InsertPt = LoLoop.InsertPt;
1137   MachineInstr *Start = LoLoop.Start;
1138   MachineBasicBlock *MBB = InsertPt->getParent();
1139   bool IsDo = Start->getOpcode() == ARM::t2DoLoopStart;
1140   unsigned Opc = LoLoop.getStartOpcode();
1141   MachineOperand &Count = LoLoop.getCount();
1142 
1143   MachineInstrBuilder MIB =
1144     BuildMI(*MBB, InsertPt, InsertPt->getDebugLoc(), TII->get(Opc));
1145 
1146   MIB.addDef(ARM::LR);
1147   MIB.add(Count);
1148   if (!IsDo)
1149     MIB.add(Start->getOperand(1));
1150 
1151   // If we're inserting at a mov lr, then remove it as it's redundant.
1152   if (InsertPt != Start)
1153     LoLoop.ToRemove.insert(InsertPt);
1154   LoLoop.ToRemove.insert(Start);
1155   LLVM_DEBUG(dbgs() << "ARM Loops: Inserted start: " << *MIB);
1156   return &*MIB;
1157 }
1158 
1159 void ARMLowOverheadLoops::ConvertVPTBlocks(LowOverheadLoop &LoLoop) {
1160   auto RemovePredicate = [](MachineInstr *MI) {
1161     LLVM_DEBUG(dbgs() << "ARM Loops: Removing predicate from: " << *MI);
1162     if (int PIdx = llvm::findFirstVPTPredOperandIdx(*MI)) {
1163       assert(MI->getOperand(PIdx).getImm() == ARMVCC::Then &&
1164              "Expected Then predicate!");
1165       MI->getOperand(PIdx).setImm(ARMVCC::None);
1166       MI->getOperand(PIdx+1).setReg(0);
1167     } else
1168       llvm_unreachable("trying to unpredicate a non-predicated instruction");
1169   };
1170 
1171   // There are a few scenarios which we have to fix up:
1172   // 1) A VPT block with is only predicated by the vctp and has no internal vpr
1173   //    defs.
1174   // 2) A VPT block which is only predicated by the vctp but has an internal
1175   //    vpr def.
1176   // 3) A VPT block which is predicated upon the vctp as well as another vpr
1177   //    def.
1178   // 4) A VPT block which is not predicated upon a vctp, but contains it and
1179   //    all instructions within the block are predicated upon in.
1180 
1181   for (auto &Block : LoLoop.getVPTBlocks()) {
1182     SmallVectorImpl<PredicatedMI> &Insts = Block.getInsts();
1183     if (Block.HasNonUniformPredicate()) {
1184       PredicatedMI *Divergent = Block.getDivergent();
1185       if (isVCTP(Divergent->MI)) {
1186         // The vctp will be removed, so the size of the vpt block needs to be
1187         // modified.
1188         uint64_t Size = (uint64_t)getARMVPTBlockMask(Block.size() - 1);
1189         Block.getVPST()->getOperand(0).setImm(Size);
1190         LLVM_DEBUG(dbgs() << "ARM Loops: Modified VPT block mask.\n");
1191       } else if (Block.IsOnlyPredicatedOn(LoLoop.VCTP)) {
1192         // The VPT block has a non-uniform predicate but it's entry is guarded
1193         // only by a vctp, which means we:
1194         // - Need to remove the original vpst.
1195         // - Then need to unpredicate any following instructions, until
1196         //   we come across the divergent vpr def.
1197         // - Insert a new vpst to predicate the instruction(s) that following
1198         //   the divergent vpr def.
1199         // TODO: We could be producing more VPT blocks than necessary and could
1200         // fold the newly created one into a proceeding one.
1201         for (auto I = ++MachineBasicBlock::iterator(Block.getVPST()),
1202              E = ++MachineBasicBlock::iterator(Divergent->MI); I != E; ++I)
1203           RemovePredicate(&*I);
1204 
1205         unsigned Size = 0;
1206         auto E = MachineBasicBlock::reverse_iterator(Divergent->MI);
1207         auto I = MachineBasicBlock::reverse_iterator(Insts.back().MI);
1208         MachineInstr *InsertAt = nullptr;
1209         while (I != E) {
1210           InsertAt = &*I;
1211           ++Size;
1212           ++I;
1213         }
1214         MachineInstrBuilder MIB = BuildMI(*InsertAt->getParent(), InsertAt,
1215                                           InsertAt->getDebugLoc(),
1216                                           TII->get(ARM::MVE_VPST));
1217         MIB.addImm((uint64_t)getARMVPTBlockMask(Size));
1218         LLVM_DEBUG(dbgs() << "ARM Loops: Removing VPST: " << *Block.getVPST());
1219         LLVM_DEBUG(dbgs() << "ARM Loops: Created VPST: " << *MIB);
1220         LoLoop.ToRemove.insert(Block.getVPST());
1221       }
1222     } else if (Block.IsOnlyPredicatedOn(LoLoop.VCTP)) {
1223       // A vpt block which is only predicated upon vctp and has no internal vpr
1224       // defs:
1225       // - Remove vpst.
1226       // - Unpredicate the remaining instructions.
1227       LLVM_DEBUG(dbgs() << "ARM Loops: Removing VPST: " << *Block.getVPST());
1228       LoLoop.ToRemove.insert(Block.getVPST());
1229       for (auto &PredMI : Insts)
1230         RemovePredicate(PredMI.MI);
1231     }
1232   }
1233   LLVM_DEBUG(dbgs() << "ARM Loops: Removing VCTP: " << *LoLoop.VCTP);
1234   LoLoop.ToRemove.insert(LoLoop.VCTP);
1235 }
1236 
1237 void ARMLowOverheadLoops::Expand(LowOverheadLoop &LoLoop) {
1238 
1239   // Combine the LoopDec and LoopEnd instructions into LE(TP).
1240   auto ExpandLoopEnd = [this](LowOverheadLoop &LoLoop) {
1241     MachineInstr *End = LoLoop.End;
1242     MachineBasicBlock *MBB = End->getParent();
1243     unsigned Opc = LoLoop.IsTailPredicationLegal() ?
1244       ARM::MVE_LETP : ARM::t2LEUpdate;
1245     MachineInstrBuilder MIB = BuildMI(*MBB, End, End->getDebugLoc(),
1246                                       TII->get(Opc));
1247     MIB.addDef(ARM::LR);
1248     MIB.add(End->getOperand(0));
1249     MIB.add(End->getOperand(1));
1250     LLVM_DEBUG(dbgs() << "ARM Loops: Inserted LE: " << *MIB);
1251     LoLoop.ToRemove.insert(LoLoop.Dec);
1252     LoLoop.ToRemove.insert(End);
1253     return &*MIB;
1254   };
1255 
1256   // TODO: We should be able to automatically remove these branches before we
1257   // get here - probably by teaching analyzeBranch about the pseudo
1258   // instructions.
1259   // If there is an unconditional branch, after I, that just branches to the
1260   // next block, remove it.
1261   auto RemoveDeadBranch = [](MachineInstr *I) {
1262     MachineBasicBlock *BB = I->getParent();
1263     MachineInstr *Terminator = &BB->instr_back();
1264     if (Terminator->isUnconditionalBranch() && I != Terminator) {
1265       MachineBasicBlock *Succ = Terminator->getOperand(0).getMBB();
1266       if (BB->isLayoutSuccessor(Succ)) {
1267         LLVM_DEBUG(dbgs() << "ARM Loops: Removing branch: " << *Terminator);
1268         Terminator->eraseFromParent();
1269       }
1270     }
1271   };
1272 
1273   if (LoLoop.Revert) {
1274     if (LoLoop.Start->getOpcode() == ARM::t2WhileLoopStart)
1275       RevertWhile(LoLoop.Start);
1276     else
1277       LoLoop.Start->eraseFromParent();
1278     bool FlagsAlreadySet = RevertLoopDec(LoLoop.Dec);
1279     RevertLoopEnd(LoLoop.End, FlagsAlreadySet);
1280   } else {
1281     LoLoop.Start = ExpandLoopStart(LoLoop);
1282     RemoveDeadBranch(LoLoop.Start);
1283     LoLoop.End = ExpandLoopEnd(LoLoop);
1284     RemoveDeadBranch(LoLoop.End);
1285     if (LoLoop.IsTailPredicationLegal())
1286       ConvertVPTBlocks(LoLoop);
1287     for (auto *I : LoLoop.ToRemove) {
1288       LLVM_DEBUG(dbgs() << "ARM Loops: Erasing " << *I);
1289       I->eraseFromParent();
1290     }
1291   }
1292 
1293   PostOrderLoopTraversal DFS(LoLoop.ML, *MLI);
1294   DFS.ProcessLoop();
1295   const SmallVectorImpl<MachineBasicBlock*> &PostOrder = DFS.getOrder();
1296   for (auto *MBB : PostOrder) {
1297     recomputeLiveIns(*MBB);
1298     // FIXME: For some reason, the live-in print order is non-deterministic for
1299     // our tests and I can't out why... So just sort them.
1300     MBB->sortUniqueLiveIns();
1301   }
1302 
1303   for (auto *MBB : reverse(PostOrder))
1304     recomputeLivenessFlags(*MBB);
1305 
1306   // We've moved, removed and inserted new instructions, so update RDA.
1307   RDA->reset();
1308 }
1309 
1310 bool ARMLowOverheadLoops::RevertNonLoops() {
1311   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting any remaining pseudos...\n");
1312   bool Changed = false;
1313 
1314   for (auto &MBB : *MF) {
1315     SmallVector<MachineInstr*, 4> Starts;
1316     SmallVector<MachineInstr*, 4> Decs;
1317     SmallVector<MachineInstr*, 4> Ends;
1318 
1319     for (auto &I : MBB) {
1320       if (isLoopStart(I))
1321         Starts.push_back(&I);
1322       else if (I.getOpcode() == ARM::t2LoopDec)
1323         Decs.push_back(&I);
1324       else if (I.getOpcode() == ARM::t2LoopEnd)
1325         Ends.push_back(&I);
1326     }
1327 
1328     if (Starts.empty() && Decs.empty() && Ends.empty())
1329       continue;
1330 
1331     Changed = true;
1332 
1333     for (auto *Start : Starts) {
1334       if (Start->getOpcode() == ARM::t2WhileLoopStart)
1335         RevertWhile(Start);
1336       else
1337         Start->eraseFromParent();
1338     }
1339     for (auto *Dec : Decs)
1340       RevertLoopDec(Dec);
1341 
1342     for (auto *End : Ends)
1343       RevertLoopEnd(End);
1344   }
1345   return Changed;
1346 }
1347 
1348 FunctionPass *llvm::createARMLowOverheadLoopsPass() {
1349   return new ARMLowOverheadLoops();
1350 }
1351