1 //===- GCNRegPressure.cpp -------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "GCNRegPressure.h" 10 #include "AMDGPUSubtarget.h" 11 #include "SIRegisterInfo.h" 12 #include "llvm/ADT/SmallVector.h" 13 #include "llvm/CodeGen/LiveInterval.h" 14 #include "llvm/CodeGen/LiveIntervals.h" 15 #include "llvm/CodeGen/MachineInstr.h" 16 #include "llvm/CodeGen/MachineOperand.h" 17 #include "llvm/CodeGen/MachineRegisterInfo.h" 18 #include "llvm/CodeGen/RegisterPressure.h" 19 #include "llvm/CodeGen/SlotIndexes.h" 20 #include "llvm/CodeGen/TargetRegisterInfo.h" 21 #include "llvm/Config/llvm-config.h" 22 #include "llvm/MC/LaneBitmask.h" 23 #include "llvm/Support/Compiler.h" 24 #include "llvm/Support/Debug.h" 25 #include "llvm/Support/ErrorHandling.h" 26 #include "llvm/Support/raw_ostream.h" 27 #include <algorithm> 28 #include <cassert> 29 30 using namespace llvm; 31 32 #define DEBUG_TYPE "machine-scheduler" 33 34 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) 35 LLVM_DUMP_METHOD 36 void llvm::printLivesAt(SlotIndex SI, 37 const LiveIntervals &LIS, 38 const MachineRegisterInfo &MRI) { 39 dbgs() << "Live regs at " << SI << ": " 40 << *LIS.getInstructionFromIndex(SI); 41 unsigned Num = 0; 42 for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) { 43 const unsigned Reg = TargetRegisterInfo::index2VirtReg(I); 44 if (!LIS.hasInterval(Reg)) 45 continue; 46 const auto &LI = LIS.getInterval(Reg); 47 if (LI.hasSubRanges()) { 48 bool firstTime = true; 49 for (const auto &S : LI.subranges()) { 50 if (!S.liveAt(SI)) continue; 51 if (firstTime) { 52 dbgs() << " " << printReg(Reg, MRI.getTargetRegisterInfo()) 53 << '\n'; 54 firstTime = false; 55 } 56 dbgs() << " " << S << '\n'; 57 ++Num; 58 } 59 } else if (LI.liveAt(SI)) { 60 dbgs() << " " << LI << '\n'; 61 ++Num; 62 } 63 } 64 if (!Num) dbgs() << " <none>\n"; 65 } 66 #endif 67 68 bool llvm::isEqual(const GCNRPTracker::LiveRegSet &S1, 69 const GCNRPTracker::LiveRegSet &S2) { 70 if (S1.size() != S2.size()) 71 return false; 72 73 for (const auto &P : S1) { 74 auto I = S2.find(P.first); 75 if (I == S2.end() || I->second != P.second) 76 return false; 77 } 78 return true; 79 } 80 81 82 /////////////////////////////////////////////////////////////////////////////// 83 // GCNRegPressure 84 85 unsigned GCNRegPressure::getRegKind(unsigned Reg, 86 const MachineRegisterInfo &MRI) { 87 assert(TargetRegisterInfo::isVirtualRegister(Reg)); 88 const auto RC = MRI.getRegClass(Reg); 89 auto STI = static_cast<const SIRegisterInfo*>(MRI.getTargetRegisterInfo()); 90 return STI->isSGPRClass(RC) ? 91 (STI->getRegSizeInBits(*RC) == 32 ? SGPR32 : SGPR_TUPLE) : 92 (STI->getRegSizeInBits(*RC) == 32 ? VGPR32 : VGPR_TUPLE); 93 } 94 95 void GCNRegPressure::inc(unsigned Reg, 96 LaneBitmask PrevMask, 97 LaneBitmask NewMask, 98 const MachineRegisterInfo &MRI) { 99 if (NewMask == PrevMask) 100 return; 101 102 int Sign = 1; 103 if (NewMask < PrevMask) { 104 std::swap(NewMask, PrevMask); 105 Sign = -1; 106 } 107 #ifndef NDEBUG 108 const auto MaxMask = MRI.getMaxLaneMaskForVReg(Reg); 109 #endif 110 switch (auto Kind = getRegKind(Reg, MRI)) { 111 case SGPR32: 112 case VGPR32: 113 assert(PrevMask.none() && NewMask == MaxMask); 114 Value[Kind] += Sign; 115 break; 116 117 case SGPR_TUPLE: 118 case VGPR_TUPLE: 119 assert(NewMask < MaxMask || NewMask == MaxMask); 120 assert(PrevMask < NewMask); 121 122 Value[Kind == SGPR_TUPLE ? SGPR32 : VGPR32] += 123 Sign * (~PrevMask & NewMask).getNumLanes(); 124 125 if (PrevMask.none()) { 126 assert(NewMask.any()); 127 Value[Kind] += Sign * MRI.getPressureSets(Reg).getWeight(); 128 } 129 break; 130 131 default: llvm_unreachable("Unknown register kind"); 132 } 133 } 134 135 bool GCNRegPressure::less(const GCNSubtarget &ST, 136 const GCNRegPressure& O, 137 unsigned MaxOccupancy) const { 138 const auto SGPROcc = std::min(MaxOccupancy, 139 ST.getOccupancyWithNumSGPRs(getSGPRNum())); 140 const auto VGPROcc = std::min(MaxOccupancy, 141 ST.getOccupancyWithNumVGPRs(getVGPRNum())); 142 const auto OtherSGPROcc = std::min(MaxOccupancy, 143 ST.getOccupancyWithNumSGPRs(O.getSGPRNum())); 144 const auto OtherVGPROcc = std::min(MaxOccupancy, 145 ST.getOccupancyWithNumVGPRs(O.getVGPRNum())); 146 147 const auto Occ = std::min(SGPROcc, VGPROcc); 148 const auto OtherOcc = std::min(OtherSGPROcc, OtherVGPROcc); 149 if (Occ != OtherOcc) 150 return Occ > OtherOcc; 151 152 bool SGPRImportant = SGPROcc < VGPROcc; 153 const bool OtherSGPRImportant = OtherSGPROcc < OtherVGPROcc; 154 155 // if both pressures disagree on what is more important compare vgprs 156 if (SGPRImportant != OtherSGPRImportant) { 157 SGPRImportant = false; 158 } 159 160 // compare large regs pressure 161 bool SGPRFirst = SGPRImportant; 162 for (int I = 2; I > 0; --I, SGPRFirst = !SGPRFirst) { 163 if (SGPRFirst) { 164 auto SW = getSGPRTuplesWeight(); 165 auto OtherSW = O.getSGPRTuplesWeight(); 166 if (SW != OtherSW) 167 return SW < OtherSW; 168 } else { 169 auto VW = getVGPRTuplesWeight(); 170 auto OtherVW = O.getVGPRTuplesWeight(); 171 if (VW != OtherVW) 172 return VW < OtherVW; 173 } 174 } 175 return SGPRImportant ? (getSGPRNum() < O.getSGPRNum()): 176 (getVGPRNum() < O.getVGPRNum()); 177 } 178 179 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) 180 LLVM_DUMP_METHOD 181 void GCNRegPressure::print(raw_ostream &OS, const GCNSubtarget *ST) const { 182 OS << "VGPRs: " << getVGPRNum(); 183 if (ST) OS << "(O" << ST->getOccupancyWithNumVGPRs(getVGPRNum()) << ')'; 184 OS << ", SGPRs: " << getSGPRNum(); 185 if (ST) OS << "(O" << ST->getOccupancyWithNumSGPRs(getSGPRNum()) << ')'; 186 OS << ", LVGPR WT: " << getVGPRTuplesWeight() 187 << ", LSGPR WT: " << getSGPRTuplesWeight(); 188 if (ST) OS << " -> Occ: " << getOccupancy(*ST); 189 OS << '\n'; 190 } 191 #endif 192 193 static LaneBitmask getDefRegMask(const MachineOperand &MO, 194 const MachineRegisterInfo &MRI) { 195 assert(MO.isDef() && MO.isReg() && 196 TargetRegisterInfo::isVirtualRegister(MO.getReg())); 197 198 // We don't rely on read-undef flag because in case of tentative schedule 199 // tracking it isn't set correctly yet. This works correctly however since 200 // use mask has been tracked before using LIS. 201 return MO.getSubReg() == 0 ? 202 MRI.getMaxLaneMaskForVReg(MO.getReg()) : 203 MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(MO.getSubReg()); 204 } 205 206 static LaneBitmask getUsedRegMask(const MachineOperand &MO, 207 const MachineRegisterInfo &MRI, 208 const LiveIntervals &LIS) { 209 assert(MO.isUse() && MO.isReg() && 210 TargetRegisterInfo::isVirtualRegister(MO.getReg())); 211 212 if (auto SubReg = MO.getSubReg()) 213 return MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(SubReg); 214 215 auto MaxMask = MRI.getMaxLaneMaskForVReg(MO.getReg()); 216 if (MaxMask == LaneBitmask::getLane(0)) // cannot have subregs 217 return MaxMask; 218 219 // For a tentative schedule LIS isn't updated yet but livemask should remain 220 // the same on any schedule. Subreg defs can be reordered but they all must 221 // dominate uses anyway. 222 auto SI = LIS.getInstructionIndex(*MO.getParent()).getBaseIndex(); 223 return getLiveLaneMask(MO.getReg(), SI, LIS, MRI); 224 } 225 226 static SmallVector<RegisterMaskPair, 8> 227 collectVirtualRegUses(const MachineInstr &MI, const LiveIntervals &LIS, 228 const MachineRegisterInfo &MRI) { 229 SmallVector<RegisterMaskPair, 8> Res; 230 for (const auto &MO : MI.operands()) { 231 if (!MO.isReg() || !TargetRegisterInfo::isVirtualRegister(MO.getReg())) 232 continue; 233 if (!MO.isUse() || !MO.readsReg()) 234 continue; 235 236 auto const UsedMask = getUsedRegMask(MO, MRI, LIS); 237 238 auto Reg = MO.getReg(); 239 auto I = std::find_if(Res.begin(), Res.end(), [Reg](const RegisterMaskPair &RM) { 240 return RM.RegUnit == Reg; 241 }); 242 if (I != Res.end()) 243 I->LaneMask |= UsedMask; 244 else 245 Res.push_back(RegisterMaskPair(Reg, UsedMask)); 246 } 247 return Res; 248 } 249 250 /////////////////////////////////////////////////////////////////////////////// 251 // GCNRPTracker 252 253 LaneBitmask llvm::getLiveLaneMask(unsigned Reg, 254 SlotIndex SI, 255 const LiveIntervals &LIS, 256 const MachineRegisterInfo &MRI) { 257 LaneBitmask LiveMask; 258 const auto &LI = LIS.getInterval(Reg); 259 if (LI.hasSubRanges()) { 260 for (const auto &S : LI.subranges()) 261 if (S.liveAt(SI)) { 262 LiveMask |= S.LaneMask; 263 assert(LiveMask < MRI.getMaxLaneMaskForVReg(Reg) || 264 LiveMask == MRI.getMaxLaneMaskForVReg(Reg)); 265 } 266 } else if (LI.liveAt(SI)) { 267 LiveMask = MRI.getMaxLaneMaskForVReg(Reg); 268 } 269 return LiveMask; 270 } 271 272 GCNRPTracker::LiveRegSet llvm::getLiveRegs(SlotIndex SI, 273 const LiveIntervals &LIS, 274 const MachineRegisterInfo &MRI) { 275 GCNRPTracker::LiveRegSet LiveRegs; 276 for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) { 277 auto Reg = TargetRegisterInfo::index2VirtReg(I); 278 if (!LIS.hasInterval(Reg)) 279 continue; 280 auto LiveMask = getLiveLaneMask(Reg, SI, LIS, MRI); 281 if (LiveMask.any()) 282 LiveRegs[Reg] = LiveMask; 283 } 284 return LiveRegs; 285 } 286 287 void GCNRPTracker::reset(const MachineInstr &MI, 288 const LiveRegSet *LiveRegsCopy, 289 bool After) { 290 const MachineFunction &MF = *MI.getMF(); 291 MRI = &MF.getRegInfo(); 292 if (LiveRegsCopy) { 293 if (&LiveRegs != LiveRegsCopy) 294 LiveRegs = *LiveRegsCopy; 295 } else { 296 LiveRegs = After ? getLiveRegsAfter(MI, LIS) 297 : getLiveRegsBefore(MI, LIS); 298 } 299 300 MaxPressure = CurPressure = getRegPressure(*MRI, LiveRegs); 301 } 302 303 void GCNUpwardRPTracker::reset(const MachineInstr &MI, 304 const LiveRegSet *LiveRegsCopy) { 305 GCNRPTracker::reset(MI, LiveRegsCopy, true); 306 } 307 308 void GCNUpwardRPTracker::recede(const MachineInstr &MI) { 309 assert(MRI && "call reset first"); 310 311 LastTrackedMI = &MI; 312 313 if (MI.isDebugInstr()) 314 return; 315 316 auto const RegUses = collectVirtualRegUses(MI, LIS, *MRI); 317 318 // calc pressure at the MI (defs + uses) 319 auto AtMIPressure = CurPressure; 320 for (const auto &U : RegUses) { 321 auto LiveMask = LiveRegs[U.RegUnit]; 322 AtMIPressure.inc(U.RegUnit, LiveMask, LiveMask | U.LaneMask, *MRI); 323 } 324 // update max pressure 325 MaxPressure = max(AtMIPressure, MaxPressure); 326 327 for (const auto &MO : MI.defs()) { 328 if (!MO.isReg() || !TargetRegisterInfo::isVirtualRegister(MO.getReg()) || 329 MO.isDead()) 330 continue; 331 332 auto Reg = MO.getReg(); 333 auto I = LiveRegs.find(Reg); 334 if (I == LiveRegs.end()) 335 continue; 336 auto &LiveMask = I->second; 337 auto PrevMask = LiveMask; 338 LiveMask &= ~getDefRegMask(MO, *MRI); 339 CurPressure.inc(Reg, PrevMask, LiveMask, *MRI); 340 if (LiveMask.none()) 341 LiveRegs.erase(I); 342 } 343 for (const auto &U : RegUses) { 344 auto &LiveMask = LiveRegs[U.RegUnit]; 345 auto PrevMask = LiveMask; 346 LiveMask |= U.LaneMask; 347 CurPressure.inc(U.RegUnit, PrevMask, LiveMask, *MRI); 348 } 349 assert(CurPressure == getRegPressure(*MRI, LiveRegs)); 350 } 351 352 bool GCNDownwardRPTracker::reset(const MachineInstr &MI, 353 const LiveRegSet *LiveRegsCopy) { 354 MRI = &MI.getParent()->getParent()->getRegInfo(); 355 LastTrackedMI = nullptr; 356 MBBEnd = MI.getParent()->end(); 357 NextMI = &MI; 358 NextMI = skipDebugInstructionsForward(NextMI, MBBEnd); 359 if (NextMI == MBBEnd) 360 return false; 361 GCNRPTracker::reset(*NextMI, LiveRegsCopy, false); 362 return true; 363 } 364 365 bool GCNDownwardRPTracker::advanceBeforeNext() { 366 assert(MRI && "call reset first"); 367 368 NextMI = skipDebugInstructionsForward(NextMI, MBBEnd); 369 if (NextMI == MBBEnd) 370 return false; 371 372 SlotIndex SI = LIS.getInstructionIndex(*NextMI).getBaseIndex(); 373 assert(SI.isValid()); 374 375 // Remove dead registers or mask bits. 376 for (auto &It : LiveRegs) { 377 const LiveInterval &LI = LIS.getInterval(It.first); 378 if (LI.hasSubRanges()) { 379 for (const auto &S : LI.subranges()) { 380 if (!S.liveAt(SI)) { 381 auto PrevMask = It.second; 382 It.second &= ~S.LaneMask; 383 CurPressure.inc(It.first, PrevMask, It.second, *MRI); 384 } 385 } 386 } else if (!LI.liveAt(SI)) { 387 auto PrevMask = It.second; 388 It.second = LaneBitmask::getNone(); 389 CurPressure.inc(It.first, PrevMask, It.second, *MRI); 390 } 391 if (It.second.none()) 392 LiveRegs.erase(It.first); 393 } 394 395 MaxPressure = max(MaxPressure, CurPressure); 396 397 return true; 398 } 399 400 void GCNDownwardRPTracker::advanceToNext() { 401 LastTrackedMI = &*NextMI++; 402 403 // Add new registers or mask bits. 404 for (const auto &MO : LastTrackedMI->defs()) { 405 if (!MO.isReg()) 406 continue; 407 unsigned Reg = MO.getReg(); 408 if (!TargetRegisterInfo::isVirtualRegister(Reg)) 409 continue; 410 auto &LiveMask = LiveRegs[Reg]; 411 auto PrevMask = LiveMask; 412 LiveMask |= getDefRegMask(MO, *MRI); 413 CurPressure.inc(Reg, PrevMask, LiveMask, *MRI); 414 } 415 416 MaxPressure = max(MaxPressure, CurPressure); 417 } 418 419 bool GCNDownwardRPTracker::advance() { 420 // If we have just called reset live set is actual. 421 if ((NextMI == MBBEnd) || (LastTrackedMI && !advanceBeforeNext())) 422 return false; 423 advanceToNext(); 424 return true; 425 } 426 427 bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator End) { 428 while (NextMI != End) 429 if (!advance()) return false; 430 return true; 431 } 432 433 bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator Begin, 434 MachineBasicBlock::const_iterator End, 435 const LiveRegSet *LiveRegsCopy) { 436 reset(*Begin, LiveRegsCopy); 437 return advance(End); 438 } 439 440 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) 441 LLVM_DUMP_METHOD 442 static void reportMismatch(const GCNRPTracker::LiveRegSet &LISLR, 443 const GCNRPTracker::LiveRegSet &TrackedLR, 444 const TargetRegisterInfo *TRI) { 445 for (auto const &P : TrackedLR) { 446 auto I = LISLR.find(P.first); 447 if (I == LISLR.end()) { 448 dbgs() << " " << printReg(P.first, TRI) 449 << ":L" << PrintLaneMask(P.second) 450 << " isn't found in LIS reported set\n"; 451 } 452 else if (I->second != P.second) { 453 dbgs() << " " << printReg(P.first, TRI) 454 << " masks doesn't match: LIS reported " 455 << PrintLaneMask(I->second) 456 << ", tracked " 457 << PrintLaneMask(P.second) 458 << '\n'; 459 } 460 } 461 for (auto const &P : LISLR) { 462 auto I = TrackedLR.find(P.first); 463 if (I == TrackedLR.end()) { 464 dbgs() << " " << printReg(P.first, TRI) 465 << ":L" << PrintLaneMask(P.second) 466 << " isn't found in tracked set\n"; 467 } 468 } 469 } 470 471 bool GCNUpwardRPTracker::isValid() const { 472 const auto &SI = LIS.getInstructionIndex(*LastTrackedMI).getBaseIndex(); 473 const auto LISLR = llvm::getLiveRegs(SI, LIS, *MRI); 474 const auto &TrackedLR = LiveRegs; 475 476 if (!isEqual(LISLR, TrackedLR)) { 477 dbgs() << "\nGCNUpwardRPTracker error: Tracked and" 478 " LIS reported livesets mismatch:\n"; 479 printLivesAt(SI, LIS, *MRI); 480 reportMismatch(LISLR, TrackedLR, MRI->getTargetRegisterInfo()); 481 return false; 482 } 483 484 auto LISPressure = getRegPressure(*MRI, LISLR); 485 if (LISPressure != CurPressure) { 486 dbgs() << "GCNUpwardRPTracker error: Pressure sets different\nTracked: "; 487 CurPressure.print(dbgs()); 488 dbgs() << "LIS rpt: "; 489 LISPressure.print(dbgs()); 490 return false; 491 } 492 return true; 493 } 494 495 void GCNRPTracker::printLiveRegs(raw_ostream &OS, const LiveRegSet& LiveRegs, 496 const MachineRegisterInfo &MRI) { 497 const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo(); 498 for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) { 499 unsigned Reg = TargetRegisterInfo::index2VirtReg(I); 500 auto It = LiveRegs.find(Reg); 501 if (It != LiveRegs.end() && It->second.any()) 502 OS << ' ' << printVRegOrUnit(Reg, TRI) << ':' 503 << PrintLaneMask(It->second); 504 } 505 OS << '\n'; 506 } 507 #endif 508