1 //===- GCNRegPressure.cpp -------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 /// This file implements the GCNRegPressure class.
11 ///
12 //===----------------------------------------------------------------------===//
13 
14 #include "GCNRegPressure.h"
15 #include "AMDGPUSubtarget.h"
16 #include "llvm/CodeGen/RegisterPressure.h"
17 
18 using namespace llvm;
19 
20 #define DEBUG_TYPE "machine-scheduler"
21 
22 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
23 LLVM_DUMP_METHOD
24 void llvm::printLivesAt(SlotIndex SI,
25                         const LiveIntervals &LIS,
26                         const MachineRegisterInfo &MRI) {
27   dbgs() << "Live regs at " << SI << ": "
28          << *LIS.getInstructionFromIndex(SI);
29   unsigned Num = 0;
30   for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
31     const unsigned Reg = Register::index2VirtReg(I);
32     if (!LIS.hasInterval(Reg))
33       continue;
34     const auto &LI = LIS.getInterval(Reg);
35     if (LI.hasSubRanges()) {
36       bool firstTime = true;
37       for (const auto &S : LI.subranges()) {
38         if (!S.liveAt(SI)) continue;
39         if (firstTime) {
40           dbgs() << "  " << printReg(Reg, MRI.getTargetRegisterInfo())
41                  << '\n';
42           firstTime = false;
43         }
44         dbgs() << "  " << S << '\n';
45         ++Num;
46       }
47     } else if (LI.liveAt(SI)) {
48       dbgs() << "  " << LI << '\n';
49       ++Num;
50     }
51   }
52   if (!Num) dbgs() << "  <none>\n";
53 }
54 #endif
55 
56 bool llvm::isEqual(const GCNRPTracker::LiveRegSet &S1,
57                    const GCNRPTracker::LiveRegSet &S2) {
58   if (S1.size() != S2.size())
59     return false;
60 
61   for (const auto &P : S1) {
62     auto I = S2.find(P.first);
63     if (I == S2.end() || I->second != P.second)
64       return false;
65   }
66   return true;
67 }
68 
69 
70 ///////////////////////////////////////////////////////////////////////////////
71 // GCNRegPressure
72 
73 unsigned GCNRegPressure::getRegKind(Register Reg,
74                                     const MachineRegisterInfo &MRI) {
75   assert(Reg.isVirtual());
76   const auto RC = MRI.getRegClass(Reg);
77   auto STI = static_cast<const SIRegisterInfo*>(MRI.getTargetRegisterInfo());
78   return STI->isSGPRClass(RC) ?
79     (STI->getRegSizeInBits(*RC) == 32 ? SGPR32 : SGPR_TUPLE) :
80     STI->hasAGPRs(RC) ?
81       (STI->getRegSizeInBits(*RC) == 32 ? AGPR32 : AGPR_TUPLE) :
82       (STI->getRegSizeInBits(*RC) == 32 ? VGPR32 : VGPR_TUPLE);
83 }
84 
85 void GCNRegPressure::inc(unsigned Reg,
86                          LaneBitmask PrevMask,
87                          LaneBitmask NewMask,
88                          const MachineRegisterInfo &MRI) {
89   if (SIRegisterInfo::getNumCoveredRegs(NewMask) ==
90       SIRegisterInfo::getNumCoveredRegs(PrevMask))
91     return;
92 
93   int Sign = 1;
94   if (NewMask < PrevMask) {
95     std::swap(NewMask, PrevMask);
96     Sign = -1;
97   }
98 
99   switch (auto Kind = getRegKind(Reg, MRI)) {
100   case SGPR32:
101   case VGPR32:
102   case AGPR32:
103     Value[Kind] += Sign;
104     break;
105 
106   case SGPR_TUPLE:
107   case VGPR_TUPLE:
108   case AGPR_TUPLE:
109     assert(PrevMask < NewMask);
110 
111     Value[Kind == SGPR_TUPLE ? SGPR32 : Kind == AGPR_TUPLE ? AGPR32 : VGPR32] +=
112       Sign * SIRegisterInfo::getNumCoveredRegs(~PrevMask & NewMask);
113 
114     if (PrevMask.none()) {
115       assert(NewMask.any());
116       Value[Kind] += Sign * MRI.getPressureSets(Reg).getWeight();
117     }
118     break;
119 
120   default: llvm_unreachable("Unknown register kind");
121   }
122 }
123 
124 bool GCNRegPressure::less(const GCNSubtarget &ST,
125                           const GCNRegPressure& O,
126                           unsigned MaxOccupancy) const {
127   const auto SGPROcc = std::min(MaxOccupancy,
128                                 ST.getOccupancyWithNumSGPRs(getSGPRNum()));
129   const auto VGPROcc = std::min(MaxOccupancy,
130                                 ST.getOccupancyWithNumVGPRs(getVGPRNum()));
131   const auto OtherSGPROcc = std::min(MaxOccupancy,
132                                 ST.getOccupancyWithNumSGPRs(O.getSGPRNum()));
133   const auto OtherVGPROcc = std::min(MaxOccupancy,
134                                 ST.getOccupancyWithNumVGPRs(O.getVGPRNum()));
135 
136   const auto Occ = std::min(SGPROcc, VGPROcc);
137   const auto OtherOcc = std::min(OtherSGPROcc, OtherVGPROcc);
138   if (Occ != OtherOcc)
139     return Occ > OtherOcc;
140 
141   bool SGPRImportant = SGPROcc < VGPROcc;
142   const bool OtherSGPRImportant = OtherSGPROcc < OtherVGPROcc;
143 
144   // if both pressures disagree on what is more important compare vgprs
145   if (SGPRImportant != OtherSGPRImportant) {
146     SGPRImportant = false;
147   }
148 
149   // compare large regs pressure
150   bool SGPRFirst = SGPRImportant;
151   for (int I = 2; I > 0; --I, SGPRFirst = !SGPRFirst) {
152     if (SGPRFirst) {
153       auto SW = getSGPRTuplesWeight();
154       auto OtherSW = O.getSGPRTuplesWeight();
155       if (SW != OtherSW)
156         return SW < OtherSW;
157     } else {
158       auto VW = getVGPRTuplesWeight();
159       auto OtherVW = O.getVGPRTuplesWeight();
160       if (VW != OtherVW)
161         return VW < OtherVW;
162     }
163   }
164   return SGPRImportant ? (getSGPRNum() < O.getSGPRNum()):
165                          (getVGPRNum() < O.getVGPRNum());
166 }
167 
168 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
169 LLVM_DUMP_METHOD
170 void GCNRegPressure::print(raw_ostream &OS, const GCNSubtarget *ST) const {
171   OS << "VGPRs: " << Value[VGPR32] << ' ';
172   OS << "AGPRs: " << Value[AGPR32];
173   if (ST) OS << "(O" << ST->getOccupancyWithNumVGPRs(getVGPRNum()) << ')';
174   OS << ", SGPRs: " << getSGPRNum();
175   if (ST) OS << "(O" << ST->getOccupancyWithNumSGPRs(getSGPRNum()) << ')';
176   OS << ", LVGPR WT: " << getVGPRTuplesWeight()
177      << ", LSGPR WT: " << getSGPRTuplesWeight();
178   if (ST) OS << " -> Occ: " << getOccupancy(*ST);
179   OS << '\n';
180 }
181 #endif
182 
183 static LaneBitmask getDefRegMask(const MachineOperand &MO,
184                                  const MachineRegisterInfo &MRI) {
185   assert(MO.isDef() && MO.isReg() && MO.getReg().isVirtual());
186 
187   // We don't rely on read-undef flag because in case of tentative schedule
188   // tracking it isn't set correctly yet. This works correctly however since
189   // use mask has been tracked before using LIS.
190   return MO.getSubReg() == 0 ?
191     MRI.getMaxLaneMaskForVReg(MO.getReg()) :
192     MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(MO.getSubReg());
193 }
194 
195 static LaneBitmask getUsedRegMask(const MachineOperand &MO,
196                                   const MachineRegisterInfo &MRI,
197                                   const LiveIntervals &LIS) {
198   assert(MO.isUse() && MO.isReg() && MO.getReg().isVirtual());
199 
200   if (auto SubReg = MO.getSubReg())
201     return MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(SubReg);
202 
203   auto MaxMask = MRI.getMaxLaneMaskForVReg(MO.getReg());
204   if (SIRegisterInfo::getNumCoveredRegs(MaxMask) > 1) // cannot have subregs
205     return MaxMask;
206 
207   // For a tentative schedule LIS isn't updated yet but livemask should remain
208   // the same on any schedule. Subreg defs can be reordered but they all must
209   // dominate uses anyway.
210   auto SI = LIS.getInstructionIndex(*MO.getParent()).getBaseIndex();
211   return getLiveLaneMask(MO.getReg(), SI, LIS, MRI);
212 }
213 
214 static SmallVector<RegisterMaskPair, 8>
215 collectVirtualRegUses(const MachineInstr &MI, const LiveIntervals &LIS,
216                       const MachineRegisterInfo &MRI) {
217   SmallVector<RegisterMaskPair, 8> Res;
218   for (const auto &MO : MI.operands()) {
219     if (!MO.isReg() || !MO.getReg().isVirtual())
220       continue;
221     if (!MO.isUse() || !MO.readsReg())
222       continue;
223 
224     auto const UsedMask = getUsedRegMask(MO, MRI, LIS);
225 
226     auto Reg = MO.getReg();
227     auto I = llvm::find_if(
228         Res, [Reg](const RegisterMaskPair &RM) { return RM.RegUnit == Reg; });
229     if (I != Res.end())
230       I->LaneMask |= UsedMask;
231     else
232       Res.push_back(RegisterMaskPair(Reg, UsedMask));
233   }
234   return Res;
235 }
236 
237 ///////////////////////////////////////////////////////////////////////////////
238 // GCNRPTracker
239 
240 LaneBitmask llvm::getLiveLaneMask(unsigned Reg,
241                                   SlotIndex SI,
242                                   const LiveIntervals &LIS,
243                                   const MachineRegisterInfo &MRI) {
244   LaneBitmask LiveMask;
245   const auto &LI = LIS.getInterval(Reg);
246   if (LI.hasSubRanges()) {
247     for (const auto &S : LI.subranges())
248       if (S.liveAt(SI)) {
249         LiveMask |= S.LaneMask;
250         assert(LiveMask < MRI.getMaxLaneMaskForVReg(Reg) ||
251                LiveMask == MRI.getMaxLaneMaskForVReg(Reg));
252       }
253   } else if (LI.liveAt(SI)) {
254     LiveMask = MRI.getMaxLaneMaskForVReg(Reg);
255   }
256   return LiveMask;
257 }
258 
259 GCNRPTracker::LiveRegSet llvm::getLiveRegs(SlotIndex SI,
260                                            const LiveIntervals &LIS,
261                                            const MachineRegisterInfo &MRI) {
262   GCNRPTracker::LiveRegSet LiveRegs;
263   for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
264     auto Reg = Register::index2VirtReg(I);
265     if (!LIS.hasInterval(Reg))
266       continue;
267     auto LiveMask = getLiveLaneMask(Reg, SI, LIS, MRI);
268     if (LiveMask.any())
269       LiveRegs[Reg] = LiveMask;
270   }
271   return LiveRegs;
272 }
273 
274 void GCNRPTracker::reset(const MachineInstr &MI,
275                          const LiveRegSet *LiveRegsCopy,
276                          bool After) {
277   const MachineFunction &MF = *MI.getMF();
278   MRI = &MF.getRegInfo();
279   if (LiveRegsCopy) {
280     if (&LiveRegs != LiveRegsCopy)
281       LiveRegs = *LiveRegsCopy;
282   } else {
283     LiveRegs = After ? getLiveRegsAfter(MI, LIS)
284                      : getLiveRegsBefore(MI, LIS);
285   }
286 
287   MaxPressure = CurPressure = getRegPressure(*MRI, LiveRegs);
288 }
289 
290 void GCNUpwardRPTracker::reset(const MachineInstr &MI,
291                                const LiveRegSet *LiveRegsCopy) {
292   GCNRPTracker::reset(MI, LiveRegsCopy, true);
293 }
294 
295 void GCNUpwardRPTracker::recede(const MachineInstr &MI) {
296   assert(MRI && "call reset first");
297 
298   LastTrackedMI = &MI;
299 
300   if (MI.isDebugInstr())
301     return;
302 
303   auto const RegUses = collectVirtualRegUses(MI, LIS, *MRI);
304 
305   // calc pressure at the MI (defs + uses)
306   auto AtMIPressure = CurPressure;
307   for (const auto &U : RegUses) {
308     auto LiveMask = LiveRegs[U.RegUnit];
309     AtMIPressure.inc(U.RegUnit, LiveMask, LiveMask | U.LaneMask, *MRI);
310   }
311   // update max pressure
312   MaxPressure = max(AtMIPressure, MaxPressure);
313 
314   for (const auto &MO : MI.operands()) {
315     if (!MO.isReg() || !MO.isDef() || !MO.getReg().isVirtual() || MO.isDead())
316       continue;
317 
318     auto Reg = MO.getReg();
319     auto I = LiveRegs.find(Reg);
320     if (I == LiveRegs.end())
321       continue;
322     auto &LiveMask = I->second;
323     auto PrevMask = LiveMask;
324     LiveMask &= ~getDefRegMask(MO, *MRI);
325     CurPressure.inc(Reg, PrevMask, LiveMask, *MRI);
326     if (LiveMask.none())
327       LiveRegs.erase(I);
328   }
329   for (const auto &U : RegUses) {
330     auto &LiveMask = LiveRegs[U.RegUnit];
331     auto PrevMask = LiveMask;
332     LiveMask |= U.LaneMask;
333     CurPressure.inc(U.RegUnit, PrevMask, LiveMask, *MRI);
334   }
335   assert(CurPressure == getRegPressure(*MRI, LiveRegs));
336 }
337 
338 bool GCNDownwardRPTracker::reset(const MachineInstr &MI,
339                                  const LiveRegSet *LiveRegsCopy) {
340   MRI = &MI.getParent()->getParent()->getRegInfo();
341   LastTrackedMI = nullptr;
342   MBBEnd = MI.getParent()->end();
343   NextMI = &MI;
344   NextMI = skipDebugInstructionsForward(NextMI, MBBEnd);
345   if (NextMI == MBBEnd)
346     return false;
347   GCNRPTracker::reset(*NextMI, LiveRegsCopy, false);
348   return true;
349 }
350 
351 bool GCNDownwardRPTracker::advanceBeforeNext() {
352   assert(MRI && "call reset first");
353 
354   NextMI = skipDebugInstructionsForward(NextMI, MBBEnd);
355   if (NextMI == MBBEnd)
356     return false;
357 
358   SlotIndex SI = LIS.getInstructionIndex(*NextMI).getBaseIndex();
359   assert(SI.isValid());
360 
361   // Remove dead registers or mask bits.
362   for (auto &It : LiveRegs) {
363     const LiveInterval &LI = LIS.getInterval(It.first);
364     if (LI.hasSubRanges()) {
365       for (const auto &S : LI.subranges()) {
366         if (!S.liveAt(SI)) {
367           auto PrevMask = It.second;
368           It.second &= ~S.LaneMask;
369           CurPressure.inc(It.first, PrevMask, It.second, *MRI);
370         }
371       }
372     } else if (!LI.liveAt(SI)) {
373       auto PrevMask = It.second;
374       It.second = LaneBitmask::getNone();
375       CurPressure.inc(It.first, PrevMask, It.second, *MRI);
376     }
377     if (It.second.none())
378       LiveRegs.erase(It.first);
379   }
380 
381   MaxPressure = max(MaxPressure, CurPressure);
382 
383   return true;
384 }
385 
386 void GCNDownwardRPTracker::advanceToNext() {
387   LastTrackedMI = &*NextMI++;
388 
389   // Add new registers or mask bits.
390   for (const auto &MO : LastTrackedMI->operands()) {
391     if (!MO.isReg() || !MO.isDef())
392       continue;
393     Register Reg = MO.getReg();
394     if (!Reg.isVirtual())
395       continue;
396     auto &LiveMask = LiveRegs[Reg];
397     auto PrevMask = LiveMask;
398     LiveMask |= getDefRegMask(MO, *MRI);
399     CurPressure.inc(Reg, PrevMask, LiveMask, *MRI);
400   }
401 
402   MaxPressure = max(MaxPressure, CurPressure);
403 }
404 
405 bool GCNDownwardRPTracker::advance() {
406   // If we have just called reset live set is actual.
407   if ((NextMI == MBBEnd) || (LastTrackedMI && !advanceBeforeNext()))
408     return false;
409   advanceToNext();
410   return true;
411 }
412 
413 bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator End) {
414   while (NextMI != End)
415     if (!advance()) return false;
416   return true;
417 }
418 
419 bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator Begin,
420                                    MachineBasicBlock::const_iterator End,
421                                    const LiveRegSet *LiveRegsCopy) {
422   reset(*Begin, LiveRegsCopy);
423   return advance(End);
424 }
425 
426 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
427 LLVM_DUMP_METHOD
428 static void reportMismatch(const GCNRPTracker::LiveRegSet &LISLR,
429                            const GCNRPTracker::LiveRegSet &TrackedLR,
430                            const TargetRegisterInfo *TRI) {
431   for (auto const &P : TrackedLR) {
432     auto I = LISLR.find(P.first);
433     if (I == LISLR.end()) {
434       dbgs() << "  " << printReg(P.first, TRI)
435              << ":L" << PrintLaneMask(P.second)
436              << " isn't found in LIS reported set\n";
437     }
438     else if (I->second != P.second) {
439       dbgs() << "  " << printReg(P.first, TRI)
440         << " masks doesn't match: LIS reported "
441         << PrintLaneMask(I->second)
442         << ", tracked "
443         << PrintLaneMask(P.second)
444         << '\n';
445     }
446   }
447   for (auto const &P : LISLR) {
448     auto I = TrackedLR.find(P.first);
449     if (I == TrackedLR.end()) {
450       dbgs() << "  " << printReg(P.first, TRI)
451              << ":L" << PrintLaneMask(P.second)
452              << " isn't found in tracked set\n";
453     }
454   }
455 }
456 
457 bool GCNUpwardRPTracker::isValid() const {
458   const auto &SI = LIS.getInstructionIndex(*LastTrackedMI).getBaseIndex();
459   const auto LISLR = llvm::getLiveRegs(SI, LIS, *MRI);
460   const auto &TrackedLR = LiveRegs;
461 
462   if (!isEqual(LISLR, TrackedLR)) {
463     dbgs() << "\nGCNUpwardRPTracker error: Tracked and"
464               " LIS reported livesets mismatch:\n";
465     printLivesAt(SI, LIS, *MRI);
466     reportMismatch(LISLR, TrackedLR, MRI->getTargetRegisterInfo());
467     return false;
468   }
469 
470   auto LISPressure = getRegPressure(*MRI, LISLR);
471   if (LISPressure != CurPressure) {
472     dbgs() << "GCNUpwardRPTracker error: Pressure sets different\nTracked: ";
473     CurPressure.print(dbgs());
474     dbgs() << "LIS rpt: ";
475     LISPressure.print(dbgs());
476     return false;
477   }
478   return true;
479 }
480 
481 void GCNRPTracker::printLiveRegs(raw_ostream &OS, const LiveRegSet& LiveRegs,
482                                  const MachineRegisterInfo &MRI) {
483   const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo();
484   for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
485     unsigned Reg = Register::index2VirtReg(I);
486     auto It = LiveRegs.find(Reg);
487     if (It != LiveRegs.end() && It->second.any())
488       OS << ' ' << printVRegOrUnit(Reg, TRI) << ':'
489          << PrintLaneMask(It->second);
490   }
491   OS << '\n';
492 }
493 #endif
494