1 //===- GCNRegPressure.cpp -------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 /// This file implements the GCNRegPressure class.
11 ///
12 //===----------------------------------------------------------------------===//
13 
14 #include "GCNRegPressure.h"
15 #include "AMDGPUSubtarget.h"
16 #include "SIRegisterInfo.h"
17 #include "llvm/ADT/SmallVector.h"
18 #include "llvm/CodeGen/LiveInterval.h"
19 #include "llvm/CodeGen/LiveIntervals.h"
20 #include "llvm/CodeGen/MachineInstr.h"
21 #include "llvm/CodeGen/MachineOperand.h"
22 #include "llvm/CodeGen/MachineRegisterInfo.h"
23 #include "llvm/CodeGen/RegisterPressure.h"
24 #include "llvm/CodeGen/SlotIndexes.h"
25 #include "llvm/CodeGen/TargetRegisterInfo.h"
26 #include "llvm/Config/llvm-config.h"
27 #include "llvm/MC/LaneBitmask.h"
28 #include "llvm/Support/Compiler.h"
29 #include "llvm/Support/Debug.h"
30 #include "llvm/Support/ErrorHandling.h"
31 #include "llvm/Support/raw_ostream.h"
32 #include <algorithm>
33 #include <cassert>
34 
35 using namespace llvm;
36 
37 #define DEBUG_TYPE "machine-scheduler"
38 
39 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
40 LLVM_DUMP_METHOD
41 void llvm::printLivesAt(SlotIndex SI,
42                         const LiveIntervals &LIS,
43                         const MachineRegisterInfo &MRI) {
44   dbgs() << "Live regs at " << SI << ": "
45          << *LIS.getInstructionFromIndex(SI);
46   unsigned Num = 0;
47   for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
48     const unsigned Reg = Register::index2VirtReg(I);
49     if (!LIS.hasInterval(Reg))
50       continue;
51     const auto &LI = LIS.getInterval(Reg);
52     if (LI.hasSubRanges()) {
53       bool firstTime = true;
54       for (const auto &S : LI.subranges()) {
55         if (!S.liveAt(SI)) continue;
56         if (firstTime) {
57           dbgs() << "  " << printReg(Reg, MRI.getTargetRegisterInfo())
58                  << '\n';
59           firstTime = false;
60         }
61         dbgs() << "  " << S << '\n';
62         ++Num;
63       }
64     } else if (LI.liveAt(SI)) {
65       dbgs() << "  " << LI << '\n';
66       ++Num;
67     }
68   }
69   if (!Num) dbgs() << "  <none>\n";
70 }
71 #endif
72 
73 bool llvm::isEqual(const GCNRPTracker::LiveRegSet &S1,
74                    const GCNRPTracker::LiveRegSet &S2) {
75   if (S1.size() != S2.size())
76     return false;
77 
78   for (const auto &P : S1) {
79     auto I = S2.find(P.first);
80     if (I == S2.end() || I->second != P.second)
81       return false;
82   }
83   return true;
84 }
85 
86 
87 ///////////////////////////////////////////////////////////////////////////////
88 // GCNRegPressure
89 
90 unsigned GCNRegPressure::getRegKind(Register Reg,
91                                     const MachineRegisterInfo &MRI) {
92   assert(Reg.isVirtual());
93   const auto RC = MRI.getRegClass(Reg);
94   auto STI = static_cast<const SIRegisterInfo*>(MRI.getTargetRegisterInfo());
95   return STI->isSGPRClass(RC) ?
96     (STI->getRegSizeInBits(*RC) == 32 ? SGPR32 : SGPR_TUPLE) :
97     STI->hasAGPRs(RC) ?
98       (STI->getRegSizeInBits(*RC) == 32 ? AGPR32 : AGPR_TUPLE) :
99       (STI->getRegSizeInBits(*RC) == 32 ? VGPR32 : VGPR_TUPLE);
100 }
101 
102 void GCNRegPressure::inc(unsigned Reg,
103                          LaneBitmask PrevMask,
104                          LaneBitmask NewMask,
105                          const MachineRegisterInfo &MRI) {
106   if (SIRegisterInfo::getNumCoveredRegs(NewMask) ==
107       SIRegisterInfo::getNumCoveredRegs(PrevMask))
108     return;
109 
110   int Sign = 1;
111   if (NewMask < PrevMask) {
112     std::swap(NewMask, PrevMask);
113     Sign = -1;
114   }
115 
116   switch (auto Kind = getRegKind(Reg, MRI)) {
117   case SGPR32:
118   case VGPR32:
119   case AGPR32:
120     Value[Kind] += Sign;
121     break;
122 
123   case SGPR_TUPLE:
124   case VGPR_TUPLE:
125   case AGPR_TUPLE:
126     assert(PrevMask < NewMask);
127 
128     Value[Kind == SGPR_TUPLE ? SGPR32 : Kind == AGPR_TUPLE ? AGPR32 : VGPR32] +=
129       Sign * SIRegisterInfo::getNumCoveredRegs(~PrevMask & NewMask);
130 
131     if (PrevMask.none()) {
132       assert(NewMask.any());
133       Value[Kind] += Sign * MRI.getPressureSets(Reg).getWeight();
134     }
135     break;
136 
137   default: llvm_unreachable("Unknown register kind");
138   }
139 }
140 
141 bool GCNRegPressure::less(const GCNSubtarget &ST,
142                           const GCNRegPressure& O,
143                           unsigned MaxOccupancy) const {
144   const auto SGPROcc = std::min(MaxOccupancy,
145                                 ST.getOccupancyWithNumSGPRs(getSGPRNum()));
146   const auto VGPROcc = std::min(MaxOccupancy,
147                                 ST.getOccupancyWithNumVGPRs(getVGPRNum()));
148   const auto OtherSGPROcc = std::min(MaxOccupancy,
149                                 ST.getOccupancyWithNumSGPRs(O.getSGPRNum()));
150   const auto OtherVGPROcc = std::min(MaxOccupancy,
151                                 ST.getOccupancyWithNumVGPRs(O.getVGPRNum()));
152 
153   const auto Occ = std::min(SGPROcc, VGPROcc);
154   const auto OtherOcc = std::min(OtherSGPROcc, OtherVGPROcc);
155   if (Occ != OtherOcc)
156     return Occ > OtherOcc;
157 
158   bool SGPRImportant = SGPROcc < VGPROcc;
159   const bool OtherSGPRImportant = OtherSGPROcc < OtherVGPROcc;
160 
161   // if both pressures disagree on what is more important compare vgprs
162   if (SGPRImportant != OtherSGPRImportant) {
163     SGPRImportant = false;
164   }
165 
166   // compare large regs pressure
167   bool SGPRFirst = SGPRImportant;
168   for (int I = 2; I > 0; --I, SGPRFirst = !SGPRFirst) {
169     if (SGPRFirst) {
170       auto SW = getSGPRTuplesWeight();
171       auto OtherSW = O.getSGPRTuplesWeight();
172       if (SW != OtherSW)
173         return SW < OtherSW;
174     } else {
175       auto VW = getVGPRTuplesWeight();
176       auto OtherVW = O.getVGPRTuplesWeight();
177       if (VW != OtherVW)
178         return VW < OtherVW;
179     }
180   }
181   return SGPRImportant ? (getSGPRNum() < O.getSGPRNum()):
182                          (getVGPRNum() < O.getVGPRNum());
183 }
184 
185 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
186 LLVM_DUMP_METHOD
187 void GCNRegPressure::print(raw_ostream &OS, const GCNSubtarget *ST) const {
188   OS << "VGPRs: " << Value[VGPR32] << ' ';
189   OS << "AGPRs: " << Value[AGPR32];
190   if (ST) OS << "(O" << ST->getOccupancyWithNumVGPRs(getVGPRNum()) << ')';
191   OS << ", SGPRs: " << getSGPRNum();
192   if (ST) OS << "(O" << ST->getOccupancyWithNumSGPRs(getSGPRNum()) << ')';
193   OS << ", LVGPR WT: " << getVGPRTuplesWeight()
194      << ", LSGPR WT: " << getSGPRTuplesWeight();
195   if (ST) OS << " -> Occ: " << getOccupancy(*ST);
196   OS << '\n';
197 }
198 #endif
199 
200 static LaneBitmask getDefRegMask(const MachineOperand &MO,
201                                  const MachineRegisterInfo &MRI) {
202   assert(MO.isDef() && MO.isReg() && MO.getReg().isVirtual());
203 
204   // We don't rely on read-undef flag because in case of tentative schedule
205   // tracking it isn't set correctly yet. This works correctly however since
206   // use mask has been tracked before using LIS.
207   return MO.getSubReg() == 0 ?
208     MRI.getMaxLaneMaskForVReg(MO.getReg()) :
209     MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(MO.getSubReg());
210 }
211 
212 static LaneBitmask getUsedRegMask(const MachineOperand &MO,
213                                   const MachineRegisterInfo &MRI,
214                                   const LiveIntervals &LIS) {
215   assert(MO.isUse() && MO.isReg() && MO.getReg().isVirtual());
216 
217   if (auto SubReg = MO.getSubReg())
218     return MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(SubReg);
219 
220   auto MaxMask = MRI.getMaxLaneMaskForVReg(MO.getReg());
221   if (SIRegisterInfo::getNumCoveredRegs(MaxMask) > 1) // cannot have subregs
222     return MaxMask;
223 
224   // For a tentative schedule LIS isn't updated yet but livemask should remain
225   // the same on any schedule. Subreg defs can be reordered but they all must
226   // dominate uses anyway.
227   auto SI = LIS.getInstructionIndex(*MO.getParent()).getBaseIndex();
228   return getLiveLaneMask(MO.getReg(), SI, LIS, MRI);
229 }
230 
231 static SmallVector<RegisterMaskPair, 8>
232 collectVirtualRegUses(const MachineInstr &MI, const LiveIntervals &LIS,
233                       const MachineRegisterInfo &MRI) {
234   SmallVector<RegisterMaskPair, 8> Res;
235   for (const auto &MO : MI.operands()) {
236     if (!MO.isReg() || !MO.getReg().isVirtual())
237       continue;
238     if (!MO.isUse() || !MO.readsReg())
239       continue;
240 
241     auto const UsedMask = getUsedRegMask(MO, MRI, LIS);
242 
243     auto Reg = MO.getReg();
244     auto I = std::find_if(Res.begin(), Res.end(), [Reg](const RegisterMaskPair &RM) {
245       return RM.RegUnit == Reg;
246     });
247     if (I != Res.end())
248       I->LaneMask |= UsedMask;
249     else
250       Res.push_back(RegisterMaskPair(Reg, UsedMask));
251   }
252   return Res;
253 }
254 
255 ///////////////////////////////////////////////////////////////////////////////
256 // GCNRPTracker
257 
258 LaneBitmask llvm::getLiveLaneMask(unsigned Reg,
259                                   SlotIndex SI,
260                                   const LiveIntervals &LIS,
261                                   const MachineRegisterInfo &MRI) {
262   LaneBitmask LiveMask;
263   const auto &LI = LIS.getInterval(Reg);
264   if (LI.hasSubRanges()) {
265     for (const auto &S : LI.subranges())
266       if (S.liveAt(SI)) {
267         LiveMask |= S.LaneMask;
268         assert(LiveMask < MRI.getMaxLaneMaskForVReg(Reg) ||
269                LiveMask == MRI.getMaxLaneMaskForVReg(Reg));
270       }
271   } else if (LI.liveAt(SI)) {
272     LiveMask = MRI.getMaxLaneMaskForVReg(Reg);
273   }
274   return LiveMask;
275 }
276 
277 GCNRPTracker::LiveRegSet llvm::getLiveRegs(SlotIndex SI,
278                                            const LiveIntervals &LIS,
279                                            const MachineRegisterInfo &MRI) {
280   GCNRPTracker::LiveRegSet LiveRegs;
281   for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
282     auto Reg = Register::index2VirtReg(I);
283     if (!LIS.hasInterval(Reg))
284       continue;
285     auto LiveMask = getLiveLaneMask(Reg, SI, LIS, MRI);
286     if (LiveMask.any())
287       LiveRegs[Reg] = LiveMask;
288   }
289   return LiveRegs;
290 }
291 
292 void GCNRPTracker::reset(const MachineInstr &MI,
293                          const LiveRegSet *LiveRegsCopy,
294                          bool After) {
295   const MachineFunction &MF = *MI.getMF();
296   MRI = &MF.getRegInfo();
297   if (LiveRegsCopy) {
298     if (&LiveRegs != LiveRegsCopy)
299       LiveRegs = *LiveRegsCopy;
300   } else {
301     LiveRegs = After ? getLiveRegsAfter(MI, LIS)
302                      : getLiveRegsBefore(MI, LIS);
303   }
304 
305   MaxPressure = CurPressure = getRegPressure(*MRI, LiveRegs);
306 }
307 
308 void GCNUpwardRPTracker::reset(const MachineInstr &MI,
309                                const LiveRegSet *LiveRegsCopy) {
310   GCNRPTracker::reset(MI, LiveRegsCopy, true);
311 }
312 
313 void GCNUpwardRPTracker::recede(const MachineInstr &MI) {
314   assert(MRI && "call reset first");
315 
316   LastTrackedMI = &MI;
317 
318   if (MI.isDebugInstr())
319     return;
320 
321   auto const RegUses = collectVirtualRegUses(MI, LIS, *MRI);
322 
323   // calc pressure at the MI (defs + uses)
324   auto AtMIPressure = CurPressure;
325   for (const auto &U : RegUses) {
326     auto LiveMask = LiveRegs[U.RegUnit];
327     AtMIPressure.inc(U.RegUnit, LiveMask, LiveMask | U.LaneMask, *MRI);
328   }
329   // update max pressure
330   MaxPressure = max(AtMIPressure, MaxPressure);
331 
332   for (const auto &MO : MI.operands()) {
333     if (!MO.isReg() || !MO.isDef() || !MO.getReg().isVirtual() || MO.isDead())
334       continue;
335 
336     auto Reg = MO.getReg();
337     auto I = LiveRegs.find(Reg);
338     if (I == LiveRegs.end())
339       continue;
340     auto &LiveMask = I->second;
341     auto PrevMask = LiveMask;
342     LiveMask &= ~getDefRegMask(MO, *MRI);
343     CurPressure.inc(Reg, PrevMask, LiveMask, *MRI);
344     if (LiveMask.none())
345       LiveRegs.erase(I);
346   }
347   for (const auto &U : RegUses) {
348     auto &LiveMask = LiveRegs[U.RegUnit];
349     auto PrevMask = LiveMask;
350     LiveMask |= U.LaneMask;
351     CurPressure.inc(U.RegUnit, PrevMask, LiveMask, *MRI);
352   }
353   assert(CurPressure == getRegPressure(*MRI, LiveRegs));
354 }
355 
356 bool GCNDownwardRPTracker::reset(const MachineInstr &MI,
357                                  const LiveRegSet *LiveRegsCopy) {
358   MRI = &MI.getParent()->getParent()->getRegInfo();
359   LastTrackedMI = nullptr;
360   MBBEnd = MI.getParent()->end();
361   NextMI = &MI;
362   NextMI = skipDebugInstructionsForward(NextMI, MBBEnd);
363   if (NextMI == MBBEnd)
364     return false;
365   GCNRPTracker::reset(*NextMI, LiveRegsCopy, false);
366   return true;
367 }
368 
369 bool GCNDownwardRPTracker::advanceBeforeNext() {
370   assert(MRI && "call reset first");
371 
372   NextMI = skipDebugInstructionsForward(NextMI, MBBEnd);
373   if (NextMI == MBBEnd)
374     return false;
375 
376   SlotIndex SI = LIS.getInstructionIndex(*NextMI).getBaseIndex();
377   assert(SI.isValid());
378 
379   // Remove dead registers or mask bits.
380   for (auto &It : LiveRegs) {
381     const LiveInterval &LI = LIS.getInterval(It.first);
382     if (LI.hasSubRanges()) {
383       for (const auto &S : LI.subranges()) {
384         if (!S.liveAt(SI)) {
385           auto PrevMask = It.second;
386           It.second &= ~S.LaneMask;
387           CurPressure.inc(It.first, PrevMask, It.second, *MRI);
388         }
389       }
390     } else if (!LI.liveAt(SI)) {
391       auto PrevMask = It.second;
392       It.second = LaneBitmask::getNone();
393       CurPressure.inc(It.first, PrevMask, It.second, *MRI);
394     }
395     if (It.second.none())
396       LiveRegs.erase(It.first);
397   }
398 
399   MaxPressure = max(MaxPressure, CurPressure);
400 
401   return true;
402 }
403 
404 void GCNDownwardRPTracker::advanceToNext() {
405   LastTrackedMI = &*NextMI++;
406 
407   // Add new registers or mask bits.
408   for (const auto &MO : LastTrackedMI->operands()) {
409     if (!MO.isReg() || !MO.isDef())
410       continue;
411     Register Reg = MO.getReg();
412     if (!Reg.isVirtual())
413       continue;
414     auto &LiveMask = LiveRegs[Reg];
415     auto PrevMask = LiveMask;
416     LiveMask |= getDefRegMask(MO, *MRI);
417     CurPressure.inc(Reg, PrevMask, LiveMask, *MRI);
418   }
419 
420   MaxPressure = max(MaxPressure, CurPressure);
421 }
422 
423 bool GCNDownwardRPTracker::advance() {
424   // If we have just called reset live set is actual.
425   if ((NextMI == MBBEnd) || (LastTrackedMI && !advanceBeforeNext()))
426     return false;
427   advanceToNext();
428   return true;
429 }
430 
431 bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator End) {
432   while (NextMI != End)
433     if (!advance()) return false;
434   return true;
435 }
436 
437 bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator Begin,
438                                    MachineBasicBlock::const_iterator End,
439                                    const LiveRegSet *LiveRegsCopy) {
440   reset(*Begin, LiveRegsCopy);
441   return advance(End);
442 }
443 
444 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
445 LLVM_DUMP_METHOD
446 static void reportMismatch(const GCNRPTracker::LiveRegSet &LISLR,
447                            const GCNRPTracker::LiveRegSet &TrackedLR,
448                            const TargetRegisterInfo *TRI) {
449   for (auto const &P : TrackedLR) {
450     auto I = LISLR.find(P.first);
451     if (I == LISLR.end()) {
452       dbgs() << "  " << printReg(P.first, TRI)
453              << ":L" << PrintLaneMask(P.second)
454              << " isn't found in LIS reported set\n";
455     }
456     else if (I->second != P.second) {
457       dbgs() << "  " << printReg(P.first, TRI)
458         << " masks doesn't match: LIS reported "
459         << PrintLaneMask(I->second)
460         << ", tracked "
461         << PrintLaneMask(P.second)
462         << '\n';
463     }
464   }
465   for (auto const &P : LISLR) {
466     auto I = TrackedLR.find(P.first);
467     if (I == TrackedLR.end()) {
468       dbgs() << "  " << printReg(P.first, TRI)
469              << ":L" << PrintLaneMask(P.second)
470              << " isn't found in tracked set\n";
471     }
472   }
473 }
474 
475 bool GCNUpwardRPTracker::isValid() const {
476   const auto &SI = LIS.getInstructionIndex(*LastTrackedMI).getBaseIndex();
477   const auto LISLR = llvm::getLiveRegs(SI, LIS, *MRI);
478   const auto &TrackedLR = LiveRegs;
479 
480   if (!isEqual(LISLR, TrackedLR)) {
481     dbgs() << "\nGCNUpwardRPTracker error: Tracked and"
482               " LIS reported livesets mismatch:\n";
483     printLivesAt(SI, LIS, *MRI);
484     reportMismatch(LISLR, TrackedLR, MRI->getTargetRegisterInfo());
485     return false;
486   }
487 
488   auto LISPressure = getRegPressure(*MRI, LISLR);
489   if (LISPressure != CurPressure) {
490     dbgs() << "GCNUpwardRPTracker error: Pressure sets different\nTracked: ";
491     CurPressure.print(dbgs());
492     dbgs() << "LIS rpt: ";
493     LISPressure.print(dbgs());
494     return false;
495   }
496   return true;
497 }
498 
499 void GCNRPTracker::printLiveRegs(raw_ostream &OS, const LiveRegSet& LiveRegs,
500                                  const MachineRegisterInfo &MRI) {
501   const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo();
502   for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
503     unsigned Reg = Register::index2VirtReg(I);
504     auto It = LiveRegs.find(Reg);
505     if (It != LiveRegs.end() && It->second.any())
506       OS << ' ' << printVRegOrUnit(Reg, TRI) << ':'
507          << PrintLaneMask(It->second);
508   }
509   OS << '\n';
510 }
511 #endif
512