1 //===- GCNRegPressure.cpp -------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "GCNRegPressure.h" 10 #include "AMDGPUSubtarget.h" 11 #include "SIRegisterInfo.h" 12 #include "llvm/ADT/SmallVector.h" 13 #include "llvm/CodeGen/LiveInterval.h" 14 #include "llvm/CodeGen/LiveIntervals.h" 15 #include "llvm/CodeGen/MachineInstr.h" 16 #include "llvm/CodeGen/MachineOperand.h" 17 #include "llvm/CodeGen/MachineRegisterInfo.h" 18 #include "llvm/CodeGen/RegisterPressure.h" 19 #include "llvm/CodeGen/SlotIndexes.h" 20 #include "llvm/CodeGen/TargetRegisterInfo.h" 21 #include "llvm/Config/llvm-config.h" 22 #include "llvm/MC/LaneBitmask.h" 23 #include "llvm/Support/Compiler.h" 24 #include "llvm/Support/Debug.h" 25 #include "llvm/Support/ErrorHandling.h" 26 #include "llvm/Support/raw_ostream.h" 27 #include <algorithm> 28 #include <cassert> 29 30 using namespace llvm; 31 32 #define DEBUG_TYPE "machine-scheduler" 33 34 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) 35 LLVM_DUMP_METHOD 36 void llvm::printLivesAt(SlotIndex SI, 37 const LiveIntervals &LIS, 38 const MachineRegisterInfo &MRI) { 39 dbgs() << "Live regs at " << SI << ": " 40 << *LIS.getInstructionFromIndex(SI); 41 unsigned Num = 0; 42 for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) { 43 const unsigned Reg = Register::index2VirtReg(I); 44 if (!LIS.hasInterval(Reg)) 45 continue; 46 const auto &LI = LIS.getInterval(Reg); 47 if (LI.hasSubRanges()) { 48 bool firstTime = true; 49 for (const auto &S : LI.subranges()) { 50 if (!S.liveAt(SI)) continue; 51 if (firstTime) { 52 dbgs() << " " << printReg(Reg, MRI.getTargetRegisterInfo()) 53 << '\n'; 54 firstTime = false; 55 } 56 dbgs() << " " << S << '\n'; 57 ++Num; 58 } 59 } else if (LI.liveAt(SI)) { 60 dbgs() << " " << LI << '\n'; 61 ++Num; 62 } 63 } 64 if (!Num) dbgs() << " <none>\n"; 65 } 66 #endif 67 68 bool llvm::isEqual(const GCNRPTracker::LiveRegSet &S1, 69 const GCNRPTracker::LiveRegSet &S2) { 70 if (S1.size() != S2.size()) 71 return false; 72 73 for (const auto &P : S1) { 74 auto I = S2.find(P.first); 75 if (I == S2.end() || I->second != P.second) 76 return false; 77 } 78 return true; 79 } 80 81 82 /////////////////////////////////////////////////////////////////////////////// 83 // GCNRegPressure 84 85 unsigned GCNRegPressure::getRegKind(unsigned Reg, 86 const MachineRegisterInfo &MRI) { 87 assert(Register::isVirtualRegister(Reg)); 88 const auto RC = MRI.getRegClass(Reg); 89 auto STI = static_cast<const SIRegisterInfo*>(MRI.getTargetRegisterInfo()); 90 return STI->isSGPRClass(RC) ? 91 (STI->getRegSizeInBits(*RC) == 32 ? SGPR32 : SGPR_TUPLE) : 92 STI->hasAGPRs(RC) ? 93 (STI->getRegSizeInBits(*RC) == 32 ? AGPR32 : AGPR_TUPLE) : 94 (STI->getRegSizeInBits(*RC) == 32 ? VGPR32 : VGPR_TUPLE); 95 } 96 97 void GCNRegPressure::inc(unsigned Reg, 98 LaneBitmask PrevMask, 99 LaneBitmask NewMask, 100 const MachineRegisterInfo &MRI) { 101 if (NewMask == PrevMask) 102 return; 103 104 int Sign = 1; 105 if (NewMask < PrevMask) { 106 std::swap(NewMask, PrevMask); 107 Sign = -1; 108 } 109 #ifndef NDEBUG 110 const auto MaxMask = MRI.getMaxLaneMaskForVReg(Reg); 111 #endif 112 switch (auto Kind = getRegKind(Reg, MRI)) { 113 case SGPR32: 114 case VGPR32: 115 case AGPR32: 116 assert(PrevMask.none() && NewMask == MaxMask); 117 Value[Kind] += Sign; 118 break; 119 120 case SGPR_TUPLE: 121 case VGPR_TUPLE: 122 case AGPR_TUPLE: 123 assert(NewMask < MaxMask || NewMask == MaxMask); 124 assert(PrevMask < NewMask); 125 126 Value[Kind == SGPR_TUPLE ? SGPR32 : Kind == AGPR_TUPLE ? AGPR32 : VGPR32] += 127 Sign * (~PrevMask & NewMask).getNumLanes(); 128 129 if (PrevMask.none()) { 130 assert(NewMask.any()); 131 Value[Kind] += Sign * MRI.getPressureSets(Reg).getWeight(); 132 } 133 break; 134 135 default: llvm_unreachable("Unknown register kind"); 136 } 137 } 138 139 bool GCNRegPressure::less(const GCNSubtarget &ST, 140 const GCNRegPressure& O, 141 unsigned MaxOccupancy) const { 142 const auto SGPROcc = std::min(MaxOccupancy, 143 ST.getOccupancyWithNumSGPRs(getSGPRNum())); 144 const auto VGPROcc = std::min(MaxOccupancy, 145 ST.getOccupancyWithNumVGPRs(getVGPRNum())); 146 const auto OtherSGPROcc = std::min(MaxOccupancy, 147 ST.getOccupancyWithNumSGPRs(O.getSGPRNum())); 148 const auto OtherVGPROcc = std::min(MaxOccupancy, 149 ST.getOccupancyWithNumVGPRs(O.getVGPRNum())); 150 151 const auto Occ = std::min(SGPROcc, VGPROcc); 152 const auto OtherOcc = std::min(OtherSGPROcc, OtherVGPROcc); 153 if (Occ != OtherOcc) 154 return Occ > OtherOcc; 155 156 bool SGPRImportant = SGPROcc < VGPROcc; 157 const bool OtherSGPRImportant = OtherSGPROcc < OtherVGPROcc; 158 159 // if both pressures disagree on what is more important compare vgprs 160 if (SGPRImportant != OtherSGPRImportant) { 161 SGPRImportant = false; 162 } 163 164 // compare large regs pressure 165 bool SGPRFirst = SGPRImportant; 166 for (int I = 2; I > 0; --I, SGPRFirst = !SGPRFirst) { 167 if (SGPRFirst) { 168 auto SW = getSGPRTuplesWeight(); 169 auto OtherSW = O.getSGPRTuplesWeight(); 170 if (SW != OtherSW) 171 return SW < OtherSW; 172 } else { 173 auto VW = getVGPRTuplesWeight(); 174 auto OtherVW = O.getVGPRTuplesWeight(); 175 if (VW != OtherVW) 176 return VW < OtherVW; 177 } 178 } 179 return SGPRImportant ? (getSGPRNum() < O.getSGPRNum()): 180 (getVGPRNum() < O.getVGPRNum()); 181 } 182 183 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) 184 LLVM_DUMP_METHOD 185 void GCNRegPressure::print(raw_ostream &OS, const GCNSubtarget *ST) const { 186 OS << "VGPRs: " << Value[VGPR32] << ' '; 187 OS << "AGPRs: " << Value[AGPR32]; 188 if (ST) OS << "(O" << ST->getOccupancyWithNumVGPRs(getVGPRNum()) << ')'; 189 OS << ", SGPRs: " << getSGPRNum(); 190 if (ST) OS << "(O" << ST->getOccupancyWithNumSGPRs(getSGPRNum()) << ')'; 191 OS << ", LVGPR WT: " << getVGPRTuplesWeight() 192 << ", LSGPR WT: " << getSGPRTuplesWeight(); 193 if (ST) OS << " -> Occ: " << getOccupancy(*ST); 194 OS << '\n'; 195 } 196 #endif 197 198 static LaneBitmask getDefRegMask(const MachineOperand &MO, 199 const MachineRegisterInfo &MRI) { 200 assert(MO.isDef() && MO.isReg() && Register::isVirtualRegister(MO.getReg())); 201 202 // We don't rely on read-undef flag because in case of tentative schedule 203 // tracking it isn't set correctly yet. This works correctly however since 204 // use mask has been tracked before using LIS. 205 return MO.getSubReg() == 0 ? 206 MRI.getMaxLaneMaskForVReg(MO.getReg()) : 207 MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(MO.getSubReg()); 208 } 209 210 static LaneBitmask getUsedRegMask(const MachineOperand &MO, 211 const MachineRegisterInfo &MRI, 212 const LiveIntervals &LIS) { 213 assert(MO.isUse() && MO.isReg() && Register::isVirtualRegister(MO.getReg())); 214 215 if (auto SubReg = MO.getSubReg()) 216 return MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(SubReg); 217 218 auto MaxMask = MRI.getMaxLaneMaskForVReg(MO.getReg()); 219 if (MaxMask == LaneBitmask::getLane(0)) // cannot have subregs 220 return MaxMask; 221 222 // For a tentative schedule LIS isn't updated yet but livemask should remain 223 // the same on any schedule. Subreg defs can be reordered but they all must 224 // dominate uses anyway. 225 auto SI = LIS.getInstructionIndex(*MO.getParent()).getBaseIndex(); 226 return getLiveLaneMask(MO.getReg(), SI, LIS, MRI); 227 } 228 229 static SmallVector<RegisterMaskPair, 8> 230 collectVirtualRegUses(const MachineInstr &MI, const LiveIntervals &LIS, 231 const MachineRegisterInfo &MRI) { 232 SmallVector<RegisterMaskPair, 8> Res; 233 for (const auto &MO : MI.operands()) { 234 if (!MO.isReg() || !Register::isVirtualRegister(MO.getReg())) 235 continue; 236 if (!MO.isUse() || !MO.readsReg()) 237 continue; 238 239 auto const UsedMask = getUsedRegMask(MO, MRI, LIS); 240 241 auto Reg = MO.getReg(); 242 auto I = std::find_if(Res.begin(), Res.end(), [Reg](const RegisterMaskPair &RM) { 243 return RM.RegUnit == Reg; 244 }); 245 if (I != Res.end()) 246 I->LaneMask |= UsedMask; 247 else 248 Res.push_back(RegisterMaskPair(Reg, UsedMask)); 249 } 250 return Res; 251 } 252 253 /////////////////////////////////////////////////////////////////////////////// 254 // GCNRPTracker 255 256 LaneBitmask llvm::getLiveLaneMask(unsigned Reg, 257 SlotIndex SI, 258 const LiveIntervals &LIS, 259 const MachineRegisterInfo &MRI) { 260 LaneBitmask LiveMask; 261 const auto &LI = LIS.getInterval(Reg); 262 if (LI.hasSubRanges()) { 263 for (const auto &S : LI.subranges()) 264 if (S.liveAt(SI)) { 265 LiveMask |= S.LaneMask; 266 assert(LiveMask < MRI.getMaxLaneMaskForVReg(Reg) || 267 LiveMask == MRI.getMaxLaneMaskForVReg(Reg)); 268 } 269 } else if (LI.liveAt(SI)) { 270 LiveMask = MRI.getMaxLaneMaskForVReg(Reg); 271 } 272 return LiveMask; 273 } 274 275 GCNRPTracker::LiveRegSet llvm::getLiveRegs(SlotIndex SI, 276 const LiveIntervals &LIS, 277 const MachineRegisterInfo &MRI) { 278 GCNRPTracker::LiveRegSet LiveRegs; 279 for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) { 280 auto Reg = Register::index2VirtReg(I); 281 if (!LIS.hasInterval(Reg)) 282 continue; 283 auto LiveMask = getLiveLaneMask(Reg, SI, LIS, MRI); 284 if (LiveMask.any()) 285 LiveRegs[Reg] = LiveMask; 286 } 287 return LiveRegs; 288 } 289 290 void GCNRPTracker::reset(const MachineInstr &MI, 291 const LiveRegSet *LiveRegsCopy, 292 bool After) { 293 const MachineFunction &MF = *MI.getMF(); 294 MRI = &MF.getRegInfo(); 295 if (LiveRegsCopy) { 296 if (&LiveRegs != LiveRegsCopy) 297 LiveRegs = *LiveRegsCopy; 298 } else { 299 LiveRegs = After ? getLiveRegsAfter(MI, LIS) 300 : getLiveRegsBefore(MI, LIS); 301 } 302 303 MaxPressure = CurPressure = getRegPressure(*MRI, LiveRegs); 304 } 305 306 void GCNUpwardRPTracker::reset(const MachineInstr &MI, 307 const LiveRegSet *LiveRegsCopy) { 308 GCNRPTracker::reset(MI, LiveRegsCopy, true); 309 } 310 311 void GCNUpwardRPTracker::recede(const MachineInstr &MI) { 312 assert(MRI && "call reset first"); 313 314 LastTrackedMI = &MI; 315 316 if (MI.isDebugInstr()) 317 return; 318 319 auto const RegUses = collectVirtualRegUses(MI, LIS, *MRI); 320 321 // calc pressure at the MI (defs + uses) 322 auto AtMIPressure = CurPressure; 323 for (const auto &U : RegUses) { 324 auto LiveMask = LiveRegs[U.RegUnit]; 325 AtMIPressure.inc(U.RegUnit, LiveMask, LiveMask | U.LaneMask, *MRI); 326 } 327 // update max pressure 328 MaxPressure = max(AtMIPressure, MaxPressure); 329 330 for (const auto &MO : MI.operands()) { 331 if (!MO.isReg() || !MO.isDef() || 332 !Register::isVirtualRegister(MO.getReg()) || MO.isDead()) 333 continue; 334 335 auto Reg = MO.getReg(); 336 auto I = LiveRegs.find(Reg); 337 if (I == LiveRegs.end()) 338 continue; 339 auto &LiveMask = I->second; 340 auto PrevMask = LiveMask; 341 LiveMask &= ~getDefRegMask(MO, *MRI); 342 CurPressure.inc(Reg, PrevMask, LiveMask, *MRI); 343 if (LiveMask.none()) 344 LiveRegs.erase(I); 345 } 346 for (const auto &U : RegUses) { 347 auto &LiveMask = LiveRegs[U.RegUnit]; 348 auto PrevMask = LiveMask; 349 LiveMask |= U.LaneMask; 350 CurPressure.inc(U.RegUnit, PrevMask, LiveMask, *MRI); 351 } 352 assert(CurPressure == getRegPressure(*MRI, LiveRegs)); 353 } 354 355 bool GCNDownwardRPTracker::reset(const MachineInstr &MI, 356 const LiveRegSet *LiveRegsCopy) { 357 MRI = &MI.getParent()->getParent()->getRegInfo(); 358 LastTrackedMI = nullptr; 359 MBBEnd = MI.getParent()->end(); 360 NextMI = &MI; 361 NextMI = skipDebugInstructionsForward(NextMI, MBBEnd); 362 if (NextMI == MBBEnd) 363 return false; 364 GCNRPTracker::reset(*NextMI, LiveRegsCopy, false); 365 return true; 366 } 367 368 bool GCNDownwardRPTracker::advanceBeforeNext() { 369 assert(MRI && "call reset first"); 370 371 NextMI = skipDebugInstructionsForward(NextMI, MBBEnd); 372 if (NextMI == MBBEnd) 373 return false; 374 375 SlotIndex SI = LIS.getInstructionIndex(*NextMI).getBaseIndex(); 376 assert(SI.isValid()); 377 378 // Remove dead registers or mask bits. 379 for (auto &It : LiveRegs) { 380 const LiveInterval &LI = LIS.getInterval(It.first); 381 if (LI.hasSubRanges()) { 382 for (const auto &S : LI.subranges()) { 383 if (!S.liveAt(SI)) { 384 auto PrevMask = It.second; 385 It.second &= ~S.LaneMask; 386 CurPressure.inc(It.first, PrevMask, It.second, *MRI); 387 } 388 } 389 } else if (!LI.liveAt(SI)) { 390 auto PrevMask = It.second; 391 It.second = LaneBitmask::getNone(); 392 CurPressure.inc(It.first, PrevMask, It.second, *MRI); 393 } 394 if (It.second.none()) 395 LiveRegs.erase(It.first); 396 } 397 398 MaxPressure = max(MaxPressure, CurPressure); 399 400 return true; 401 } 402 403 void GCNDownwardRPTracker::advanceToNext() { 404 LastTrackedMI = &*NextMI++; 405 406 // Add new registers or mask bits. 407 for (const auto &MO : LastTrackedMI->operands()) { 408 if (!MO.isReg() || !MO.isDef()) 409 continue; 410 Register Reg = MO.getReg(); 411 if (!Register::isVirtualRegister(Reg)) 412 continue; 413 auto &LiveMask = LiveRegs[Reg]; 414 auto PrevMask = LiveMask; 415 LiveMask |= getDefRegMask(MO, *MRI); 416 CurPressure.inc(Reg, PrevMask, LiveMask, *MRI); 417 } 418 419 MaxPressure = max(MaxPressure, CurPressure); 420 } 421 422 bool GCNDownwardRPTracker::advance() { 423 // If we have just called reset live set is actual. 424 if ((NextMI == MBBEnd) || (LastTrackedMI && !advanceBeforeNext())) 425 return false; 426 advanceToNext(); 427 return true; 428 } 429 430 bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator End) { 431 while (NextMI != End) 432 if (!advance()) return false; 433 return true; 434 } 435 436 bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator Begin, 437 MachineBasicBlock::const_iterator End, 438 const LiveRegSet *LiveRegsCopy) { 439 reset(*Begin, LiveRegsCopy); 440 return advance(End); 441 } 442 443 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) 444 LLVM_DUMP_METHOD 445 static void reportMismatch(const GCNRPTracker::LiveRegSet &LISLR, 446 const GCNRPTracker::LiveRegSet &TrackedLR, 447 const TargetRegisterInfo *TRI) { 448 for (auto const &P : TrackedLR) { 449 auto I = LISLR.find(P.first); 450 if (I == LISLR.end()) { 451 dbgs() << " " << printReg(P.first, TRI) 452 << ":L" << PrintLaneMask(P.second) 453 << " isn't found in LIS reported set\n"; 454 } 455 else if (I->second != P.second) { 456 dbgs() << " " << printReg(P.first, TRI) 457 << " masks doesn't match: LIS reported " 458 << PrintLaneMask(I->second) 459 << ", tracked " 460 << PrintLaneMask(P.second) 461 << '\n'; 462 } 463 } 464 for (auto const &P : LISLR) { 465 auto I = TrackedLR.find(P.first); 466 if (I == TrackedLR.end()) { 467 dbgs() << " " << printReg(P.first, TRI) 468 << ":L" << PrintLaneMask(P.second) 469 << " isn't found in tracked set\n"; 470 } 471 } 472 } 473 474 bool GCNUpwardRPTracker::isValid() const { 475 const auto &SI = LIS.getInstructionIndex(*LastTrackedMI).getBaseIndex(); 476 const auto LISLR = llvm::getLiveRegs(SI, LIS, *MRI); 477 const auto &TrackedLR = LiveRegs; 478 479 if (!isEqual(LISLR, TrackedLR)) { 480 dbgs() << "\nGCNUpwardRPTracker error: Tracked and" 481 " LIS reported livesets mismatch:\n"; 482 printLivesAt(SI, LIS, *MRI); 483 reportMismatch(LISLR, TrackedLR, MRI->getTargetRegisterInfo()); 484 return false; 485 } 486 487 auto LISPressure = getRegPressure(*MRI, LISLR); 488 if (LISPressure != CurPressure) { 489 dbgs() << "GCNUpwardRPTracker error: Pressure sets different\nTracked: "; 490 CurPressure.print(dbgs()); 491 dbgs() << "LIS rpt: "; 492 LISPressure.print(dbgs()); 493 return false; 494 } 495 return true; 496 } 497 498 void GCNRPTracker::printLiveRegs(raw_ostream &OS, const LiveRegSet& LiveRegs, 499 const MachineRegisterInfo &MRI) { 500 const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo(); 501 for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) { 502 unsigned Reg = Register::index2VirtReg(I); 503 auto It = LiveRegs.find(Reg); 504 if (It != LiveRegs.end() && It->second.any()) 505 OS << ' ' << printVRegOrUnit(Reg, TRI) << ':' 506 << PrintLaneMask(It->second); 507 } 508 OS << '\n'; 509 } 510 #endif 511