1 //===- GCNRegPressure.cpp -------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 /// 9 /// \file 10 /// This file implements the GCNRegPressure class. 11 /// 12 //===----------------------------------------------------------------------===// 13 14 #include "GCNRegPressure.h" 15 #include "AMDGPUSubtarget.h" 16 #include "SIRegisterInfo.h" 17 #include "llvm/ADT/SmallVector.h" 18 #include "llvm/CodeGen/LiveInterval.h" 19 #include "llvm/CodeGen/LiveIntervals.h" 20 #include "llvm/CodeGen/MachineInstr.h" 21 #include "llvm/CodeGen/MachineOperand.h" 22 #include "llvm/CodeGen/MachineRegisterInfo.h" 23 #include "llvm/CodeGen/RegisterPressure.h" 24 #include "llvm/CodeGen/SlotIndexes.h" 25 #include "llvm/CodeGen/TargetRegisterInfo.h" 26 #include "llvm/Config/llvm-config.h" 27 #include "llvm/MC/LaneBitmask.h" 28 #include "llvm/Support/Compiler.h" 29 #include "llvm/Support/Debug.h" 30 #include "llvm/Support/ErrorHandling.h" 31 #include "llvm/Support/raw_ostream.h" 32 #include <algorithm> 33 #include <cassert> 34 35 using namespace llvm; 36 37 #define DEBUG_TYPE "machine-scheduler" 38 39 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) 40 LLVM_DUMP_METHOD 41 void llvm::printLivesAt(SlotIndex SI, 42 const LiveIntervals &LIS, 43 const MachineRegisterInfo &MRI) { 44 dbgs() << "Live regs at " << SI << ": " 45 << *LIS.getInstructionFromIndex(SI); 46 unsigned Num = 0; 47 for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) { 48 const unsigned Reg = Register::index2VirtReg(I); 49 if (!LIS.hasInterval(Reg)) 50 continue; 51 const auto &LI = LIS.getInterval(Reg); 52 if (LI.hasSubRanges()) { 53 bool firstTime = true; 54 for (const auto &S : LI.subranges()) { 55 if (!S.liveAt(SI)) continue; 56 if (firstTime) { 57 dbgs() << " " << printReg(Reg, MRI.getTargetRegisterInfo()) 58 << '\n'; 59 firstTime = false; 60 } 61 dbgs() << " " << S << '\n'; 62 ++Num; 63 } 64 } else if (LI.liveAt(SI)) { 65 dbgs() << " " << LI << '\n'; 66 ++Num; 67 } 68 } 69 if (!Num) dbgs() << " <none>\n"; 70 } 71 #endif 72 73 bool llvm::isEqual(const GCNRPTracker::LiveRegSet &S1, 74 const GCNRPTracker::LiveRegSet &S2) { 75 if (S1.size() != S2.size()) 76 return false; 77 78 for (const auto &P : S1) { 79 auto I = S2.find(P.first); 80 if (I == S2.end() || I->second != P.second) 81 return false; 82 } 83 return true; 84 } 85 86 87 /////////////////////////////////////////////////////////////////////////////// 88 // GCNRegPressure 89 90 unsigned GCNRegPressure::getRegKind(unsigned Reg, 91 const MachineRegisterInfo &MRI) { 92 assert(Register::isVirtualRegister(Reg)); 93 const auto RC = MRI.getRegClass(Reg); 94 auto STI = static_cast<const SIRegisterInfo*>(MRI.getTargetRegisterInfo()); 95 return STI->isSGPRClass(RC) ? 96 (STI->getRegSizeInBits(*RC) == 32 ? SGPR32 : SGPR_TUPLE) : 97 STI->hasAGPRs(RC) ? 98 (STI->getRegSizeInBits(*RC) == 32 ? AGPR32 : AGPR_TUPLE) : 99 (STI->getRegSizeInBits(*RC) == 32 ? VGPR32 : VGPR_TUPLE); 100 } 101 102 void GCNRegPressure::inc(unsigned Reg, 103 LaneBitmask PrevMask, 104 LaneBitmask NewMask, 105 const MachineRegisterInfo &MRI) { 106 if (NewMask == PrevMask) 107 return; 108 109 int Sign = 1; 110 if (NewMask < PrevMask) { 111 std::swap(NewMask, PrevMask); 112 Sign = -1; 113 } 114 #ifndef NDEBUG 115 const auto MaxMask = MRI.getMaxLaneMaskForVReg(Reg); 116 #endif 117 switch (auto Kind = getRegKind(Reg, MRI)) { 118 case SGPR32: 119 case VGPR32: 120 case AGPR32: 121 assert(PrevMask.none() && NewMask == MaxMask); 122 Value[Kind] += Sign; 123 break; 124 125 case SGPR_TUPLE: 126 case VGPR_TUPLE: 127 case AGPR_TUPLE: 128 assert(NewMask < MaxMask || NewMask == MaxMask); 129 assert(PrevMask < NewMask); 130 131 Value[Kind == SGPR_TUPLE ? SGPR32 : Kind == AGPR_TUPLE ? AGPR32 : VGPR32] += 132 Sign * SIRegisterInfo::getNumCoveredRegs(~PrevMask & NewMask); 133 134 if (PrevMask.none()) { 135 assert(NewMask.any()); 136 Value[Kind] += Sign * MRI.getPressureSets(Reg).getWeight(); 137 } 138 break; 139 140 default: llvm_unreachable("Unknown register kind"); 141 } 142 } 143 144 bool GCNRegPressure::less(const GCNSubtarget &ST, 145 const GCNRegPressure& O, 146 unsigned MaxOccupancy) const { 147 const auto SGPROcc = std::min(MaxOccupancy, 148 ST.getOccupancyWithNumSGPRs(getSGPRNum())); 149 const auto VGPROcc = std::min(MaxOccupancy, 150 ST.getOccupancyWithNumVGPRs(getVGPRNum())); 151 const auto OtherSGPROcc = std::min(MaxOccupancy, 152 ST.getOccupancyWithNumSGPRs(O.getSGPRNum())); 153 const auto OtherVGPROcc = std::min(MaxOccupancy, 154 ST.getOccupancyWithNumVGPRs(O.getVGPRNum())); 155 156 const auto Occ = std::min(SGPROcc, VGPROcc); 157 const auto OtherOcc = std::min(OtherSGPROcc, OtherVGPROcc); 158 if (Occ != OtherOcc) 159 return Occ > OtherOcc; 160 161 bool SGPRImportant = SGPROcc < VGPROcc; 162 const bool OtherSGPRImportant = OtherSGPROcc < OtherVGPROcc; 163 164 // if both pressures disagree on what is more important compare vgprs 165 if (SGPRImportant != OtherSGPRImportant) { 166 SGPRImportant = false; 167 } 168 169 // compare large regs pressure 170 bool SGPRFirst = SGPRImportant; 171 for (int I = 2; I > 0; --I, SGPRFirst = !SGPRFirst) { 172 if (SGPRFirst) { 173 auto SW = getSGPRTuplesWeight(); 174 auto OtherSW = O.getSGPRTuplesWeight(); 175 if (SW != OtherSW) 176 return SW < OtherSW; 177 } else { 178 auto VW = getVGPRTuplesWeight(); 179 auto OtherVW = O.getVGPRTuplesWeight(); 180 if (VW != OtherVW) 181 return VW < OtherVW; 182 } 183 } 184 return SGPRImportant ? (getSGPRNum() < O.getSGPRNum()): 185 (getVGPRNum() < O.getVGPRNum()); 186 } 187 188 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) 189 LLVM_DUMP_METHOD 190 void GCNRegPressure::print(raw_ostream &OS, const GCNSubtarget *ST) const { 191 OS << "VGPRs: " << Value[VGPR32] << ' '; 192 OS << "AGPRs: " << Value[AGPR32]; 193 if (ST) OS << "(O" << ST->getOccupancyWithNumVGPRs(getVGPRNum()) << ')'; 194 OS << ", SGPRs: " << getSGPRNum(); 195 if (ST) OS << "(O" << ST->getOccupancyWithNumSGPRs(getSGPRNum()) << ')'; 196 OS << ", LVGPR WT: " << getVGPRTuplesWeight() 197 << ", LSGPR WT: " << getSGPRTuplesWeight(); 198 if (ST) OS << " -> Occ: " << getOccupancy(*ST); 199 OS << '\n'; 200 } 201 #endif 202 203 static LaneBitmask getDefRegMask(const MachineOperand &MO, 204 const MachineRegisterInfo &MRI) { 205 assert(MO.isDef() && MO.isReg() && Register::isVirtualRegister(MO.getReg())); 206 207 // We don't rely on read-undef flag because in case of tentative schedule 208 // tracking it isn't set correctly yet. This works correctly however since 209 // use mask has been tracked before using LIS. 210 return MO.getSubReg() == 0 ? 211 MRI.getMaxLaneMaskForVReg(MO.getReg()) : 212 MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(MO.getSubReg()); 213 } 214 215 static LaneBitmask getUsedRegMask(const MachineOperand &MO, 216 const MachineRegisterInfo &MRI, 217 const LiveIntervals &LIS) { 218 assert(MO.isUse() && MO.isReg() && Register::isVirtualRegister(MO.getReg())); 219 220 if (auto SubReg = MO.getSubReg()) 221 return MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(SubReg); 222 223 auto MaxMask = MRI.getMaxLaneMaskForVReg(MO.getReg()); 224 if (SIRegisterInfo::getNumCoveredRegs(MaxMask) > 1) // cannot have subregs 225 return MaxMask; 226 227 // For a tentative schedule LIS isn't updated yet but livemask should remain 228 // the same on any schedule. Subreg defs can be reordered but they all must 229 // dominate uses anyway. 230 auto SI = LIS.getInstructionIndex(*MO.getParent()).getBaseIndex(); 231 return getLiveLaneMask(MO.getReg(), SI, LIS, MRI); 232 } 233 234 static SmallVector<RegisterMaskPair, 8> 235 collectVirtualRegUses(const MachineInstr &MI, const LiveIntervals &LIS, 236 const MachineRegisterInfo &MRI) { 237 SmallVector<RegisterMaskPair, 8> Res; 238 for (const auto &MO : MI.operands()) { 239 if (!MO.isReg() || !Register::isVirtualRegister(MO.getReg())) 240 continue; 241 if (!MO.isUse() || !MO.readsReg()) 242 continue; 243 244 auto const UsedMask = getUsedRegMask(MO, MRI, LIS); 245 246 auto Reg = MO.getReg(); 247 auto I = std::find_if(Res.begin(), Res.end(), [Reg](const RegisterMaskPair &RM) { 248 return RM.RegUnit == Reg; 249 }); 250 if (I != Res.end()) 251 I->LaneMask |= UsedMask; 252 else 253 Res.push_back(RegisterMaskPair(Reg, UsedMask)); 254 } 255 return Res; 256 } 257 258 /////////////////////////////////////////////////////////////////////////////// 259 // GCNRPTracker 260 261 LaneBitmask llvm::getLiveLaneMask(unsigned Reg, 262 SlotIndex SI, 263 const LiveIntervals &LIS, 264 const MachineRegisterInfo &MRI) { 265 LaneBitmask LiveMask; 266 const auto &LI = LIS.getInterval(Reg); 267 if (LI.hasSubRanges()) { 268 for (const auto &S : LI.subranges()) 269 if (S.liveAt(SI)) { 270 LiveMask |= S.LaneMask; 271 assert(LiveMask < MRI.getMaxLaneMaskForVReg(Reg) || 272 LiveMask == MRI.getMaxLaneMaskForVReg(Reg)); 273 } 274 } else if (LI.liveAt(SI)) { 275 LiveMask = MRI.getMaxLaneMaskForVReg(Reg); 276 } 277 return LiveMask; 278 } 279 280 GCNRPTracker::LiveRegSet llvm::getLiveRegs(SlotIndex SI, 281 const LiveIntervals &LIS, 282 const MachineRegisterInfo &MRI) { 283 GCNRPTracker::LiveRegSet LiveRegs; 284 for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) { 285 auto Reg = Register::index2VirtReg(I); 286 if (!LIS.hasInterval(Reg)) 287 continue; 288 auto LiveMask = getLiveLaneMask(Reg, SI, LIS, MRI); 289 if (LiveMask.any()) 290 LiveRegs[Reg] = LiveMask; 291 } 292 return LiveRegs; 293 } 294 295 void GCNRPTracker::reset(const MachineInstr &MI, 296 const LiveRegSet *LiveRegsCopy, 297 bool After) { 298 const MachineFunction &MF = *MI.getMF(); 299 MRI = &MF.getRegInfo(); 300 if (LiveRegsCopy) { 301 if (&LiveRegs != LiveRegsCopy) 302 LiveRegs = *LiveRegsCopy; 303 } else { 304 LiveRegs = After ? getLiveRegsAfter(MI, LIS) 305 : getLiveRegsBefore(MI, LIS); 306 } 307 308 MaxPressure = CurPressure = getRegPressure(*MRI, LiveRegs); 309 } 310 311 void GCNUpwardRPTracker::reset(const MachineInstr &MI, 312 const LiveRegSet *LiveRegsCopy) { 313 GCNRPTracker::reset(MI, LiveRegsCopy, true); 314 } 315 316 void GCNUpwardRPTracker::recede(const MachineInstr &MI) { 317 assert(MRI && "call reset first"); 318 319 LastTrackedMI = &MI; 320 321 if (MI.isDebugInstr()) 322 return; 323 324 auto const RegUses = collectVirtualRegUses(MI, LIS, *MRI); 325 326 // calc pressure at the MI (defs + uses) 327 auto AtMIPressure = CurPressure; 328 for (const auto &U : RegUses) { 329 auto LiveMask = LiveRegs[U.RegUnit]; 330 AtMIPressure.inc(U.RegUnit, LiveMask, LiveMask | U.LaneMask, *MRI); 331 } 332 // update max pressure 333 MaxPressure = max(AtMIPressure, MaxPressure); 334 335 for (const auto &MO : MI.operands()) { 336 if (!MO.isReg() || !MO.isDef() || 337 !Register::isVirtualRegister(MO.getReg()) || MO.isDead()) 338 continue; 339 340 auto Reg = MO.getReg(); 341 auto I = LiveRegs.find(Reg); 342 if (I == LiveRegs.end()) 343 continue; 344 auto &LiveMask = I->second; 345 auto PrevMask = LiveMask; 346 LiveMask &= ~getDefRegMask(MO, *MRI); 347 CurPressure.inc(Reg, PrevMask, LiveMask, *MRI); 348 if (LiveMask.none()) 349 LiveRegs.erase(I); 350 } 351 for (const auto &U : RegUses) { 352 auto &LiveMask = LiveRegs[U.RegUnit]; 353 auto PrevMask = LiveMask; 354 LiveMask |= U.LaneMask; 355 CurPressure.inc(U.RegUnit, PrevMask, LiveMask, *MRI); 356 } 357 assert(CurPressure == getRegPressure(*MRI, LiveRegs)); 358 } 359 360 bool GCNDownwardRPTracker::reset(const MachineInstr &MI, 361 const LiveRegSet *LiveRegsCopy) { 362 MRI = &MI.getParent()->getParent()->getRegInfo(); 363 LastTrackedMI = nullptr; 364 MBBEnd = MI.getParent()->end(); 365 NextMI = &MI; 366 NextMI = skipDebugInstructionsForward(NextMI, MBBEnd); 367 if (NextMI == MBBEnd) 368 return false; 369 GCNRPTracker::reset(*NextMI, LiveRegsCopy, false); 370 return true; 371 } 372 373 bool GCNDownwardRPTracker::advanceBeforeNext() { 374 assert(MRI && "call reset first"); 375 376 NextMI = skipDebugInstructionsForward(NextMI, MBBEnd); 377 if (NextMI == MBBEnd) 378 return false; 379 380 SlotIndex SI = LIS.getInstructionIndex(*NextMI).getBaseIndex(); 381 assert(SI.isValid()); 382 383 // Remove dead registers or mask bits. 384 for (auto &It : LiveRegs) { 385 const LiveInterval &LI = LIS.getInterval(It.first); 386 if (LI.hasSubRanges()) { 387 for (const auto &S : LI.subranges()) { 388 if (!S.liveAt(SI)) { 389 auto PrevMask = It.second; 390 It.second &= ~S.LaneMask; 391 CurPressure.inc(It.first, PrevMask, It.second, *MRI); 392 } 393 } 394 } else if (!LI.liveAt(SI)) { 395 auto PrevMask = It.second; 396 It.second = LaneBitmask::getNone(); 397 CurPressure.inc(It.first, PrevMask, It.second, *MRI); 398 } 399 if (It.second.none()) 400 LiveRegs.erase(It.first); 401 } 402 403 MaxPressure = max(MaxPressure, CurPressure); 404 405 return true; 406 } 407 408 void GCNDownwardRPTracker::advanceToNext() { 409 LastTrackedMI = &*NextMI++; 410 411 // Add new registers or mask bits. 412 for (const auto &MO : LastTrackedMI->operands()) { 413 if (!MO.isReg() || !MO.isDef()) 414 continue; 415 Register Reg = MO.getReg(); 416 if (!Register::isVirtualRegister(Reg)) 417 continue; 418 auto &LiveMask = LiveRegs[Reg]; 419 auto PrevMask = LiveMask; 420 LiveMask |= getDefRegMask(MO, *MRI); 421 CurPressure.inc(Reg, PrevMask, LiveMask, *MRI); 422 } 423 424 MaxPressure = max(MaxPressure, CurPressure); 425 } 426 427 bool GCNDownwardRPTracker::advance() { 428 // If we have just called reset live set is actual. 429 if ((NextMI == MBBEnd) || (LastTrackedMI && !advanceBeforeNext())) 430 return false; 431 advanceToNext(); 432 return true; 433 } 434 435 bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator End) { 436 while (NextMI != End) 437 if (!advance()) return false; 438 return true; 439 } 440 441 bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator Begin, 442 MachineBasicBlock::const_iterator End, 443 const LiveRegSet *LiveRegsCopy) { 444 reset(*Begin, LiveRegsCopy); 445 return advance(End); 446 } 447 448 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) 449 LLVM_DUMP_METHOD 450 static void reportMismatch(const GCNRPTracker::LiveRegSet &LISLR, 451 const GCNRPTracker::LiveRegSet &TrackedLR, 452 const TargetRegisterInfo *TRI) { 453 for (auto const &P : TrackedLR) { 454 auto I = LISLR.find(P.first); 455 if (I == LISLR.end()) { 456 dbgs() << " " << printReg(P.first, TRI) 457 << ":L" << PrintLaneMask(P.second) 458 << " isn't found in LIS reported set\n"; 459 } 460 else if (I->second != P.second) { 461 dbgs() << " " << printReg(P.first, TRI) 462 << " masks doesn't match: LIS reported " 463 << PrintLaneMask(I->second) 464 << ", tracked " 465 << PrintLaneMask(P.second) 466 << '\n'; 467 } 468 } 469 for (auto const &P : LISLR) { 470 auto I = TrackedLR.find(P.first); 471 if (I == TrackedLR.end()) { 472 dbgs() << " " << printReg(P.first, TRI) 473 << ":L" << PrintLaneMask(P.second) 474 << " isn't found in tracked set\n"; 475 } 476 } 477 } 478 479 bool GCNUpwardRPTracker::isValid() const { 480 const auto &SI = LIS.getInstructionIndex(*LastTrackedMI).getBaseIndex(); 481 const auto LISLR = llvm::getLiveRegs(SI, LIS, *MRI); 482 const auto &TrackedLR = LiveRegs; 483 484 if (!isEqual(LISLR, TrackedLR)) { 485 dbgs() << "\nGCNUpwardRPTracker error: Tracked and" 486 " LIS reported livesets mismatch:\n"; 487 printLivesAt(SI, LIS, *MRI); 488 reportMismatch(LISLR, TrackedLR, MRI->getTargetRegisterInfo()); 489 return false; 490 } 491 492 auto LISPressure = getRegPressure(*MRI, LISLR); 493 if (LISPressure != CurPressure) { 494 dbgs() << "GCNUpwardRPTracker error: Pressure sets different\nTracked: "; 495 CurPressure.print(dbgs()); 496 dbgs() << "LIS rpt: "; 497 LISPressure.print(dbgs()); 498 return false; 499 } 500 return true; 501 } 502 503 void GCNRPTracker::printLiveRegs(raw_ostream &OS, const LiveRegSet& LiveRegs, 504 const MachineRegisterInfo &MRI) { 505 const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo(); 506 for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) { 507 unsigned Reg = Register::index2VirtReg(I); 508 auto It = LiveRegs.find(Reg); 509 if (It != LiveRegs.end() && It->second.any()) 510 OS << ' ' << printVRegOrUnit(Reg, TRI) << ':' 511 << PrintLaneMask(It->second); 512 } 513 OS << '\n'; 514 } 515 #endif 516