1 //===- GCNRegPressure.cpp -------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "GCNRegPressure.h" 10 #include "AMDGPUSubtarget.h" 11 #include "SIRegisterInfo.h" 12 #include "llvm/ADT/SmallVector.h" 13 #include "llvm/CodeGen/LiveInterval.h" 14 #include "llvm/CodeGen/LiveIntervals.h" 15 #include "llvm/CodeGen/MachineInstr.h" 16 #include "llvm/CodeGen/MachineOperand.h" 17 #include "llvm/CodeGen/MachineRegisterInfo.h" 18 #include "llvm/CodeGen/RegisterPressure.h" 19 #include "llvm/CodeGen/SlotIndexes.h" 20 #include "llvm/CodeGen/TargetRegisterInfo.h" 21 #include "llvm/Config/llvm-config.h" 22 #include "llvm/MC/LaneBitmask.h" 23 #include "llvm/Support/Compiler.h" 24 #include "llvm/Support/Debug.h" 25 #include "llvm/Support/ErrorHandling.h" 26 #include "llvm/Support/raw_ostream.h" 27 #include <algorithm> 28 #include <cassert> 29 30 using namespace llvm; 31 32 #define DEBUG_TYPE "machine-scheduler" 33 34 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) 35 LLVM_DUMP_METHOD 36 void llvm::printLivesAt(SlotIndex SI, 37 const LiveIntervals &LIS, 38 const MachineRegisterInfo &MRI) { 39 dbgs() << "Live regs at " << SI << ": " 40 << *LIS.getInstructionFromIndex(SI); 41 unsigned Num = 0; 42 for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) { 43 const unsigned Reg = TargetRegisterInfo::index2VirtReg(I); 44 if (!LIS.hasInterval(Reg)) 45 continue; 46 const auto &LI = LIS.getInterval(Reg); 47 if (LI.hasSubRanges()) { 48 bool firstTime = true; 49 for (const auto &S : LI.subranges()) { 50 if (!S.liveAt(SI)) continue; 51 if (firstTime) { 52 dbgs() << " " << printReg(Reg, MRI.getTargetRegisterInfo()) 53 << '\n'; 54 firstTime = false; 55 } 56 dbgs() << " " << S << '\n'; 57 ++Num; 58 } 59 } else if (LI.liveAt(SI)) { 60 dbgs() << " " << LI << '\n'; 61 ++Num; 62 } 63 } 64 if (!Num) dbgs() << " <none>\n"; 65 } 66 67 static bool isEqual(const GCNRPTracker::LiveRegSet &S1, 68 const GCNRPTracker::LiveRegSet &S2) { 69 if (S1.size() != S2.size()) 70 return false; 71 72 for (const auto &P : S1) { 73 auto I = S2.find(P.first); 74 if (I == S2.end() || I->second != P.second) 75 return false; 76 } 77 return true; 78 } 79 #endif 80 81 /////////////////////////////////////////////////////////////////////////////// 82 // GCNRegPressure 83 84 unsigned GCNRegPressure::getRegKind(unsigned Reg, 85 const MachineRegisterInfo &MRI) { 86 assert(TargetRegisterInfo::isVirtualRegister(Reg)); 87 const auto RC = MRI.getRegClass(Reg); 88 auto STI = static_cast<const SIRegisterInfo*>(MRI.getTargetRegisterInfo()); 89 return STI->isSGPRClass(RC) ? 90 (STI->getRegSizeInBits(*RC) == 32 ? SGPR32 : SGPR_TUPLE) : 91 (STI->getRegSizeInBits(*RC) == 32 ? VGPR32 : VGPR_TUPLE); 92 } 93 94 void GCNRegPressure::inc(unsigned Reg, 95 LaneBitmask PrevMask, 96 LaneBitmask NewMask, 97 const MachineRegisterInfo &MRI) { 98 if (NewMask == PrevMask) 99 return; 100 101 int Sign = 1; 102 if (NewMask < PrevMask) { 103 std::swap(NewMask, PrevMask); 104 Sign = -1; 105 } 106 #ifndef NDEBUG 107 const auto MaxMask = MRI.getMaxLaneMaskForVReg(Reg); 108 #endif 109 switch (auto Kind = getRegKind(Reg, MRI)) { 110 case SGPR32: 111 case VGPR32: 112 assert(PrevMask.none() && NewMask == MaxMask); 113 Value[Kind] += Sign; 114 break; 115 116 case SGPR_TUPLE: 117 case VGPR_TUPLE: 118 assert(NewMask < MaxMask || NewMask == MaxMask); 119 assert(PrevMask < NewMask); 120 121 Value[Kind == SGPR_TUPLE ? SGPR32 : VGPR32] += 122 Sign * (~PrevMask & NewMask).getNumLanes(); 123 124 if (PrevMask.none()) { 125 assert(NewMask.any()); 126 Value[Kind] += Sign * MRI.getPressureSets(Reg).getWeight(); 127 } 128 break; 129 130 default: llvm_unreachable("Unknown register kind"); 131 } 132 } 133 134 bool GCNRegPressure::less(const GCNSubtarget &ST, 135 const GCNRegPressure& O, 136 unsigned MaxOccupancy) const { 137 const auto SGPROcc = std::min(MaxOccupancy, 138 ST.getOccupancyWithNumSGPRs(getSGPRNum())); 139 const auto VGPROcc = std::min(MaxOccupancy, 140 ST.getOccupancyWithNumVGPRs(getVGPRNum())); 141 const auto OtherSGPROcc = std::min(MaxOccupancy, 142 ST.getOccupancyWithNumSGPRs(O.getSGPRNum())); 143 const auto OtherVGPROcc = std::min(MaxOccupancy, 144 ST.getOccupancyWithNumVGPRs(O.getVGPRNum())); 145 146 const auto Occ = std::min(SGPROcc, VGPROcc); 147 const auto OtherOcc = std::min(OtherSGPROcc, OtherVGPROcc); 148 if (Occ != OtherOcc) 149 return Occ > OtherOcc; 150 151 bool SGPRImportant = SGPROcc < VGPROcc; 152 const bool OtherSGPRImportant = OtherSGPROcc < OtherVGPROcc; 153 154 // if both pressures disagree on what is more important compare vgprs 155 if (SGPRImportant != OtherSGPRImportant) { 156 SGPRImportant = false; 157 } 158 159 // compare large regs pressure 160 bool SGPRFirst = SGPRImportant; 161 for (int I = 2; I > 0; --I, SGPRFirst = !SGPRFirst) { 162 if (SGPRFirst) { 163 auto SW = getSGPRTuplesWeight(); 164 auto OtherSW = O.getSGPRTuplesWeight(); 165 if (SW != OtherSW) 166 return SW < OtherSW; 167 } else { 168 auto VW = getVGPRTuplesWeight(); 169 auto OtherVW = O.getVGPRTuplesWeight(); 170 if (VW != OtherVW) 171 return VW < OtherVW; 172 } 173 } 174 return SGPRImportant ? (getSGPRNum() < O.getSGPRNum()): 175 (getVGPRNum() < O.getVGPRNum()); 176 } 177 178 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) 179 LLVM_DUMP_METHOD 180 void GCNRegPressure::print(raw_ostream &OS, const GCNSubtarget *ST) const { 181 OS << "VGPRs: " << getVGPRNum(); 182 if (ST) OS << "(O" << ST->getOccupancyWithNumVGPRs(getVGPRNum()) << ')'; 183 OS << ", SGPRs: " << getSGPRNum(); 184 if (ST) OS << "(O" << ST->getOccupancyWithNumSGPRs(getSGPRNum()) << ')'; 185 OS << ", LVGPR WT: " << getVGPRTuplesWeight() 186 << ", LSGPR WT: " << getSGPRTuplesWeight(); 187 if (ST) OS << " -> Occ: " << getOccupancy(*ST); 188 OS << '\n'; 189 } 190 #endif 191 192 static LaneBitmask getDefRegMask(const MachineOperand &MO, 193 const MachineRegisterInfo &MRI) { 194 assert(MO.isDef() && MO.isReg() && 195 TargetRegisterInfo::isVirtualRegister(MO.getReg())); 196 197 // We don't rely on read-undef flag because in case of tentative schedule 198 // tracking it isn't set correctly yet. This works correctly however since 199 // use mask has been tracked before using LIS. 200 return MO.getSubReg() == 0 ? 201 MRI.getMaxLaneMaskForVReg(MO.getReg()) : 202 MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(MO.getSubReg()); 203 } 204 205 static LaneBitmask getUsedRegMask(const MachineOperand &MO, 206 const MachineRegisterInfo &MRI, 207 const LiveIntervals &LIS) { 208 assert(MO.isUse() && MO.isReg() && 209 TargetRegisterInfo::isVirtualRegister(MO.getReg())); 210 211 if (auto SubReg = MO.getSubReg()) 212 return MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(SubReg); 213 214 auto MaxMask = MRI.getMaxLaneMaskForVReg(MO.getReg()); 215 if (MaxMask == LaneBitmask::getLane(0)) // cannot have subregs 216 return MaxMask; 217 218 // For a tentative schedule LIS isn't updated yet but livemask should remain 219 // the same on any schedule. Subreg defs can be reordered but they all must 220 // dominate uses anyway. 221 auto SI = LIS.getInstructionIndex(*MO.getParent()).getBaseIndex(); 222 return getLiveLaneMask(MO.getReg(), SI, LIS, MRI); 223 } 224 225 static SmallVector<RegisterMaskPair, 8> 226 collectVirtualRegUses(const MachineInstr &MI, const LiveIntervals &LIS, 227 const MachineRegisterInfo &MRI) { 228 SmallVector<RegisterMaskPair, 8> Res; 229 for (const auto &MO : MI.operands()) { 230 if (!MO.isReg() || !TargetRegisterInfo::isVirtualRegister(MO.getReg())) 231 continue; 232 if (!MO.isUse() || !MO.readsReg()) 233 continue; 234 235 auto const UsedMask = getUsedRegMask(MO, MRI, LIS); 236 237 auto Reg = MO.getReg(); 238 auto I = std::find_if(Res.begin(), Res.end(), [Reg](const RegisterMaskPair &RM) { 239 return RM.RegUnit == Reg; 240 }); 241 if (I != Res.end()) 242 I->LaneMask |= UsedMask; 243 else 244 Res.push_back(RegisterMaskPair(Reg, UsedMask)); 245 } 246 return Res; 247 } 248 249 /////////////////////////////////////////////////////////////////////////////// 250 // GCNRPTracker 251 252 LaneBitmask llvm::getLiveLaneMask(unsigned Reg, 253 SlotIndex SI, 254 const LiveIntervals &LIS, 255 const MachineRegisterInfo &MRI) { 256 LaneBitmask LiveMask; 257 const auto &LI = LIS.getInterval(Reg); 258 if (LI.hasSubRanges()) { 259 for (const auto &S : LI.subranges()) 260 if (S.liveAt(SI)) { 261 LiveMask |= S.LaneMask; 262 assert(LiveMask < MRI.getMaxLaneMaskForVReg(Reg) || 263 LiveMask == MRI.getMaxLaneMaskForVReg(Reg)); 264 } 265 } else if (LI.liveAt(SI)) { 266 LiveMask = MRI.getMaxLaneMaskForVReg(Reg); 267 } 268 return LiveMask; 269 } 270 271 GCNRPTracker::LiveRegSet llvm::getLiveRegs(SlotIndex SI, 272 const LiveIntervals &LIS, 273 const MachineRegisterInfo &MRI) { 274 GCNRPTracker::LiveRegSet LiveRegs; 275 for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) { 276 auto Reg = TargetRegisterInfo::index2VirtReg(I); 277 if (!LIS.hasInterval(Reg)) 278 continue; 279 auto LiveMask = getLiveLaneMask(Reg, SI, LIS, MRI); 280 if (LiveMask.any()) 281 LiveRegs[Reg] = LiveMask; 282 } 283 return LiveRegs; 284 } 285 286 void GCNRPTracker::reset(const MachineInstr &MI, 287 const LiveRegSet *LiveRegsCopy, 288 bool After) { 289 const MachineFunction &MF = *MI.getMF(); 290 MRI = &MF.getRegInfo(); 291 if (LiveRegsCopy) { 292 if (&LiveRegs != LiveRegsCopy) 293 LiveRegs = *LiveRegsCopy; 294 } else { 295 LiveRegs = After ? getLiveRegsAfter(MI, LIS) 296 : getLiveRegsBefore(MI, LIS); 297 } 298 299 MaxPressure = CurPressure = getRegPressure(*MRI, LiveRegs); 300 } 301 302 void GCNUpwardRPTracker::reset(const MachineInstr &MI, 303 const LiveRegSet *LiveRegsCopy) { 304 GCNRPTracker::reset(MI, LiveRegsCopy, true); 305 } 306 307 void GCNUpwardRPTracker::recede(const MachineInstr &MI) { 308 assert(MRI && "call reset first"); 309 310 LastTrackedMI = &MI; 311 312 if (MI.isDebugInstr()) 313 return; 314 315 auto const RegUses = collectVirtualRegUses(MI, LIS, *MRI); 316 317 // calc pressure at the MI (defs + uses) 318 auto AtMIPressure = CurPressure; 319 for (const auto &U : RegUses) { 320 auto LiveMask = LiveRegs[U.RegUnit]; 321 AtMIPressure.inc(U.RegUnit, LiveMask, LiveMask | U.LaneMask, *MRI); 322 } 323 // update max pressure 324 MaxPressure = max(AtMIPressure, MaxPressure); 325 326 for (const auto &MO : MI.defs()) { 327 if (!MO.isReg() || !TargetRegisterInfo::isVirtualRegister(MO.getReg()) || 328 MO.isDead()) 329 continue; 330 331 auto Reg = MO.getReg(); 332 auto I = LiveRegs.find(Reg); 333 if (I == LiveRegs.end()) 334 continue; 335 auto &LiveMask = I->second; 336 auto PrevMask = LiveMask; 337 LiveMask &= ~getDefRegMask(MO, *MRI); 338 CurPressure.inc(Reg, PrevMask, LiveMask, *MRI); 339 if (LiveMask.none()) 340 LiveRegs.erase(I); 341 } 342 for (const auto &U : RegUses) { 343 auto &LiveMask = LiveRegs[U.RegUnit]; 344 auto PrevMask = LiveMask; 345 LiveMask |= U.LaneMask; 346 CurPressure.inc(U.RegUnit, PrevMask, LiveMask, *MRI); 347 } 348 assert(CurPressure == getRegPressure(*MRI, LiveRegs)); 349 } 350 351 bool GCNDownwardRPTracker::reset(const MachineInstr &MI, 352 const LiveRegSet *LiveRegsCopy) { 353 MRI = &MI.getParent()->getParent()->getRegInfo(); 354 LastTrackedMI = nullptr; 355 MBBEnd = MI.getParent()->end(); 356 NextMI = &MI; 357 NextMI = skipDebugInstructionsForward(NextMI, MBBEnd); 358 if (NextMI == MBBEnd) 359 return false; 360 GCNRPTracker::reset(*NextMI, LiveRegsCopy, false); 361 return true; 362 } 363 364 bool GCNDownwardRPTracker::advanceBeforeNext() { 365 assert(MRI && "call reset first"); 366 367 NextMI = skipDebugInstructionsForward(NextMI, MBBEnd); 368 if (NextMI == MBBEnd) 369 return false; 370 371 SlotIndex SI = LIS.getInstructionIndex(*NextMI).getBaseIndex(); 372 assert(SI.isValid()); 373 374 // Remove dead registers or mask bits. 375 for (auto &It : LiveRegs) { 376 const LiveInterval &LI = LIS.getInterval(It.first); 377 if (LI.hasSubRanges()) { 378 for (const auto &S : LI.subranges()) { 379 if (!S.liveAt(SI)) { 380 auto PrevMask = It.second; 381 It.second &= ~S.LaneMask; 382 CurPressure.inc(It.first, PrevMask, It.second, *MRI); 383 } 384 } 385 } else if (!LI.liveAt(SI)) { 386 auto PrevMask = It.second; 387 It.second = LaneBitmask::getNone(); 388 CurPressure.inc(It.first, PrevMask, It.second, *MRI); 389 } 390 if (It.second.none()) 391 LiveRegs.erase(It.first); 392 } 393 394 MaxPressure = max(MaxPressure, CurPressure); 395 396 return true; 397 } 398 399 void GCNDownwardRPTracker::advanceToNext() { 400 LastTrackedMI = &*NextMI++; 401 402 // Add new registers or mask bits. 403 for (const auto &MO : LastTrackedMI->defs()) { 404 if (!MO.isReg()) 405 continue; 406 unsigned Reg = MO.getReg(); 407 if (!TargetRegisterInfo::isVirtualRegister(Reg)) 408 continue; 409 auto &LiveMask = LiveRegs[Reg]; 410 auto PrevMask = LiveMask; 411 LiveMask |= getDefRegMask(MO, *MRI); 412 CurPressure.inc(Reg, PrevMask, LiveMask, *MRI); 413 } 414 415 MaxPressure = max(MaxPressure, CurPressure); 416 } 417 418 bool GCNDownwardRPTracker::advance() { 419 // If we have just called reset live set is actual. 420 if ((NextMI == MBBEnd) || (LastTrackedMI && !advanceBeforeNext())) 421 return false; 422 advanceToNext(); 423 return true; 424 } 425 426 bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator End) { 427 while (NextMI != End) 428 if (!advance()) return false; 429 return true; 430 } 431 432 bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator Begin, 433 MachineBasicBlock::const_iterator End, 434 const LiveRegSet *LiveRegsCopy) { 435 reset(*Begin, LiveRegsCopy); 436 return advance(End); 437 } 438 439 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) 440 LLVM_DUMP_METHOD 441 static void reportMismatch(const GCNRPTracker::LiveRegSet &LISLR, 442 const GCNRPTracker::LiveRegSet &TrackedLR, 443 const TargetRegisterInfo *TRI) { 444 for (auto const &P : TrackedLR) { 445 auto I = LISLR.find(P.first); 446 if (I == LISLR.end()) { 447 dbgs() << " " << printReg(P.first, TRI) 448 << ":L" << PrintLaneMask(P.second) 449 << " isn't found in LIS reported set\n"; 450 } 451 else if (I->second != P.second) { 452 dbgs() << " " << printReg(P.first, TRI) 453 << " masks doesn't match: LIS reported " 454 << PrintLaneMask(I->second) 455 << ", tracked " 456 << PrintLaneMask(P.second) 457 << '\n'; 458 } 459 } 460 for (auto const &P : LISLR) { 461 auto I = TrackedLR.find(P.first); 462 if (I == TrackedLR.end()) { 463 dbgs() << " " << printReg(P.first, TRI) 464 << ":L" << PrintLaneMask(P.second) 465 << " isn't found in tracked set\n"; 466 } 467 } 468 } 469 470 bool GCNUpwardRPTracker::isValid() const { 471 const auto &SI = LIS.getInstructionIndex(*LastTrackedMI).getBaseIndex(); 472 const auto LISLR = llvm::getLiveRegs(SI, LIS, *MRI); 473 const auto &TrackedLR = LiveRegs; 474 475 if (!isEqual(LISLR, TrackedLR)) { 476 dbgs() << "\nGCNUpwardRPTracker error: Tracked and" 477 " LIS reported livesets mismatch:\n"; 478 printLivesAt(SI, LIS, *MRI); 479 reportMismatch(LISLR, TrackedLR, MRI->getTargetRegisterInfo()); 480 return false; 481 } 482 483 auto LISPressure = getRegPressure(*MRI, LISLR); 484 if (LISPressure != CurPressure) { 485 dbgs() << "GCNUpwardRPTracker error: Pressure sets different\nTracked: "; 486 CurPressure.print(dbgs()); 487 dbgs() << "LIS rpt: "; 488 LISPressure.print(dbgs()); 489 return false; 490 } 491 return true; 492 } 493 494 void GCNRPTracker::printLiveRegs(raw_ostream &OS, const LiveRegSet& LiveRegs, 495 const MachineRegisterInfo &MRI) { 496 const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo(); 497 for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) { 498 unsigned Reg = TargetRegisterInfo::index2VirtReg(I); 499 auto It = LiveRegs.find(Reg); 500 if (It != LiveRegs.end() && It->second.any()) 501 OS << ' ' << printVRegOrUnit(Reg, TRI) << ':' 502 << PrintLaneMask(It->second); 503 } 504 OS << '\n'; 505 } 506 #endif 507