1 //===- GCNRegPressure.cpp -------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "GCNRegPressure.h"
10 #include "AMDGPUSubtarget.h"
11 #include "SIRegisterInfo.h"
12 #include "llvm/ADT/SmallVector.h"
13 #include "llvm/CodeGen/LiveInterval.h"
14 #include "llvm/CodeGen/LiveIntervals.h"
15 #include "llvm/CodeGen/MachineInstr.h"
16 #include "llvm/CodeGen/MachineOperand.h"
17 #include "llvm/CodeGen/MachineRegisterInfo.h"
18 #include "llvm/CodeGen/RegisterPressure.h"
19 #include "llvm/CodeGen/SlotIndexes.h"
20 #include "llvm/CodeGen/TargetRegisterInfo.h"
21 #include "llvm/Config/llvm-config.h"
22 #include "llvm/MC/LaneBitmask.h"
23 #include "llvm/Support/Compiler.h"
24 #include "llvm/Support/Debug.h"
25 #include "llvm/Support/ErrorHandling.h"
26 #include "llvm/Support/raw_ostream.h"
27 #include <algorithm>
28 #include <cassert>
29 
30 using namespace llvm;
31 
32 #define DEBUG_TYPE "machine-scheduler"
33 
34 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
35 LLVM_DUMP_METHOD
36 void llvm::printLivesAt(SlotIndex SI,
37                         const LiveIntervals &LIS,
38                         const MachineRegisterInfo &MRI) {
39   dbgs() << "Live regs at " << SI << ": "
40          << *LIS.getInstructionFromIndex(SI);
41   unsigned Num = 0;
42   for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
43     const unsigned Reg = TargetRegisterInfo::index2VirtReg(I);
44     if (!LIS.hasInterval(Reg))
45       continue;
46     const auto &LI = LIS.getInterval(Reg);
47     if (LI.hasSubRanges()) {
48       bool firstTime = true;
49       for (const auto &S : LI.subranges()) {
50         if (!S.liveAt(SI)) continue;
51         if (firstTime) {
52           dbgs() << "  " << printReg(Reg, MRI.getTargetRegisterInfo())
53                  << '\n';
54           firstTime = false;
55         }
56         dbgs() << "  " << S << '\n';
57         ++Num;
58       }
59     } else if (LI.liveAt(SI)) {
60       dbgs() << "  " << LI << '\n';
61       ++Num;
62     }
63   }
64   if (!Num) dbgs() << "  <none>\n";
65 }
66 #endif
67 
68 bool llvm::isEqual(const GCNRPTracker::LiveRegSet &S1,
69                    const GCNRPTracker::LiveRegSet &S2) {
70   if (S1.size() != S2.size())
71     return false;
72 
73   for (const auto &P : S1) {
74     auto I = S2.find(P.first);
75     if (I == S2.end() || I->second != P.second)
76       return false;
77   }
78   return true;
79 }
80 
81 
82 ///////////////////////////////////////////////////////////////////////////////
83 // GCNRegPressure
84 
85 unsigned GCNRegPressure::getRegKind(unsigned Reg,
86                                     const MachineRegisterInfo &MRI) {
87   assert(TargetRegisterInfo::isVirtualRegister(Reg));
88   const auto RC = MRI.getRegClass(Reg);
89   auto STI = static_cast<const SIRegisterInfo*>(MRI.getTargetRegisterInfo());
90   return STI->isSGPRClass(RC) ?
91     (STI->getRegSizeInBits(*RC) == 32 ? SGPR32 : SGPR_TUPLE) :
92     STI->hasAGPRs(RC) ?
93       (STI->getRegSizeInBits(*RC) == 32 ? AGPR32 : AGPR_TUPLE) :
94       (STI->getRegSizeInBits(*RC) == 32 ? VGPR32 : VGPR_TUPLE);
95 }
96 
97 void GCNRegPressure::inc(unsigned Reg,
98                          LaneBitmask PrevMask,
99                          LaneBitmask NewMask,
100                          const MachineRegisterInfo &MRI) {
101   if (NewMask == PrevMask)
102     return;
103 
104   int Sign = 1;
105   if (NewMask < PrevMask) {
106     std::swap(NewMask, PrevMask);
107     Sign = -1;
108   }
109 #ifndef NDEBUG
110   const auto MaxMask = MRI.getMaxLaneMaskForVReg(Reg);
111 #endif
112   switch (auto Kind = getRegKind(Reg, MRI)) {
113   case SGPR32:
114   case VGPR32:
115   case AGPR32:
116     assert(PrevMask.none() && NewMask == MaxMask);
117     Value[Kind] += Sign;
118     break;
119 
120   case SGPR_TUPLE:
121   case VGPR_TUPLE:
122   case AGPR_TUPLE:
123     assert(NewMask < MaxMask || NewMask == MaxMask);
124     assert(PrevMask < NewMask);
125 
126     Value[Kind == SGPR_TUPLE ? SGPR32 : Kind == AGPR_TUPLE ? AGPR32 : VGPR32] +=
127       Sign * (~PrevMask & NewMask).getNumLanes();
128 
129     if (PrevMask.none()) {
130       assert(NewMask.any());
131       Value[Kind] += Sign * MRI.getPressureSets(Reg).getWeight();
132     }
133     break;
134 
135   default: llvm_unreachable("Unknown register kind");
136   }
137 }
138 
139 bool GCNRegPressure::less(const GCNSubtarget &ST,
140                           const GCNRegPressure& O,
141                           unsigned MaxOccupancy) const {
142   const auto SGPROcc = std::min(MaxOccupancy,
143                                 ST.getOccupancyWithNumSGPRs(getSGPRNum()));
144   const auto VGPROcc = std::min(MaxOccupancy,
145                                 ST.getOccupancyWithNumVGPRs(getVGPRNum()));
146   const auto OtherSGPROcc = std::min(MaxOccupancy,
147                                 ST.getOccupancyWithNumSGPRs(O.getSGPRNum()));
148   const auto OtherVGPROcc = std::min(MaxOccupancy,
149                                 ST.getOccupancyWithNumVGPRs(O.getVGPRNum()));
150 
151   const auto Occ = std::min(SGPROcc, VGPROcc);
152   const auto OtherOcc = std::min(OtherSGPROcc, OtherVGPROcc);
153   if (Occ != OtherOcc)
154     return Occ > OtherOcc;
155 
156   bool SGPRImportant = SGPROcc < VGPROcc;
157   const bool OtherSGPRImportant = OtherSGPROcc < OtherVGPROcc;
158 
159   // if both pressures disagree on what is more important compare vgprs
160   if (SGPRImportant != OtherSGPRImportant) {
161     SGPRImportant = false;
162   }
163 
164   // compare large regs pressure
165   bool SGPRFirst = SGPRImportant;
166   for (int I = 2; I > 0; --I, SGPRFirst = !SGPRFirst) {
167     if (SGPRFirst) {
168       auto SW = getSGPRTuplesWeight();
169       auto OtherSW = O.getSGPRTuplesWeight();
170       if (SW != OtherSW)
171         return SW < OtherSW;
172     } else {
173       auto VW = getVGPRTuplesWeight();
174       auto OtherVW = O.getVGPRTuplesWeight();
175       if (VW != OtherVW)
176         return VW < OtherVW;
177     }
178   }
179   return SGPRImportant ? (getSGPRNum() < O.getSGPRNum()):
180                          (getVGPRNum() < O.getVGPRNum());
181 }
182 
183 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
184 LLVM_DUMP_METHOD
185 void GCNRegPressure::print(raw_ostream &OS, const GCNSubtarget *ST) const {
186   OS << "VGPRs: " << Value[VGPR32] << ' ';
187   OS << "AGPRs: " << Value[AGPR32];
188   if (ST) OS << "(O" << ST->getOccupancyWithNumVGPRs(getVGPRNum()) << ')';
189   OS << ", SGPRs: " << getSGPRNum();
190   if (ST) OS << "(O" << ST->getOccupancyWithNumSGPRs(getSGPRNum()) << ')';
191   OS << ", LVGPR WT: " << getVGPRTuplesWeight()
192      << ", LSGPR WT: " << getSGPRTuplesWeight();
193   if (ST) OS << " -> Occ: " << getOccupancy(*ST);
194   OS << '\n';
195 }
196 #endif
197 
198 static LaneBitmask getDefRegMask(const MachineOperand &MO,
199                                  const MachineRegisterInfo &MRI) {
200   assert(MO.isDef() && MO.isReg() &&
201     TargetRegisterInfo::isVirtualRegister(MO.getReg()));
202 
203   // We don't rely on read-undef flag because in case of tentative schedule
204   // tracking it isn't set correctly yet. This works correctly however since
205   // use mask has been tracked before using LIS.
206   return MO.getSubReg() == 0 ?
207     MRI.getMaxLaneMaskForVReg(MO.getReg()) :
208     MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(MO.getSubReg());
209 }
210 
211 static LaneBitmask getUsedRegMask(const MachineOperand &MO,
212                                   const MachineRegisterInfo &MRI,
213                                   const LiveIntervals &LIS) {
214   assert(MO.isUse() && MO.isReg() &&
215          TargetRegisterInfo::isVirtualRegister(MO.getReg()));
216 
217   if (auto SubReg = MO.getSubReg())
218     return MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(SubReg);
219 
220   auto MaxMask = MRI.getMaxLaneMaskForVReg(MO.getReg());
221   if (MaxMask == LaneBitmask::getLane(0)) // cannot have subregs
222     return MaxMask;
223 
224   // For a tentative schedule LIS isn't updated yet but livemask should remain
225   // the same on any schedule. Subreg defs can be reordered but they all must
226   // dominate uses anyway.
227   auto SI = LIS.getInstructionIndex(*MO.getParent()).getBaseIndex();
228   return getLiveLaneMask(MO.getReg(), SI, LIS, MRI);
229 }
230 
231 static SmallVector<RegisterMaskPair, 8>
232 collectVirtualRegUses(const MachineInstr &MI, const LiveIntervals &LIS,
233                       const MachineRegisterInfo &MRI) {
234   SmallVector<RegisterMaskPair, 8> Res;
235   for (const auto &MO : MI.operands()) {
236     if (!MO.isReg() || !TargetRegisterInfo::isVirtualRegister(MO.getReg()))
237       continue;
238     if (!MO.isUse() || !MO.readsReg())
239       continue;
240 
241     auto const UsedMask = getUsedRegMask(MO, MRI, LIS);
242 
243     auto Reg = MO.getReg();
244     auto I = std::find_if(Res.begin(), Res.end(), [Reg](const RegisterMaskPair &RM) {
245       return RM.RegUnit == Reg;
246     });
247     if (I != Res.end())
248       I->LaneMask |= UsedMask;
249     else
250       Res.push_back(RegisterMaskPair(Reg, UsedMask));
251   }
252   return Res;
253 }
254 
255 ///////////////////////////////////////////////////////////////////////////////
256 // GCNRPTracker
257 
258 LaneBitmask llvm::getLiveLaneMask(unsigned Reg,
259                                   SlotIndex SI,
260                                   const LiveIntervals &LIS,
261                                   const MachineRegisterInfo &MRI) {
262   LaneBitmask LiveMask;
263   const auto &LI = LIS.getInterval(Reg);
264   if (LI.hasSubRanges()) {
265     for (const auto &S : LI.subranges())
266       if (S.liveAt(SI)) {
267         LiveMask |= S.LaneMask;
268         assert(LiveMask < MRI.getMaxLaneMaskForVReg(Reg) ||
269                LiveMask == MRI.getMaxLaneMaskForVReg(Reg));
270       }
271   } else if (LI.liveAt(SI)) {
272     LiveMask = MRI.getMaxLaneMaskForVReg(Reg);
273   }
274   return LiveMask;
275 }
276 
277 GCNRPTracker::LiveRegSet llvm::getLiveRegs(SlotIndex SI,
278                                            const LiveIntervals &LIS,
279                                            const MachineRegisterInfo &MRI) {
280   GCNRPTracker::LiveRegSet LiveRegs;
281   for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
282     auto Reg = TargetRegisterInfo::index2VirtReg(I);
283     if (!LIS.hasInterval(Reg))
284       continue;
285     auto LiveMask = getLiveLaneMask(Reg, SI, LIS, MRI);
286     if (LiveMask.any())
287       LiveRegs[Reg] = LiveMask;
288   }
289   return LiveRegs;
290 }
291 
292 void GCNRPTracker::reset(const MachineInstr &MI,
293                          const LiveRegSet *LiveRegsCopy,
294                          bool After) {
295   const MachineFunction &MF = *MI.getMF();
296   MRI = &MF.getRegInfo();
297   if (LiveRegsCopy) {
298     if (&LiveRegs != LiveRegsCopy)
299       LiveRegs = *LiveRegsCopy;
300   } else {
301     LiveRegs = After ? getLiveRegsAfter(MI, LIS)
302                      : getLiveRegsBefore(MI, LIS);
303   }
304 
305   MaxPressure = CurPressure = getRegPressure(*MRI, LiveRegs);
306 }
307 
308 void GCNUpwardRPTracker::reset(const MachineInstr &MI,
309                                const LiveRegSet *LiveRegsCopy) {
310   GCNRPTracker::reset(MI, LiveRegsCopy, true);
311 }
312 
313 void GCNUpwardRPTracker::recede(const MachineInstr &MI) {
314   assert(MRI && "call reset first");
315 
316   LastTrackedMI = &MI;
317 
318   if (MI.isDebugInstr())
319     return;
320 
321   auto const RegUses = collectVirtualRegUses(MI, LIS, *MRI);
322 
323   // calc pressure at the MI (defs + uses)
324   auto AtMIPressure = CurPressure;
325   for (const auto &U : RegUses) {
326     auto LiveMask = LiveRegs[U.RegUnit];
327     AtMIPressure.inc(U.RegUnit, LiveMask, LiveMask | U.LaneMask, *MRI);
328   }
329   // update max pressure
330   MaxPressure = max(AtMIPressure, MaxPressure);
331 
332   for (const auto &MO : MI.defs()) {
333     if (!MO.isReg() || !TargetRegisterInfo::isVirtualRegister(MO.getReg()) ||
334          MO.isDead())
335       continue;
336 
337     auto Reg = MO.getReg();
338     auto I = LiveRegs.find(Reg);
339     if (I == LiveRegs.end())
340       continue;
341     auto &LiveMask = I->second;
342     auto PrevMask = LiveMask;
343     LiveMask &= ~getDefRegMask(MO, *MRI);
344     CurPressure.inc(Reg, PrevMask, LiveMask, *MRI);
345     if (LiveMask.none())
346       LiveRegs.erase(I);
347   }
348   for (const auto &U : RegUses) {
349     auto &LiveMask = LiveRegs[U.RegUnit];
350     auto PrevMask = LiveMask;
351     LiveMask |= U.LaneMask;
352     CurPressure.inc(U.RegUnit, PrevMask, LiveMask, *MRI);
353   }
354   assert(CurPressure == getRegPressure(*MRI, LiveRegs));
355 }
356 
357 bool GCNDownwardRPTracker::reset(const MachineInstr &MI,
358                                  const LiveRegSet *LiveRegsCopy) {
359   MRI = &MI.getParent()->getParent()->getRegInfo();
360   LastTrackedMI = nullptr;
361   MBBEnd = MI.getParent()->end();
362   NextMI = &MI;
363   NextMI = skipDebugInstructionsForward(NextMI, MBBEnd);
364   if (NextMI == MBBEnd)
365     return false;
366   GCNRPTracker::reset(*NextMI, LiveRegsCopy, false);
367   return true;
368 }
369 
370 bool GCNDownwardRPTracker::advanceBeforeNext() {
371   assert(MRI && "call reset first");
372 
373   NextMI = skipDebugInstructionsForward(NextMI, MBBEnd);
374   if (NextMI == MBBEnd)
375     return false;
376 
377   SlotIndex SI = LIS.getInstructionIndex(*NextMI).getBaseIndex();
378   assert(SI.isValid());
379 
380   // Remove dead registers or mask bits.
381   for (auto &It : LiveRegs) {
382     const LiveInterval &LI = LIS.getInterval(It.first);
383     if (LI.hasSubRanges()) {
384       for (const auto &S : LI.subranges()) {
385         if (!S.liveAt(SI)) {
386           auto PrevMask = It.second;
387           It.second &= ~S.LaneMask;
388           CurPressure.inc(It.first, PrevMask, It.second, *MRI);
389         }
390       }
391     } else if (!LI.liveAt(SI)) {
392       auto PrevMask = It.second;
393       It.second = LaneBitmask::getNone();
394       CurPressure.inc(It.first, PrevMask, It.second, *MRI);
395     }
396     if (It.second.none())
397       LiveRegs.erase(It.first);
398   }
399 
400   MaxPressure = max(MaxPressure, CurPressure);
401 
402   return true;
403 }
404 
405 void GCNDownwardRPTracker::advanceToNext() {
406   LastTrackedMI = &*NextMI++;
407 
408   // Add new registers or mask bits.
409   for (const auto &MO : LastTrackedMI->defs()) {
410     if (!MO.isReg())
411       continue;
412     unsigned Reg = MO.getReg();
413     if (!TargetRegisterInfo::isVirtualRegister(Reg))
414       continue;
415     auto &LiveMask = LiveRegs[Reg];
416     auto PrevMask = LiveMask;
417     LiveMask |= getDefRegMask(MO, *MRI);
418     CurPressure.inc(Reg, PrevMask, LiveMask, *MRI);
419   }
420 
421   MaxPressure = max(MaxPressure, CurPressure);
422 }
423 
424 bool GCNDownwardRPTracker::advance() {
425   // If we have just called reset live set is actual.
426   if ((NextMI == MBBEnd) || (LastTrackedMI && !advanceBeforeNext()))
427     return false;
428   advanceToNext();
429   return true;
430 }
431 
432 bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator End) {
433   while (NextMI != End)
434     if (!advance()) return false;
435   return true;
436 }
437 
438 bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator Begin,
439                                    MachineBasicBlock::const_iterator End,
440                                    const LiveRegSet *LiveRegsCopy) {
441   reset(*Begin, LiveRegsCopy);
442   return advance(End);
443 }
444 
445 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
446 LLVM_DUMP_METHOD
447 static void reportMismatch(const GCNRPTracker::LiveRegSet &LISLR,
448                            const GCNRPTracker::LiveRegSet &TrackedLR,
449                            const TargetRegisterInfo *TRI) {
450   for (auto const &P : TrackedLR) {
451     auto I = LISLR.find(P.first);
452     if (I == LISLR.end()) {
453       dbgs() << "  " << printReg(P.first, TRI)
454              << ":L" << PrintLaneMask(P.second)
455              << " isn't found in LIS reported set\n";
456     }
457     else if (I->second != P.second) {
458       dbgs() << "  " << printReg(P.first, TRI)
459         << " masks doesn't match: LIS reported "
460         << PrintLaneMask(I->second)
461         << ", tracked "
462         << PrintLaneMask(P.second)
463         << '\n';
464     }
465   }
466   for (auto const &P : LISLR) {
467     auto I = TrackedLR.find(P.first);
468     if (I == TrackedLR.end()) {
469       dbgs() << "  " << printReg(P.first, TRI)
470              << ":L" << PrintLaneMask(P.second)
471              << " isn't found in tracked set\n";
472     }
473   }
474 }
475 
476 bool GCNUpwardRPTracker::isValid() const {
477   const auto &SI = LIS.getInstructionIndex(*LastTrackedMI).getBaseIndex();
478   const auto LISLR = llvm::getLiveRegs(SI, LIS, *MRI);
479   const auto &TrackedLR = LiveRegs;
480 
481   if (!isEqual(LISLR, TrackedLR)) {
482     dbgs() << "\nGCNUpwardRPTracker error: Tracked and"
483               " LIS reported livesets mismatch:\n";
484     printLivesAt(SI, LIS, *MRI);
485     reportMismatch(LISLR, TrackedLR, MRI->getTargetRegisterInfo());
486     return false;
487   }
488 
489   auto LISPressure = getRegPressure(*MRI, LISLR);
490   if (LISPressure != CurPressure) {
491     dbgs() << "GCNUpwardRPTracker error: Pressure sets different\nTracked: ";
492     CurPressure.print(dbgs());
493     dbgs() << "LIS rpt: ";
494     LISPressure.print(dbgs());
495     return false;
496   }
497   return true;
498 }
499 
500 void GCNRPTracker::printLiveRegs(raw_ostream &OS, const LiveRegSet& LiveRegs,
501                                  const MachineRegisterInfo &MRI) {
502   const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo();
503   for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
504     unsigned Reg = TargetRegisterInfo::index2VirtReg(I);
505     auto It = LiveRegs.find(Reg);
506     if (It != LiveRegs.end() && It->second.any())
507       OS << ' ' << printVRegOrUnit(Reg, TRI) << ':'
508          << PrintLaneMask(It->second);
509   }
510   OS << '\n';
511 }
512 #endif
513