1 //===- GCNRegPressure.cpp -------------------------------------------------===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 10 #include "GCNRegPressure.h" 11 #include "AMDGPUSubtarget.h" 12 #include "SIRegisterInfo.h" 13 #include "llvm/ADT/SmallVector.h" 14 #include "llvm/CodeGen/LiveInterval.h" 15 #include "llvm/CodeGen/LiveIntervals.h" 16 #include "llvm/CodeGen/MachineInstr.h" 17 #include "llvm/CodeGen/MachineOperand.h" 18 #include "llvm/CodeGen/MachineRegisterInfo.h" 19 #include "llvm/CodeGen/RegisterPressure.h" 20 #include "llvm/CodeGen/SlotIndexes.h" 21 #include "llvm/CodeGen/TargetRegisterInfo.h" 22 #include "llvm/Config/llvm-config.h" 23 #include "llvm/MC/LaneBitmask.h" 24 #include "llvm/Support/Compiler.h" 25 #include "llvm/Support/Debug.h" 26 #include "llvm/Support/ErrorHandling.h" 27 #include "llvm/Support/raw_ostream.h" 28 #include <algorithm> 29 #include <cassert> 30 31 using namespace llvm; 32 33 #define DEBUG_TYPE "machine-scheduler" 34 35 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) 36 LLVM_DUMP_METHOD 37 void llvm::printLivesAt(SlotIndex SI, 38 const LiveIntervals &LIS, 39 const MachineRegisterInfo &MRI) { 40 dbgs() << "Live regs at " << SI << ": " 41 << *LIS.getInstructionFromIndex(SI); 42 unsigned Num = 0; 43 for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) { 44 const unsigned Reg = TargetRegisterInfo::index2VirtReg(I); 45 if (!LIS.hasInterval(Reg)) 46 continue; 47 const auto &LI = LIS.getInterval(Reg); 48 if (LI.hasSubRanges()) { 49 bool firstTime = true; 50 for (const auto &S : LI.subranges()) { 51 if (!S.liveAt(SI)) continue; 52 if (firstTime) { 53 dbgs() << " " << printReg(Reg, MRI.getTargetRegisterInfo()) 54 << '\n'; 55 firstTime = false; 56 } 57 dbgs() << " " << S << '\n'; 58 ++Num; 59 } 60 } else if (LI.liveAt(SI)) { 61 dbgs() << " " << LI << '\n'; 62 ++Num; 63 } 64 } 65 if (!Num) dbgs() << " <none>\n"; 66 } 67 68 static bool isEqual(const GCNRPTracker::LiveRegSet &S1, 69 const GCNRPTracker::LiveRegSet &S2) { 70 if (S1.size() != S2.size()) 71 return false; 72 73 for (const auto &P : S1) { 74 auto I = S2.find(P.first); 75 if (I == S2.end() || I->second != P.second) 76 return false; 77 } 78 return true; 79 } 80 #endif 81 82 /////////////////////////////////////////////////////////////////////////////// 83 // GCNRegPressure 84 85 unsigned GCNRegPressure::getRegKind(unsigned Reg, 86 const MachineRegisterInfo &MRI) { 87 assert(TargetRegisterInfo::isVirtualRegister(Reg)); 88 const auto RC = MRI.getRegClass(Reg); 89 auto STI = static_cast<const SIRegisterInfo*>(MRI.getTargetRegisterInfo()); 90 return STI->isSGPRClass(RC) ? 91 (STI->getRegSizeInBits(*RC) == 32 ? SGPR32 : SGPR_TUPLE) : 92 (STI->getRegSizeInBits(*RC) == 32 ? VGPR32 : VGPR_TUPLE); 93 } 94 95 void GCNRegPressure::inc(unsigned Reg, 96 LaneBitmask PrevMask, 97 LaneBitmask NewMask, 98 const MachineRegisterInfo &MRI) { 99 if (NewMask == PrevMask) 100 return; 101 102 int Sign = 1; 103 if (NewMask < PrevMask) { 104 std::swap(NewMask, PrevMask); 105 Sign = -1; 106 } 107 #ifndef NDEBUG 108 const auto MaxMask = MRI.getMaxLaneMaskForVReg(Reg); 109 #endif 110 switch (auto Kind = getRegKind(Reg, MRI)) { 111 case SGPR32: 112 case VGPR32: 113 assert(PrevMask.none() && NewMask == MaxMask); 114 Value[Kind] += Sign; 115 break; 116 117 case SGPR_TUPLE: 118 case VGPR_TUPLE: 119 assert(NewMask < MaxMask || NewMask == MaxMask); 120 assert(PrevMask < NewMask); 121 122 Value[Kind == SGPR_TUPLE ? SGPR32 : VGPR32] += 123 Sign * (~PrevMask & NewMask).getNumLanes(); 124 125 if (PrevMask.none()) { 126 assert(NewMask.any()); 127 Value[Kind] += Sign * MRI.getPressureSets(Reg).getWeight(); 128 } 129 break; 130 131 default: llvm_unreachable("Unknown register kind"); 132 } 133 } 134 135 bool GCNRegPressure::less(const SISubtarget &ST, 136 const GCNRegPressure& O, 137 unsigned MaxOccupancy) const { 138 const auto SGPROcc = std::min(MaxOccupancy, 139 ST.getOccupancyWithNumSGPRs(getSGPRNum())); 140 const auto VGPROcc = std::min(MaxOccupancy, 141 ST.getOccupancyWithNumVGPRs(getVGPRNum())); 142 const auto OtherSGPROcc = std::min(MaxOccupancy, 143 ST.getOccupancyWithNumSGPRs(O.getSGPRNum())); 144 const auto OtherVGPROcc = std::min(MaxOccupancy, 145 ST.getOccupancyWithNumVGPRs(O.getVGPRNum())); 146 147 const auto Occ = std::min(SGPROcc, VGPROcc); 148 const auto OtherOcc = std::min(OtherSGPROcc, OtherVGPROcc); 149 if (Occ != OtherOcc) 150 return Occ > OtherOcc; 151 152 bool SGPRImportant = SGPROcc < VGPROcc; 153 const bool OtherSGPRImportant = OtherSGPROcc < OtherVGPROcc; 154 155 // if both pressures disagree on what is more important compare vgprs 156 if (SGPRImportant != OtherSGPRImportant) { 157 SGPRImportant = false; 158 } 159 160 // compare large regs pressure 161 bool SGPRFirst = SGPRImportant; 162 for (int I = 2; I > 0; --I, SGPRFirst = !SGPRFirst) { 163 if (SGPRFirst) { 164 auto SW = getSGPRTuplesWeight(); 165 auto OtherSW = O.getSGPRTuplesWeight(); 166 if (SW != OtherSW) 167 return SW < OtherSW; 168 } else { 169 auto VW = getVGPRTuplesWeight(); 170 auto OtherVW = O.getVGPRTuplesWeight(); 171 if (VW != OtherVW) 172 return VW < OtherVW; 173 } 174 } 175 return SGPRImportant ? (getSGPRNum() < O.getSGPRNum()): 176 (getVGPRNum() < O.getVGPRNum()); 177 } 178 179 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) 180 LLVM_DUMP_METHOD 181 void GCNRegPressure::print(raw_ostream &OS, const SISubtarget *ST) const { 182 OS << "VGPRs: " << getVGPRNum(); 183 if (ST) OS << "(O" << ST->getOccupancyWithNumVGPRs(getVGPRNum()) << ')'; 184 OS << ", SGPRs: " << getSGPRNum(); 185 if (ST) OS << "(O" << ST->getOccupancyWithNumSGPRs(getSGPRNum()) << ')'; 186 OS << ", LVGPR WT: " << getVGPRTuplesWeight() 187 << ", LSGPR WT: " << getSGPRTuplesWeight(); 188 if (ST) OS << " -> Occ: " << getOccupancy(*ST); 189 OS << '\n'; 190 } 191 #endif 192 193 static LaneBitmask getDefRegMask(const MachineOperand &MO, 194 const MachineRegisterInfo &MRI) { 195 assert(MO.isDef() && MO.isReg() && 196 TargetRegisterInfo::isVirtualRegister(MO.getReg())); 197 198 // We don't rely on read-undef flag because in case of tentative schedule 199 // tracking it isn't set correctly yet. This works correctly however since 200 // use mask has been tracked before using LIS. 201 return MO.getSubReg() == 0 ? 202 MRI.getMaxLaneMaskForVReg(MO.getReg()) : 203 MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(MO.getSubReg()); 204 } 205 206 static LaneBitmask getUsedRegMask(const MachineOperand &MO, 207 const MachineRegisterInfo &MRI, 208 const LiveIntervals &LIS) { 209 assert(MO.isUse() && MO.isReg() && 210 TargetRegisterInfo::isVirtualRegister(MO.getReg())); 211 212 if (auto SubReg = MO.getSubReg()) 213 return MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(SubReg); 214 215 auto MaxMask = MRI.getMaxLaneMaskForVReg(MO.getReg()); 216 if (MaxMask == LaneBitmask::getLane(0)) // cannot have subregs 217 return MaxMask; 218 219 // For a tentative schedule LIS isn't updated yet but livemask should remain 220 // the same on any schedule. Subreg defs can be reordered but they all must 221 // dominate uses anyway. 222 auto SI = LIS.getInstructionIndex(*MO.getParent()).getBaseIndex(); 223 return getLiveLaneMask(MO.getReg(), SI, LIS, MRI); 224 } 225 226 static SmallVector<RegisterMaskPair, 8> 227 collectVirtualRegUses(const MachineInstr &MI, const LiveIntervals &LIS, 228 const MachineRegisterInfo &MRI) { 229 SmallVector<RegisterMaskPair, 8> Res; 230 for (const auto &MO : MI.operands()) { 231 if (!MO.isReg() || !TargetRegisterInfo::isVirtualRegister(MO.getReg())) 232 continue; 233 if (!MO.isUse() || !MO.readsReg()) 234 continue; 235 236 auto const UsedMask = getUsedRegMask(MO, MRI, LIS); 237 238 auto Reg = MO.getReg(); 239 auto I = std::find_if(Res.begin(), Res.end(), [Reg](const RegisterMaskPair &RM) { 240 return RM.RegUnit == Reg; 241 }); 242 if (I != Res.end()) 243 I->LaneMask |= UsedMask; 244 else 245 Res.push_back(RegisterMaskPair(Reg, UsedMask)); 246 } 247 return Res; 248 } 249 250 /////////////////////////////////////////////////////////////////////////////// 251 // GCNRPTracker 252 253 LaneBitmask llvm::getLiveLaneMask(unsigned Reg, 254 SlotIndex SI, 255 const LiveIntervals &LIS, 256 const MachineRegisterInfo &MRI) { 257 LaneBitmask LiveMask; 258 const auto &LI = LIS.getInterval(Reg); 259 if (LI.hasSubRanges()) { 260 for (const auto &S : LI.subranges()) 261 if (S.liveAt(SI)) { 262 LiveMask |= S.LaneMask; 263 assert(LiveMask < MRI.getMaxLaneMaskForVReg(Reg) || 264 LiveMask == MRI.getMaxLaneMaskForVReg(Reg)); 265 } 266 } else if (LI.liveAt(SI)) { 267 LiveMask = MRI.getMaxLaneMaskForVReg(Reg); 268 } 269 return LiveMask; 270 } 271 272 GCNRPTracker::LiveRegSet llvm::getLiveRegs(SlotIndex SI, 273 const LiveIntervals &LIS, 274 const MachineRegisterInfo &MRI) { 275 GCNRPTracker::LiveRegSet LiveRegs; 276 for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) { 277 auto Reg = TargetRegisterInfo::index2VirtReg(I); 278 if (!LIS.hasInterval(Reg)) 279 continue; 280 auto LiveMask = getLiveLaneMask(Reg, SI, LIS, MRI); 281 if (LiveMask.any()) 282 LiveRegs[Reg] = LiveMask; 283 } 284 return LiveRegs; 285 } 286 287 void GCNUpwardRPTracker::reset(const MachineInstr &MI, 288 const LiveRegSet *LiveRegsCopy) { 289 MRI = &MI.getParent()->getParent()->getRegInfo(); 290 if (LiveRegsCopy) { 291 if (&LiveRegs != LiveRegsCopy) 292 LiveRegs = *LiveRegsCopy; 293 } else { 294 LiveRegs = getLiveRegsAfter(MI, LIS); 295 } 296 MaxPressure = CurPressure = getRegPressure(*MRI, LiveRegs); 297 } 298 299 void GCNUpwardRPTracker::recede(const MachineInstr &MI) { 300 assert(MRI && "call reset first"); 301 302 LastTrackedMI = &MI; 303 304 if (MI.isDebugInstr()) 305 return; 306 307 auto const RegUses = collectVirtualRegUses(MI, LIS, *MRI); 308 309 // calc pressure at the MI (defs + uses) 310 auto AtMIPressure = CurPressure; 311 for (const auto &U : RegUses) { 312 auto LiveMask = LiveRegs[U.RegUnit]; 313 AtMIPressure.inc(U.RegUnit, LiveMask, LiveMask | U.LaneMask, *MRI); 314 } 315 // update max pressure 316 MaxPressure = max(AtMIPressure, MaxPressure); 317 318 for (const auto &MO : MI.defs()) { 319 if (!MO.isReg() || !TargetRegisterInfo::isVirtualRegister(MO.getReg()) || 320 MO.isDead()) 321 continue; 322 323 auto Reg = MO.getReg(); 324 auto I = LiveRegs.find(Reg); 325 if (I == LiveRegs.end()) 326 continue; 327 auto &LiveMask = I->second; 328 auto PrevMask = LiveMask; 329 LiveMask &= ~getDefRegMask(MO, *MRI); 330 CurPressure.inc(Reg, PrevMask, LiveMask, *MRI); 331 if (LiveMask.none()) 332 LiveRegs.erase(I); 333 } 334 for (const auto &U : RegUses) { 335 auto &LiveMask = LiveRegs[U.RegUnit]; 336 auto PrevMask = LiveMask; 337 LiveMask |= U.LaneMask; 338 CurPressure.inc(U.RegUnit, PrevMask, LiveMask, *MRI); 339 } 340 assert(CurPressure == getRegPressure(*MRI, LiveRegs)); 341 } 342 343 bool GCNDownwardRPTracker::reset(const MachineInstr &MI, 344 const LiveRegSet *LiveRegsCopy) { 345 MRI = &MI.getParent()->getParent()->getRegInfo(); 346 LastTrackedMI = nullptr; 347 MBBEnd = MI.getParent()->end(); 348 NextMI = &MI; 349 NextMI = skipDebugInstructionsForward(NextMI, MBBEnd); 350 if (NextMI == MBBEnd) 351 return false; 352 if (LiveRegsCopy) { 353 if (&LiveRegs != LiveRegsCopy) 354 LiveRegs = *LiveRegsCopy; 355 } else { 356 LiveRegs = getLiveRegsBefore(*NextMI, LIS); 357 } 358 MaxPressure = CurPressure = getRegPressure(*MRI, LiveRegs); 359 return true; 360 } 361 362 bool GCNDownwardRPTracker::advanceBeforeNext() { 363 assert(MRI && "call reset first"); 364 365 NextMI = skipDebugInstructionsForward(NextMI, MBBEnd); 366 if (NextMI == MBBEnd) 367 return false; 368 369 SlotIndex SI = LIS.getInstructionIndex(*NextMI).getBaseIndex(); 370 assert(SI.isValid()); 371 372 // Remove dead registers or mask bits. 373 for (auto &It : LiveRegs) { 374 const LiveInterval &LI = LIS.getInterval(It.first); 375 if (LI.hasSubRanges()) { 376 for (const auto &S : LI.subranges()) { 377 if (!S.liveAt(SI)) { 378 auto PrevMask = It.second; 379 It.second &= ~S.LaneMask; 380 CurPressure.inc(It.first, PrevMask, It.second, *MRI); 381 } 382 } 383 } else if (!LI.liveAt(SI)) { 384 auto PrevMask = It.second; 385 It.second = LaneBitmask::getNone(); 386 CurPressure.inc(It.first, PrevMask, It.second, *MRI); 387 } 388 if (It.second.none()) 389 LiveRegs.erase(It.first); 390 } 391 392 MaxPressure = max(MaxPressure, CurPressure); 393 394 return true; 395 } 396 397 void GCNDownwardRPTracker::advanceToNext() { 398 LastTrackedMI = &*NextMI++; 399 400 // Add new registers or mask bits. 401 for (const auto &MO : LastTrackedMI->defs()) { 402 if (!MO.isReg()) 403 continue; 404 unsigned Reg = MO.getReg(); 405 if (!TargetRegisterInfo::isVirtualRegister(Reg)) 406 continue; 407 auto &LiveMask = LiveRegs[Reg]; 408 auto PrevMask = LiveMask; 409 LiveMask |= getDefRegMask(MO, *MRI); 410 CurPressure.inc(Reg, PrevMask, LiveMask, *MRI); 411 } 412 413 MaxPressure = max(MaxPressure, CurPressure); 414 } 415 416 bool GCNDownwardRPTracker::advance() { 417 // If we have just called reset live set is actual. 418 if ((NextMI == MBBEnd) || (LastTrackedMI && !advanceBeforeNext())) 419 return false; 420 advanceToNext(); 421 return true; 422 } 423 424 bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator End) { 425 while (NextMI != End) 426 if (!advance()) return false; 427 return true; 428 } 429 430 bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator Begin, 431 MachineBasicBlock::const_iterator End, 432 const LiveRegSet *LiveRegsCopy) { 433 reset(*Begin, LiveRegsCopy); 434 return advance(End); 435 } 436 437 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) 438 LLVM_DUMP_METHOD 439 static void reportMismatch(const GCNRPTracker::LiveRegSet &LISLR, 440 const GCNRPTracker::LiveRegSet &TrackedLR, 441 const TargetRegisterInfo *TRI) { 442 for (auto const &P : TrackedLR) { 443 auto I = LISLR.find(P.first); 444 if (I == LISLR.end()) { 445 dbgs() << " " << printReg(P.first, TRI) 446 << ":L" << PrintLaneMask(P.second) 447 << " isn't found in LIS reported set\n"; 448 } 449 else if (I->second != P.second) { 450 dbgs() << " " << printReg(P.first, TRI) 451 << " masks doesn't match: LIS reported " 452 << PrintLaneMask(I->second) 453 << ", tracked " 454 << PrintLaneMask(P.second) 455 << '\n'; 456 } 457 } 458 for (auto const &P : LISLR) { 459 auto I = TrackedLR.find(P.first); 460 if (I == TrackedLR.end()) { 461 dbgs() << " " << printReg(P.first, TRI) 462 << ":L" << PrintLaneMask(P.second) 463 << " isn't found in tracked set\n"; 464 } 465 } 466 } 467 468 bool GCNUpwardRPTracker::isValid() const { 469 const auto &SI = LIS.getInstructionIndex(*LastTrackedMI).getBaseIndex(); 470 const auto LISLR = llvm::getLiveRegs(SI, LIS, *MRI); 471 const auto &TrackedLR = LiveRegs; 472 473 if (!isEqual(LISLR, TrackedLR)) { 474 dbgs() << "\nGCNUpwardRPTracker error: Tracked and" 475 " LIS reported livesets mismatch:\n"; 476 printLivesAt(SI, LIS, *MRI); 477 reportMismatch(LISLR, TrackedLR, MRI->getTargetRegisterInfo()); 478 return false; 479 } 480 481 auto LISPressure = getRegPressure(*MRI, LISLR); 482 if (LISPressure != CurPressure) { 483 dbgs() << "GCNUpwardRPTracker error: Pressure sets different\nTracked: "; 484 CurPressure.print(dbgs()); 485 dbgs() << "LIS rpt: "; 486 LISPressure.print(dbgs()); 487 return false; 488 } 489 return true; 490 } 491 492 void GCNRPTracker::printLiveRegs(raw_ostream &OS, const LiveRegSet& LiveRegs, 493 const MachineRegisterInfo &MRI) { 494 const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo(); 495 for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) { 496 unsigned Reg = TargetRegisterInfo::index2VirtReg(I); 497 auto It = LiveRegs.find(Reg); 498 if (It != LiveRegs.end() && It->second.any()) 499 OS << ' ' << printVRegOrUnit(Reg, TRI) << ':' 500 << PrintLaneMask(It->second); 501 } 502 OS << '\n'; 503 } 504 #endif 505