1 //===- GCNRegPressure.cpp -------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 /// 9 /// \file 10 /// This file implements the GCNRegPressure class. 11 /// 12 //===----------------------------------------------------------------------===// 13 14 #include "GCNRegPressure.h" 15 #include "AMDGPUSubtarget.h" 16 #include "llvm/CodeGen/RegisterPressure.h" 17 18 using namespace llvm; 19 20 #define DEBUG_TYPE "machine-scheduler" 21 22 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) 23 LLVM_DUMP_METHOD 24 void llvm::printLivesAt(SlotIndex SI, 25 const LiveIntervals &LIS, 26 const MachineRegisterInfo &MRI) { 27 dbgs() << "Live regs at " << SI << ": " 28 << *LIS.getInstructionFromIndex(SI); 29 unsigned Num = 0; 30 for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) { 31 const unsigned Reg = Register::index2VirtReg(I); 32 if (!LIS.hasInterval(Reg)) 33 continue; 34 const auto &LI = LIS.getInterval(Reg); 35 if (LI.hasSubRanges()) { 36 bool firstTime = true; 37 for (const auto &S : LI.subranges()) { 38 if (!S.liveAt(SI)) continue; 39 if (firstTime) { 40 dbgs() << " " << printReg(Reg, MRI.getTargetRegisterInfo()) 41 << '\n'; 42 firstTime = false; 43 } 44 dbgs() << " " << S << '\n'; 45 ++Num; 46 } 47 } else if (LI.liveAt(SI)) { 48 dbgs() << " " << LI << '\n'; 49 ++Num; 50 } 51 } 52 if (!Num) dbgs() << " <none>\n"; 53 } 54 #endif 55 56 bool llvm::isEqual(const GCNRPTracker::LiveRegSet &S1, 57 const GCNRPTracker::LiveRegSet &S2) { 58 if (S1.size() != S2.size()) 59 return false; 60 61 for (const auto &P : S1) { 62 auto I = S2.find(P.first); 63 if (I == S2.end() || I->second != P.second) 64 return false; 65 } 66 return true; 67 } 68 69 70 /////////////////////////////////////////////////////////////////////////////// 71 // GCNRegPressure 72 73 unsigned GCNRegPressure::getRegKind(Register Reg, 74 const MachineRegisterInfo &MRI) { 75 assert(Reg.isVirtual()); 76 const auto RC = MRI.getRegClass(Reg); 77 auto STI = static_cast<const SIRegisterInfo*>(MRI.getTargetRegisterInfo()); 78 return STI->isSGPRClass(RC) ? 79 (STI->getRegSizeInBits(*RC) == 32 ? SGPR32 : SGPR_TUPLE) : 80 STI->hasAGPRs(RC) ? 81 (STI->getRegSizeInBits(*RC) == 32 ? AGPR32 : AGPR_TUPLE) : 82 (STI->getRegSizeInBits(*RC) == 32 ? VGPR32 : VGPR_TUPLE); 83 } 84 85 void GCNRegPressure::inc(unsigned Reg, 86 LaneBitmask PrevMask, 87 LaneBitmask NewMask, 88 const MachineRegisterInfo &MRI) { 89 if (SIRegisterInfo::getNumCoveredRegs(NewMask) == 90 SIRegisterInfo::getNumCoveredRegs(PrevMask)) 91 return; 92 93 int Sign = 1; 94 if (NewMask < PrevMask) { 95 std::swap(NewMask, PrevMask); 96 Sign = -1; 97 } 98 99 switch (auto Kind = getRegKind(Reg, MRI)) { 100 case SGPR32: 101 case VGPR32: 102 case AGPR32: 103 Value[Kind] += Sign; 104 break; 105 106 case SGPR_TUPLE: 107 case VGPR_TUPLE: 108 case AGPR_TUPLE: 109 assert(PrevMask < NewMask); 110 111 Value[Kind == SGPR_TUPLE ? SGPR32 : Kind == AGPR_TUPLE ? AGPR32 : VGPR32] += 112 Sign * SIRegisterInfo::getNumCoveredRegs(~PrevMask & NewMask); 113 114 if (PrevMask.none()) { 115 assert(NewMask.any()); 116 Value[Kind] += Sign * MRI.getPressureSets(Reg).getWeight(); 117 } 118 break; 119 120 default: llvm_unreachable("Unknown register kind"); 121 } 122 } 123 124 bool GCNRegPressure::less(const GCNSubtarget &ST, 125 const GCNRegPressure& O, 126 unsigned MaxOccupancy) const { 127 const auto SGPROcc = std::min(MaxOccupancy, 128 ST.getOccupancyWithNumSGPRs(getSGPRNum())); 129 const auto VGPROcc = std::min(MaxOccupancy, 130 ST.getOccupancyWithNumVGPRs(getVGPRNum())); 131 const auto OtherSGPROcc = std::min(MaxOccupancy, 132 ST.getOccupancyWithNumSGPRs(O.getSGPRNum())); 133 const auto OtherVGPROcc = std::min(MaxOccupancy, 134 ST.getOccupancyWithNumVGPRs(O.getVGPRNum())); 135 136 const auto Occ = std::min(SGPROcc, VGPROcc); 137 const auto OtherOcc = std::min(OtherSGPROcc, OtherVGPROcc); 138 if (Occ != OtherOcc) 139 return Occ > OtherOcc; 140 141 bool SGPRImportant = SGPROcc < VGPROcc; 142 const bool OtherSGPRImportant = OtherSGPROcc < OtherVGPROcc; 143 144 // if both pressures disagree on what is more important compare vgprs 145 if (SGPRImportant != OtherSGPRImportant) { 146 SGPRImportant = false; 147 } 148 149 // compare large regs pressure 150 bool SGPRFirst = SGPRImportant; 151 for (int I = 2; I > 0; --I, SGPRFirst = !SGPRFirst) { 152 if (SGPRFirst) { 153 auto SW = getSGPRTuplesWeight(); 154 auto OtherSW = O.getSGPRTuplesWeight(); 155 if (SW != OtherSW) 156 return SW < OtherSW; 157 } else { 158 auto VW = getVGPRTuplesWeight(); 159 auto OtherVW = O.getVGPRTuplesWeight(); 160 if (VW != OtherVW) 161 return VW < OtherVW; 162 } 163 } 164 return SGPRImportant ? (getSGPRNum() < O.getSGPRNum()): 165 (getVGPRNum() < O.getVGPRNum()); 166 } 167 168 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) 169 LLVM_DUMP_METHOD 170 void GCNRegPressure::print(raw_ostream &OS, const GCNSubtarget *ST) const { 171 OS << "VGPRs: " << Value[VGPR32] << ' '; 172 OS << "AGPRs: " << Value[AGPR32]; 173 if (ST) OS << "(O" << ST->getOccupancyWithNumVGPRs(getVGPRNum()) << ')'; 174 OS << ", SGPRs: " << getSGPRNum(); 175 if (ST) OS << "(O" << ST->getOccupancyWithNumSGPRs(getSGPRNum()) << ')'; 176 OS << ", LVGPR WT: " << getVGPRTuplesWeight() 177 << ", LSGPR WT: " << getSGPRTuplesWeight(); 178 if (ST) OS << " -> Occ: " << getOccupancy(*ST); 179 OS << '\n'; 180 } 181 #endif 182 183 static LaneBitmask getDefRegMask(const MachineOperand &MO, 184 const MachineRegisterInfo &MRI) { 185 assert(MO.isDef() && MO.isReg() && MO.getReg().isVirtual()); 186 187 // We don't rely on read-undef flag because in case of tentative schedule 188 // tracking it isn't set correctly yet. This works correctly however since 189 // use mask has been tracked before using LIS. 190 return MO.getSubReg() == 0 ? 191 MRI.getMaxLaneMaskForVReg(MO.getReg()) : 192 MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(MO.getSubReg()); 193 } 194 195 static LaneBitmask getUsedRegMask(const MachineOperand &MO, 196 const MachineRegisterInfo &MRI, 197 const LiveIntervals &LIS) { 198 assert(MO.isUse() && MO.isReg() && MO.getReg().isVirtual()); 199 200 if (auto SubReg = MO.getSubReg()) 201 return MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(SubReg); 202 203 auto MaxMask = MRI.getMaxLaneMaskForVReg(MO.getReg()); 204 if (SIRegisterInfo::getNumCoveredRegs(MaxMask) > 1) // cannot have subregs 205 return MaxMask; 206 207 // For a tentative schedule LIS isn't updated yet but livemask should remain 208 // the same on any schedule. Subreg defs can be reordered but they all must 209 // dominate uses anyway. 210 auto SI = LIS.getInstructionIndex(*MO.getParent()).getBaseIndex(); 211 return getLiveLaneMask(MO.getReg(), SI, LIS, MRI); 212 } 213 214 static SmallVector<RegisterMaskPair, 8> 215 collectVirtualRegUses(const MachineInstr &MI, const LiveIntervals &LIS, 216 const MachineRegisterInfo &MRI) { 217 SmallVector<RegisterMaskPair, 8> Res; 218 for (const auto &MO : MI.operands()) { 219 if (!MO.isReg() || !MO.getReg().isVirtual()) 220 continue; 221 if (!MO.isUse() || !MO.readsReg()) 222 continue; 223 224 auto const UsedMask = getUsedRegMask(MO, MRI, LIS); 225 226 auto Reg = MO.getReg(); 227 auto I = llvm::find_if( 228 Res, [Reg](const RegisterMaskPair &RM) { return RM.RegUnit == Reg; }); 229 if (I != Res.end()) 230 I->LaneMask |= UsedMask; 231 else 232 Res.push_back(RegisterMaskPair(Reg, UsedMask)); 233 } 234 return Res; 235 } 236 237 /////////////////////////////////////////////////////////////////////////////// 238 // GCNRPTracker 239 240 LaneBitmask llvm::getLiveLaneMask(unsigned Reg, 241 SlotIndex SI, 242 const LiveIntervals &LIS, 243 const MachineRegisterInfo &MRI) { 244 LaneBitmask LiveMask; 245 const auto &LI = LIS.getInterval(Reg); 246 if (LI.hasSubRanges()) { 247 for (const auto &S : LI.subranges()) 248 if (S.liveAt(SI)) { 249 LiveMask |= S.LaneMask; 250 assert(LiveMask < MRI.getMaxLaneMaskForVReg(Reg) || 251 LiveMask == MRI.getMaxLaneMaskForVReg(Reg)); 252 } 253 } else if (LI.liveAt(SI)) { 254 LiveMask = MRI.getMaxLaneMaskForVReg(Reg); 255 } 256 return LiveMask; 257 } 258 259 GCNRPTracker::LiveRegSet llvm::getLiveRegs(SlotIndex SI, 260 const LiveIntervals &LIS, 261 const MachineRegisterInfo &MRI) { 262 GCNRPTracker::LiveRegSet LiveRegs; 263 for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) { 264 auto Reg = Register::index2VirtReg(I); 265 if (!LIS.hasInterval(Reg)) 266 continue; 267 auto LiveMask = getLiveLaneMask(Reg, SI, LIS, MRI); 268 if (LiveMask.any()) 269 LiveRegs[Reg] = LiveMask; 270 } 271 return LiveRegs; 272 } 273 274 void GCNRPTracker::reset(const MachineInstr &MI, 275 const LiveRegSet *LiveRegsCopy, 276 bool After) { 277 const MachineFunction &MF = *MI.getMF(); 278 MRI = &MF.getRegInfo(); 279 if (LiveRegsCopy) { 280 if (&LiveRegs != LiveRegsCopy) 281 LiveRegs = *LiveRegsCopy; 282 } else { 283 LiveRegs = After ? getLiveRegsAfter(MI, LIS) 284 : getLiveRegsBefore(MI, LIS); 285 } 286 287 MaxPressure = CurPressure = getRegPressure(*MRI, LiveRegs); 288 } 289 290 void GCNUpwardRPTracker::reset(const MachineInstr &MI, 291 const LiveRegSet *LiveRegsCopy) { 292 GCNRPTracker::reset(MI, LiveRegsCopy, true); 293 } 294 295 void GCNUpwardRPTracker::recede(const MachineInstr &MI) { 296 assert(MRI && "call reset first"); 297 298 LastTrackedMI = &MI; 299 300 if (MI.isDebugInstr()) 301 return; 302 303 auto const RegUses = collectVirtualRegUses(MI, LIS, *MRI); 304 305 // calc pressure at the MI (defs + uses) 306 auto AtMIPressure = CurPressure; 307 for (const auto &U : RegUses) { 308 auto LiveMask = LiveRegs[U.RegUnit]; 309 AtMIPressure.inc(U.RegUnit, LiveMask, LiveMask | U.LaneMask, *MRI); 310 } 311 // update max pressure 312 MaxPressure = max(AtMIPressure, MaxPressure); 313 314 for (const auto &MO : MI.operands()) { 315 if (!MO.isReg() || !MO.isDef() || !MO.getReg().isVirtual() || MO.isDead()) 316 continue; 317 318 auto Reg = MO.getReg(); 319 auto I = LiveRegs.find(Reg); 320 if (I == LiveRegs.end()) 321 continue; 322 auto &LiveMask = I->second; 323 auto PrevMask = LiveMask; 324 LiveMask &= ~getDefRegMask(MO, *MRI); 325 CurPressure.inc(Reg, PrevMask, LiveMask, *MRI); 326 if (LiveMask.none()) 327 LiveRegs.erase(I); 328 } 329 for (const auto &U : RegUses) { 330 auto &LiveMask = LiveRegs[U.RegUnit]; 331 auto PrevMask = LiveMask; 332 LiveMask |= U.LaneMask; 333 CurPressure.inc(U.RegUnit, PrevMask, LiveMask, *MRI); 334 } 335 assert(CurPressure == getRegPressure(*MRI, LiveRegs)); 336 } 337 338 bool GCNDownwardRPTracker::reset(const MachineInstr &MI, 339 const LiveRegSet *LiveRegsCopy) { 340 MRI = &MI.getParent()->getParent()->getRegInfo(); 341 LastTrackedMI = nullptr; 342 MBBEnd = MI.getParent()->end(); 343 NextMI = &MI; 344 NextMI = skipDebugInstructionsForward(NextMI, MBBEnd); 345 if (NextMI == MBBEnd) 346 return false; 347 GCNRPTracker::reset(*NextMI, LiveRegsCopy, false); 348 return true; 349 } 350 351 bool GCNDownwardRPTracker::advanceBeforeNext() { 352 assert(MRI && "call reset first"); 353 354 NextMI = skipDebugInstructionsForward(NextMI, MBBEnd); 355 if (NextMI == MBBEnd) 356 return false; 357 358 SlotIndex SI = LIS.getInstructionIndex(*NextMI).getBaseIndex(); 359 assert(SI.isValid()); 360 361 // Remove dead registers or mask bits. 362 for (auto &It : LiveRegs) { 363 const LiveInterval &LI = LIS.getInterval(It.first); 364 if (LI.hasSubRanges()) { 365 for (const auto &S : LI.subranges()) { 366 if (!S.liveAt(SI)) { 367 auto PrevMask = It.second; 368 It.second &= ~S.LaneMask; 369 CurPressure.inc(It.first, PrevMask, It.second, *MRI); 370 } 371 } 372 } else if (!LI.liveAt(SI)) { 373 auto PrevMask = It.second; 374 It.second = LaneBitmask::getNone(); 375 CurPressure.inc(It.first, PrevMask, It.second, *MRI); 376 } 377 if (It.second.none()) 378 LiveRegs.erase(It.first); 379 } 380 381 MaxPressure = max(MaxPressure, CurPressure); 382 383 return true; 384 } 385 386 void GCNDownwardRPTracker::advanceToNext() { 387 LastTrackedMI = &*NextMI++; 388 389 // Add new registers or mask bits. 390 for (const auto &MO : LastTrackedMI->operands()) { 391 if (!MO.isReg() || !MO.isDef()) 392 continue; 393 Register Reg = MO.getReg(); 394 if (!Reg.isVirtual()) 395 continue; 396 auto &LiveMask = LiveRegs[Reg]; 397 auto PrevMask = LiveMask; 398 LiveMask |= getDefRegMask(MO, *MRI); 399 CurPressure.inc(Reg, PrevMask, LiveMask, *MRI); 400 } 401 402 MaxPressure = max(MaxPressure, CurPressure); 403 } 404 405 bool GCNDownwardRPTracker::advance() { 406 // If we have just called reset live set is actual. 407 if ((NextMI == MBBEnd) || (LastTrackedMI && !advanceBeforeNext())) 408 return false; 409 advanceToNext(); 410 return true; 411 } 412 413 bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator End) { 414 while (NextMI != End) 415 if (!advance()) return false; 416 return true; 417 } 418 419 bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator Begin, 420 MachineBasicBlock::const_iterator End, 421 const LiveRegSet *LiveRegsCopy) { 422 reset(*Begin, LiveRegsCopy); 423 return advance(End); 424 } 425 426 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) 427 LLVM_DUMP_METHOD 428 static void reportMismatch(const GCNRPTracker::LiveRegSet &LISLR, 429 const GCNRPTracker::LiveRegSet &TrackedLR, 430 const TargetRegisterInfo *TRI) { 431 for (auto const &P : TrackedLR) { 432 auto I = LISLR.find(P.first); 433 if (I == LISLR.end()) { 434 dbgs() << " " << printReg(P.first, TRI) 435 << ":L" << PrintLaneMask(P.second) 436 << " isn't found in LIS reported set\n"; 437 } 438 else if (I->second != P.second) { 439 dbgs() << " " << printReg(P.first, TRI) 440 << " masks doesn't match: LIS reported " 441 << PrintLaneMask(I->second) 442 << ", tracked " 443 << PrintLaneMask(P.second) 444 << '\n'; 445 } 446 } 447 for (auto const &P : LISLR) { 448 auto I = TrackedLR.find(P.first); 449 if (I == TrackedLR.end()) { 450 dbgs() << " " << printReg(P.first, TRI) 451 << ":L" << PrintLaneMask(P.second) 452 << " isn't found in tracked set\n"; 453 } 454 } 455 } 456 457 bool GCNUpwardRPTracker::isValid() const { 458 const auto &SI = LIS.getInstructionIndex(*LastTrackedMI).getBaseIndex(); 459 const auto LISLR = llvm::getLiveRegs(SI, LIS, *MRI); 460 const auto &TrackedLR = LiveRegs; 461 462 if (!isEqual(LISLR, TrackedLR)) { 463 dbgs() << "\nGCNUpwardRPTracker error: Tracked and" 464 " LIS reported livesets mismatch:\n"; 465 printLivesAt(SI, LIS, *MRI); 466 reportMismatch(LISLR, TrackedLR, MRI->getTargetRegisterInfo()); 467 return false; 468 } 469 470 auto LISPressure = getRegPressure(*MRI, LISLR); 471 if (LISPressure != CurPressure) { 472 dbgs() << "GCNUpwardRPTracker error: Pressure sets different\nTracked: "; 473 CurPressure.print(dbgs()); 474 dbgs() << "LIS rpt: "; 475 LISPressure.print(dbgs()); 476 return false; 477 } 478 return true; 479 } 480 481 void GCNRPTracker::printLiveRegs(raw_ostream &OS, const LiveRegSet& LiveRegs, 482 const MachineRegisterInfo &MRI) { 483 const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo(); 484 for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) { 485 unsigned Reg = Register::index2VirtReg(I); 486 auto It = LiveRegs.find(Reg); 487 if (It != LiveRegs.end() && It->second.any()) 488 OS << ' ' << printVRegOrUnit(Reg, TRI) << ':' 489 << PrintLaneMask(It->second); 490 } 491 OS << '\n'; 492 } 493 #endif 494