1 //===-- X86PreTileConfig.cpp - Tile Register Pre-configure-----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// \file Pass to pre-config the shapes of AMX registers 10 /// AMX register needs to be configured before use. The shapes of AMX register 11 /// are encoded in the 1st and 2nd machine operand of AMX pseudo instructions. 12 /// 13 /// The instruction ldtilecfg is used to config the shapes. It must be reachable 14 /// for all variable shapes. ldtilecfg will be inserted more than once if we 15 /// cannot find a dominating point for all AMX instructions. 16 /// 17 /// The configure register is caller saved according to ABI. We need to insert 18 /// ldtilecfg again after the call instruction if callee clobbers any AMX 19 /// registers. 20 /// 21 /// This pass calculates all points that ldtilecfg need to be inserted to and 22 /// insert them. It reports error if the reachability conditions aren't met. 23 // 24 //===----------------------------------------------------------------------===// 25 26 #include "X86.h" 27 #include "X86InstrBuilder.h" 28 #include "X86RegisterInfo.h" 29 #include "X86Subtarget.h" 30 #include "llvm/CodeGen/MachineFunctionPass.h" 31 #include "llvm/CodeGen/MachineInstr.h" 32 #include "llvm/CodeGen/MachineLoopInfo.h" 33 #include "llvm/CodeGen/MachineRegisterInfo.h" 34 #include "llvm/CodeGen/Passes.h" 35 #include "llvm/CodeGen/TargetInstrInfo.h" 36 #include "llvm/CodeGen/TargetRegisterInfo.h" 37 #include "llvm/InitializePasses.h" 38 39 using namespace llvm; 40 41 #define DEBUG_TYPE "tile-pre-config" 42 #define ASSERT_VALID_COMPARE \ 43 assert((!MBB || !RHS.MBB || MBB == RHS.MBB) && \ 44 "Cannot compare between different BBs"); 45 #define REPORT_CONFIG_FAIL \ 46 report_fatal_error( \ 47 MF.getName() + \ 48 ": Failed to config tile register, please define the shape earlier"); 49 50 namespace { 51 52 struct MIRef { 53 MachineInstr *MI = nullptr; 54 MachineBasicBlock *MBB = nullptr; 55 // A virtual position for instruction that will be inserted after MI. 56 size_t Pos = 0; 57 MIRef() = default; 58 MIRef(MachineBasicBlock *MBB) : MBB(MBB) { 59 for (auto I = MBB->begin(), E = MBB->end(); I != E && I->isPHI(); 60 ++I, ++Pos) 61 MI = &*I; 62 } 63 MIRef(MachineInstr *MI, MachineBasicBlock *MBB) 64 : MI(MI), MBB(MBB), 65 Pos(std::distance(MBB->instr_begin(), ++MI->getIterator())) {} 66 MIRef(MachineInstr *MI, MachineBasicBlock *MBB, size_t Pos) 67 : MI(MI), MBB(MBB), Pos(Pos) {} 68 operator bool() const { return MBB != nullptr; } 69 bool operator==(const MIRef &RHS) const { 70 return MI == RHS.MI && MBB == RHS.MBB; 71 } 72 bool operator<(const MIRef &RHS) const { 73 ASSERT_VALID_COMPARE; 74 return Pos < RHS.Pos; 75 } 76 bool operator>(const MIRef &RHS) const { 77 ASSERT_VALID_COMPARE; 78 return Pos > RHS.Pos; 79 } 80 }; 81 82 struct BBInfo { 83 MIRef FirstAMX; 84 MIRef LastCall; 85 MIRef LastShape; 86 bool TileCfgForbidden = false; 87 bool NeedTileCfgLiveIn = false; 88 }; 89 90 class X86PreTileConfig : public MachineFunctionPass { 91 MachineRegisterInfo *MRI; 92 const MachineLoopInfo *MLI; 93 SmallSet<MachineInstr *, 8> DefVisited; 94 SmallSet<MachineBasicBlock *, 8> ShapeBBs; 95 DenseMap<MachineBasicBlock *, BBInfo> BBVisitedInfo; 96 97 /// Check if the callee will clobber AMX registers. 98 bool isDestructiveCall(MachineInstr &MI, BitVector UsableRegs) { 99 auto Iter = llvm::find_if( 100 MI.operands(), [](MachineOperand &MO) { return MO.isRegMask(); }); 101 if (Iter == MI.operands_end()) 102 return false; 103 UsableRegs.clearBitsInMask(Iter->getRegMask()); 104 return !UsableRegs.none(); 105 } 106 107 /// Check if MI is AMX pseudo instruction. 108 bool isAMXInstruction(MachineInstr &MI) { 109 if (MI.isPHI() || MI.isDebugInstr() || MI.getNumOperands() < 3) 110 return false; 111 MachineOperand &MO = MI.getOperand(0); 112 // We can simply check if it is AMX instruction by its def. 113 // But we should exclude old API which uses physical registers. 114 if (MO.isReg() && MO.getReg().isVirtual() && 115 MRI->getRegClass(MO.getReg())->getID() == X86::TILERegClassID) { 116 collectShapeInfo(MI); 117 return true; 118 } 119 // PTILESTOREDV is the only exception that doesn't def a AMX register. 120 return MI.getOpcode() == X86::PTILESTOREDV; 121 } 122 123 /// Check if it is an edge from loop bottom to loop head. 124 bool isLoopBackEdge(MachineBasicBlock *Header, MachineBasicBlock *Bottom) { 125 return MLI->isLoopHeader(Header) && 126 MLI->getLoopFor(Header)->getBottomBlock() == Bottom; 127 } 128 129 /// Collect the shape def information for later use. 130 void collectShapeInfo(MachineInstr &MI); 131 132 public: 133 X86PreTileConfig() : MachineFunctionPass(ID) {} 134 135 /// Return the pass name. 136 StringRef getPassName() const override { 137 return "Tile Register Pre-configure"; 138 } 139 140 /// X86PreTileConfig analysis usage. 141 void getAnalysisUsage(AnalysisUsage &AU) const override { 142 AU.setPreservesAll(); 143 AU.addRequired<MachineLoopInfo>(); 144 MachineFunctionPass::getAnalysisUsage(AU); 145 } 146 147 /// Clear MF related structures. 148 void releaseMemory() override { 149 ShapeBBs.clear(); 150 DefVisited.clear(); 151 BBVisitedInfo.clear(); 152 } 153 154 /// Perform ldtilecfg instructions inserting. 155 bool runOnMachineFunction(MachineFunction &MF) override; 156 157 static char ID; 158 }; 159 160 } // end anonymous namespace 161 162 char X86PreTileConfig::ID = 0; 163 164 INITIALIZE_PASS_BEGIN(X86PreTileConfig, "tilepreconfig", 165 "Tile Register Pre-configure", false, false) 166 INITIALIZE_PASS_DEPENDENCY(MachineLoopInfo) 167 INITIALIZE_PASS_END(X86PreTileConfig, "tilepreconfig", 168 "Tile Register Pre-configure", false, false) 169 170 void X86PreTileConfig::collectShapeInfo(MachineInstr &MI) { 171 auto RecordShape = [&](MachineInstr *MI, MachineBasicBlock *MBB) { 172 MIRef MIR(MI, MBB); 173 if (BBVisitedInfo[MBB].LastShape < MIR) 174 BBVisitedInfo[MBB].LastShape = MIR; 175 ShapeBBs.insert(MBB); 176 }; 177 178 SmallVector<Register, 8> WorkList( 179 {MI.getOperand(1).getReg(), MI.getOperand(2).getReg()}); 180 while (!WorkList.empty()) { 181 Register R = WorkList.pop_back_val(); 182 MachineInstr *DefMI = MRI->getVRegDef(R); 183 MachineBasicBlock *DefMBB = DefMI->getParent(); 184 if (!DefMI || DefMI->isMoveImmediate() || !DefVisited.insert(DefMI).second) 185 continue; 186 if (DefMI->isPHI()) { 187 for (unsigned I = 1; I < DefMI->getNumOperands(); I += 2) 188 if (isLoopBackEdge(DefMBB, DefMI->getOperand(I + 1).getMBB())) 189 RecordShape(DefMI, DefMBB); // In this case, PHI is also a shape def. 190 else 191 WorkList.push_back(DefMI->getOperand(I).getReg()); 192 } else { 193 RecordShape(DefMI, DefMBB); 194 } 195 } 196 } 197 198 bool X86PreTileConfig::runOnMachineFunction(MachineFunction &MF) { 199 const X86Subtarget &ST = MF.getSubtarget<X86Subtarget>(); 200 const TargetInstrInfo *TII = ST.getInstrInfo(); 201 const TargetRegisterInfo *TRI = ST.getRegisterInfo(); 202 const TargetRegisterClass *RC = TRI->getRegClass(X86::TILERegClassID); 203 204 BitVector AMXRegs(TRI->getNumRegs()); 205 for (unsigned I = 0; I < RC->getNumRegs(); I++) 206 AMXRegs.set(X86::TMM0 + I); 207 208 // Iterate MF to collect information. 209 MRI = &MF.getRegInfo(); 210 MLI = &getAnalysis<MachineLoopInfo>(); 211 SmallSet<MIRef, 8> CfgNeedInsert; 212 SmallVector<MachineBasicBlock *, 8> CfgLiveInBBs; 213 for (auto &MBB : MF) { 214 size_t Pos = 0; 215 for (auto &MI : MBB) { 216 ++Pos; 217 if (isAMXInstruction(MI)) { 218 // If there's call before the AMX, we need to reload tile config. 219 if (BBVisitedInfo[&MBB].LastCall) 220 CfgNeedInsert.insert(BBVisitedInfo[&MBB].LastCall); 221 else // Otherwise, we need tile config to live in this BB. 222 BBVisitedInfo[&MBB].NeedTileCfgLiveIn = true; 223 // Always record the first AMX in case there's shape def after it. 224 if (!BBVisitedInfo[&MBB].FirstAMX) 225 BBVisitedInfo[&MBB].FirstAMX = MIRef(&MI, &MBB, Pos); 226 } else if (MI.isCall() && isDestructiveCall(MI, AMXRegs)) { 227 // Record the call only if the callee clobbers all AMX registers. 228 BBVisitedInfo[&MBB].LastCall = MIRef(&MI, &MBB, Pos); 229 } 230 } 231 if (BBVisitedInfo[&MBB].NeedTileCfgLiveIn) { 232 if (&MBB == &MF.front()) 233 CfgNeedInsert.insert(MIRef(&MBB)); 234 else 235 CfgLiveInBBs.push_back(&MBB); 236 } 237 } 238 239 // Update NeedTileCfgLiveIn for predecessors. 240 while (!CfgLiveInBBs.empty()) { 241 MachineBasicBlock *MBB = CfgLiveInBBs.pop_back_val(); 242 for (auto *Pred : MBB->predecessors()) { 243 if (BBVisitedInfo[Pred].LastCall) { 244 CfgNeedInsert.insert(BBVisitedInfo[Pred].LastCall); 245 } else if (!BBVisitedInfo[Pred].NeedTileCfgLiveIn) { 246 BBVisitedInfo[Pred].NeedTileCfgLiveIn = true; 247 if (Pred == &MF.front()) 248 CfgNeedInsert.insert(MIRef(Pred)); 249 else 250 CfgLiveInBBs.push_back(Pred); 251 } 252 } 253 } 254 255 // There's no AMX instruction if we didn't find a tile config live in point. 256 if (CfgNeedInsert.empty()) 257 return false; 258 259 // Avoid to insert ldtilecfg before any shape defs. 260 SmallVector<MachineBasicBlock *, 8> WorkList( 261 make_range(ShapeBBs.begin(), ShapeBBs.end())); 262 while (!WorkList.empty()) { 263 MachineBasicBlock *MBB = WorkList.pop_back_val(); 264 for (auto *Pred : MBB->predecessors()) { 265 if (!BBVisitedInfo[Pred].TileCfgForbidden && !isLoopBackEdge(MBB, Pred)) { 266 BBVisitedInfo[Pred].TileCfgForbidden = true; 267 WorkList.push_back(Pred); 268 } 269 } 270 } 271 272 DebugLoc DL; 273 SmallSet<MIRef, 8> VisitedOrInserted; 274 int SS = MF.getFrameInfo().CreateStackObject( 275 ST.getTileConfigSize(), ST.getTileConfigAlignment(), false); 276 277 // Try to insert for the tile config live in points. 278 for (auto I : CfgNeedInsert) { 279 SmallSet<MIRef, 8> InsertPoints; 280 SmallVector<MIRef, 8> WorkList({I}); 281 while (!WorkList.empty()) { 282 MIRef I = WorkList.pop_back_val(); 283 if (!VisitedOrInserted.count(I)) { 284 if (!BBVisitedInfo[I.MBB].TileCfgForbidden) { 285 // If the BB is all shapes reachable, stop sink and try to insert. 286 InsertPoints.insert(I); 287 } else { 288 // Avoid the BB to be multi visited. 289 VisitedOrInserted.insert(I); 290 // We cannot sink it across any AMX instruction. 291 if (BBVisitedInfo[I.MBB].FirstAMX) 292 REPORT_CONFIG_FAIL; 293 // Sink the inserting point along the chain with NeedTileCfgLiveIn = 294 // true when MBB isn't all shapes reachable. 295 for (auto *Succ : I.MBB->successors()) 296 if (BBVisitedInfo[Succ].NeedTileCfgLiveIn) 297 WorkList.push_back(MIRef(Succ)); 298 } 299 } 300 } 301 302 // A given point might be forked due to shape conditions are not met. 303 for (MIRef I : InsertPoints) { 304 // Even MBB is all shapes reachable, we still need to check if there's 305 // AMX that intersects with shapes in the same MBB. 306 if (BBVisitedInfo[I.MBB].FirstAMX && 307 BBVisitedInfo[I.MBB].FirstAMX < BBVisitedInfo[I.MBB].LastShape) 308 REPORT_CONFIG_FAIL; 309 // Make sure we insert ldtilecfg after the last shape def in MBB. 310 if (I < BBVisitedInfo[I.MBB].LastShape) 311 I = BBVisitedInfo[I.MBB].LastShape; 312 // There're chances the MBB is sunk more than once. Record it to avoid 313 // multi insert. 314 if (VisitedOrInserted.insert(I).second) { 315 auto II = I.MI ? I.MI->getIterator() : I.MBB->instr_begin(); 316 addFrameReference(BuildMI(*I.MBB, ++II, DL, TII->get(X86::LDTILECFG)), 317 SS); 318 } 319 } 320 } 321 322 // Zero stack slot. 323 MachineBasicBlock &MBB = MF.front(); 324 MachineInstr *MI = &*MBB.begin(); 325 if (ST.hasAVX512()) { 326 Register Zmm = MRI->createVirtualRegister(&X86::VR512RegClass); 327 BuildMI(MBB, MI, DL, TII->get(X86::VPXORDZrr), Zmm) 328 .addReg(Zmm, RegState::Undef) 329 .addReg(Zmm, RegState::Undef); 330 addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::VMOVUPSZmr)), SS) 331 .addReg(Zmm); 332 } else if (ST.hasAVX2()) { 333 Register Ymm = MRI->createVirtualRegister(&X86::VR256RegClass); 334 BuildMI(MBB, MI, DL, TII->get(X86::VPXORYrr), Ymm) 335 .addReg(Ymm, RegState::Undef) 336 .addReg(Ymm, RegState::Undef); 337 addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::VMOVUPSYmr)), SS) 338 .addReg(Ymm); 339 addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::VMOVUPSYmr)), SS, 32) 340 .addReg(Ymm); 341 } else { 342 assert(ST.hasSSE2() && "AMX should assume SSE2 enabled"); 343 Register Xmm = MRI->createVirtualRegister(&X86::VR128RegClass); 344 BuildMI(MBB, MI, DL, TII->get(X86::PXORrr), Xmm) 345 .addReg(Xmm, RegState::Undef) 346 .addReg(Xmm, RegState::Undef); 347 addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOVUPSmr)), SS) 348 .addReg(Xmm); 349 addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOVUPSmr)), SS, 16) 350 .addReg(Xmm); 351 addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOVUPSmr)), SS, 32) 352 .addReg(Xmm); 353 addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOVUPSmr)), SS, 48) 354 .addReg(Xmm); 355 } 356 // Fill in the palette first. 357 addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOV8mi)), SS).addImm(1); 358 359 return true; 360 } 361 362 FunctionPass *llvm::createX86PreTileConfigPass() { 363 return new X86PreTileConfig(); 364 } 365