1 //===- GCNRegPressure.cpp -------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "GCNRegPressure.h"
10 #include "AMDGPUSubtarget.h"
11 #include "SIRegisterInfo.h"
12 #include "llvm/ADT/SmallVector.h"
13 #include "llvm/CodeGen/LiveInterval.h"
14 #include "llvm/CodeGen/LiveIntervals.h"
15 #include "llvm/CodeGen/MachineInstr.h"
16 #include "llvm/CodeGen/MachineOperand.h"
17 #include "llvm/CodeGen/MachineRegisterInfo.h"
18 #include "llvm/CodeGen/RegisterPressure.h"
19 #include "llvm/CodeGen/SlotIndexes.h"
20 #include "llvm/CodeGen/TargetRegisterInfo.h"
21 #include "llvm/Config/llvm-config.h"
22 #include "llvm/MC/LaneBitmask.h"
23 #include "llvm/Support/Compiler.h"
24 #include "llvm/Support/Debug.h"
25 #include "llvm/Support/ErrorHandling.h"
26 #include "llvm/Support/raw_ostream.h"
27 #include <algorithm>
28 #include <cassert>
29 
30 using namespace llvm;
31 
32 #define DEBUG_TYPE "machine-scheduler"
33 
34 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
35 LLVM_DUMP_METHOD
36 void llvm::printLivesAt(SlotIndex SI,
37                         const LiveIntervals &LIS,
38                         const MachineRegisterInfo &MRI) {
39   dbgs() << "Live regs at " << SI << ": "
40          << *LIS.getInstructionFromIndex(SI);
41   unsigned Num = 0;
42   for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
43     const unsigned Reg = TargetRegisterInfo::index2VirtReg(I);
44     if (!LIS.hasInterval(Reg))
45       continue;
46     const auto &LI = LIS.getInterval(Reg);
47     if (LI.hasSubRanges()) {
48       bool firstTime = true;
49       for (const auto &S : LI.subranges()) {
50         if (!S.liveAt(SI)) continue;
51         if (firstTime) {
52           dbgs() << "  " << printReg(Reg, MRI.getTargetRegisterInfo())
53                  << '\n';
54           firstTime = false;
55         }
56         dbgs() << "  " << S << '\n';
57         ++Num;
58       }
59     } else if (LI.liveAt(SI)) {
60       dbgs() << "  " << LI << '\n';
61       ++Num;
62     }
63   }
64   if (!Num) dbgs() << "  <none>\n";
65 }
66 
67 static bool isEqual(const GCNRPTracker::LiveRegSet &S1,
68                     const GCNRPTracker::LiveRegSet &S2) {
69   if (S1.size() != S2.size())
70     return false;
71 
72   for (const auto &P : S1) {
73     auto I = S2.find(P.first);
74     if (I == S2.end() || I->second != P.second)
75       return false;
76   }
77   return true;
78 }
79 #endif
80 
81 ///////////////////////////////////////////////////////////////////////////////
82 // GCNRegPressure
83 
84 unsigned GCNRegPressure::getRegKind(unsigned Reg,
85                                     const MachineRegisterInfo &MRI) {
86   assert(TargetRegisterInfo::isVirtualRegister(Reg));
87   const auto RC = MRI.getRegClass(Reg);
88   auto STI = static_cast<const SIRegisterInfo*>(MRI.getTargetRegisterInfo());
89   return STI->isSGPRClass(RC) ?
90     (STI->getRegSizeInBits(*RC) == 32 ? SGPR32 : SGPR_TUPLE) :
91     (STI->getRegSizeInBits(*RC) == 32 ? VGPR32 : VGPR_TUPLE);
92 }
93 
94 void GCNRegPressure::inc(unsigned Reg,
95                          LaneBitmask PrevMask,
96                          LaneBitmask NewMask,
97                          const MachineRegisterInfo &MRI) {
98   if (NewMask == PrevMask)
99     return;
100 
101   int Sign = 1;
102   if (NewMask < PrevMask) {
103     std::swap(NewMask, PrevMask);
104     Sign = -1;
105   }
106 #ifndef NDEBUG
107   const auto MaxMask = MRI.getMaxLaneMaskForVReg(Reg);
108 #endif
109   switch (auto Kind = getRegKind(Reg, MRI)) {
110   case SGPR32:
111   case VGPR32:
112     assert(PrevMask.none() && NewMask == MaxMask);
113     Value[Kind] += Sign;
114     break;
115 
116   case SGPR_TUPLE:
117   case VGPR_TUPLE:
118     assert(NewMask < MaxMask || NewMask == MaxMask);
119     assert(PrevMask < NewMask);
120 
121     Value[Kind == SGPR_TUPLE ? SGPR32 : VGPR32] +=
122       Sign * (~PrevMask & NewMask).getNumLanes();
123 
124     if (PrevMask.none()) {
125       assert(NewMask.any());
126       Value[Kind] += Sign * MRI.getPressureSets(Reg).getWeight();
127     }
128     break;
129 
130   default: llvm_unreachable("Unknown register kind");
131   }
132 }
133 
134 bool GCNRegPressure::less(const GCNSubtarget &ST,
135                           const GCNRegPressure& O,
136                           unsigned MaxOccupancy) const {
137   const auto SGPROcc = std::min(MaxOccupancy,
138                                 ST.getOccupancyWithNumSGPRs(getSGPRNum()));
139   const auto VGPROcc = std::min(MaxOccupancy,
140                                 ST.getOccupancyWithNumVGPRs(getVGPRNum()));
141   const auto OtherSGPROcc = std::min(MaxOccupancy,
142                                 ST.getOccupancyWithNumSGPRs(O.getSGPRNum()));
143   const auto OtherVGPROcc = std::min(MaxOccupancy,
144                                 ST.getOccupancyWithNumVGPRs(O.getVGPRNum()));
145 
146   const auto Occ = std::min(SGPROcc, VGPROcc);
147   const auto OtherOcc = std::min(OtherSGPROcc, OtherVGPROcc);
148   if (Occ != OtherOcc)
149     return Occ > OtherOcc;
150 
151   bool SGPRImportant = SGPROcc < VGPROcc;
152   const bool OtherSGPRImportant = OtherSGPROcc < OtherVGPROcc;
153 
154   // if both pressures disagree on what is more important compare vgprs
155   if (SGPRImportant != OtherSGPRImportant) {
156     SGPRImportant = false;
157   }
158 
159   // compare large regs pressure
160   bool SGPRFirst = SGPRImportant;
161   for (int I = 2; I > 0; --I, SGPRFirst = !SGPRFirst) {
162     if (SGPRFirst) {
163       auto SW = getSGPRTuplesWeight();
164       auto OtherSW = O.getSGPRTuplesWeight();
165       if (SW != OtherSW)
166         return SW < OtherSW;
167     } else {
168       auto VW = getVGPRTuplesWeight();
169       auto OtherVW = O.getVGPRTuplesWeight();
170       if (VW != OtherVW)
171         return VW < OtherVW;
172     }
173   }
174   return SGPRImportant ? (getSGPRNum() < O.getSGPRNum()):
175                          (getVGPRNum() < O.getVGPRNum());
176 }
177 
178 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
179 LLVM_DUMP_METHOD
180 void GCNRegPressure::print(raw_ostream &OS, const GCNSubtarget *ST) const {
181   OS << "VGPRs: " << getVGPRNum();
182   if (ST) OS << "(O" << ST->getOccupancyWithNumVGPRs(getVGPRNum()) << ')';
183   OS << ", SGPRs: " << getSGPRNum();
184   if (ST) OS << "(O" << ST->getOccupancyWithNumSGPRs(getSGPRNum()) << ')';
185   OS << ", LVGPR WT: " << getVGPRTuplesWeight()
186      << ", LSGPR WT: " << getSGPRTuplesWeight();
187   if (ST) OS << " -> Occ: " << getOccupancy(*ST);
188   OS << '\n';
189 }
190 #endif
191 
192 static LaneBitmask getDefRegMask(const MachineOperand &MO,
193                                  const MachineRegisterInfo &MRI) {
194   assert(MO.isDef() && MO.isReg() &&
195     TargetRegisterInfo::isVirtualRegister(MO.getReg()));
196 
197   // We don't rely on read-undef flag because in case of tentative schedule
198   // tracking it isn't set correctly yet. This works correctly however since
199   // use mask has been tracked before using LIS.
200   return MO.getSubReg() == 0 ?
201     MRI.getMaxLaneMaskForVReg(MO.getReg()) :
202     MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(MO.getSubReg());
203 }
204 
205 static LaneBitmask getUsedRegMask(const MachineOperand &MO,
206                                   const MachineRegisterInfo &MRI,
207                                   const LiveIntervals &LIS) {
208   assert(MO.isUse() && MO.isReg() &&
209          TargetRegisterInfo::isVirtualRegister(MO.getReg()));
210 
211   if (auto SubReg = MO.getSubReg())
212     return MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(SubReg);
213 
214   auto MaxMask = MRI.getMaxLaneMaskForVReg(MO.getReg());
215   if (MaxMask == LaneBitmask::getLane(0)) // cannot have subregs
216     return MaxMask;
217 
218   // For a tentative schedule LIS isn't updated yet but livemask should remain
219   // the same on any schedule. Subreg defs can be reordered but they all must
220   // dominate uses anyway.
221   auto SI = LIS.getInstructionIndex(*MO.getParent()).getBaseIndex();
222   return getLiveLaneMask(MO.getReg(), SI, LIS, MRI);
223 }
224 
225 static SmallVector<RegisterMaskPair, 8>
226 collectVirtualRegUses(const MachineInstr &MI, const LiveIntervals &LIS,
227                       const MachineRegisterInfo &MRI) {
228   SmallVector<RegisterMaskPair, 8> Res;
229   for (const auto &MO : MI.operands()) {
230     if (!MO.isReg() || !TargetRegisterInfo::isVirtualRegister(MO.getReg()))
231       continue;
232     if (!MO.isUse() || !MO.readsReg())
233       continue;
234 
235     auto const UsedMask = getUsedRegMask(MO, MRI, LIS);
236 
237     auto Reg = MO.getReg();
238     auto I = std::find_if(Res.begin(), Res.end(), [Reg](const RegisterMaskPair &RM) {
239       return RM.RegUnit == Reg;
240     });
241     if (I != Res.end())
242       I->LaneMask |= UsedMask;
243     else
244       Res.push_back(RegisterMaskPair(Reg, UsedMask));
245   }
246   return Res;
247 }
248 
249 ///////////////////////////////////////////////////////////////////////////////
250 // GCNRPTracker
251 
252 LaneBitmask llvm::getLiveLaneMask(unsigned Reg,
253                                   SlotIndex SI,
254                                   const LiveIntervals &LIS,
255                                   const MachineRegisterInfo &MRI) {
256   LaneBitmask LiveMask;
257   const auto &LI = LIS.getInterval(Reg);
258   if (LI.hasSubRanges()) {
259     for (const auto &S : LI.subranges())
260       if (S.liveAt(SI)) {
261         LiveMask |= S.LaneMask;
262         assert(LiveMask < MRI.getMaxLaneMaskForVReg(Reg) ||
263                LiveMask == MRI.getMaxLaneMaskForVReg(Reg));
264       }
265   } else if (LI.liveAt(SI)) {
266     LiveMask = MRI.getMaxLaneMaskForVReg(Reg);
267   }
268   return LiveMask;
269 }
270 
271 GCNRPTracker::LiveRegSet llvm::getLiveRegs(SlotIndex SI,
272                                            const LiveIntervals &LIS,
273                                            const MachineRegisterInfo &MRI) {
274   GCNRPTracker::LiveRegSet LiveRegs;
275   for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
276     auto Reg = TargetRegisterInfo::index2VirtReg(I);
277     if (!LIS.hasInterval(Reg))
278       continue;
279     auto LiveMask = getLiveLaneMask(Reg, SI, LIS, MRI);
280     if (LiveMask.any())
281       LiveRegs[Reg] = LiveMask;
282   }
283   return LiveRegs;
284 }
285 
286 void GCNRPTracker::reset(const MachineInstr &MI,
287                          const LiveRegSet *LiveRegsCopy,
288                          bool After) {
289   const MachineFunction &MF = *MI.getMF();
290   MRI = &MF.getRegInfo();
291   if (LiveRegsCopy) {
292     if (&LiveRegs != LiveRegsCopy)
293       LiveRegs = *LiveRegsCopy;
294   } else {
295     LiveRegs = After ? getLiveRegsAfter(MI, LIS)
296                      : getLiveRegsBefore(MI, LIS);
297   }
298 
299   MaxPressure = CurPressure = getRegPressure(*MRI, LiveRegs);
300 }
301 
302 void GCNUpwardRPTracker::reset(const MachineInstr &MI,
303                                const LiveRegSet *LiveRegsCopy) {
304   GCNRPTracker::reset(MI, LiveRegsCopy, true);
305 }
306 
307 void GCNUpwardRPTracker::recede(const MachineInstr &MI) {
308   assert(MRI && "call reset first");
309 
310   LastTrackedMI = &MI;
311 
312   if (MI.isDebugInstr())
313     return;
314 
315   auto const RegUses = collectVirtualRegUses(MI, LIS, *MRI);
316 
317   // calc pressure at the MI (defs + uses)
318   auto AtMIPressure = CurPressure;
319   for (const auto &U : RegUses) {
320     auto LiveMask = LiveRegs[U.RegUnit];
321     AtMIPressure.inc(U.RegUnit, LiveMask, LiveMask | U.LaneMask, *MRI);
322   }
323   // update max pressure
324   MaxPressure = max(AtMIPressure, MaxPressure);
325 
326   for (const auto &MO : MI.defs()) {
327     if (!MO.isReg() || !TargetRegisterInfo::isVirtualRegister(MO.getReg()) ||
328          MO.isDead())
329       continue;
330 
331     auto Reg = MO.getReg();
332     auto I = LiveRegs.find(Reg);
333     if (I == LiveRegs.end())
334       continue;
335     auto &LiveMask = I->second;
336     auto PrevMask = LiveMask;
337     LiveMask &= ~getDefRegMask(MO, *MRI);
338     CurPressure.inc(Reg, PrevMask, LiveMask, *MRI);
339     if (LiveMask.none())
340       LiveRegs.erase(I);
341   }
342   for (const auto &U : RegUses) {
343     auto &LiveMask = LiveRegs[U.RegUnit];
344     auto PrevMask = LiveMask;
345     LiveMask |= U.LaneMask;
346     CurPressure.inc(U.RegUnit, PrevMask, LiveMask, *MRI);
347   }
348   assert(CurPressure == getRegPressure(*MRI, LiveRegs));
349 }
350 
351 bool GCNDownwardRPTracker::reset(const MachineInstr &MI,
352                                  const LiveRegSet *LiveRegsCopy) {
353   MRI = &MI.getParent()->getParent()->getRegInfo();
354   LastTrackedMI = nullptr;
355   MBBEnd = MI.getParent()->end();
356   NextMI = &MI;
357   NextMI = skipDebugInstructionsForward(NextMI, MBBEnd);
358   if (NextMI == MBBEnd)
359     return false;
360   GCNRPTracker::reset(*NextMI, LiveRegsCopy, false);
361   return true;
362 }
363 
364 bool GCNDownwardRPTracker::advanceBeforeNext() {
365   assert(MRI && "call reset first");
366 
367   NextMI = skipDebugInstructionsForward(NextMI, MBBEnd);
368   if (NextMI == MBBEnd)
369     return false;
370 
371   SlotIndex SI = LIS.getInstructionIndex(*NextMI).getBaseIndex();
372   assert(SI.isValid());
373 
374   // Remove dead registers or mask bits.
375   for (auto &It : LiveRegs) {
376     const LiveInterval &LI = LIS.getInterval(It.first);
377     if (LI.hasSubRanges()) {
378       for (const auto &S : LI.subranges()) {
379         if (!S.liveAt(SI)) {
380           auto PrevMask = It.second;
381           It.second &= ~S.LaneMask;
382           CurPressure.inc(It.first, PrevMask, It.second, *MRI);
383         }
384       }
385     } else if (!LI.liveAt(SI)) {
386       auto PrevMask = It.second;
387       It.second = LaneBitmask::getNone();
388       CurPressure.inc(It.first, PrevMask, It.second, *MRI);
389     }
390     if (It.second.none())
391       LiveRegs.erase(It.first);
392   }
393 
394   MaxPressure = max(MaxPressure, CurPressure);
395 
396   return true;
397 }
398 
399 void GCNDownwardRPTracker::advanceToNext() {
400   LastTrackedMI = &*NextMI++;
401 
402   // Add new registers or mask bits.
403   for (const auto &MO : LastTrackedMI->defs()) {
404     if (!MO.isReg())
405       continue;
406     unsigned Reg = MO.getReg();
407     if (!TargetRegisterInfo::isVirtualRegister(Reg))
408       continue;
409     auto &LiveMask = LiveRegs[Reg];
410     auto PrevMask = LiveMask;
411     LiveMask |= getDefRegMask(MO, *MRI);
412     CurPressure.inc(Reg, PrevMask, LiveMask, *MRI);
413   }
414 
415   MaxPressure = max(MaxPressure, CurPressure);
416 }
417 
418 bool GCNDownwardRPTracker::advance() {
419   // If we have just called reset live set is actual.
420   if ((NextMI == MBBEnd) || (LastTrackedMI && !advanceBeforeNext()))
421     return false;
422   advanceToNext();
423   return true;
424 }
425 
426 bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator End) {
427   while (NextMI != End)
428     if (!advance()) return false;
429   return true;
430 }
431 
432 bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator Begin,
433                                    MachineBasicBlock::const_iterator End,
434                                    const LiveRegSet *LiveRegsCopy) {
435   reset(*Begin, LiveRegsCopy);
436   return advance(End);
437 }
438 
439 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
440 LLVM_DUMP_METHOD
441 static void reportMismatch(const GCNRPTracker::LiveRegSet &LISLR,
442                            const GCNRPTracker::LiveRegSet &TrackedLR,
443                            const TargetRegisterInfo *TRI) {
444   for (auto const &P : TrackedLR) {
445     auto I = LISLR.find(P.first);
446     if (I == LISLR.end()) {
447       dbgs() << "  " << printReg(P.first, TRI)
448              << ":L" << PrintLaneMask(P.second)
449              << " isn't found in LIS reported set\n";
450     }
451     else if (I->second != P.second) {
452       dbgs() << "  " << printReg(P.first, TRI)
453         << " masks doesn't match: LIS reported "
454         << PrintLaneMask(I->second)
455         << ", tracked "
456         << PrintLaneMask(P.second)
457         << '\n';
458     }
459   }
460   for (auto const &P : LISLR) {
461     auto I = TrackedLR.find(P.first);
462     if (I == TrackedLR.end()) {
463       dbgs() << "  " << printReg(P.first, TRI)
464              << ":L" << PrintLaneMask(P.second)
465              << " isn't found in tracked set\n";
466     }
467   }
468 }
469 
470 bool GCNUpwardRPTracker::isValid() const {
471   const auto &SI = LIS.getInstructionIndex(*LastTrackedMI).getBaseIndex();
472   const auto LISLR = llvm::getLiveRegs(SI, LIS, *MRI);
473   const auto &TrackedLR = LiveRegs;
474 
475   if (!isEqual(LISLR, TrackedLR)) {
476     dbgs() << "\nGCNUpwardRPTracker error: Tracked and"
477               " LIS reported livesets mismatch:\n";
478     printLivesAt(SI, LIS, *MRI);
479     reportMismatch(LISLR, TrackedLR, MRI->getTargetRegisterInfo());
480     return false;
481   }
482 
483   auto LISPressure = getRegPressure(*MRI, LISLR);
484   if (LISPressure != CurPressure) {
485     dbgs() << "GCNUpwardRPTracker error: Pressure sets different\nTracked: ";
486     CurPressure.print(dbgs());
487     dbgs() << "LIS rpt: ";
488     LISPressure.print(dbgs());
489     return false;
490   }
491   return true;
492 }
493 
494 void GCNRPTracker::printLiveRegs(raw_ostream &OS, const LiveRegSet& LiveRegs,
495                                  const MachineRegisterInfo &MRI) {
496   const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo();
497   for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
498     unsigned Reg = TargetRegisterInfo::index2VirtReg(I);
499     auto It = LiveRegs.find(Reg);
500     if (It != LiveRegs.end() && It->second.any())
501       OS << ' ' << printVRegOrUnit(Reg, TRI) << ':'
502          << PrintLaneMask(It->second);
503   }
504   OS << '\n';
505 }
506 #endif
507