1 //===-- ARMLowOverheadLoops.cpp - CodeGen Low-overhead Loops ---*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 /// \file
9 /// Finalize v8.1-m low-overhead loops by converting the associated pseudo
10 /// instructions into machine operations.
11 /// The expectation is that the loop contains three pseudo instructions:
12 /// - t2*LoopStart - placed in the preheader or pre-preheader. The do-loop
13 ///   form should be in the preheader, whereas the while form should be in the
14 ///   preheaders only predecessor.
15 /// - t2LoopDec - placed within in the loop body.
16 /// - t2LoopEnd - the loop latch terminator.
17 ///
18 /// In addition to this, we also look for the presence of the VCTP instruction,
19 /// which determines whether we can generated the tail-predicated low-overhead
20 /// loop form.
21 ///
22 /// Assumptions and Dependencies:
23 /// Low-overhead loops are constructed and executed using a setup instruction:
24 /// DLS, WLS, DLSTP or WLSTP and an instruction that loops back: LE or LETP.
25 /// WLS(TP) and LE(TP) are branching instructions with a (large) limited range
26 /// but fixed polarity: WLS can only branch forwards and LE can only branch
27 /// backwards. These restrictions mean that this pass is dependent upon block
28 /// layout and block sizes, which is why it's the last pass to run. The same is
29 /// true for ConstantIslands, but this pass does not increase the size of the
30 /// basic blocks, nor does it change the CFG. Instructions are mainly removed
31 /// during the transform and pseudo instructions are replaced by real ones. In
32 /// some cases, when we have to revert to a 'normal' loop, we have to introduce
33 /// multiple instructions for a single pseudo (see RevertWhile and
34 /// RevertLoopEnd). To handle this situation, t2WhileLoopStart and t2LoopEnd
35 /// are defined to be as large as this maximum sequence of replacement
36 /// instructions.
37 ///
38 /// A note on VPR.P0 (the lane mask):
39 /// VPT, VCMP, VPNOT and VCTP won't overwrite VPR.P0 when they update it in a
40 /// "VPT Active" context (which includes low-overhead loops and vpt blocks).
41 /// They will simply "and" the result of their calculation with the current
42 /// value of VPR.P0. You can think of it like this:
43 /// \verbatim
44 /// if VPT active:    ; Between a DLSTP/LETP, or for predicated instrs
45 ///   VPR.P0 &= Value
46 /// else
47 ///   VPR.P0 = Value
48 /// \endverbatim
49 /// When we're inside the low-overhead loop (between DLSTP and LETP), we always
50 /// fall in the "VPT active" case, so we can consider that all VPR writes by
51 /// one of those instruction is actually a "and".
52 //===----------------------------------------------------------------------===//
53 
54 #include "ARM.h"
55 #include "ARMBaseInstrInfo.h"
56 #include "ARMBaseRegisterInfo.h"
57 #include "ARMBasicBlockInfo.h"
58 #include "ARMSubtarget.h"
59 #include "Thumb2InstrInfo.h"
60 #include "llvm/ADT/SetOperations.h"
61 #include "llvm/ADT/SmallSet.h"
62 #include "llvm/CodeGen/LivePhysRegs.h"
63 #include "llvm/CodeGen/MachineFunctionPass.h"
64 #include "llvm/CodeGen/MachineLoopInfo.h"
65 #include "llvm/CodeGen/MachineLoopUtils.h"
66 #include "llvm/CodeGen/MachineRegisterInfo.h"
67 #include "llvm/CodeGen/Passes.h"
68 #include "llvm/CodeGen/ReachingDefAnalysis.h"
69 #include "llvm/MC/MCInstrDesc.h"
70 
71 using namespace llvm;
72 
73 #define DEBUG_TYPE "arm-low-overhead-loops"
74 #define ARM_LOW_OVERHEAD_LOOPS_NAME "ARM Low Overhead Loops pass"
75 
76 namespace {
77 
78   using InstSet = SmallPtrSetImpl<MachineInstr *>;
79 
80   class PostOrderLoopTraversal {
81     MachineLoop &ML;
82     MachineLoopInfo &MLI;
83     SmallPtrSet<MachineBasicBlock*, 4> Visited;
84     SmallVector<MachineBasicBlock*, 4> Order;
85 
86   public:
87     PostOrderLoopTraversal(MachineLoop &ML, MachineLoopInfo &MLI)
88       : ML(ML), MLI(MLI) { }
89 
90     const SmallVectorImpl<MachineBasicBlock*> &getOrder() const {
91       return Order;
92     }
93 
94     // Visit all the blocks within the loop, as well as exit blocks and any
95     // blocks properly dominating the header.
96     void ProcessLoop() {
97       std::function<void(MachineBasicBlock*)> Search = [this, &Search]
98         (MachineBasicBlock *MBB) -> void {
99         if (Visited.count(MBB))
100           return;
101 
102         Visited.insert(MBB);
103         for (auto *Succ : MBB->successors()) {
104           if (!ML.contains(Succ))
105             continue;
106           Search(Succ);
107         }
108         Order.push_back(MBB);
109       };
110 
111       // Insert exit blocks.
112       SmallVector<MachineBasicBlock*, 2> ExitBlocks;
113       ML.getExitBlocks(ExitBlocks);
114       for (auto *MBB : ExitBlocks)
115         Order.push_back(MBB);
116 
117       // Then add the loop body.
118       Search(ML.getHeader());
119 
120       // Then try the preheader and its predecessors.
121       std::function<void(MachineBasicBlock*)> GetPredecessor =
122         [this, &GetPredecessor] (MachineBasicBlock *MBB) -> void {
123         Order.push_back(MBB);
124         if (MBB->pred_size() == 1)
125           GetPredecessor(*MBB->pred_begin());
126       };
127 
128       if (auto *Preheader = ML.getLoopPreheader())
129         GetPredecessor(Preheader);
130       else if (auto *Preheader = MLI.findLoopPreheader(&ML, true))
131         GetPredecessor(Preheader);
132     }
133   };
134 
135   struct PredicatedMI {
136     MachineInstr *MI = nullptr;
137     SetVector<MachineInstr*> Predicates;
138 
139   public:
140     PredicatedMI(MachineInstr *I, SetVector<MachineInstr *> &Preds) : MI(I) {
141       assert(I && "Instruction must not be null!");
142       Predicates.insert(Preds.begin(), Preds.end());
143     }
144   };
145 
146   // Represent a VPT block, a list of instructions that begins with a VPT/VPST
147   // and has a maximum of four proceeding instructions. All instructions within
148   // the block are predicated upon the vpr and we allow instructions to define
149   // the vpr within in the block too.
150   class VPTBlock {
151     // The predicate then instruction, which is either a VPT, or a VPST
152     // instruction.
153     std::unique_ptr<PredicatedMI> PredicateThen;
154     PredicatedMI *Divergent = nullptr;
155     SmallVector<PredicatedMI, 4> Insts;
156 
157   public:
158     VPTBlock(MachineInstr *MI, SetVector<MachineInstr*> &Preds) {
159       PredicateThen = std::make_unique<PredicatedMI>(MI, Preds);
160     }
161 
162     void addInst(MachineInstr *MI, SetVector<MachineInstr*> &Preds) {
163       LLVM_DEBUG(dbgs() << "ARM Loops: Adding predicated MI: " << *MI);
164       if (!Divergent && !set_difference(Preds, PredicateThen->Predicates).empty()) {
165         Divergent = &Insts.back();
166         LLVM_DEBUG(dbgs() << " - has divergent predicate: " << *Divergent->MI);
167       }
168       Insts.emplace_back(MI, Preds);
169       assert(Insts.size() <= 4 && "Too many instructions in VPT block!");
170     }
171 
172     // Have we found an instruction within the block which defines the vpr? If
173     // so, not all the instructions in the block will have the same predicate.
174     bool HasNonUniformPredicate() const {
175       return Divergent != nullptr;
176     }
177 
178     // Is the given instruction part of the predicate set controlling the entry
179     // to the block.
180     bool IsPredicatedOn(MachineInstr *MI) const {
181       return PredicateThen->Predicates.count(MI);
182     }
183 
184     // Returns true if this is a VPT instruction.
185     bool isVPT() const { return !isVPST(); }
186 
187     // Returns true if this is a VPST instruction.
188     bool isVPST() const {
189       return PredicateThen->MI->getOpcode() == ARM::MVE_VPST;
190     }
191 
192     // Is the given instruction the only predicate which controls the entry to
193     // the block.
194     bool IsOnlyPredicatedOn(MachineInstr *MI) const {
195       return IsPredicatedOn(MI) && PredicateThen->Predicates.size() == 1;
196     }
197 
198     unsigned size() const { return Insts.size(); }
199     SmallVectorImpl<PredicatedMI> &getInsts() { return Insts; }
200     MachineInstr *getPredicateThen() const { return PredicateThen->MI; }
201     PredicatedMI *getDivergent() const { return Divergent; }
202   };
203 
204   struct LowOverheadLoop {
205 
206     MachineLoop &ML;
207     MachineLoopInfo &MLI;
208     ReachingDefAnalysis &RDA;
209     const TargetRegisterInfo &TRI;
210     MachineFunction *MF = nullptr;
211     MachineInstr *InsertPt = nullptr;
212     MachineInstr *Start = nullptr;
213     MachineInstr *Dec = nullptr;
214     MachineInstr *End = nullptr;
215     MachineInstr *VCTP = nullptr;
216     SmallPtrSet<MachineInstr*, 4> SecondaryVCTPs;
217     VPTBlock *CurrentBlock = nullptr;
218     SetVector<MachineInstr*> CurrentPredicate;
219     SmallVector<VPTBlock, 4> VPTBlocks;
220     SmallPtrSet<MachineInstr*, 4> ToRemove;
221     SmallPtrSet<MachineInstr*, 4> BlockMasksToRecompute;
222     bool Revert = false;
223     bool CannotTailPredicate = false;
224 
225     LowOverheadLoop(MachineLoop &ML, MachineLoopInfo &MLI,
226                     ReachingDefAnalysis &RDA, const TargetRegisterInfo &TRI)
227       : ML(ML), MLI(MLI), RDA(RDA), TRI(TRI) {
228       MF = ML.getHeader()->getParent();
229     }
230 
231     // If this is an MVE instruction, check that we know how to use tail
232     // predication with it. Record VPT blocks and return whether the
233     // instruction is valid for tail predication.
234     bool ValidateMVEInst(MachineInstr *MI);
235 
236     void AnalyseMVEInst(MachineInstr *MI) {
237       CannotTailPredicate = !ValidateMVEInst(MI);
238     }
239 
240     bool IsTailPredicationLegal() const {
241       // For now, let's keep things really simple and only support a single
242       // block for tail predication.
243       return !Revert && FoundAllComponents() && VCTP &&
244              !CannotTailPredicate && ML.getNumBlocks() == 1;
245     }
246 
247     // Check that the predication in the loop will be equivalent once we
248     // perform the conversion. Also ensure that we can provide the number
249     // of elements to the loop start instruction.
250     bool ValidateTailPredicate(MachineInstr *StartInsertPt);
251 
252     // Check that any values available outside of the loop will be the same
253     // after tail predication conversion.
254     bool ValidateLiveOuts() const;
255 
256     // Is it safe to define LR with DLS/WLS?
257     // LR can be defined if it is the operand to start, because it's the same
258     // value, or if it's going to be equivalent to the operand to Start.
259     MachineInstr *isSafeToDefineLR();
260 
261     // Check the branch targets are within range and we satisfy our
262     // restrictions.
263     void CheckLegality(ARMBasicBlockUtils *BBUtils);
264 
265     bool FoundAllComponents() const {
266       return Start && Dec && End;
267     }
268 
269     SmallVectorImpl<VPTBlock> &getVPTBlocks() { return VPTBlocks; }
270 
271     // Return the loop iteration count, or the number of elements if we're tail
272     // predicating.
273     MachineOperand &getCount() {
274       return IsTailPredicationLegal() ?
275         VCTP->getOperand(1) : Start->getOperand(0);
276     }
277 
278     unsigned getStartOpcode() const {
279       bool IsDo = Start->getOpcode() == ARM::t2DoLoopStart;
280       if (!IsTailPredicationLegal())
281         return IsDo ? ARM::t2DLS : ARM::t2WLS;
282 
283       return VCTPOpcodeToLSTP(VCTP->getOpcode(), IsDo);
284     }
285 
286     void dump() const {
287       if (Start) dbgs() << "ARM Loops: Found Loop Start: " << *Start;
288       if (Dec) dbgs() << "ARM Loops: Found Loop Dec: " << *Dec;
289       if (End) dbgs() << "ARM Loops: Found Loop End: " << *End;
290       if (VCTP) dbgs() << "ARM Loops: Found VCTP: " << *VCTP;
291       if (!FoundAllComponents())
292         dbgs() << "ARM Loops: Not a low-overhead loop.\n";
293       else if (!(Start && Dec && End))
294         dbgs() << "ARM Loops: Failed to find all loop components.\n";
295     }
296   };
297 
298   class ARMLowOverheadLoops : public MachineFunctionPass {
299     MachineFunction           *MF = nullptr;
300     MachineLoopInfo           *MLI = nullptr;
301     ReachingDefAnalysis       *RDA = nullptr;
302     const ARMBaseInstrInfo    *TII = nullptr;
303     MachineRegisterInfo       *MRI = nullptr;
304     const TargetRegisterInfo  *TRI = nullptr;
305     std::unique_ptr<ARMBasicBlockUtils> BBUtils = nullptr;
306 
307   public:
308     static char ID;
309 
310     ARMLowOverheadLoops() : MachineFunctionPass(ID) { }
311 
312     void getAnalysisUsage(AnalysisUsage &AU) const override {
313       AU.setPreservesCFG();
314       AU.addRequired<MachineLoopInfo>();
315       AU.addRequired<ReachingDefAnalysis>();
316       MachineFunctionPass::getAnalysisUsage(AU);
317     }
318 
319     bool runOnMachineFunction(MachineFunction &MF) override;
320 
321     MachineFunctionProperties getRequiredProperties() const override {
322       return MachineFunctionProperties().set(
323           MachineFunctionProperties::Property::NoVRegs).set(
324           MachineFunctionProperties::Property::TracksLiveness);
325     }
326 
327     StringRef getPassName() const override {
328       return ARM_LOW_OVERHEAD_LOOPS_NAME;
329     }
330 
331   private:
332     bool ProcessLoop(MachineLoop *ML);
333 
334     bool RevertNonLoops();
335 
336     void RevertWhile(MachineInstr *MI) const;
337 
338     bool RevertLoopDec(MachineInstr *MI) const;
339 
340     void RevertLoopEnd(MachineInstr *MI, bool SkipCmp = false) const;
341 
342     void ConvertVPTBlocks(LowOverheadLoop &LoLoop);
343 
344     MachineInstr *ExpandLoopStart(LowOverheadLoop &LoLoop);
345 
346     void Expand(LowOverheadLoop &LoLoop);
347 
348     void IterationCountDCE(LowOverheadLoop &LoLoop);
349   };
350 }
351 
352 char ARMLowOverheadLoops::ID = 0;
353 
354 INITIALIZE_PASS(ARMLowOverheadLoops, DEBUG_TYPE, ARM_LOW_OVERHEAD_LOOPS_NAME,
355                 false, false)
356 
357 MachineInstr *LowOverheadLoop::isSafeToDefineLR() {
358   // We can define LR because LR already contains the same value.
359   if (Start->getOperand(0).getReg() == ARM::LR)
360     return Start;
361 
362   unsigned CountReg = Start->getOperand(0).getReg();
363   auto IsMoveLR = [&CountReg](MachineInstr *MI) {
364     return MI->getOpcode() == ARM::tMOVr &&
365            MI->getOperand(0).getReg() == ARM::LR &&
366            MI->getOperand(1).getReg() == CountReg &&
367            MI->getOperand(2).getImm() == ARMCC::AL;
368    };
369 
370   MachineBasicBlock *MBB = Start->getParent();
371 
372   // Find an insertion point:
373   // - Is there a (mov lr, Count) before Start? If so, and nothing else writes
374   //   to Count before Start, we can insert at that mov.
375   if (auto *LRDef = RDA.getUniqueReachingMIDef(Start, ARM::LR))
376     if (IsMoveLR(LRDef) && RDA.hasSameReachingDef(Start, LRDef, CountReg))
377       return LRDef;
378 
379   // - Is there a (mov lr, Count) after Start? If so, and nothing else writes
380   //   to Count after Start, we can insert at that mov.
381   if (auto *LRDef = RDA.getLocalLiveOutMIDef(MBB, ARM::LR))
382     if (IsMoveLR(LRDef) && RDA.hasSameReachingDef(Start, LRDef, CountReg))
383       return LRDef;
384 
385   // We've found no suitable LR def and Start doesn't use LR directly. Can we
386   // just define LR anyway?
387   return RDA.isSafeToDefRegAt(Start, ARM::LR) ? Start : nullptr;
388 }
389 
390 bool LowOverheadLoop::ValidateTailPredicate(MachineInstr *StartInsertPt) {
391   assert(VCTP && "VCTP instruction expected but is not set");
392   // All predication within the loop should be based on vctp. If the block
393   // isn't predicated on entry, check whether the vctp is within the block
394   // and that all other instructions are then predicated on it.
395   for (auto &Block : VPTBlocks) {
396     if (Block.IsPredicatedOn(VCTP))
397       continue;
398     if (Block.HasNonUniformPredicate() && !isVCTP(Block.getDivergent()->MI)) {
399       LLVM_DEBUG(dbgs() << "ARM Loops: Found unsupported diverging predicate: "
400                         << *Block.getDivergent()->MI);
401       return false;
402     }
403     SmallVectorImpl<PredicatedMI> &Insts = Block.getInsts();
404     for (auto &PredMI : Insts) {
405       // Check the instructions in the block and only allow:
406       //   - VCTPs
407       //   - Instructions predicated on the main VCTP
408       //   - Any VCMP
409       //      - VCMPs just "and" their result with VPR.P0. Whether they are
410       //      located before/after the VCTP is irrelevant - the end result will
411       //      be the same in both cases, so there's no point in requiring them
412       //      to be located after the VCTP!
413       if (PredMI.Predicates.count(VCTP) || isVCTP(PredMI.MI) ||
414           VCMPOpcodeToVPT(PredMI.MI->getOpcode()) != 0)
415         continue;
416       LLVM_DEBUG(dbgs() << "ARM Loops: Can't convert: " << *PredMI.MI
417                  << " - which is predicated on:\n";
418                  for (auto *MI : PredMI.Predicates)
419                    dbgs() << "   - " << *MI);
420       return false;
421     }
422   }
423 
424   if (!ValidateLiveOuts())
425     return false;
426 
427   // For tail predication, we need to provide the number of elements, instead
428   // of the iteration count, to the loop start instruction. The number of
429   // elements is provided to the vctp instruction, so we need to check that
430   // we can use this register at InsertPt.
431   Register NumElements = VCTP->getOperand(1).getReg();
432 
433   // If the register is defined within loop, then we can't perform TP.
434   // TODO: Check whether this is just a mov of a register that would be
435   // available.
436   if (RDA.hasLocalDefBefore(VCTP, NumElements)) {
437     LLVM_DEBUG(dbgs() << "ARM Loops: VCTP operand is defined in the loop.\n");
438     return false;
439   }
440 
441   // The element count register maybe defined after InsertPt, in which case we
442   // need to try to move either InsertPt or the def so that the [w|d]lstp can
443   // use the value.
444   // TODO: On failing to move an instruction, check if the count is provided by
445   // a mov and whether we can use the mov operand directly.
446   MachineBasicBlock *InsertBB = StartInsertPt->getParent();
447   if (!RDA.isReachingDefLiveOut(StartInsertPt, NumElements)) {
448     if (auto *ElemDef = RDA.getLocalLiveOutMIDef(InsertBB, NumElements)) {
449       if (RDA.isSafeToMoveForwards(ElemDef, StartInsertPt)) {
450         ElemDef->removeFromParent();
451         InsertBB->insert(MachineBasicBlock::iterator(StartInsertPt), ElemDef);
452         LLVM_DEBUG(dbgs() << "ARM Loops: Moved element count def: "
453                    << *ElemDef);
454       } else if (RDA.isSafeToMoveBackwards(StartInsertPt, ElemDef)) {
455         StartInsertPt->removeFromParent();
456         InsertBB->insertAfter(MachineBasicBlock::iterator(ElemDef),
457                               StartInsertPt);
458         LLVM_DEBUG(dbgs() << "ARM Loops: Moved start past: " << *ElemDef);
459       } else {
460         LLVM_DEBUG(dbgs() << "ARM Loops: Unable to move element count to loop "
461                    << "start instruction.\n");
462         return false;
463       }
464     }
465   }
466 
467   // Especially in the case of while loops, InsertBB may not be the
468   // preheader, so we need to check that the register isn't redefined
469   // before entering the loop.
470   auto CannotProvideElements = [this](MachineBasicBlock *MBB,
471                                       Register NumElements) {
472     // NumElements is redefined in this block.
473     if (RDA.hasLocalDefBefore(&MBB->back(), NumElements))
474       return true;
475 
476     // Don't continue searching up through multiple predecessors.
477     if (MBB->pred_size() > 1)
478       return true;
479 
480     return false;
481   };
482 
483   // First, find the block that looks like the preheader.
484   MachineBasicBlock *MBB = MLI.findLoopPreheader(&ML, true);
485   if (!MBB) {
486     LLVM_DEBUG(dbgs() << "ARM Loops: Didn't find preheader.\n");
487     return false;
488   }
489 
490   // Then search backwards for a def, until we get to InsertBB.
491   while (MBB != InsertBB) {
492     if (CannotProvideElements(MBB, NumElements)) {
493       LLVM_DEBUG(dbgs() << "ARM Loops: Unable to provide element count.\n");
494       return false;
495     }
496     MBB = *MBB->pred_begin();
497   }
498 
499   // Check that the value change of the element count is what we expect and
500   // that the predication will be equivalent. For this we need:
501   // NumElements = NumElements - VectorWidth. The sub will be a sub immediate
502   // and we can also allow register copies within the chain too.
503   auto IsValidSub = [](MachineInstr *MI, int ExpectedVecWidth) {
504     return -getAddSubImmediate(*MI) == ExpectedVecWidth;
505   };
506 
507   MBB = VCTP->getParent();
508   if (auto *Def = RDA.getUniqueReachingMIDef(&MBB->back(), NumElements)) {
509     SmallPtrSet<MachineInstr*, 2> ElementChain;
510     SmallPtrSet<MachineInstr*, 2> Ignore = { VCTP };
511     unsigned ExpectedVectorWidth = getTailPredVectorWidth(VCTP->getOpcode());
512 
513     Ignore.insert(SecondaryVCTPs.begin(), SecondaryVCTPs.end());
514 
515     if (RDA.isSafeToRemove(Def, ElementChain, Ignore)) {
516       bool FoundSub = false;
517 
518       for (auto *MI : ElementChain) {
519         if (isMovRegOpcode(MI->getOpcode()))
520           continue;
521 
522         if (isSubImmOpcode(MI->getOpcode())) {
523           if (FoundSub || !IsValidSub(MI, ExpectedVectorWidth))
524             return false;
525           FoundSub = true;
526         } else
527           return false;
528       }
529 
530       LLVM_DEBUG(dbgs() << "ARM Loops: Will remove element count chain:\n";
531                  for (auto *MI : ElementChain)
532                    dbgs() << " - " << *MI);
533       ToRemove.insert(ElementChain.begin(), ElementChain.end());
534     }
535   }
536   return true;
537 }
538 
539 static bool isVectorPredicated(MachineInstr *MI) {
540   int PIdx = llvm::findFirstVPTPredOperandIdx(*MI);
541   return PIdx != -1 && MI->getOperand(PIdx + 1).getReg() == ARM::VPR;
542 }
543 
544 static bool isRegInClass(const MachineOperand &MO,
545                          const TargetRegisterClass *Class) {
546   return MO.isReg() && MO.getReg() && Class->contains(MO.getReg());
547 }
548 
549 // MVE 'narrowing' operate on half a lane, reading from half and writing
550 // to half, which are referred to has the top and bottom half. The other
551 // half retains its previous value.
552 static bool retainsPreviousHalfElement(const MachineInstr &MI) {
553   const MCInstrDesc &MCID = MI.getDesc();
554   uint64_t Flags = MCID.TSFlags;
555   return (Flags & ARMII::RetainsPreviousHalfElement) != 0;
556 }
557 
558 // Some MVE instructions read from the top/bottom halves of their operand(s)
559 // and generate a vector result with result elements that are double the
560 // width of the input.
561 static bool producesDoubleWidthResult(const MachineInstr &MI) {
562   const MCInstrDesc &MCID = MI.getDesc();
563   uint64_t Flags = MCID.TSFlags;
564   return (Flags & ARMII::DoubleWidthResult) != 0;
565 }
566 
567 static bool isHorizontalReduction(const MachineInstr &MI) {
568   const MCInstrDesc &MCID = MI.getDesc();
569   uint64_t Flags = MCID.TSFlags;
570   return (Flags & ARMII::HorizontalReduction) != 0;
571 }
572 
573 // Can this instruction generate a non-zero result when given only zeroed
574 // operands? This allows us to know that, given operands with false bytes
575 // zeroed by masked loads, that the result will also contain zeros in those
576 // bytes.
577 static bool canGenerateNonZeros(const MachineInstr &MI) {
578 
579   // Check for instructions which can write into a larger element size,
580   // possibly writing into a previous zero'd lane.
581   if (producesDoubleWidthResult(MI))
582     return true;
583 
584   switch (MI.getOpcode()) {
585   default:
586     break;
587   // FIXME: VNEG FP and -0? I think we'll need to handle this once we allow
588   // fp16 -> fp32 vector conversions.
589   // Instructions that perform a NOT will generate 1s from 0s.
590   case ARM::MVE_VMVN:
591   case ARM::MVE_VORN:
592   // Count leading zeros will do just that!
593   case ARM::MVE_VCLZs8:
594   case ARM::MVE_VCLZs16:
595   case ARM::MVE_VCLZs32:
596     return true;
597   }
598   return false;
599 }
600 
601 
602 // Look at its register uses to see if it only can only receive zeros
603 // into its false lanes which would then produce zeros. Also check that
604 // the output register is also defined by an FalseLanesZero instruction
605 // so that if tail-predication happens, the lanes that aren't updated will
606 // still be zeros.
607 static bool producesFalseLanesZero(MachineInstr &MI,
608                                    const TargetRegisterClass *QPRs,
609                                    const ReachingDefAnalysis &RDA,
610                                    InstSet &FalseLanesZero) {
611   if (canGenerateNonZeros(MI))
612     return false;
613 
614   bool AllowScalars = isHorizontalReduction(MI);
615   for (auto &MO : MI.operands()) {
616     if (!MO.isReg() || !MO.getReg())
617       continue;
618     if (!isRegInClass(MO, QPRs) && AllowScalars)
619       continue;
620     if (auto *OpDef = RDA.getMIOperand(&MI, MO))
621       if (FalseLanesZero.count(OpDef))
622        continue;
623     return false;
624   }
625   LLVM_DEBUG(dbgs() << "ARM Loops: Always False Zeros: " << MI);
626   return true;
627 }
628 
629 bool LowOverheadLoop::ValidateLiveOuts() const {
630   // We want to find out if the tail-predicated version of this loop will
631   // produce the same values as the loop in its original form. For this to
632   // be true, the newly inserted implicit predication must not change the
633   // the (observable) results.
634   // We're doing this because many instructions in the loop will not be
635   // predicated and so the conversion from VPT predication to tail-predication
636   // can result in different values being produced; due to the tail-predication
637   // preventing many instructions from updating their falsely predicated
638   // lanes. This analysis assumes that all the instructions perform lane-wise
639   // operations and don't perform any exchanges.
640   // A masked load, whether through VPT or tail predication, will write zeros
641   // to any of the falsely predicated bytes. So, from the loads, we know that
642   // the false lanes are zeroed and here we're trying to track that those false
643   // lanes remain zero, or where they change, the differences are masked away
644   // by their user(s).
645   // All MVE loads and stores have to be predicated, so we know that any load
646   // operands, or stored results are equivalent already. Other explicitly
647   // predicated instructions will perform the same operation in the original
648   // loop and the tail-predicated form too. Because of this, we can insert
649   // loads, stores and other predicated instructions into our Predicated
650   // set and build from there.
651   const TargetRegisterClass *QPRs = TRI.getRegClass(ARM::MQPRRegClassID);
652   SetVector<MachineInstr *> FalseLanesUnknown;
653   SmallPtrSet<MachineInstr *, 4> FalseLanesZero;
654   SmallPtrSet<MachineInstr *, 4> Predicated;
655   MachineBasicBlock *MBB = ML.getHeader();
656 
657   for (auto &MI : *MBB) {
658     const MCInstrDesc &MCID = MI.getDesc();
659     uint64_t Flags = MCID.TSFlags;
660     if ((Flags & ARMII::DomainMask) != ARMII::DomainMVE)
661       continue;
662 
663     if (isVCTP(&MI) || isVPTOpcode(MI.getOpcode()))
664       continue;
665 
666     // Predicated loads will write zeros to the falsely predicated bytes of the
667     // destination register.
668     if (isVectorPredicated(&MI)) {
669       if (MI.mayLoad())
670         FalseLanesZero.insert(&MI);
671       Predicated.insert(&MI);
672       continue;
673     }
674 
675     if (MI.getNumDefs() == 0)
676       continue;
677 
678     if (!producesFalseLanesZero(MI, QPRs, RDA, FalseLanesZero)) {
679       // We require retaining and horizontal operations to operate upon zero'd
680       // false lanes to ensure the conversion doesn't change the output.
681       if (retainsPreviousHalfElement(MI) || isHorizontalReduction(MI))
682         return false;
683       // Otherwise we need to evaluate this instruction later to see whether
684       // unknown false lanes will get masked away by their user(s).
685       FalseLanesUnknown.insert(&MI);
686     } else if (!isHorizontalReduction(MI))
687       FalseLanesZero.insert(&MI);
688   }
689 
690   auto HasPredicatedUsers = [this](MachineInstr *MI, const MachineOperand &MO,
691                               SmallPtrSetImpl<MachineInstr *> &Predicated) {
692     SmallPtrSet<MachineInstr *, 2> Uses;
693     RDA.getGlobalUses(MI, MO.getReg(), Uses);
694     for (auto *Use : Uses) {
695       if (Use != MI && !Predicated.count(Use))
696         return false;
697     }
698     return true;
699   };
700 
701   // Visit the unknowns in reverse so that we can start at the values being
702   // stored and then we can work towards the leaves, hopefully adding more
703   // instructions to Predicated. Successfully terminating the loop means that
704   // all the unknown values have to found to be masked by predicated user(s).
705   for (auto *MI : reverse(FalseLanesUnknown)) {
706     for (auto &MO : MI->operands()) {
707       if (!isRegInClass(MO, QPRs) || !MO.isDef())
708         continue;
709       if (!HasPredicatedUsers(MI, MO, Predicated)) {
710         LLVM_DEBUG(dbgs() << "ARM Loops: Found an unknown def of : "
711                           << TRI.getRegAsmName(MO.getReg()) << " at " << *MI);
712         return false;
713       }
714     }
715     // Any unknown false lanes have been masked away by the user(s).
716     Predicated.insert(MI);
717   }
718 
719   // Collect Q-regs that are live in the exit blocks. We don't collect scalars
720   // because they won't be affected by lane predication.
721   SmallSet<Register, 2> LiveOuts;
722   SmallVector<MachineBasicBlock *, 2> ExitBlocks;
723   ML.getExitBlocks(ExitBlocks);
724   for (auto *MBB : ExitBlocks)
725     for (const MachineBasicBlock::RegisterMaskPair &RegMask : MBB->liveins())
726       if (QPRs->contains(RegMask.PhysReg))
727         LiveOuts.insert(RegMask.PhysReg);
728 
729   // Collect the instructions in the loop body that define the live-out values.
730   SmallPtrSet<MachineInstr *, 2> LiveMIs;
731   assert(ML.getNumBlocks() == 1 && "Expected single block loop!");
732   for (auto Reg : LiveOuts)
733     if (auto *MI = RDA.getLocalLiveOutMIDef(MBB, Reg))
734       LiveMIs.insert(MI);
735 
736   LLVM_DEBUG(dbgs() << "ARM Loops: Found loop live-outs:\n";
737              for (auto *MI : LiveMIs)
738                dbgs() << " - " << *MI);
739   // We've already validated that any VPT predication within the loop will be
740   // equivalent when we perform the predication transformation; so we know that
741   // any VPT predicated instruction is predicated upon VCTP. Any live-out
742   // instruction needs to be predicated, so check this here.
743   for (auto *MI : LiveMIs)
744     if (!isVectorPredicated(MI))
745       return false;
746 
747   return true;
748 }
749 
750 void LowOverheadLoop::CheckLegality(ARMBasicBlockUtils *BBUtils) {
751   if (Revert)
752     return;
753 
754   if (!End->getOperand(1).isMBB())
755     report_fatal_error("Expected LoopEnd to target basic block");
756 
757   // TODO Maybe there's cases where the target doesn't have to be the header,
758   // but for now be safe and revert.
759   if (End->getOperand(1).getMBB() != ML.getHeader()) {
760     LLVM_DEBUG(dbgs() << "ARM Loops: LoopEnd is not targetting header.\n");
761     Revert = true;
762     return;
763   }
764 
765   // The WLS and LE instructions have 12-bits for the label offset. WLS
766   // requires a positive offset, while LE uses negative.
767   if (BBUtils->getOffsetOf(End) < BBUtils->getOffsetOf(ML.getHeader()) ||
768       !BBUtils->isBBInRange(End, ML.getHeader(), 4094)) {
769     LLVM_DEBUG(dbgs() << "ARM Loops: LE offset is out-of-range\n");
770     Revert = true;
771     return;
772   }
773 
774   if (Start->getOpcode() == ARM::t2WhileLoopStart &&
775       (BBUtils->getOffsetOf(Start) >
776        BBUtils->getOffsetOf(Start->getOperand(1).getMBB()) ||
777        !BBUtils->isBBInRange(Start, Start->getOperand(1).getMBB(), 4094))) {
778     LLVM_DEBUG(dbgs() << "ARM Loops: WLS offset is out-of-range!\n");
779     Revert = true;
780     return;
781   }
782 
783   InsertPt = Revert ? nullptr : isSafeToDefineLR();
784   if (!InsertPt) {
785     LLVM_DEBUG(dbgs() << "ARM Loops: Unable to find safe insertion point.\n");
786     Revert = true;
787     return;
788   } else
789     LLVM_DEBUG(dbgs() << "ARM Loops: Start insertion point: " << *InsertPt);
790 
791   if (!IsTailPredicationLegal()) {
792     LLVM_DEBUG(if (!VCTP)
793                  dbgs() << "ARM Loops: Didn't find a VCTP instruction.\n";
794                dbgs() << "ARM Loops: Tail-predication is not valid.\n");
795     return;
796   }
797 
798   assert(ML.getBlocks().size() == 1 &&
799          "Shouldn't be processing a loop with more than one block");
800   CannotTailPredicate = !ValidateTailPredicate(InsertPt);
801   LLVM_DEBUG(if (CannotTailPredicate)
802              dbgs() << "ARM Loops: Couldn't validate tail predicate.\n");
803 }
804 
805 bool LowOverheadLoop::ValidateMVEInst(MachineInstr* MI) {
806   if (CannotTailPredicate)
807     return false;
808 
809   if (isVCTP(MI)) {
810     // If we find another VCTP, check whether it uses the same value as the main VCTP.
811     // If it does, store it in the SecondaryVCTPs set, else refuse it.
812     if (VCTP) {
813       if (!VCTP->getOperand(1).isIdenticalTo(MI->getOperand(1)) ||
814           !RDA.hasSameReachingDef(VCTP, MI, MI->getOperand(1).getReg())) {
815         LLVM_DEBUG(dbgs() << "ARM Loops: Found VCTP with a different reaching "
816                              "definition from the main VCTP");
817         return false;
818       }
819       LLVM_DEBUG(dbgs() << "ARM Loops: Found secondary VCTP: " << *MI);
820       SecondaryVCTPs.insert(MI);
821     } else {
822       LLVM_DEBUG(dbgs() << "ARM Loops: Found 'main' VCTP: " << *MI);
823       VCTP = MI;
824     }
825   } else if (isVPTOpcode(MI->getOpcode())) {
826     if (MI->getOpcode() != ARM::MVE_VPST) {
827       assert(MI->findRegisterDefOperandIdx(ARM::VPR) != -1 &&
828              "VPT does not implicitly define VPR?!");
829       CurrentPredicate.insert(MI);
830     }
831 
832     VPTBlocks.emplace_back(MI, CurrentPredicate);
833     CurrentBlock = &VPTBlocks.back();
834     return true;
835   } else if (MI->getOpcode() == ARM::MVE_VPSEL ||
836              MI->getOpcode() == ARM::MVE_VPNOT) {
837     // TODO: Allow VPSEL and VPNOT, we currently cannot because:
838     // 1) It will use the VPR as a predicate operand, but doesn't have to be
839     //    instead a VPT block, which means we can assert while building up
840     //    the VPT block because we don't find another VPT or VPST to being a new
841     //    one.
842     // 2) VPSEL still requires a VPR operand even after tail predicating,
843     //    which means we can't remove it unless there is another
844     //    instruction, such as vcmp, that can provide the VPR def.
845     return false;
846   }
847 
848   bool IsUse = false;
849   bool IsDef = false;
850   const MCInstrDesc &MCID = MI->getDesc();
851   for (int i = MI->getNumOperands() - 1; i >= 0; --i) {
852     const MachineOperand &MO = MI->getOperand(i);
853     if (!MO.isReg() || MO.getReg() != ARM::VPR)
854       continue;
855 
856     if (MO.isDef()) {
857       CurrentPredicate.insert(MI);
858       IsDef = true;
859     } else if (ARM::isVpred(MCID.OpInfo[i].OperandType)) {
860       CurrentBlock->addInst(MI, CurrentPredicate);
861       IsUse = true;
862     } else {
863       LLVM_DEBUG(dbgs() << "ARM Loops: Found instruction using vpr: " << *MI);
864       return false;
865     }
866   }
867 
868   // If we find a vpr def that is not already predicated on the vctp, we've
869   // got disjoint predicates that may not be equivalent when we do the
870   // conversion.
871   if (IsDef && !IsUse && VCTP && !isVCTP(MI)) {
872     LLVM_DEBUG(dbgs() << "ARM Loops: Found disjoint vpr def: " << *MI);
873     return false;
874   }
875 
876   uint64_t Flags = MCID.TSFlags;
877   if ((Flags & ARMII::DomainMask) != ARMII::DomainMVE)
878     return true;
879 
880   // If we find an instruction that has been marked as not valid for tail
881   // predication, only allow the instruction if it's contained within a valid
882   // VPT block.
883   if ((Flags & ARMII::ValidForTailPredication) == 0 && !IsUse) {
884     LLVM_DEBUG(dbgs() << "ARM Loops: Can't tail predicate: " << *MI);
885     return false;
886   }
887 
888   // If the instruction is already explicitly predicated, then the conversion
889   // will be fine, but ensure that all memory operations are predicated.
890   return !IsUse && MI->mayLoadOrStore() ? false : true;
891 }
892 
893 bool ARMLowOverheadLoops::runOnMachineFunction(MachineFunction &mf) {
894   const ARMSubtarget &ST = static_cast<const ARMSubtarget&>(mf.getSubtarget());
895   if (!ST.hasLOB())
896     return false;
897 
898   MF = &mf;
899   LLVM_DEBUG(dbgs() << "ARM Loops on " << MF->getName() << " ------------- \n");
900 
901   MLI = &getAnalysis<MachineLoopInfo>();
902   RDA = &getAnalysis<ReachingDefAnalysis>();
903   MF->getProperties().set(MachineFunctionProperties::Property::TracksLiveness);
904   MRI = &MF->getRegInfo();
905   TII = static_cast<const ARMBaseInstrInfo*>(ST.getInstrInfo());
906   TRI = ST.getRegisterInfo();
907   BBUtils = std::unique_ptr<ARMBasicBlockUtils>(new ARMBasicBlockUtils(*MF));
908   BBUtils->computeAllBlockSizes();
909   BBUtils->adjustBBOffsetsAfter(&MF->front());
910 
911   bool Changed = false;
912   for (auto ML : *MLI) {
913     if (!ML->getParentLoop())
914       Changed |= ProcessLoop(ML);
915   }
916   Changed |= RevertNonLoops();
917   return Changed;
918 }
919 
920 bool ARMLowOverheadLoops::ProcessLoop(MachineLoop *ML) {
921 
922   bool Changed = false;
923 
924   // Process inner loops first.
925   for (auto I = ML->begin(), E = ML->end(); I != E; ++I)
926     Changed |= ProcessLoop(*I);
927 
928   LLVM_DEBUG(dbgs() << "ARM Loops: Processing loop containing:\n";
929              if (auto *Preheader = ML->getLoopPreheader())
930                dbgs() << " - " << Preheader->getName() << "\n";
931              else if (auto *Preheader = MLI->findLoopPreheader(ML))
932                dbgs() << " - " << Preheader->getName() << "\n";
933              else if (auto *Preheader = MLI->findLoopPreheader(ML, true))
934                dbgs() << " - " << Preheader->getName() << "\n";
935              for (auto *MBB : ML->getBlocks())
936                dbgs() << " - " << MBB->getName() << "\n";
937             );
938 
939   // Search the given block for a loop start instruction. If one isn't found,
940   // and there's only one predecessor block, search that one too.
941   std::function<MachineInstr*(MachineBasicBlock*)> SearchForStart =
942     [&SearchForStart](MachineBasicBlock *MBB) -> MachineInstr* {
943     for (auto &MI : *MBB) {
944       if (isLoopStart(MI))
945         return &MI;
946     }
947     if (MBB->pred_size() == 1)
948       return SearchForStart(*MBB->pred_begin());
949     return nullptr;
950   };
951 
952   LowOverheadLoop LoLoop(*ML, *MLI, *RDA, *TRI);
953   // Search the preheader for the start intrinsic.
954   // FIXME: I don't see why we shouldn't be supporting multiple predecessors
955   // with potentially multiple set.loop.iterations, so we need to enable this.
956   if (auto *Preheader = ML->getLoopPreheader())
957     LoLoop.Start = SearchForStart(Preheader);
958   else if (auto *Preheader = MLI->findLoopPreheader(ML, true))
959     LoLoop.Start = SearchForStart(Preheader);
960   else
961     return false;
962 
963   // Find the low-overhead loop components and decide whether or not to fall
964   // back to a normal loop. Also look for a vctp instructions and decide
965   // whether we can convert that predicate using tail predication.
966   for (auto *MBB : reverse(ML->getBlocks())) {
967     for (auto &MI : *MBB) {
968       if (MI.isDebugValue())
969         continue;
970       else if (MI.getOpcode() == ARM::t2LoopDec)
971         LoLoop.Dec = &MI;
972       else if (MI.getOpcode() == ARM::t2LoopEnd)
973         LoLoop.End = &MI;
974       else if (isLoopStart(MI))
975         LoLoop.Start = &MI;
976       else if (MI.getDesc().isCall()) {
977         // TODO: Though the call will require LE to execute again, does this
978         // mean we should revert? Always executing LE hopefully should be
979         // faster than performing a sub,cmp,br or even subs,br.
980         LoLoop.Revert = true;
981         LLVM_DEBUG(dbgs() << "ARM Loops: Found call.\n");
982       } else {
983         // Record VPR defs and build up their corresponding vpt blocks.
984         // Check we know how to tail predicate any mve instructions.
985         LoLoop.AnalyseMVEInst(&MI);
986       }
987     }
988   }
989 
990   LLVM_DEBUG(LoLoop.dump());
991   if (!LoLoop.FoundAllComponents()) {
992     LLVM_DEBUG(dbgs() << "ARM Loops: Didn't find loop start, update, end\n");
993     return false;
994   }
995 
996   // Check that the only instruction using LoopDec is LoopEnd.
997   // TODO: Check for copy chains that really have no effect.
998   SmallPtrSet<MachineInstr*, 2> Uses;
999   RDA->getReachingLocalUses(LoLoop.Dec, ARM::LR, Uses);
1000   if (Uses.size() > 1 || !Uses.count(LoLoop.End)) {
1001     LLVM_DEBUG(dbgs() << "ARM Loops: Unable to remove LoopDec.\n");
1002     LoLoop.Revert = true;
1003   }
1004   LoLoop.CheckLegality(BBUtils.get());
1005   Expand(LoLoop);
1006   return true;
1007 }
1008 
1009 // WhileLoopStart holds the exit block, so produce a cmp lr, 0 and then a
1010 // beq that branches to the exit branch.
1011 // TODO: We could also try to generate a cbz if the value in LR is also in
1012 // another low register.
1013 void ARMLowOverheadLoops::RevertWhile(MachineInstr *MI) const {
1014   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to cmp: " << *MI);
1015   MachineBasicBlock *MBB = MI->getParent();
1016   MachineInstrBuilder MIB = BuildMI(*MBB, MI, MI->getDebugLoc(),
1017                                     TII->get(ARM::t2CMPri));
1018   MIB.add(MI->getOperand(0));
1019   MIB.addImm(0);
1020   MIB.addImm(ARMCC::AL);
1021   MIB.addReg(ARM::NoRegister);
1022 
1023   MachineBasicBlock *DestBB = MI->getOperand(1).getMBB();
1024   unsigned BrOpc = BBUtils->isBBInRange(MI, DestBB, 254) ?
1025     ARM::tBcc : ARM::t2Bcc;
1026 
1027   MIB = BuildMI(*MBB, MI, MI->getDebugLoc(), TII->get(BrOpc));
1028   MIB.add(MI->getOperand(1));   // branch target
1029   MIB.addImm(ARMCC::EQ);        // condition code
1030   MIB.addReg(ARM::CPSR);
1031   MI->eraseFromParent();
1032 }
1033 
1034 bool ARMLowOverheadLoops::RevertLoopDec(MachineInstr *MI) const {
1035   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to sub: " << *MI);
1036   MachineBasicBlock *MBB = MI->getParent();
1037   SmallPtrSet<MachineInstr*, 1> Ignore;
1038   for (auto I = MachineBasicBlock::iterator(MI), E = MBB->end(); I != E; ++I) {
1039     if (I->getOpcode() == ARM::t2LoopEnd) {
1040       Ignore.insert(&*I);
1041       break;
1042     }
1043   }
1044 
1045   // If nothing defines CPSR between LoopDec and LoopEnd, use a t2SUBS.
1046   bool SetFlags = RDA->isSafeToDefRegAt(MI, ARM::CPSR, Ignore);
1047 
1048   MachineInstrBuilder MIB = BuildMI(*MBB, MI, MI->getDebugLoc(),
1049                                     TII->get(ARM::t2SUBri));
1050   MIB.addDef(ARM::LR);
1051   MIB.add(MI->getOperand(1));
1052   MIB.add(MI->getOperand(2));
1053   MIB.addImm(ARMCC::AL);
1054   MIB.addReg(0);
1055 
1056   if (SetFlags) {
1057     MIB.addReg(ARM::CPSR);
1058     MIB->getOperand(5).setIsDef(true);
1059   } else
1060     MIB.addReg(0);
1061 
1062   MI->eraseFromParent();
1063   return SetFlags;
1064 }
1065 
1066 // Generate a subs, or sub and cmp, and a branch instead of an LE.
1067 void ARMLowOverheadLoops::RevertLoopEnd(MachineInstr *MI, bool SkipCmp) const {
1068   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to cmp, br: " << *MI);
1069 
1070   MachineBasicBlock *MBB = MI->getParent();
1071   // Create cmp
1072   if (!SkipCmp) {
1073     MachineInstrBuilder MIB = BuildMI(*MBB, MI, MI->getDebugLoc(),
1074                                       TII->get(ARM::t2CMPri));
1075     MIB.addReg(ARM::LR);
1076     MIB.addImm(0);
1077     MIB.addImm(ARMCC::AL);
1078     MIB.addReg(ARM::NoRegister);
1079   }
1080 
1081   MachineBasicBlock *DestBB = MI->getOperand(1).getMBB();
1082   unsigned BrOpc = BBUtils->isBBInRange(MI, DestBB, 254) ?
1083     ARM::tBcc : ARM::t2Bcc;
1084 
1085   // Create bne
1086   MachineInstrBuilder MIB =
1087     BuildMI(*MBB, MI, MI->getDebugLoc(), TII->get(BrOpc));
1088   MIB.add(MI->getOperand(1));   // branch target
1089   MIB.addImm(ARMCC::NE);        // condition code
1090   MIB.addReg(ARM::CPSR);
1091   MI->eraseFromParent();
1092 }
1093 
1094 // Perform dead code elimation on the loop iteration count setup expression.
1095 // If we are tail-predicating, the number of elements to be processed is the
1096 // operand of the VCTP instruction in the vector body, see getCount(), which is
1097 // register $r3 in this example:
1098 //
1099 //   $lr = big-itercount-expression
1100 //   ..
1101 //   t2DoLoopStart renamable $lr
1102 //   vector.body:
1103 //     ..
1104 //     $vpr = MVE_VCTP32 renamable $r3
1105 //     renamable $lr = t2LoopDec killed renamable $lr, 1
1106 //     t2LoopEnd renamable $lr, %vector.body
1107 //     tB %end
1108 //
1109 // What we would like achieve here is to replace the do-loop start pseudo
1110 // instruction t2DoLoopStart with:
1111 //
1112 //    $lr = MVE_DLSTP_32 killed renamable $r3
1113 //
1114 // Thus, $r3 which defines the number of elements, is written to $lr,
1115 // and then we want to delete the whole chain that used to define $lr,
1116 // see the comment below how this chain could look like.
1117 //
1118 void ARMLowOverheadLoops::IterationCountDCE(LowOverheadLoop &LoLoop) {
1119   if (!LoLoop.IsTailPredicationLegal())
1120     return;
1121 
1122   LLVM_DEBUG(dbgs() << "ARM Loops: Trying DCE on loop iteration count.\n");
1123 
1124   MachineInstr *Def = RDA->getMIOperand(LoLoop.Start, 0);
1125   if (!Def) {
1126     LLVM_DEBUG(dbgs() << "ARM Loops: Couldn't find iteration count.\n");
1127     return;
1128   }
1129 
1130   // Collect and remove the users of iteration count.
1131   SmallPtrSet<MachineInstr*, 4> Killed  = { LoLoop.Start, LoLoop.Dec,
1132                                             LoLoop.End, LoLoop.InsertPt };
1133   SmallPtrSet<MachineInstr*, 2> Remove;
1134   if (RDA->isSafeToRemove(Def, Remove, Killed))
1135     LoLoop.ToRemove.insert(Remove.begin(), Remove.end());
1136   else {
1137     LLVM_DEBUG(dbgs() << "ARM Loops: Unsafe to remove loop iteration count.\n");
1138     return;
1139   }
1140 
1141   // Collect the dead code and the MBBs in which they reside.
1142   RDA->collectKilledOperands(Def, Killed);
1143   SmallPtrSet<MachineBasicBlock*, 2> BasicBlocks;
1144   for (auto *MI : Killed)
1145     BasicBlocks.insert(MI->getParent());
1146 
1147   // Collect IT blocks in all affected basic blocks.
1148   std::map<MachineInstr *, SmallPtrSet<MachineInstr *, 2>> ITBlocks;
1149   for (auto *MBB : BasicBlocks) {
1150     for (auto &MI : *MBB) {
1151       if (MI.getOpcode() != ARM::t2IT)
1152         continue;
1153       RDA->getReachingLocalUses(&MI, ARM::ITSTATE, ITBlocks[&MI]);
1154     }
1155   }
1156 
1157   // If we're removing all of the instructions within an IT block, then
1158   // also remove the IT instruction.
1159   SmallPtrSet<MachineInstr*, 2> ModifiedITs;
1160   for (auto *MI : Killed) {
1161     if (MachineOperand *MO = MI->findRegisterUseOperand(ARM::ITSTATE)) {
1162       MachineInstr *IT = RDA->getMIOperand(MI, *MO);
1163       auto &CurrentBlock = ITBlocks[IT];
1164       CurrentBlock.erase(MI);
1165       if (CurrentBlock.empty())
1166         ModifiedITs.erase(IT);
1167       else
1168         ModifiedITs.insert(IT);
1169     }
1170   }
1171 
1172   // Delete the killed instructions only if we don't have any IT blocks that
1173   // need to be modified because we need to fixup the mask.
1174   // TODO: Handle cases where IT blocks are modified.
1175   if (ModifiedITs.empty()) {
1176     LLVM_DEBUG(dbgs() << "ARM Loops: Will remove iteration count:\n";
1177                for (auto *MI : Killed)
1178                  dbgs() << " - " << *MI);
1179     LoLoop.ToRemove.insert(Killed.begin(), Killed.end());
1180   } else
1181     LLVM_DEBUG(dbgs() << "ARM Loops: Would need to modify IT block(s).\n");
1182 }
1183 
1184 MachineInstr* ARMLowOverheadLoops::ExpandLoopStart(LowOverheadLoop &LoLoop) {
1185   LLVM_DEBUG(dbgs() << "ARM Loops: Expanding LoopStart.\n");
1186   // When using tail-predication, try to delete the dead code that was used to
1187   // calculate the number of loop iterations.
1188   IterationCountDCE(LoLoop);
1189 
1190   MachineInstr *InsertPt = LoLoop.InsertPt;
1191   MachineInstr *Start = LoLoop.Start;
1192   MachineBasicBlock *MBB = InsertPt->getParent();
1193   bool IsDo = Start->getOpcode() == ARM::t2DoLoopStart;
1194   unsigned Opc = LoLoop.getStartOpcode();
1195   MachineOperand &Count = LoLoop.getCount();
1196 
1197   MachineInstrBuilder MIB =
1198     BuildMI(*MBB, InsertPt, InsertPt->getDebugLoc(), TII->get(Opc));
1199 
1200   MIB.addDef(ARM::LR);
1201   MIB.add(Count);
1202   if (!IsDo)
1203     MIB.add(Start->getOperand(1));
1204 
1205   // If we're inserting at a mov lr, then remove it as it's redundant.
1206   if (InsertPt != Start)
1207     LoLoop.ToRemove.insert(InsertPt);
1208   LoLoop.ToRemove.insert(Start);
1209   LLVM_DEBUG(dbgs() << "ARM Loops: Inserted start: " << *MIB);
1210   return &*MIB;
1211 }
1212 
1213 void ARMLowOverheadLoops::ConvertVPTBlocks(LowOverheadLoop &LoLoop) {
1214   auto RemovePredicate = [](MachineInstr *MI) {
1215     LLVM_DEBUG(dbgs() << "ARM Loops: Removing predicate from: " << *MI);
1216     if (int PIdx = llvm::findFirstVPTPredOperandIdx(*MI)) {
1217       assert(MI->getOperand(PIdx).getImm() == ARMVCC::Then &&
1218              "Expected Then predicate!");
1219       MI->getOperand(PIdx).setImm(ARMVCC::None);
1220       MI->getOperand(PIdx+1).setReg(0);
1221     } else
1222       llvm_unreachable("trying to unpredicate a non-predicated instruction");
1223   };
1224 
1225   // There are a few scenarios which we have to fix up:
1226   // 1. VPT Blocks with non-uniform predicates:
1227   //    - a. When the divergent instruction is a vctp
1228   //    - b. When the block uses a vpst, and is only predicated on the vctp
1229   //    - c. When the block uses a vpt and (optionally) contains one or more
1230   //         vctp.
1231   // 2. VPT Blocks with uniform predicates:
1232   //    - a. The block uses a vpst, and is only predicated on the vctp
1233   for (auto &Block : LoLoop.getVPTBlocks()) {
1234     SmallVectorImpl<PredicatedMI> &Insts = Block.getInsts();
1235     if (Block.HasNonUniformPredicate()) {
1236       PredicatedMI *Divergent = Block.getDivergent();
1237       if (isVCTP(Divergent->MI)) {
1238         // The vctp will be removed, so the block mask of the vp(s)t will need
1239         // to be recomputed.
1240         LoLoop.BlockMasksToRecompute.insert(Block.getPredicateThen());
1241       } else if (Block.isVPST() && Block.IsOnlyPredicatedOn(LoLoop.VCTP)) {
1242         // The VPT block has a non-uniform predicate but it uses a vpst and its
1243         // entry is guarded only by a vctp, which means we:
1244         // - Need to remove the original vpst.
1245         // - Then need to unpredicate any following instructions, until
1246         //   we come across the divergent vpr def.
1247         // - Insert a new vpst to predicate the instruction(s) that following
1248         //   the divergent vpr def.
1249         // TODO: We could be producing more VPT blocks than necessary and could
1250         // fold the newly created one into a proceeding one.
1251         for (auto I = ++MachineBasicBlock::iterator(Block.getPredicateThen()),
1252              E = ++MachineBasicBlock::iterator(Divergent->MI); I != E; ++I)
1253           RemovePredicate(&*I);
1254 
1255         unsigned Size = 0;
1256         auto E = MachineBasicBlock::reverse_iterator(Divergent->MI);
1257         auto I = MachineBasicBlock::reverse_iterator(Insts.back().MI);
1258         MachineInstr *InsertAt = nullptr;
1259         while (I != E) {
1260           InsertAt = &*I;
1261           ++Size;
1262           ++I;
1263         }
1264         // Create a VPST (with a null mask for now, we'll recompute it later).
1265         MachineInstrBuilder MIB = BuildMI(*InsertAt->getParent(), InsertAt,
1266                                           InsertAt->getDebugLoc(),
1267                                           TII->get(ARM::MVE_VPST));
1268         MIB.addImm(0);
1269         LLVM_DEBUG(dbgs() << "ARM Loops: Removing VPST: " << *Block.getPredicateThen());
1270         LLVM_DEBUG(dbgs() << "ARM Loops: Created VPST: " << *MIB);
1271         LoLoop.ToRemove.insert(Block.getPredicateThen());
1272         LoLoop.BlockMasksToRecompute.insert(MIB.getInstr());
1273       }
1274       // Else, if the block uses a vpt, iterate over the block, removing the
1275       // extra VCTPs it may contain.
1276       else if (Block.isVPT()) {
1277         bool RemovedVCTP = false;
1278         for (PredicatedMI &Elt : Block.getInsts()) {
1279           MachineInstr *MI = Elt.MI;
1280           if (isVCTP(MI)) {
1281             LLVM_DEBUG(dbgs() << "ARM Loops: Removing VCTP: " << *MI);
1282             LoLoop.ToRemove.insert(MI);
1283             RemovedVCTP = true;
1284             continue;
1285           }
1286         }
1287         if (RemovedVCTP)
1288           LoLoop.BlockMasksToRecompute.insert(Block.getPredicateThen());
1289       }
1290     } else if (Block.IsOnlyPredicatedOn(LoLoop.VCTP) && Block.isVPST()) {
1291       // A vpt block starting with VPST, is only predicated upon vctp and has no
1292       // internal vpr defs:
1293       // - Remove vpst.
1294       // - Unpredicate the remaining instructions.
1295       LLVM_DEBUG(dbgs() << "ARM Loops: Removing VPST: " << *Block.getPredicateThen());
1296       LoLoop.ToRemove.insert(Block.getPredicateThen());
1297       for (auto &PredMI : Insts)
1298         RemovePredicate(PredMI.MI);
1299     }
1300   }
1301   LLVM_DEBUG(dbgs() << "ARM Loops: Removing remaining VCTPs...\n");
1302   // Remove the "main" VCTP
1303   LoLoop.ToRemove.insert(LoLoop.VCTP);
1304   LLVM_DEBUG(dbgs() << "    " << *LoLoop.VCTP);
1305   // Remove remaining secondary VCTPs
1306   for (MachineInstr *VCTP : LoLoop.SecondaryVCTPs) {
1307     // All VCTPs that aren't marked for removal yet should be unpredicated ones.
1308     // The predicated ones should have already been marked for removal when
1309     // visiting the VPT blocks.
1310     if (LoLoop.ToRemove.insert(VCTP).second) {
1311       assert(getVPTInstrPredicate(*VCTP) == ARMVCC::None &&
1312              "Removing Predicated VCTP without updating the block mask!");
1313       LLVM_DEBUG(dbgs() << "    " << *VCTP);
1314     }
1315   }
1316 }
1317 
1318 void ARMLowOverheadLoops::Expand(LowOverheadLoop &LoLoop) {
1319 
1320   // Combine the LoopDec and LoopEnd instructions into LE(TP).
1321   auto ExpandLoopEnd = [this](LowOverheadLoop &LoLoop) {
1322     MachineInstr *End = LoLoop.End;
1323     MachineBasicBlock *MBB = End->getParent();
1324     unsigned Opc = LoLoop.IsTailPredicationLegal() ?
1325       ARM::MVE_LETP : ARM::t2LEUpdate;
1326     MachineInstrBuilder MIB = BuildMI(*MBB, End, End->getDebugLoc(),
1327                                       TII->get(Opc));
1328     MIB.addDef(ARM::LR);
1329     MIB.add(End->getOperand(0));
1330     MIB.add(End->getOperand(1));
1331     LLVM_DEBUG(dbgs() << "ARM Loops: Inserted LE: " << *MIB);
1332     LoLoop.ToRemove.insert(LoLoop.Dec);
1333     LoLoop.ToRemove.insert(End);
1334     return &*MIB;
1335   };
1336 
1337   // TODO: We should be able to automatically remove these branches before we
1338   // get here - probably by teaching analyzeBranch about the pseudo
1339   // instructions.
1340   // If there is an unconditional branch, after I, that just branches to the
1341   // next block, remove it.
1342   auto RemoveDeadBranch = [](MachineInstr *I) {
1343     MachineBasicBlock *BB = I->getParent();
1344     MachineInstr *Terminator = &BB->instr_back();
1345     if (Terminator->isUnconditionalBranch() && I != Terminator) {
1346       MachineBasicBlock *Succ = Terminator->getOperand(0).getMBB();
1347       if (BB->isLayoutSuccessor(Succ)) {
1348         LLVM_DEBUG(dbgs() << "ARM Loops: Removing branch: " << *Terminator);
1349         Terminator->eraseFromParent();
1350       }
1351     }
1352   };
1353 
1354   if (LoLoop.Revert) {
1355     if (LoLoop.Start->getOpcode() == ARM::t2WhileLoopStart)
1356       RevertWhile(LoLoop.Start);
1357     else
1358       LoLoop.Start->eraseFromParent();
1359     bool FlagsAlreadySet = RevertLoopDec(LoLoop.Dec);
1360     RevertLoopEnd(LoLoop.End, FlagsAlreadySet);
1361   } else {
1362     LoLoop.Start = ExpandLoopStart(LoLoop);
1363     RemoveDeadBranch(LoLoop.Start);
1364     LoLoop.End = ExpandLoopEnd(LoLoop);
1365     RemoveDeadBranch(LoLoop.End);
1366     if (LoLoop.IsTailPredicationLegal())
1367       ConvertVPTBlocks(LoLoop);
1368     for (auto *I : LoLoop.ToRemove) {
1369       LLVM_DEBUG(dbgs() << "ARM Loops: Erasing " << *I);
1370       I->eraseFromParent();
1371     }
1372     for (auto *I : LoLoop.BlockMasksToRecompute) {
1373       LLVM_DEBUG(dbgs() << "ARM Loops: Recomputing VPT/VPST Block Mask: " << *I);
1374       recomputeVPTBlockMask(*I);
1375       LLVM_DEBUG(dbgs() << "           ... done: " << *I);
1376     }
1377   }
1378 
1379   PostOrderLoopTraversal DFS(LoLoop.ML, *MLI);
1380   DFS.ProcessLoop();
1381   const SmallVectorImpl<MachineBasicBlock*> &PostOrder = DFS.getOrder();
1382   for (auto *MBB : PostOrder) {
1383     recomputeLiveIns(*MBB);
1384     // FIXME: For some reason, the live-in print order is non-deterministic for
1385     // our tests and I can't out why... So just sort them.
1386     MBB->sortUniqueLiveIns();
1387   }
1388 
1389   for (auto *MBB : reverse(PostOrder))
1390     recomputeLivenessFlags(*MBB);
1391 
1392   // We've moved, removed and inserted new instructions, so update RDA.
1393   RDA->reset();
1394 }
1395 
1396 bool ARMLowOverheadLoops::RevertNonLoops() {
1397   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting any remaining pseudos...\n");
1398   bool Changed = false;
1399 
1400   for (auto &MBB : *MF) {
1401     SmallVector<MachineInstr*, 4> Starts;
1402     SmallVector<MachineInstr*, 4> Decs;
1403     SmallVector<MachineInstr*, 4> Ends;
1404 
1405     for (auto &I : MBB) {
1406       if (isLoopStart(I))
1407         Starts.push_back(&I);
1408       else if (I.getOpcode() == ARM::t2LoopDec)
1409         Decs.push_back(&I);
1410       else if (I.getOpcode() == ARM::t2LoopEnd)
1411         Ends.push_back(&I);
1412     }
1413 
1414     if (Starts.empty() && Decs.empty() && Ends.empty())
1415       continue;
1416 
1417     Changed = true;
1418 
1419     for (auto *Start : Starts) {
1420       if (Start->getOpcode() == ARM::t2WhileLoopStart)
1421         RevertWhile(Start);
1422       else
1423         Start->eraseFromParent();
1424     }
1425     for (auto *Dec : Decs)
1426       RevertLoopDec(Dec);
1427 
1428     for (auto *End : Ends)
1429       RevertLoopEnd(End);
1430   }
1431   return Changed;
1432 }
1433 
1434 FunctionPass *llvm::createARMLowOverheadLoopsPass() {
1435   return new ARMLowOverheadLoops();
1436 }
1437