1 //===-- ARMLowOverheadLoops.cpp - CodeGen Low-overhead Loops ---*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 /// \file
9 /// Finalize v8.1-m low-overhead loops by converting the associated pseudo
10 /// instructions into machine operations.
11 /// The expectation is that the loop contains three pseudo instructions:
12 /// - t2*LoopStart - placed in the preheader or pre-preheader. The do-loop
13 ///   form should be in the preheader, whereas the while form should be in the
14 ///   preheaders only predecessor.
15 /// - t2LoopDec - placed within in the loop body.
16 /// - t2LoopEnd - the loop latch terminator.
17 ///
18 /// In addition to this, we also look for the presence of the VCTP instruction,
19 /// which determines whether we can generated the tail-predicated low-overhead
20 /// loop form.
21 ///
22 /// Assumptions and Dependencies:
23 /// Low-overhead loops are constructed and executed using a setup instruction:
24 /// DLS, WLS, DLSTP or WLSTP and an instruction that loops back: LE or LETP.
25 /// WLS(TP) and LE(TP) are branching instructions with a (large) limited range
26 /// but fixed polarity: WLS can only branch forwards and LE can only branch
27 /// backwards. These restrictions mean that this pass is dependent upon block
28 /// layout and block sizes, which is why it's the last pass to run. The same is
29 /// true for ConstantIslands, but this pass does not increase the size of the
30 /// basic blocks, nor does it change the CFG. Instructions are mainly removed
31 /// during the transform and pseudo instructions are replaced by real ones. In
32 /// some cases, when we have to revert to a 'normal' loop, we have to introduce
33 /// multiple instructions for a single pseudo (see RevertWhile and
34 /// RevertLoopEnd). To handle this situation, t2WhileLoopStart and t2LoopEnd
35 /// are defined to be as large as this maximum sequence of replacement
36 /// instructions.
37 ///
38 //===----------------------------------------------------------------------===//
39 
40 #include "ARM.h"
41 #include "ARMBaseInstrInfo.h"
42 #include "ARMBaseRegisterInfo.h"
43 #include "ARMBasicBlockInfo.h"
44 #include "ARMSubtarget.h"
45 #include "Thumb2InstrInfo.h"
46 #include "llvm/ADT/SetOperations.h"
47 #include "llvm/ADT/SmallSet.h"
48 #include "llvm/CodeGen/LivePhysRegs.h"
49 #include "llvm/CodeGen/MachineFunctionPass.h"
50 #include "llvm/CodeGen/MachineLoopInfo.h"
51 #include "llvm/CodeGen/MachineLoopUtils.h"
52 #include "llvm/CodeGen/MachineRegisterInfo.h"
53 #include "llvm/CodeGen/Passes.h"
54 #include "llvm/CodeGen/ReachingDefAnalysis.h"
55 #include "llvm/MC/MCInstrDesc.h"
56 
57 using namespace llvm;
58 
59 #define DEBUG_TYPE "arm-low-overhead-loops"
60 #define ARM_LOW_OVERHEAD_LOOPS_NAME "ARM Low Overhead Loops pass"
61 
62 namespace {
63 
64   using InstSet = SmallPtrSetImpl<MachineInstr *>;
65 
66   class PostOrderLoopTraversal {
67     MachineLoop &ML;
68     MachineLoopInfo &MLI;
69     SmallPtrSet<MachineBasicBlock*, 4> Visited;
70     SmallVector<MachineBasicBlock*, 4> Order;
71 
72   public:
73     PostOrderLoopTraversal(MachineLoop &ML, MachineLoopInfo &MLI)
74       : ML(ML), MLI(MLI) { }
75 
76     const SmallVectorImpl<MachineBasicBlock*> &getOrder() const {
77       return Order;
78     }
79 
80     // Visit all the blocks within the loop, as well as exit blocks and any
81     // blocks properly dominating the header.
82     void ProcessLoop() {
83       std::function<void(MachineBasicBlock*)> Search = [this, &Search]
84         (MachineBasicBlock *MBB) -> void {
85         if (Visited.count(MBB))
86           return;
87 
88         Visited.insert(MBB);
89         for (auto *Succ : MBB->successors()) {
90           if (!ML.contains(Succ))
91             continue;
92           Search(Succ);
93         }
94         Order.push_back(MBB);
95       };
96 
97       // Insert exit blocks.
98       SmallVector<MachineBasicBlock*, 2> ExitBlocks;
99       ML.getExitBlocks(ExitBlocks);
100       for (auto *MBB : ExitBlocks)
101         Order.push_back(MBB);
102 
103       // Then add the loop body.
104       Search(ML.getHeader());
105 
106       // Then try the preheader and its predecessors.
107       std::function<void(MachineBasicBlock*)> GetPredecessor =
108         [this, &GetPredecessor] (MachineBasicBlock *MBB) -> void {
109         Order.push_back(MBB);
110         if (MBB->pred_size() == 1)
111           GetPredecessor(*MBB->pred_begin());
112       };
113 
114       if (auto *Preheader = ML.getLoopPreheader())
115         GetPredecessor(Preheader);
116       else if (auto *Preheader = MLI.findLoopPreheader(&ML, true))
117         GetPredecessor(Preheader);
118     }
119   };
120 
121   struct PredicatedMI {
122     MachineInstr *MI = nullptr;
123     SetVector<MachineInstr*> Predicates;
124 
125   public:
126     PredicatedMI(MachineInstr *I, SetVector<MachineInstr*> &Preds) :
127       MI(I) { Predicates.insert(Preds.begin(), Preds.end()); }
128   };
129 
130   // Represent a VPT block, a list of instructions that begins with a VPST and
131   // has a maximum of four proceeding instructions. All instructions within the
132   // block are predicated upon the vpr and we allow instructions to define the
133   // vpr within in the block too.
134   class VPTBlock {
135     std::unique_ptr<PredicatedMI> VPST;
136     PredicatedMI *Divergent = nullptr;
137     SmallVector<PredicatedMI, 4> Insts;
138 
139   public:
140     VPTBlock(MachineInstr *MI, SetVector<MachineInstr*> &Preds) {
141       VPST = std::make_unique<PredicatedMI>(MI, Preds);
142     }
143 
144     void addInst(MachineInstr *MI, SetVector<MachineInstr*> &Preds) {
145       LLVM_DEBUG(dbgs() << "ARM Loops: Adding predicated MI: " << *MI);
146       if (!Divergent && !set_difference(Preds, VPST->Predicates).empty()) {
147         Divergent = &Insts.back();
148         LLVM_DEBUG(dbgs() << " - has divergent predicate: " << *Divergent->MI);
149       }
150       Insts.emplace_back(MI, Preds);
151       assert(Insts.size() <= 4 && "Too many instructions in VPT block!");
152     }
153 
154     // Have we found an instruction within the block which defines the vpr? If
155     // so, not all the instructions in the block will have the same predicate.
156     bool HasNonUniformPredicate() const {
157       return Divergent != nullptr;
158     }
159 
160     // Is the given instruction part of the predicate set controlling the entry
161     // to the block.
162     bool IsPredicatedOn(MachineInstr *MI) const {
163       return VPST->Predicates.count(MI);
164     }
165 
166     // Is the given instruction the only predicate which controls the entry to
167     // the block.
168     bool IsOnlyPredicatedOn(MachineInstr *MI) const {
169       return IsPredicatedOn(MI) && VPST->Predicates.size() == 1;
170     }
171 
172     unsigned size() const { return Insts.size(); }
173     SmallVectorImpl<PredicatedMI> &getInsts() { return Insts; }
174     MachineInstr *getVPST() const { return VPST->MI; }
175     PredicatedMI *getDivergent() const { return Divergent; }
176   };
177 
178   struct LowOverheadLoop {
179 
180     MachineLoop &ML;
181     MachineLoopInfo &MLI;
182     ReachingDefAnalysis &RDA;
183     const TargetRegisterInfo &TRI;
184     MachineFunction *MF = nullptr;
185     MachineInstr *InsertPt = nullptr;
186     MachineInstr *Start = nullptr;
187     MachineInstr *Dec = nullptr;
188     MachineInstr *End = nullptr;
189     MachineInstr *VCTP = nullptr;
190     VPTBlock *CurrentBlock = nullptr;
191     SetVector<MachineInstr*> CurrentPredicate;
192     SmallVector<VPTBlock, 4> VPTBlocks;
193     SmallPtrSet<MachineInstr*, 4> ToRemove;
194     bool Revert = false;
195     bool CannotTailPredicate = false;
196 
197     LowOverheadLoop(MachineLoop &ML, MachineLoopInfo &MLI,
198                     ReachingDefAnalysis &RDA, const TargetRegisterInfo &TRI)
199       : ML(ML), MLI(MLI), RDA(RDA), TRI(TRI) {
200       MF = ML.getHeader()->getParent();
201     }
202 
203     // If this is an MVE instruction, check that we know how to use tail
204     // predication with it. Record VPT blocks and return whether the
205     // instruction is valid for tail predication.
206     bool ValidateMVEInst(MachineInstr *MI);
207 
208     void AnalyseMVEInst(MachineInstr *MI) {
209       CannotTailPredicate = !ValidateMVEInst(MI);
210     }
211 
212     bool IsTailPredicationLegal() const {
213       // For now, let's keep things really simple and only support a single
214       // block for tail predication.
215       return !Revert && FoundAllComponents() && VCTP &&
216              !CannotTailPredicate && ML.getNumBlocks() == 1;
217     }
218 
219     // Check that the predication in the loop will be equivalent once we
220     // perform the conversion. Also ensure that we can provide the number
221     // of elements to the loop start instruction.
222     bool ValidateTailPredicate(MachineInstr *StartInsertPt);
223 
224     // Check that any values available outside of the loop will be the same
225     // after tail predication conversion.
226     bool ValidateLiveOuts() const;
227 
228     // Is it safe to define LR with DLS/WLS?
229     // LR can be defined if it is the operand to start, because it's the same
230     // value, or if it's going to be equivalent to the operand to Start.
231     MachineInstr *isSafeToDefineLR();
232 
233     // Check the branch targets are within range and we satisfy our
234     // restrictions.
235     void CheckLegality(ARMBasicBlockUtils *BBUtils);
236 
237     bool FoundAllComponents() const {
238       return Start && Dec && End;
239     }
240 
241     SmallVectorImpl<VPTBlock> &getVPTBlocks() { return VPTBlocks; }
242 
243     // Return the loop iteration count, or the number of elements if we're tail
244     // predicating.
245     MachineOperand &getCount() {
246       return IsTailPredicationLegal() ?
247         VCTP->getOperand(1) : Start->getOperand(0);
248     }
249 
250     unsigned getStartOpcode() const {
251       bool IsDo = Start->getOpcode() == ARM::t2DoLoopStart;
252       if (!IsTailPredicationLegal())
253         return IsDo ? ARM::t2DLS : ARM::t2WLS;
254 
255       return VCTPOpcodeToLSTP(VCTP->getOpcode(), IsDo);
256     }
257 
258     void dump() const {
259       if (Start) dbgs() << "ARM Loops: Found Loop Start: " << *Start;
260       if (Dec) dbgs() << "ARM Loops: Found Loop Dec: " << *Dec;
261       if (End) dbgs() << "ARM Loops: Found Loop End: " << *End;
262       if (VCTP) dbgs() << "ARM Loops: Found VCTP: " << *VCTP;
263       if (!FoundAllComponents())
264         dbgs() << "ARM Loops: Not a low-overhead loop.\n";
265       else if (!(Start && Dec && End))
266         dbgs() << "ARM Loops: Failed to find all loop components.\n";
267     }
268   };
269 
270   class ARMLowOverheadLoops : public MachineFunctionPass {
271     MachineFunction           *MF = nullptr;
272     MachineLoopInfo           *MLI = nullptr;
273     ReachingDefAnalysis       *RDA = nullptr;
274     const ARMBaseInstrInfo    *TII = nullptr;
275     MachineRegisterInfo       *MRI = nullptr;
276     const TargetRegisterInfo  *TRI = nullptr;
277     std::unique_ptr<ARMBasicBlockUtils> BBUtils = nullptr;
278 
279   public:
280     static char ID;
281 
282     ARMLowOverheadLoops() : MachineFunctionPass(ID) { }
283 
284     void getAnalysisUsage(AnalysisUsage &AU) const override {
285       AU.setPreservesCFG();
286       AU.addRequired<MachineLoopInfo>();
287       AU.addRequired<ReachingDefAnalysis>();
288       MachineFunctionPass::getAnalysisUsage(AU);
289     }
290 
291     bool runOnMachineFunction(MachineFunction &MF) override;
292 
293     MachineFunctionProperties getRequiredProperties() const override {
294       return MachineFunctionProperties().set(
295           MachineFunctionProperties::Property::NoVRegs).set(
296           MachineFunctionProperties::Property::TracksLiveness);
297     }
298 
299     StringRef getPassName() const override {
300       return ARM_LOW_OVERHEAD_LOOPS_NAME;
301     }
302 
303   private:
304     bool ProcessLoop(MachineLoop *ML);
305 
306     bool RevertNonLoops();
307 
308     void RevertWhile(MachineInstr *MI) const;
309 
310     bool RevertLoopDec(MachineInstr *MI) const;
311 
312     void RevertLoopEnd(MachineInstr *MI, bool SkipCmp = false) const;
313 
314     void ConvertVPTBlocks(LowOverheadLoop &LoLoop);
315 
316     MachineInstr *ExpandLoopStart(LowOverheadLoop &LoLoop);
317 
318     void Expand(LowOverheadLoop &LoLoop);
319 
320     void IterationCountDCE(LowOverheadLoop &LoLoop);
321   };
322 }
323 
324 char ARMLowOverheadLoops::ID = 0;
325 
326 INITIALIZE_PASS(ARMLowOverheadLoops, DEBUG_TYPE, ARM_LOW_OVERHEAD_LOOPS_NAME,
327                 false, false)
328 
329 MachineInstr *LowOverheadLoop::isSafeToDefineLR() {
330   // We can define LR because LR already contains the same value.
331   if (Start->getOperand(0).getReg() == ARM::LR)
332     return Start;
333 
334   unsigned CountReg = Start->getOperand(0).getReg();
335   auto IsMoveLR = [&CountReg](MachineInstr *MI) {
336     return MI->getOpcode() == ARM::tMOVr &&
337            MI->getOperand(0).getReg() == ARM::LR &&
338            MI->getOperand(1).getReg() == CountReg &&
339            MI->getOperand(2).getImm() == ARMCC::AL;
340    };
341 
342   MachineBasicBlock *MBB = Start->getParent();
343 
344   // Find an insertion point:
345   // - Is there a (mov lr, Count) before Start? If so, and nothing else writes
346   //   to Count before Start, we can insert at that mov.
347   if (auto *LRDef = RDA.getUniqueReachingMIDef(Start, ARM::LR))
348     if (IsMoveLR(LRDef) && RDA.hasSameReachingDef(Start, LRDef, CountReg))
349       return LRDef;
350 
351   // - Is there a (mov lr, Count) after Start? If so, and nothing else writes
352   //   to Count after Start, we can insert at that mov.
353   if (auto *LRDef = RDA.getLocalLiveOutMIDef(MBB, ARM::LR))
354     if (IsMoveLR(LRDef) && RDA.hasSameReachingDef(Start, LRDef, CountReg))
355       return LRDef;
356 
357   // We've found no suitable LR def and Start doesn't use LR directly. Can we
358   // just define LR anyway?
359   return RDA.isSafeToDefRegAt(Start, ARM::LR) ? Start : nullptr;
360 }
361 
362 bool LowOverheadLoop::ValidateTailPredicate(MachineInstr *StartInsertPt) {
363   assert(VCTP && "VCTP instruction expected but is not set");
364   // All predication within the loop should be based on vctp. If the block
365   // isn't predicated on entry, check whether the vctp is within the block
366   // and that all other instructions are then predicated on it.
367   for (auto &Block : VPTBlocks) {
368     if (Block.IsPredicatedOn(VCTP))
369       continue;
370     if (!Block.HasNonUniformPredicate() || !isVCTP(Block.getDivergent()->MI)) {
371       LLVM_DEBUG(dbgs() << "ARM Loops: Found unsupported diverging predicate: "
372                  << *Block.getDivergent()->MI);
373       return false;
374     }
375     SmallVectorImpl<PredicatedMI> &Insts = Block.getInsts();
376     for (auto &PredMI : Insts) {
377       if (PredMI.Predicates.count(VCTP) || isVCTP(PredMI.MI))
378         continue;
379       LLVM_DEBUG(dbgs() << "ARM Loops: Can't convert: " << *PredMI.MI
380                  << " - which is predicated on:\n";
381                  for (auto *MI : PredMI.Predicates)
382                    dbgs() << "   - " << *MI);
383       return false;
384     }
385   }
386 
387   if (!ValidateLiveOuts())
388     return false;
389 
390   // For tail predication, we need to provide the number of elements, instead
391   // of the iteration count, to the loop start instruction. The number of
392   // elements is provided to the vctp instruction, so we need to check that
393   // we can use this register at InsertPt.
394   Register NumElements = VCTP->getOperand(1).getReg();
395 
396   // If the register is defined within loop, then we can't perform TP.
397   // TODO: Check whether this is just a mov of a register that would be
398   // available.
399   if (RDA.hasLocalDefBefore(VCTP, NumElements)) {
400     LLVM_DEBUG(dbgs() << "ARM Loops: VCTP operand is defined in the loop.\n");
401     return false;
402   }
403 
404   // The element count register maybe defined after InsertPt, in which case we
405   // need to try to move either InsertPt or the def so that the [w|d]lstp can
406   // use the value.
407   // TODO: On failing to move an instruction, check if the count is provided by
408   // a mov and whether we can use the mov operand directly.
409   MachineBasicBlock *InsertBB = StartInsertPt->getParent();
410   if (!RDA.isReachingDefLiveOut(StartInsertPt, NumElements)) {
411     if (auto *ElemDef = RDA.getLocalLiveOutMIDef(InsertBB, NumElements)) {
412       if (RDA.isSafeToMoveForwards(ElemDef, StartInsertPt)) {
413         ElemDef->removeFromParent();
414         InsertBB->insert(MachineBasicBlock::iterator(StartInsertPt), ElemDef);
415         LLVM_DEBUG(dbgs() << "ARM Loops: Moved element count def: "
416                    << *ElemDef);
417       } else if (RDA.isSafeToMoveBackwards(StartInsertPt, ElemDef)) {
418         StartInsertPt->removeFromParent();
419         InsertBB->insertAfter(MachineBasicBlock::iterator(ElemDef),
420                               StartInsertPt);
421         LLVM_DEBUG(dbgs() << "ARM Loops: Moved start past: " << *ElemDef);
422       } else {
423         LLVM_DEBUG(dbgs() << "ARM Loops: Unable to move element count to loop "
424                    << "start instruction.\n");
425         return false;
426       }
427     }
428   }
429 
430   // Especially in the case of while loops, InsertBB may not be the
431   // preheader, so we need to check that the register isn't redefined
432   // before entering the loop.
433   auto CannotProvideElements = [this](MachineBasicBlock *MBB,
434                                       Register NumElements) {
435     // NumElements is redefined in this block.
436     if (RDA.hasLocalDefBefore(&MBB->back(), NumElements))
437       return true;
438 
439     // Don't continue searching up through multiple predecessors.
440     if (MBB->pred_size() > 1)
441       return true;
442 
443     return false;
444   };
445 
446   // First, find the block that looks like the preheader.
447   MachineBasicBlock *MBB = MLI.findLoopPreheader(&ML, true);
448   if (!MBB) {
449     LLVM_DEBUG(dbgs() << "ARM Loops: Didn't find preheader.\n");
450     return false;
451   }
452 
453   // Then search backwards for a def, until we get to InsertBB.
454   while (MBB != InsertBB) {
455     if (CannotProvideElements(MBB, NumElements)) {
456       LLVM_DEBUG(dbgs() << "ARM Loops: Unable to provide element count.\n");
457       return false;
458     }
459     MBB = *MBB->pred_begin();
460   }
461 
462   // Check that the value change of the element count is what we expect and
463   // that the predication will be equivalent. For this we need:
464   // NumElements = NumElements - VectorWidth. The sub will be a sub immediate
465   // and we can also allow register copies within the chain too.
466   auto IsValidSub = [](MachineInstr *MI, unsigned ExpectedVecWidth) {
467     unsigned ImmOpIdx = 0;
468     switch (MI->getOpcode()) {
469     default:
470       llvm_unreachable("unhandled sub opcode");
471     case ARM::tSUBi3:
472     case ARM::tSUBi8:
473       ImmOpIdx = 3;
474       break;
475     case ARM::t2SUBri:
476     case ARM::t2SUBri12:
477       ImmOpIdx = 2;
478       break;
479     }
480     return MI->getOperand(ImmOpIdx).getImm() == ExpectedVecWidth;
481   };
482 
483   MBB = VCTP->getParent();
484   if (auto *Def = RDA.getUniqueReachingMIDef(&MBB->back(), NumElements)) {
485     SmallPtrSet<MachineInstr*, 2> ElementChain;
486     SmallPtrSet<MachineInstr*, 2> Ignore = { VCTP };
487     unsigned ExpectedVectorWidth = getTailPredVectorWidth(VCTP->getOpcode());
488 
489     if (RDA.isSafeToRemove(Def, ElementChain, Ignore)) {
490       bool FoundSub = false;
491 
492       for (auto *MI : ElementChain) {
493         if (isMovRegOpcode(MI->getOpcode()))
494           continue;
495 
496         if (isSubImmOpcode(MI->getOpcode())) {
497           if (FoundSub || !IsValidSub(MI, ExpectedVectorWidth))
498             return false;
499           FoundSub = true;
500         } else
501           return false;
502       }
503 
504       LLVM_DEBUG(dbgs() << "ARM Loops: Will remove element count chain:\n";
505                  for (auto *MI : ElementChain)
506                    dbgs() << " - " << *MI);
507       ToRemove.insert(ElementChain.begin(), ElementChain.end());
508     }
509   }
510   return true;
511 }
512 
513 static bool isVectorPredicated(MachineInstr *MI) {
514   int PIdx = llvm::findFirstVPTPredOperandIdx(*MI);
515   return PIdx != -1 && MI->getOperand(PIdx + 1).getReg() == ARM::VPR;
516 }
517 
518 static bool isRegInClass(const MachineOperand &MO,
519                          const TargetRegisterClass *Class) {
520   return MO.isReg() && MO.getReg() && Class->contains(MO.getReg());
521 }
522 
523 // MVE 'narrowing' operate on half a lane, reading from half and writing
524 // to half, which are referred to has the top and bottom half. The other
525 // half retains its previous value.
526 static bool retainsPreviousHalfElement(const MachineInstr &MI) {
527   const MCInstrDesc &MCID = MI.getDesc();
528   uint64_t Flags = MCID.TSFlags;
529   return (Flags & ARMII::RetainsPreviousHalfElement) != 0;
530 }
531 
532 // Some MVE instructions read from the top/bottom halves of their operand(s)
533 // and generate a vector result with result elements that are double the
534 // width of the input.
535 static bool producesDoubleWidthResult(const MachineInstr &MI) {
536   const MCInstrDesc &MCID = MI.getDesc();
537   uint64_t Flags = MCID.TSFlags;
538   return (Flags & ARMII::DoubleWidthResult) != 0;
539 }
540 
541 static bool isHorizontalReduction(const MachineInstr &MI) {
542   const MCInstrDesc &MCID = MI.getDesc();
543   uint64_t Flags = MCID.TSFlags;
544   return (Flags & ARMII::HorizontalReduction) != 0;
545 }
546 
547 // Can this instruction generate a non-zero result when given only zeroed
548 // operands? This allows us to know that, given operands with false bytes
549 // zeroed by masked loads, that the result will also contain zeros in those
550 // bytes.
551 static bool canGenerateNonZeros(const MachineInstr &MI) {
552 
553   // Check for instructions which can write into a larger element size,
554   // possibly writing into a previous zero'd lane.
555   if (producesDoubleWidthResult(MI))
556     return true;
557 
558   switch (MI.getOpcode()) {
559   default:
560     break;
561   // FIXME: VNEG FP and -0? I think we'll need to handle this once we allow
562   // fp16 -> fp32 vector conversions.
563   // Instructions that perform a NOT will generate 1s from 0s.
564   case ARM::MVE_VMVN:
565   case ARM::MVE_VORN:
566   // Count leading zeros will do just that!
567   case ARM::MVE_VCLZs8:
568   case ARM::MVE_VCLZs16:
569   case ARM::MVE_VCLZs32:
570     return true;
571   }
572   return false;
573 }
574 
575 
576 // Look at its register uses to see if it only can only receive zeros
577 // into its false lanes which would then produce zeros. Also check that
578 // the output register is also defined by an FalseLanesZero instruction
579 // so that if tail-predication happens, the lanes that aren't updated will
580 // still be zeros.
581 static bool producesFalseLanesZero(MachineInstr &MI,
582                                    const TargetRegisterClass *QPRs,
583                                    const ReachingDefAnalysis &RDA,
584                                    InstSet &FalseLanesZero) {
585   if (canGenerateNonZeros(MI))
586     return false;
587 
588   bool AllowScalars = isHorizontalReduction(MI);
589   for (auto &MO : MI.operands()) {
590     if (!MO.isReg() || !MO.getReg())
591       continue;
592     if (!isRegInClass(MO, QPRs) && AllowScalars)
593       continue;
594     if (auto *OpDef = RDA.getMIOperand(&MI, MO))
595       if (FalseLanesZero.count(OpDef))
596        continue;
597     return false;
598   }
599   LLVM_DEBUG(dbgs() << "ARM Loops: Always False Zeros: " << MI);
600   return true;
601 }
602 
603 bool LowOverheadLoop::ValidateLiveOuts() const {
604   // We want to find out if the tail-predicated version of this loop will
605   // produce the same values as the loop in its original form. For this to
606   // be true, the newly inserted implicit predication must not change the
607   // the (observable) results.
608   // We're doing this because many instructions in the loop will not be
609   // predicated and so the conversion from VPT predication to tail-predication
610   // can result in different values being produced; due to the tail-predication
611   // preventing many instructions from updating their falsely predicated
612   // lanes. This analysis assumes that all the instructions perform lane-wise
613   // operations and don't perform any exchanges.
614   // A masked load, whether through VPT or tail predication, will write zeros
615   // to any of the falsely predicated bytes. So, from the loads, we know that
616   // the false lanes are zeroed and here we're trying to track that those false
617   // lanes remain zero, or where they change, the differences are masked away
618   // by their user(s).
619   // All MVE loads and stores have to be predicated, so we know that any load
620   // operands, or stored results are equivalent already. Other explicitly
621   // predicated instructions will perform the same operation in the original
622   // loop and the tail-predicated form too. Because of this, we can insert
623   // loads, stores and other predicated instructions into our Predicated
624   // set and build from there.
625   const TargetRegisterClass *QPRs = TRI.getRegClass(ARM::MQPRRegClassID);
626   SetVector<MachineInstr *> FalseLanesUnknown;
627   SmallPtrSet<MachineInstr *, 4> FalseLanesZero;
628   SmallPtrSet<MachineInstr *, 4> Predicated;
629   MachineBasicBlock *MBB = ML.getHeader();
630 
631   for (auto &MI : *MBB) {
632     const MCInstrDesc &MCID = MI.getDesc();
633     uint64_t Flags = MCID.TSFlags;
634     if ((Flags & ARMII::DomainMask) != ARMII::DomainMVE)
635       continue;
636 
637     if (isVCTP(&MI) || MI.getOpcode() == ARM::MVE_VPST)
638       continue;
639 
640     // Predicated loads will write zeros to the falsely predicated bytes of the
641     // destination register.
642     if (isVectorPredicated(&MI)) {
643       if (MI.mayLoad())
644         FalseLanesZero.insert(&MI);
645       Predicated.insert(&MI);
646       continue;
647     }
648 
649     if (MI.getNumDefs() == 0)
650       continue;
651 
652     if (!producesFalseLanesZero(MI, QPRs, RDA, FalseLanesZero)) {
653       // We require retaining and horizontal operations to operate upon zero'd
654       // false lanes to ensure the conversion doesn't change the output.
655       if (retainsPreviousHalfElement(MI) || isHorizontalReduction(MI))
656         return false;
657       // Otherwise we need to evaluate this instruction later to see whether
658       // unknown false lanes will get masked away by their user(s).
659       FalseLanesUnknown.insert(&MI);
660     } else if (!isHorizontalReduction(MI))
661       FalseLanesZero.insert(&MI);
662   }
663 
664   auto HasPredicatedUsers = [this](MachineInstr *MI, const MachineOperand &MO,
665                               SmallPtrSetImpl<MachineInstr *> &Predicated) {
666     SmallPtrSet<MachineInstr *, 2> Uses;
667     RDA.getGlobalUses(MI, MO.getReg(), Uses);
668     for (auto *Use : Uses) {
669       if (Use != MI && !Predicated.count(Use))
670         return false;
671     }
672     return true;
673   };
674 
675   // Visit the unknowns in reverse so that we can start at the values being
676   // stored and then we can work towards the leaves, hopefully adding more
677   // instructions to Predicated. Successfully terminating the loop means that
678   // all the unknown values have to found to be masked by predicated user(s).
679   for (auto *MI : reverse(FalseLanesUnknown)) {
680     for (auto &MO : MI->operands()) {
681       if (!isRegInClass(MO, QPRs) || !MO.isDef())
682         continue;
683       if (!HasPredicatedUsers(MI, MO, Predicated)) {
684         LLVM_DEBUG(dbgs() << "ARM Loops: Found an unknown def of : "
685                           << TRI.getRegAsmName(MO.getReg()) << " at " << *MI);
686         return false;
687       }
688     }
689     // Any unknown false lanes have been masked away by the user(s).
690     Predicated.insert(MI);
691   }
692 
693   // Collect Q-regs that are live in the exit blocks. We don't collect scalars
694   // because they won't be affected by lane predication.
695   SmallSet<Register, 2> LiveOuts;
696   SmallVector<MachineBasicBlock *, 2> ExitBlocks;
697   ML.getExitBlocks(ExitBlocks);
698   for (auto *MBB : ExitBlocks)
699     for (const MachineBasicBlock::RegisterMaskPair &RegMask : MBB->liveins())
700       if (QPRs->contains(RegMask.PhysReg))
701         LiveOuts.insert(RegMask.PhysReg);
702 
703   // Collect the instructions in the loop body that define the live-out values.
704   SmallPtrSet<MachineInstr *, 2> LiveMIs;
705   assert(ML.getNumBlocks() == 1 && "Expected single block loop!");
706   for (auto Reg : LiveOuts)
707     if (auto *MI = RDA.getLocalLiveOutMIDef(MBB, Reg))
708       LiveMIs.insert(MI);
709 
710   LLVM_DEBUG(dbgs() << "ARM Loops: Found loop live-outs:\n";
711              for (auto *MI : LiveMIs)
712                dbgs() << " - " << *MI);
713   // We've already validated that any VPT predication within the loop will be
714   // equivalent when we perform the predication transformation; so we know that
715   // any VPT predicated instruction is predicated upon VCTP. Any live-out
716   // instruction needs to be predicated, so check this here.
717   for (auto *MI : LiveMIs)
718     if (!isVectorPredicated(MI))
719       return false;
720 
721   return true;
722 }
723 
724 void LowOverheadLoop::CheckLegality(ARMBasicBlockUtils *BBUtils) {
725   if (Revert)
726     return;
727 
728   if (!End->getOperand(1).isMBB())
729     report_fatal_error("Expected LoopEnd to target basic block");
730 
731   // TODO Maybe there's cases where the target doesn't have to be the header,
732   // but for now be safe and revert.
733   if (End->getOperand(1).getMBB() != ML.getHeader()) {
734     LLVM_DEBUG(dbgs() << "ARM Loops: LoopEnd is not targetting header.\n");
735     Revert = true;
736     return;
737   }
738 
739   // The WLS and LE instructions have 12-bits for the label offset. WLS
740   // requires a positive offset, while LE uses negative.
741   if (BBUtils->getOffsetOf(End) < BBUtils->getOffsetOf(ML.getHeader()) ||
742       !BBUtils->isBBInRange(End, ML.getHeader(), 4094)) {
743     LLVM_DEBUG(dbgs() << "ARM Loops: LE offset is out-of-range\n");
744     Revert = true;
745     return;
746   }
747 
748   if (Start->getOpcode() == ARM::t2WhileLoopStart &&
749       (BBUtils->getOffsetOf(Start) >
750        BBUtils->getOffsetOf(Start->getOperand(1).getMBB()) ||
751        !BBUtils->isBBInRange(Start, Start->getOperand(1).getMBB(), 4094))) {
752     LLVM_DEBUG(dbgs() << "ARM Loops: WLS offset is out-of-range!\n");
753     Revert = true;
754     return;
755   }
756 
757   InsertPt = Revert ? nullptr : isSafeToDefineLR();
758   if (!InsertPt) {
759     LLVM_DEBUG(dbgs() << "ARM Loops: Unable to find safe insertion point.\n");
760     Revert = true;
761     return;
762   } else
763     LLVM_DEBUG(dbgs() << "ARM Loops: Start insertion point: " << *InsertPt);
764 
765   if (!IsTailPredicationLegal()) {
766     LLVM_DEBUG(if (!VCTP)
767                  dbgs() << "ARM Loops: Didn't find a VCTP instruction.\n";
768                dbgs() << "ARM Loops: Tail-predication is not valid.\n");
769     return;
770   }
771 
772   assert(ML.getBlocks().size() == 1 &&
773          "Shouldn't be processing a loop with more than one block");
774   CannotTailPredicate = !ValidateTailPredicate(InsertPt);
775   LLVM_DEBUG(if (CannotTailPredicate)
776              dbgs() << "ARM Loops: Couldn't validate tail predicate.\n");
777 }
778 
779 bool LowOverheadLoop::ValidateMVEInst(MachineInstr* MI) {
780   if (CannotTailPredicate)
781     return false;
782 
783   // Only support a single vctp.
784   if (isVCTP(MI) && VCTP)
785     return false;
786 
787   // Start a new vpt block when we discover a vpt.
788   if (MI->getOpcode() == ARM::MVE_VPST) {
789     VPTBlocks.emplace_back(MI, CurrentPredicate);
790     CurrentBlock = &VPTBlocks.back();
791     return true;
792   } else if (isVCTP(MI))
793     VCTP = MI;
794   else if (MI->getOpcode() == ARM::MVE_VPSEL ||
795            MI->getOpcode() == ARM::MVE_VPNOT)
796     return false;
797 
798   // TODO: Allow VPSEL and VPNOT, we currently cannot because:
799   // 1) It will use the VPR as a predicate operand, but doesn't have to be
800   //    instead a VPT block, which means we can assert while building up
801   //    the VPT block because we don't find another VPST to being a new
802   //    one.
803   // 2) VPSEL still requires a VPR operand even after tail predicating,
804   //    which means we can't remove it unless there is another
805   //    instruction, such as vcmp, that can provide the VPR def.
806 
807   bool IsUse = false;
808   bool IsDef = false;
809   const MCInstrDesc &MCID = MI->getDesc();
810   for (int i = MI->getNumOperands() - 1; i >= 0; --i) {
811     const MachineOperand &MO = MI->getOperand(i);
812     if (!MO.isReg() || MO.getReg() != ARM::VPR)
813       continue;
814 
815     if (MO.isDef()) {
816       CurrentPredicate.insert(MI);
817       IsDef = true;
818     } else if (ARM::isVpred(MCID.OpInfo[i].OperandType)) {
819       CurrentBlock->addInst(MI, CurrentPredicate);
820       IsUse = true;
821     } else {
822       LLVM_DEBUG(dbgs() << "ARM Loops: Found instruction using vpr: " << *MI);
823       return false;
824     }
825   }
826 
827   // If we find a vpr def that is not already predicated on the vctp, we've
828   // got disjoint predicates that may not be equivalent when we do the
829   // conversion.
830   if (IsDef && !IsUse && VCTP && !isVCTP(MI)) {
831     LLVM_DEBUG(dbgs() << "ARM Loops: Found disjoint vpr def: " << *MI);
832     return false;
833   }
834 
835   uint64_t Flags = MCID.TSFlags;
836   if ((Flags & ARMII::DomainMask) != ARMII::DomainMVE)
837     return true;
838 
839   // If we find an instruction that has been marked as not valid for tail
840   // predication, only allow the instruction if it's contained within a valid
841   // VPT block.
842   if ((Flags & ARMII::ValidForTailPredication) == 0 && !IsUse) {
843     LLVM_DEBUG(dbgs() << "ARM Loops: Can't tail predicate: " << *MI);
844     return false;
845   }
846 
847   // If the instruction is already explicitly predicated, then the conversion
848   // will be fine, but ensure that all memory operations are predicated.
849   return !IsUse && MI->mayLoadOrStore() ? false : true;
850 }
851 
852 bool ARMLowOverheadLoops::runOnMachineFunction(MachineFunction &mf) {
853   const ARMSubtarget &ST = static_cast<const ARMSubtarget&>(mf.getSubtarget());
854   if (!ST.hasLOB())
855     return false;
856 
857   MF = &mf;
858   LLVM_DEBUG(dbgs() << "ARM Loops on " << MF->getName() << " ------------- \n");
859 
860   MLI = &getAnalysis<MachineLoopInfo>();
861   RDA = &getAnalysis<ReachingDefAnalysis>();
862   MF->getProperties().set(MachineFunctionProperties::Property::TracksLiveness);
863   MRI = &MF->getRegInfo();
864   TII = static_cast<const ARMBaseInstrInfo*>(ST.getInstrInfo());
865   TRI = ST.getRegisterInfo();
866   BBUtils = std::unique_ptr<ARMBasicBlockUtils>(new ARMBasicBlockUtils(*MF));
867   BBUtils->computeAllBlockSizes();
868   BBUtils->adjustBBOffsetsAfter(&MF->front());
869 
870   bool Changed = false;
871   for (auto ML : *MLI) {
872     if (!ML->getParentLoop())
873       Changed |= ProcessLoop(ML);
874   }
875   Changed |= RevertNonLoops();
876   return Changed;
877 }
878 
879 bool ARMLowOverheadLoops::ProcessLoop(MachineLoop *ML) {
880 
881   bool Changed = false;
882 
883   // Process inner loops first.
884   for (auto I = ML->begin(), E = ML->end(); I != E; ++I)
885     Changed |= ProcessLoop(*I);
886 
887   LLVM_DEBUG(dbgs() << "ARM Loops: Processing loop containing:\n";
888              if (auto *Preheader = ML->getLoopPreheader())
889                dbgs() << " - " << Preheader->getName() << "\n";
890              else if (auto *Preheader = MLI->findLoopPreheader(ML))
891                dbgs() << " - " << Preheader->getName() << "\n";
892              else if (auto *Preheader = MLI->findLoopPreheader(ML, true))
893                dbgs() << " - " << Preheader->getName() << "\n";
894              for (auto *MBB : ML->getBlocks())
895                dbgs() << " - " << MBB->getName() << "\n";
896             );
897 
898   // Search the given block for a loop start instruction. If one isn't found,
899   // and there's only one predecessor block, search that one too.
900   std::function<MachineInstr*(MachineBasicBlock*)> SearchForStart =
901     [&SearchForStart](MachineBasicBlock *MBB) -> MachineInstr* {
902     for (auto &MI : *MBB) {
903       if (isLoopStart(MI))
904         return &MI;
905     }
906     if (MBB->pred_size() == 1)
907       return SearchForStart(*MBB->pred_begin());
908     return nullptr;
909   };
910 
911   LowOverheadLoop LoLoop(*ML, *MLI, *RDA, *TRI);
912   // Search the preheader for the start intrinsic.
913   // FIXME: I don't see why we shouldn't be supporting multiple predecessors
914   // with potentially multiple set.loop.iterations, so we need to enable this.
915   if (auto *Preheader = ML->getLoopPreheader())
916     LoLoop.Start = SearchForStart(Preheader);
917   else if (auto *Preheader = MLI->findLoopPreheader(ML, true))
918     LoLoop.Start = SearchForStart(Preheader);
919   else
920     return false;
921 
922   // Find the low-overhead loop components and decide whether or not to fall
923   // back to a normal loop. Also look for a vctp instructions and decide
924   // whether we can convert that predicate using tail predication.
925   for (auto *MBB : reverse(ML->getBlocks())) {
926     for (auto &MI : *MBB) {
927       if (MI.isDebugValue())
928         continue;
929       else if (MI.getOpcode() == ARM::t2LoopDec)
930         LoLoop.Dec = &MI;
931       else if (MI.getOpcode() == ARM::t2LoopEnd)
932         LoLoop.End = &MI;
933       else if (isLoopStart(MI))
934         LoLoop.Start = &MI;
935       else if (MI.getDesc().isCall()) {
936         // TODO: Though the call will require LE to execute again, does this
937         // mean we should revert? Always executing LE hopefully should be
938         // faster than performing a sub,cmp,br or even subs,br.
939         LoLoop.Revert = true;
940         LLVM_DEBUG(dbgs() << "ARM Loops: Found call.\n");
941       } else {
942         // Record VPR defs and build up their corresponding vpt blocks.
943         // Check we know how to tail predicate any mve instructions.
944         LoLoop.AnalyseMVEInst(&MI);
945       }
946     }
947   }
948 
949   LLVM_DEBUG(LoLoop.dump());
950   if (!LoLoop.FoundAllComponents()) {
951     LLVM_DEBUG(dbgs() << "ARM Loops: Didn't find loop start, update, end\n");
952     return false;
953   }
954 
955   // Check that the only instruction using LoopDec is LoopEnd.
956   // TODO: Check for copy chains that really have no effect.
957   SmallPtrSet<MachineInstr*, 2> Uses;
958   RDA->getReachingLocalUses(LoLoop.Dec, ARM::LR, Uses);
959   if (Uses.size() > 1 || !Uses.count(LoLoop.End)) {
960     LLVM_DEBUG(dbgs() << "ARM Loops: Unable to remove LoopDec.\n");
961     LoLoop.Revert = true;
962   }
963   LoLoop.CheckLegality(BBUtils.get());
964   Expand(LoLoop);
965   return true;
966 }
967 
968 // WhileLoopStart holds the exit block, so produce a cmp lr, 0 and then a
969 // beq that branches to the exit branch.
970 // TODO: We could also try to generate a cbz if the value in LR is also in
971 // another low register.
972 void ARMLowOverheadLoops::RevertWhile(MachineInstr *MI) const {
973   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to cmp: " << *MI);
974   MachineBasicBlock *MBB = MI->getParent();
975   MachineInstrBuilder MIB = BuildMI(*MBB, MI, MI->getDebugLoc(),
976                                     TII->get(ARM::t2CMPri));
977   MIB.add(MI->getOperand(0));
978   MIB.addImm(0);
979   MIB.addImm(ARMCC::AL);
980   MIB.addReg(ARM::NoRegister);
981 
982   MachineBasicBlock *DestBB = MI->getOperand(1).getMBB();
983   unsigned BrOpc = BBUtils->isBBInRange(MI, DestBB, 254) ?
984     ARM::tBcc : ARM::t2Bcc;
985 
986   MIB = BuildMI(*MBB, MI, MI->getDebugLoc(), TII->get(BrOpc));
987   MIB.add(MI->getOperand(1));   // branch target
988   MIB.addImm(ARMCC::EQ);        // condition code
989   MIB.addReg(ARM::CPSR);
990   MI->eraseFromParent();
991 }
992 
993 bool ARMLowOverheadLoops::RevertLoopDec(MachineInstr *MI) const {
994   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to sub: " << *MI);
995   MachineBasicBlock *MBB = MI->getParent();
996   SmallPtrSet<MachineInstr*, 1> Ignore;
997   for (auto I = MachineBasicBlock::iterator(MI), E = MBB->end(); I != E; ++I) {
998     if (I->getOpcode() == ARM::t2LoopEnd) {
999       Ignore.insert(&*I);
1000       break;
1001     }
1002   }
1003 
1004   // If nothing defines CPSR between LoopDec and LoopEnd, use a t2SUBS.
1005   bool SetFlags = RDA->isSafeToDefRegAt(MI, ARM::CPSR, Ignore);
1006 
1007   MachineInstrBuilder MIB = BuildMI(*MBB, MI, MI->getDebugLoc(),
1008                                     TII->get(ARM::t2SUBri));
1009   MIB.addDef(ARM::LR);
1010   MIB.add(MI->getOperand(1));
1011   MIB.add(MI->getOperand(2));
1012   MIB.addImm(ARMCC::AL);
1013   MIB.addReg(0);
1014 
1015   if (SetFlags) {
1016     MIB.addReg(ARM::CPSR);
1017     MIB->getOperand(5).setIsDef(true);
1018   } else
1019     MIB.addReg(0);
1020 
1021   MI->eraseFromParent();
1022   return SetFlags;
1023 }
1024 
1025 // Generate a subs, or sub and cmp, and a branch instead of an LE.
1026 void ARMLowOverheadLoops::RevertLoopEnd(MachineInstr *MI, bool SkipCmp) const {
1027   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to cmp, br: " << *MI);
1028 
1029   MachineBasicBlock *MBB = MI->getParent();
1030   // Create cmp
1031   if (!SkipCmp) {
1032     MachineInstrBuilder MIB = BuildMI(*MBB, MI, MI->getDebugLoc(),
1033                                       TII->get(ARM::t2CMPri));
1034     MIB.addReg(ARM::LR);
1035     MIB.addImm(0);
1036     MIB.addImm(ARMCC::AL);
1037     MIB.addReg(ARM::NoRegister);
1038   }
1039 
1040   MachineBasicBlock *DestBB = MI->getOperand(1).getMBB();
1041   unsigned BrOpc = BBUtils->isBBInRange(MI, DestBB, 254) ?
1042     ARM::tBcc : ARM::t2Bcc;
1043 
1044   // Create bne
1045   MachineInstrBuilder MIB =
1046     BuildMI(*MBB, MI, MI->getDebugLoc(), TII->get(BrOpc));
1047   MIB.add(MI->getOperand(1));   // branch target
1048   MIB.addImm(ARMCC::NE);        // condition code
1049   MIB.addReg(ARM::CPSR);
1050   MI->eraseFromParent();
1051 }
1052 
1053 // Perform dead code elimation on the loop iteration count setup expression.
1054 // If we are tail-predicating, the number of elements to be processed is the
1055 // operand of the VCTP instruction in the vector body, see getCount(), which is
1056 // register $r3 in this example:
1057 //
1058 //   $lr = big-itercount-expression
1059 //   ..
1060 //   t2DoLoopStart renamable $lr
1061 //   vector.body:
1062 //     ..
1063 //     $vpr = MVE_VCTP32 renamable $r3
1064 //     renamable $lr = t2LoopDec killed renamable $lr, 1
1065 //     t2LoopEnd renamable $lr, %vector.body
1066 //     tB %end
1067 //
1068 // What we would like achieve here is to replace the do-loop start pseudo
1069 // instruction t2DoLoopStart with:
1070 //
1071 //    $lr = MVE_DLSTP_32 killed renamable $r3
1072 //
1073 // Thus, $r3 which defines the number of elements, is written to $lr,
1074 // and then we want to delete the whole chain that used to define $lr,
1075 // see the comment below how this chain could look like.
1076 //
1077 void ARMLowOverheadLoops::IterationCountDCE(LowOverheadLoop &LoLoop) {
1078   if (!LoLoop.IsTailPredicationLegal())
1079     return;
1080 
1081   LLVM_DEBUG(dbgs() << "ARM Loops: Trying DCE on loop iteration count.\n");
1082 
1083   MachineInstr *Def = RDA->getMIOperand(LoLoop.Start, 0);
1084   if (!Def) {
1085     LLVM_DEBUG(dbgs() << "ARM Loops: Couldn't find iteration count.\n");
1086     return;
1087   }
1088 
1089   // Collect and remove the users of iteration count.
1090   SmallPtrSet<MachineInstr*, 4> Killed  = { LoLoop.Start, LoLoop.Dec,
1091                                             LoLoop.End, LoLoop.InsertPt };
1092   SmallPtrSet<MachineInstr*, 2> Remove;
1093   if (RDA->isSafeToRemove(Def, Remove, Killed))
1094     LoLoop.ToRemove.insert(Remove.begin(), Remove.end());
1095   else {
1096     LLVM_DEBUG(dbgs() << "ARM Loops: Unsafe to remove loop iteration count.\n");
1097     return;
1098   }
1099 
1100   // Collect the dead code and the MBBs in which they reside.
1101   RDA->collectKilledOperands(Def, Killed);
1102   SmallPtrSet<MachineBasicBlock*, 2> BasicBlocks;
1103   for (auto *MI : Killed)
1104     BasicBlocks.insert(MI->getParent());
1105 
1106   // Collect IT blocks in all affected basic blocks.
1107   std::map<MachineInstr *, SmallPtrSet<MachineInstr *, 2>> ITBlocks;
1108   for (auto *MBB : BasicBlocks) {
1109     for (auto &MI : *MBB) {
1110       if (MI.getOpcode() != ARM::t2IT)
1111         continue;
1112       RDA->getReachingLocalUses(&MI, ARM::ITSTATE, ITBlocks[&MI]);
1113     }
1114   }
1115 
1116   // If we're removing all of the instructions within an IT block, then
1117   // also remove the IT instruction.
1118   SmallPtrSet<MachineInstr*, 2> ModifiedITs;
1119   for (auto *MI : Killed) {
1120     if (MachineOperand *MO = MI->findRegisterUseOperand(ARM::ITSTATE)) {
1121       MachineInstr *IT = RDA->getMIOperand(MI, *MO);
1122       auto &CurrentBlock = ITBlocks[IT];
1123       CurrentBlock.erase(MI);
1124       if (CurrentBlock.empty())
1125         ModifiedITs.erase(IT);
1126       else
1127         ModifiedITs.insert(IT);
1128     }
1129   }
1130 
1131   // Delete the killed instructions only if we don't have any IT blocks that
1132   // need to be modified because we need to fixup the mask.
1133   // TODO: Handle cases where IT blocks are modified.
1134   if (ModifiedITs.empty()) {
1135     LLVM_DEBUG(dbgs() << "ARM Loops: Will remove iteration count:\n";
1136                for (auto *MI : Killed)
1137                  dbgs() << " - " << *MI);
1138     LoLoop.ToRemove.insert(Killed.begin(), Killed.end());
1139   } else
1140     LLVM_DEBUG(dbgs() << "ARM Loops: Would need to modify IT block(s).\n");
1141 }
1142 
1143 MachineInstr* ARMLowOverheadLoops::ExpandLoopStart(LowOverheadLoop &LoLoop) {
1144   LLVM_DEBUG(dbgs() << "ARM Loops: Expanding LoopStart.\n");
1145   // When using tail-predication, try to delete the dead code that was used to
1146   // calculate the number of loop iterations.
1147   IterationCountDCE(LoLoop);
1148 
1149   MachineInstr *InsertPt = LoLoop.InsertPt;
1150   MachineInstr *Start = LoLoop.Start;
1151   MachineBasicBlock *MBB = InsertPt->getParent();
1152   bool IsDo = Start->getOpcode() == ARM::t2DoLoopStart;
1153   unsigned Opc = LoLoop.getStartOpcode();
1154   MachineOperand &Count = LoLoop.getCount();
1155 
1156   MachineInstrBuilder MIB =
1157     BuildMI(*MBB, InsertPt, InsertPt->getDebugLoc(), TII->get(Opc));
1158 
1159   MIB.addDef(ARM::LR);
1160   MIB.add(Count);
1161   if (!IsDo)
1162     MIB.add(Start->getOperand(1));
1163 
1164   // If we're inserting at a mov lr, then remove it as it's redundant.
1165   if (InsertPt != Start)
1166     LoLoop.ToRemove.insert(InsertPt);
1167   LoLoop.ToRemove.insert(Start);
1168   LLVM_DEBUG(dbgs() << "ARM Loops: Inserted start: " << *MIB);
1169   return &*MIB;
1170 }
1171 
1172 void ARMLowOverheadLoops::ConvertVPTBlocks(LowOverheadLoop &LoLoop) {
1173   auto RemovePredicate = [](MachineInstr *MI) {
1174     LLVM_DEBUG(dbgs() << "ARM Loops: Removing predicate from: " << *MI);
1175     if (int PIdx = llvm::findFirstVPTPredOperandIdx(*MI)) {
1176       assert(MI->getOperand(PIdx).getImm() == ARMVCC::Then &&
1177              "Expected Then predicate!");
1178       MI->getOperand(PIdx).setImm(ARMVCC::None);
1179       MI->getOperand(PIdx+1).setReg(0);
1180     } else
1181       llvm_unreachable("trying to unpredicate a non-predicated instruction");
1182   };
1183 
1184   // There are a few scenarios which we have to fix up:
1185   // 1) A VPT block with is only predicated by the vctp and has no internal vpr
1186   //    defs.
1187   // 2) A VPT block which is only predicated by the vctp but has an internal
1188   //    vpr def.
1189   // 3) A VPT block which is predicated upon the vctp as well as another vpr
1190   //    def.
1191   // 4) A VPT block which is not predicated upon a vctp, but contains it and
1192   //    all instructions within the block are predicated upon in.
1193 
1194   for (auto &Block : LoLoop.getVPTBlocks()) {
1195     SmallVectorImpl<PredicatedMI> &Insts = Block.getInsts();
1196     if (Block.HasNonUniformPredicate()) {
1197       PredicatedMI *Divergent = Block.getDivergent();
1198       if (isVCTP(Divergent->MI)) {
1199         // The vctp will be removed, so the size of the vpt block needs to be
1200         // modified.
1201         uint64_t Size = getARMVPTBlockMask(Block.size() - 1);
1202         Block.getVPST()->getOperand(0).setImm(Size);
1203         LLVM_DEBUG(dbgs() << "ARM Loops: Modified VPT block mask.\n");
1204       } else if (Block.IsOnlyPredicatedOn(LoLoop.VCTP)) {
1205         // The VPT block has a non-uniform predicate but it's entry is guarded
1206         // only by a vctp, which means we:
1207         // - Need to remove the original vpst.
1208         // - Then need to unpredicate any following instructions, until
1209         //   we come across the divergent vpr def.
1210         // - Insert a new vpst to predicate the instruction(s) that following
1211         //   the divergent vpr def.
1212         // TODO: We could be producing more VPT blocks than necessary and could
1213         // fold the newly created one into a proceeding one.
1214         for (auto I = ++MachineBasicBlock::iterator(Block.getVPST()),
1215              E = ++MachineBasicBlock::iterator(Divergent->MI); I != E; ++I)
1216           RemovePredicate(&*I);
1217 
1218         unsigned Size = 0;
1219         auto E = MachineBasicBlock::reverse_iterator(Divergent->MI);
1220         auto I = MachineBasicBlock::reverse_iterator(Insts.back().MI);
1221         MachineInstr *InsertAt = nullptr;
1222         while (I != E) {
1223           InsertAt = &*I;
1224           ++Size;
1225           ++I;
1226         }
1227         MachineInstrBuilder MIB = BuildMI(*InsertAt->getParent(), InsertAt,
1228                                           InsertAt->getDebugLoc(),
1229                                           TII->get(ARM::MVE_VPST));
1230         MIB.addImm(getARMVPTBlockMask(Size));
1231         LLVM_DEBUG(dbgs() << "ARM Loops: Removing VPST: " << *Block.getVPST());
1232         LLVM_DEBUG(dbgs() << "ARM Loops: Created VPST: " << *MIB);
1233         LoLoop.ToRemove.insert(Block.getVPST());
1234       }
1235     } else if (Block.IsOnlyPredicatedOn(LoLoop.VCTP)) {
1236       // A vpt block which is only predicated upon vctp and has no internal vpr
1237       // defs:
1238       // - Remove vpst.
1239       // - Unpredicate the remaining instructions.
1240       LLVM_DEBUG(dbgs() << "ARM Loops: Removing VPST: " << *Block.getVPST());
1241       LoLoop.ToRemove.insert(Block.getVPST());
1242       for (auto &PredMI : Insts)
1243         RemovePredicate(PredMI.MI);
1244     }
1245   }
1246   LLVM_DEBUG(dbgs() << "ARM Loops: Removing VCTP: " << *LoLoop.VCTP);
1247   LoLoop.ToRemove.insert(LoLoop.VCTP);
1248 }
1249 
1250 void ARMLowOverheadLoops::Expand(LowOverheadLoop &LoLoop) {
1251 
1252   // Combine the LoopDec and LoopEnd instructions into LE(TP).
1253   auto ExpandLoopEnd = [this](LowOverheadLoop &LoLoop) {
1254     MachineInstr *End = LoLoop.End;
1255     MachineBasicBlock *MBB = End->getParent();
1256     unsigned Opc = LoLoop.IsTailPredicationLegal() ?
1257       ARM::MVE_LETP : ARM::t2LEUpdate;
1258     MachineInstrBuilder MIB = BuildMI(*MBB, End, End->getDebugLoc(),
1259                                       TII->get(Opc));
1260     MIB.addDef(ARM::LR);
1261     MIB.add(End->getOperand(0));
1262     MIB.add(End->getOperand(1));
1263     LLVM_DEBUG(dbgs() << "ARM Loops: Inserted LE: " << *MIB);
1264     LoLoop.ToRemove.insert(LoLoop.Dec);
1265     LoLoop.ToRemove.insert(End);
1266     return &*MIB;
1267   };
1268 
1269   // TODO: We should be able to automatically remove these branches before we
1270   // get here - probably by teaching analyzeBranch about the pseudo
1271   // instructions.
1272   // If there is an unconditional branch, after I, that just branches to the
1273   // next block, remove it.
1274   auto RemoveDeadBranch = [](MachineInstr *I) {
1275     MachineBasicBlock *BB = I->getParent();
1276     MachineInstr *Terminator = &BB->instr_back();
1277     if (Terminator->isUnconditionalBranch() && I != Terminator) {
1278       MachineBasicBlock *Succ = Terminator->getOperand(0).getMBB();
1279       if (BB->isLayoutSuccessor(Succ)) {
1280         LLVM_DEBUG(dbgs() << "ARM Loops: Removing branch: " << *Terminator);
1281         Terminator->eraseFromParent();
1282       }
1283     }
1284   };
1285 
1286   if (LoLoop.Revert) {
1287     if (LoLoop.Start->getOpcode() == ARM::t2WhileLoopStart)
1288       RevertWhile(LoLoop.Start);
1289     else
1290       LoLoop.Start->eraseFromParent();
1291     bool FlagsAlreadySet = RevertLoopDec(LoLoop.Dec);
1292     RevertLoopEnd(LoLoop.End, FlagsAlreadySet);
1293   } else {
1294     LoLoop.Start = ExpandLoopStart(LoLoop);
1295     RemoveDeadBranch(LoLoop.Start);
1296     LoLoop.End = ExpandLoopEnd(LoLoop);
1297     RemoveDeadBranch(LoLoop.End);
1298     if (LoLoop.IsTailPredicationLegal())
1299       ConvertVPTBlocks(LoLoop);
1300     for (auto *I : LoLoop.ToRemove) {
1301       LLVM_DEBUG(dbgs() << "ARM Loops: Erasing " << *I);
1302       I->eraseFromParent();
1303     }
1304   }
1305 
1306   PostOrderLoopTraversal DFS(LoLoop.ML, *MLI);
1307   DFS.ProcessLoop();
1308   const SmallVectorImpl<MachineBasicBlock*> &PostOrder = DFS.getOrder();
1309   for (auto *MBB : PostOrder) {
1310     recomputeLiveIns(*MBB);
1311     // FIXME: For some reason, the live-in print order is non-deterministic for
1312     // our tests and I can't out why... So just sort them.
1313     MBB->sortUniqueLiveIns();
1314   }
1315 
1316   for (auto *MBB : reverse(PostOrder))
1317     recomputeLivenessFlags(*MBB);
1318 
1319   // We've moved, removed and inserted new instructions, so update RDA.
1320   RDA->reset();
1321 }
1322 
1323 bool ARMLowOverheadLoops::RevertNonLoops() {
1324   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting any remaining pseudos...\n");
1325   bool Changed = false;
1326 
1327   for (auto &MBB : *MF) {
1328     SmallVector<MachineInstr*, 4> Starts;
1329     SmallVector<MachineInstr*, 4> Decs;
1330     SmallVector<MachineInstr*, 4> Ends;
1331 
1332     for (auto &I : MBB) {
1333       if (isLoopStart(I))
1334         Starts.push_back(&I);
1335       else if (I.getOpcode() == ARM::t2LoopDec)
1336         Decs.push_back(&I);
1337       else if (I.getOpcode() == ARM::t2LoopEnd)
1338         Ends.push_back(&I);
1339     }
1340 
1341     if (Starts.empty() && Decs.empty() && Ends.empty())
1342       continue;
1343 
1344     Changed = true;
1345 
1346     for (auto *Start : Starts) {
1347       if (Start->getOpcode() == ARM::t2WhileLoopStart)
1348         RevertWhile(Start);
1349       else
1350         Start->eraseFromParent();
1351     }
1352     for (auto *Dec : Decs)
1353       RevertLoopDec(Dec);
1354 
1355     for (auto *End : Ends)
1356       RevertLoopEnd(End);
1357   }
1358   return Changed;
1359 }
1360 
1361 FunctionPass *llvm::createARMLowOverheadLoopsPass() {
1362   return new ARMLowOverheadLoops();
1363 }
1364