1 //===- GCNRegPressure.cpp -------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 /// This file implements the GCNRegPressure class.
11 ///
12 //===----------------------------------------------------------------------===//
13 
14 #include "GCNRegPressure.h"
15 #include "AMDGPUSubtarget.h"
16 #include "SIRegisterInfo.h"
17 #include "llvm/ADT/SmallVector.h"
18 #include "llvm/CodeGen/LiveInterval.h"
19 #include "llvm/CodeGen/LiveIntervals.h"
20 #include "llvm/CodeGen/MachineInstr.h"
21 #include "llvm/CodeGen/MachineOperand.h"
22 #include "llvm/CodeGen/MachineRegisterInfo.h"
23 #include "llvm/CodeGen/RegisterPressure.h"
24 #include "llvm/CodeGen/SlotIndexes.h"
25 #include "llvm/CodeGen/TargetRegisterInfo.h"
26 #include "llvm/Config/llvm-config.h"
27 #include "llvm/MC/LaneBitmask.h"
28 #include "llvm/Support/Compiler.h"
29 #include "llvm/Support/Debug.h"
30 #include "llvm/Support/ErrorHandling.h"
31 #include "llvm/Support/raw_ostream.h"
32 #include <algorithm>
33 #include <cassert>
34 
35 using namespace llvm;
36 
37 #define DEBUG_TYPE "machine-scheduler"
38 
39 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
40 LLVM_DUMP_METHOD
41 void llvm::printLivesAt(SlotIndex SI,
42                         const LiveIntervals &LIS,
43                         const MachineRegisterInfo &MRI) {
44   dbgs() << "Live regs at " << SI << ": "
45          << *LIS.getInstructionFromIndex(SI);
46   unsigned Num = 0;
47   for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
48     const unsigned Reg = Register::index2VirtReg(I);
49     if (!LIS.hasInterval(Reg))
50       continue;
51     const auto &LI = LIS.getInterval(Reg);
52     if (LI.hasSubRanges()) {
53       bool firstTime = true;
54       for (const auto &S : LI.subranges()) {
55         if (!S.liveAt(SI)) continue;
56         if (firstTime) {
57           dbgs() << "  " << printReg(Reg, MRI.getTargetRegisterInfo())
58                  << '\n';
59           firstTime = false;
60         }
61         dbgs() << "  " << S << '\n';
62         ++Num;
63       }
64     } else if (LI.liveAt(SI)) {
65       dbgs() << "  " << LI << '\n';
66       ++Num;
67     }
68   }
69   if (!Num) dbgs() << "  <none>\n";
70 }
71 #endif
72 
73 bool llvm::isEqual(const GCNRPTracker::LiveRegSet &S1,
74                    const GCNRPTracker::LiveRegSet &S2) {
75   if (S1.size() != S2.size())
76     return false;
77 
78   for (const auto &P : S1) {
79     auto I = S2.find(P.first);
80     if (I == S2.end() || I->second != P.second)
81       return false;
82   }
83   return true;
84 }
85 
86 
87 ///////////////////////////////////////////////////////////////////////////////
88 // GCNRegPressure
89 
90 unsigned GCNRegPressure::getRegKind(unsigned Reg,
91                                     const MachineRegisterInfo &MRI) {
92   assert(Register::isVirtualRegister(Reg));
93   const auto RC = MRI.getRegClass(Reg);
94   auto STI = static_cast<const SIRegisterInfo*>(MRI.getTargetRegisterInfo());
95   return STI->isSGPRClass(RC) ?
96     (STI->getRegSizeInBits(*RC) == 32 ? SGPR32 : SGPR_TUPLE) :
97     STI->hasAGPRs(RC) ?
98       (STI->getRegSizeInBits(*RC) == 32 ? AGPR32 : AGPR_TUPLE) :
99       (STI->getRegSizeInBits(*RC) == 32 ? VGPR32 : VGPR_TUPLE);
100 }
101 
102 void GCNRegPressure::inc(unsigned Reg,
103                          LaneBitmask PrevMask,
104                          LaneBitmask NewMask,
105                          const MachineRegisterInfo &MRI) {
106   if (NewMask == PrevMask)
107     return;
108 
109   int Sign = 1;
110   if (NewMask < PrevMask) {
111     std::swap(NewMask, PrevMask);
112     Sign = -1;
113   }
114 #ifndef NDEBUG
115   const auto MaxMask = MRI.getMaxLaneMaskForVReg(Reg);
116 #endif
117   switch (auto Kind = getRegKind(Reg, MRI)) {
118   case SGPR32:
119   case VGPR32:
120   case AGPR32:
121     assert(PrevMask.none() && NewMask == MaxMask);
122     Value[Kind] += Sign;
123     break;
124 
125   case SGPR_TUPLE:
126   case VGPR_TUPLE:
127   case AGPR_TUPLE:
128     assert(NewMask < MaxMask || NewMask == MaxMask);
129     assert(PrevMask < NewMask);
130 
131     Value[Kind == SGPR_TUPLE ? SGPR32 : Kind == AGPR_TUPLE ? AGPR32 : VGPR32] +=
132       Sign * SIRegisterInfo::getNumCoveredRegs(~PrevMask & NewMask);
133 
134     if (PrevMask.none()) {
135       assert(NewMask.any());
136       Value[Kind] += Sign * MRI.getPressureSets(Reg).getWeight();
137     }
138     break;
139 
140   default: llvm_unreachable("Unknown register kind");
141   }
142 }
143 
144 bool GCNRegPressure::less(const GCNSubtarget &ST,
145                           const GCNRegPressure& O,
146                           unsigned MaxOccupancy) const {
147   const auto SGPROcc = std::min(MaxOccupancy,
148                                 ST.getOccupancyWithNumSGPRs(getSGPRNum()));
149   const auto VGPROcc = std::min(MaxOccupancy,
150                                 ST.getOccupancyWithNumVGPRs(getVGPRNum()));
151   const auto OtherSGPROcc = std::min(MaxOccupancy,
152                                 ST.getOccupancyWithNumSGPRs(O.getSGPRNum()));
153   const auto OtherVGPROcc = std::min(MaxOccupancy,
154                                 ST.getOccupancyWithNumVGPRs(O.getVGPRNum()));
155 
156   const auto Occ = std::min(SGPROcc, VGPROcc);
157   const auto OtherOcc = std::min(OtherSGPROcc, OtherVGPROcc);
158   if (Occ != OtherOcc)
159     return Occ > OtherOcc;
160 
161   bool SGPRImportant = SGPROcc < VGPROcc;
162   const bool OtherSGPRImportant = OtherSGPROcc < OtherVGPROcc;
163 
164   // if both pressures disagree on what is more important compare vgprs
165   if (SGPRImportant != OtherSGPRImportant) {
166     SGPRImportant = false;
167   }
168 
169   // compare large regs pressure
170   bool SGPRFirst = SGPRImportant;
171   for (int I = 2; I > 0; --I, SGPRFirst = !SGPRFirst) {
172     if (SGPRFirst) {
173       auto SW = getSGPRTuplesWeight();
174       auto OtherSW = O.getSGPRTuplesWeight();
175       if (SW != OtherSW)
176         return SW < OtherSW;
177     } else {
178       auto VW = getVGPRTuplesWeight();
179       auto OtherVW = O.getVGPRTuplesWeight();
180       if (VW != OtherVW)
181         return VW < OtherVW;
182     }
183   }
184   return SGPRImportant ? (getSGPRNum() < O.getSGPRNum()):
185                          (getVGPRNum() < O.getVGPRNum());
186 }
187 
188 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
189 LLVM_DUMP_METHOD
190 void GCNRegPressure::print(raw_ostream &OS, const GCNSubtarget *ST) const {
191   OS << "VGPRs: " << Value[VGPR32] << ' ';
192   OS << "AGPRs: " << Value[AGPR32];
193   if (ST) OS << "(O" << ST->getOccupancyWithNumVGPRs(getVGPRNum()) << ')';
194   OS << ", SGPRs: " << getSGPRNum();
195   if (ST) OS << "(O" << ST->getOccupancyWithNumSGPRs(getSGPRNum()) << ')';
196   OS << ", LVGPR WT: " << getVGPRTuplesWeight()
197      << ", LSGPR WT: " << getSGPRTuplesWeight();
198   if (ST) OS << " -> Occ: " << getOccupancy(*ST);
199   OS << '\n';
200 }
201 #endif
202 
203 static LaneBitmask getDefRegMask(const MachineOperand &MO,
204                                  const MachineRegisterInfo &MRI) {
205   assert(MO.isDef() && MO.isReg() && Register::isVirtualRegister(MO.getReg()));
206 
207   // We don't rely on read-undef flag because in case of tentative schedule
208   // tracking it isn't set correctly yet. This works correctly however since
209   // use mask has been tracked before using LIS.
210   return MO.getSubReg() == 0 ?
211     MRI.getMaxLaneMaskForVReg(MO.getReg()) :
212     MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(MO.getSubReg());
213 }
214 
215 static LaneBitmask getUsedRegMask(const MachineOperand &MO,
216                                   const MachineRegisterInfo &MRI,
217                                   const LiveIntervals &LIS) {
218   assert(MO.isUse() && MO.isReg() && Register::isVirtualRegister(MO.getReg()));
219 
220   if (auto SubReg = MO.getSubReg())
221     return MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(SubReg);
222 
223   auto MaxMask = MRI.getMaxLaneMaskForVReg(MO.getReg());
224   if (SIRegisterInfo::getNumCoveredRegs(MaxMask) > 1) // cannot have subregs
225     return MaxMask;
226 
227   // For a tentative schedule LIS isn't updated yet but livemask should remain
228   // the same on any schedule. Subreg defs can be reordered but they all must
229   // dominate uses anyway.
230   auto SI = LIS.getInstructionIndex(*MO.getParent()).getBaseIndex();
231   return getLiveLaneMask(MO.getReg(), SI, LIS, MRI);
232 }
233 
234 static SmallVector<RegisterMaskPair, 8>
235 collectVirtualRegUses(const MachineInstr &MI, const LiveIntervals &LIS,
236                       const MachineRegisterInfo &MRI) {
237   SmallVector<RegisterMaskPair, 8> Res;
238   for (const auto &MO : MI.operands()) {
239     if (!MO.isReg() || !Register::isVirtualRegister(MO.getReg()))
240       continue;
241     if (!MO.isUse() || !MO.readsReg())
242       continue;
243 
244     auto const UsedMask = getUsedRegMask(MO, MRI, LIS);
245 
246     auto Reg = MO.getReg();
247     auto I = std::find_if(Res.begin(), Res.end(), [Reg](const RegisterMaskPair &RM) {
248       return RM.RegUnit == Reg;
249     });
250     if (I != Res.end())
251       I->LaneMask |= UsedMask;
252     else
253       Res.push_back(RegisterMaskPair(Reg, UsedMask));
254   }
255   return Res;
256 }
257 
258 ///////////////////////////////////////////////////////////////////////////////
259 // GCNRPTracker
260 
261 LaneBitmask llvm::getLiveLaneMask(unsigned Reg,
262                                   SlotIndex SI,
263                                   const LiveIntervals &LIS,
264                                   const MachineRegisterInfo &MRI) {
265   LaneBitmask LiveMask;
266   const auto &LI = LIS.getInterval(Reg);
267   if (LI.hasSubRanges()) {
268     for (const auto &S : LI.subranges())
269       if (S.liveAt(SI)) {
270         LiveMask |= S.LaneMask;
271         assert(LiveMask < MRI.getMaxLaneMaskForVReg(Reg) ||
272                LiveMask == MRI.getMaxLaneMaskForVReg(Reg));
273       }
274   } else if (LI.liveAt(SI)) {
275     LiveMask = MRI.getMaxLaneMaskForVReg(Reg);
276   }
277   return LiveMask;
278 }
279 
280 GCNRPTracker::LiveRegSet llvm::getLiveRegs(SlotIndex SI,
281                                            const LiveIntervals &LIS,
282                                            const MachineRegisterInfo &MRI) {
283   GCNRPTracker::LiveRegSet LiveRegs;
284   for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
285     auto Reg = Register::index2VirtReg(I);
286     if (!LIS.hasInterval(Reg))
287       continue;
288     auto LiveMask = getLiveLaneMask(Reg, SI, LIS, MRI);
289     if (LiveMask.any())
290       LiveRegs[Reg] = LiveMask;
291   }
292   return LiveRegs;
293 }
294 
295 void GCNRPTracker::reset(const MachineInstr &MI,
296                          const LiveRegSet *LiveRegsCopy,
297                          bool After) {
298   const MachineFunction &MF = *MI.getMF();
299   MRI = &MF.getRegInfo();
300   if (LiveRegsCopy) {
301     if (&LiveRegs != LiveRegsCopy)
302       LiveRegs = *LiveRegsCopy;
303   } else {
304     LiveRegs = After ? getLiveRegsAfter(MI, LIS)
305                      : getLiveRegsBefore(MI, LIS);
306   }
307 
308   MaxPressure = CurPressure = getRegPressure(*MRI, LiveRegs);
309 }
310 
311 void GCNUpwardRPTracker::reset(const MachineInstr &MI,
312                                const LiveRegSet *LiveRegsCopy) {
313   GCNRPTracker::reset(MI, LiveRegsCopy, true);
314 }
315 
316 void GCNUpwardRPTracker::recede(const MachineInstr &MI) {
317   assert(MRI && "call reset first");
318 
319   LastTrackedMI = &MI;
320 
321   if (MI.isDebugInstr())
322     return;
323 
324   auto const RegUses = collectVirtualRegUses(MI, LIS, *MRI);
325 
326   // calc pressure at the MI (defs + uses)
327   auto AtMIPressure = CurPressure;
328   for (const auto &U : RegUses) {
329     auto LiveMask = LiveRegs[U.RegUnit];
330     AtMIPressure.inc(U.RegUnit, LiveMask, LiveMask | U.LaneMask, *MRI);
331   }
332   // update max pressure
333   MaxPressure = max(AtMIPressure, MaxPressure);
334 
335   for (const auto &MO : MI.operands()) {
336     if (!MO.isReg() || !MO.isDef() ||
337         !Register::isVirtualRegister(MO.getReg()) || MO.isDead())
338       continue;
339 
340     auto Reg = MO.getReg();
341     auto I = LiveRegs.find(Reg);
342     if (I == LiveRegs.end())
343       continue;
344     auto &LiveMask = I->second;
345     auto PrevMask = LiveMask;
346     LiveMask &= ~getDefRegMask(MO, *MRI);
347     CurPressure.inc(Reg, PrevMask, LiveMask, *MRI);
348     if (LiveMask.none())
349       LiveRegs.erase(I);
350   }
351   for (const auto &U : RegUses) {
352     auto &LiveMask = LiveRegs[U.RegUnit];
353     auto PrevMask = LiveMask;
354     LiveMask |= U.LaneMask;
355     CurPressure.inc(U.RegUnit, PrevMask, LiveMask, *MRI);
356   }
357   assert(CurPressure == getRegPressure(*MRI, LiveRegs));
358 }
359 
360 bool GCNDownwardRPTracker::reset(const MachineInstr &MI,
361                                  const LiveRegSet *LiveRegsCopy) {
362   MRI = &MI.getParent()->getParent()->getRegInfo();
363   LastTrackedMI = nullptr;
364   MBBEnd = MI.getParent()->end();
365   NextMI = &MI;
366   NextMI = skipDebugInstructionsForward(NextMI, MBBEnd);
367   if (NextMI == MBBEnd)
368     return false;
369   GCNRPTracker::reset(*NextMI, LiveRegsCopy, false);
370   return true;
371 }
372 
373 bool GCNDownwardRPTracker::advanceBeforeNext() {
374   assert(MRI && "call reset first");
375 
376   NextMI = skipDebugInstructionsForward(NextMI, MBBEnd);
377   if (NextMI == MBBEnd)
378     return false;
379 
380   SlotIndex SI = LIS.getInstructionIndex(*NextMI).getBaseIndex();
381   assert(SI.isValid());
382 
383   // Remove dead registers or mask bits.
384   for (auto &It : LiveRegs) {
385     const LiveInterval &LI = LIS.getInterval(It.first);
386     if (LI.hasSubRanges()) {
387       for (const auto &S : LI.subranges()) {
388         if (!S.liveAt(SI)) {
389           auto PrevMask = It.second;
390           It.second &= ~S.LaneMask;
391           CurPressure.inc(It.first, PrevMask, It.second, *MRI);
392         }
393       }
394     } else if (!LI.liveAt(SI)) {
395       auto PrevMask = It.second;
396       It.second = LaneBitmask::getNone();
397       CurPressure.inc(It.first, PrevMask, It.second, *MRI);
398     }
399     if (It.second.none())
400       LiveRegs.erase(It.first);
401   }
402 
403   MaxPressure = max(MaxPressure, CurPressure);
404 
405   return true;
406 }
407 
408 void GCNDownwardRPTracker::advanceToNext() {
409   LastTrackedMI = &*NextMI++;
410 
411   // Add new registers or mask bits.
412   for (const auto &MO : LastTrackedMI->operands()) {
413     if (!MO.isReg() || !MO.isDef())
414       continue;
415     Register Reg = MO.getReg();
416     if (!Register::isVirtualRegister(Reg))
417       continue;
418     auto &LiveMask = LiveRegs[Reg];
419     auto PrevMask = LiveMask;
420     LiveMask |= getDefRegMask(MO, *MRI);
421     CurPressure.inc(Reg, PrevMask, LiveMask, *MRI);
422   }
423 
424   MaxPressure = max(MaxPressure, CurPressure);
425 }
426 
427 bool GCNDownwardRPTracker::advance() {
428   // If we have just called reset live set is actual.
429   if ((NextMI == MBBEnd) || (LastTrackedMI && !advanceBeforeNext()))
430     return false;
431   advanceToNext();
432   return true;
433 }
434 
435 bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator End) {
436   while (NextMI != End)
437     if (!advance()) return false;
438   return true;
439 }
440 
441 bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator Begin,
442                                    MachineBasicBlock::const_iterator End,
443                                    const LiveRegSet *LiveRegsCopy) {
444   reset(*Begin, LiveRegsCopy);
445   return advance(End);
446 }
447 
448 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
449 LLVM_DUMP_METHOD
450 static void reportMismatch(const GCNRPTracker::LiveRegSet &LISLR,
451                            const GCNRPTracker::LiveRegSet &TrackedLR,
452                            const TargetRegisterInfo *TRI) {
453   for (auto const &P : TrackedLR) {
454     auto I = LISLR.find(P.first);
455     if (I == LISLR.end()) {
456       dbgs() << "  " << printReg(P.first, TRI)
457              << ":L" << PrintLaneMask(P.second)
458              << " isn't found in LIS reported set\n";
459     }
460     else if (I->second != P.second) {
461       dbgs() << "  " << printReg(P.first, TRI)
462         << " masks doesn't match: LIS reported "
463         << PrintLaneMask(I->second)
464         << ", tracked "
465         << PrintLaneMask(P.second)
466         << '\n';
467     }
468   }
469   for (auto const &P : LISLR) {
470     auto I = TrackedLR.find(P.first);
471     if (I == TrackedLR.end()) {
472       dbgs() << "  " << printReg(P.first, TRI)
473              << ":L" << PrintLaneMask(P.second)
474              << " isn't found in tracked set\n";
475     }
476   }
477 }
478 
479 bool GCNUpwardRPTracker::isValid() const {
480   const auto &SI = LIS.getInstructionIndex(*LastTrackedMI).getBaseIndex();
481   const auto LISLR = llvm::getLiveRegs(SI, LIS, *MRI);
482   const auto &TrackedLR = LiveRegs;
483 
484   if (!isEqual(LISLR, TrackedLR)) {
485     dbgs() << "\nGCNUpwardRPTracker error: Tracked and"
486               " LIS reported livesets mismatch:\n";
487     printLivesAt(SI, LIS, *MRI);
488     reportMismatch(LISLR, TrackedLR, MRI->getTargetRegisterInfo());
489     return false;
490   }
491 
492   auto LISPressure = getRegPressure(*MRI, LISLR);
493   if (LISPressure != CurPressure) {
494     dbgs() << "GCNUpwardRPTracker error: Pressure sets different\nTracked: ";
495     CurPressure.print(dbgs());
496     dbgs() << "LIS rpt: ";
497     LISPressure.print(dbgs());
498     return false;
499   }
500   return true;
501 }
502 
503 void GCNRPTracker::printLiveRegs(raw_ostream &OS, const LiveRegSet& LiveRegs,
504                                  const MachineRegisterInfo &MRI) {
505   const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo();
506   for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
507     unsigned Reg = Register::index2VirtReg(I);
508     auto It = LiveRegs.find(Reg);
509     if (It != LiveRegs.end() && It->second.any())
510       OS << ' ' << printVRegOrUnit(Reg, TRI) << ':'
511          << PrintLaneMask(It->second);
512   }
513   OS << '\n';
514 }
515 #endif
516