1 //===-- X86PreTileConfig.cpp - Tile Register Pre-configure-----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// \file Pass to pre-config the shape of AMX register 10 /// AMX register need to be configured before use. The shape of AMX register 11 /// is encoded in the 1st and 2nd machine operand of AMX pseudo instructions. 12 /// The pldtilecfg is to config tile registers. It should dominator all AMX 13 /// instructions. The pldtilecfg produce a virtual cfg register and the cfg 14 /// register is used by all AMX instructions. 15 /// This pass is to find the common dominator of all AMX instructions and 16 /// insert the pldtilecfg instruction. Besides the cfg register that pldtilecfg 17 /// produces is inserted as the last operand of each AMX instruction. We use 18 /// this scheme to model the def-use relationship between AMX config instruction 19 /// and other AMX instructions. Below is an example. 20 /// 21 /// ----B1---- 22 /// / \ 23 /// / \ 24 /// B2 B3 25 /// %1:tile = PTILELOADDV %2:tile = PTILELOADDV 26 /// 27 /// is transformed to 28 /// 29 /// B1 30 /// %25:tilecfg = PLDTILECFG 31 /// / \ 32 /// / \ 33 /// %1:tile = PTILELOADDV %25 %2:tile = PTILELOADDV %25 34 // 35 //===----------------------------------------------------------------------===// 36 37 #include "X86.h" 38 #include "X86InstrBuilder.h" 39 #include "X86RegisterInfo.h" 40 #include "X86Subtarget.h" 41 #include "llvm/CodeGen/MachineDominators.h" 42 #include "llvm/CodeGen/MachineFunctionPass.h" 43 #include "llvm/CodeGen/MachineInstr.h" 44 #include "llvm/CodeGen/MachineRegisterInfo.h" 45 #include "llvm/CodeGen/Passes.h" 46 #include "llvm/CodeGen/TargetInstrInfo.h" 47 #include "llvm/CodeGen/TargetRegisterInfo.h" 48 #include "llvm/CodeGen/TileShapeInfo.h" 49 #include "llvm/InitializePasses.h" 50 51 using namespace llvm; 52 53 #define DEBUG_TYPE "tile-pre-config" 54 55 namespace { 56 57 class X86PreTileConfig : public MachineFunctionPass { 58 // context 59 MachineFunction *MF = nullptr; 60 const X86Subtarget *ST = nullptr; 61 const TargetRegisterInfo *TRI; 62 const TargetInstrInfo *TII; 63 MachineDominatorTree *DomTree = nullptr; 64 MachineRegisterInfo *MRI = nullptr; 65 66 MachineInstr *getTileConfigPoint(); 67 68 public: 69 X86PreTileConfig() : MachineFunctionPass(ID) {} 70 71 /// Return the pass name. 72 StringRef getPassName() const override { 73 return "Tile Register Pre-configure"; 74 } 75 76 /// X86PreTileConfig analysis usage. 77 void getAnalysisUsage(AnalysisUsage &AU) const override; 78 79 /// Perform register allocation. 80 bool runOnMachineFunction(MachineFunction &mf) override; 81 82 static char ID; 83 }; 84 85 } // end anonymous namespace 86 87 char X86PreTileConfig::ID = 0; 88 89 INITIALIZE_PASS_BEGIN(X86PreTileConfig, "tilepreconfig", 90 "Tile Register Pre-configure", false, false) 91 INITIALIZE_PASS_DEPENDENCY(MachineDominatorTree) 92 INITIALIZE_PASS_END(X86PreTileConfig, "tilepreconfig", 93 "Tile Register Pre-configure", false, false) 94 95 void X86PreTileConfig::getAnalysisUsage(AnalysisUsage &AU) const { 96 AU.setPreservesAll(); 97 AU.addRequired<MachineDominatorTree>(); 98 MachineFunctionPass::getAnalysisUsage(AU); 99 } 100 101 static void buildConfigMI(MachineBasicBlock::iterator MI, int FrameIdx, 102 const TargetInstrInfo *TII, MachineRegisterInfo *MRI, 103 const X86Subtarget *ST) { 104 auto *MBB = MI->getParent(); 105 106 // Zero stack slot. 107 if (ST->hasAVX512()) { 108 Register Zmm = MRI->createVirtualRegister(&X86::VR512RegClass); 109 BuildMI(*MBB, MI, DebugLoc(), TII->get(X86::VPXORDZrr), Zmm) 110 .addReg(Zmm, RegState::Undef) 111 .addReg(Zmm, RegState::Undef); 112 addFrameReference(BuildMI(*MBB, MI, DebugLoc(), TII->get(X86::VMOVUPSZmr)), 113 FrameIdx) 114 .addReg(Zmm); 115 } else if (ST->hasAVX2()) { 116 Register Ymm = MRI->createVirtualRegister(&X86::VR256RegClass); 117 BuildMI(*MBB, MI, DebugLoc(), TII->get(X86::VPXORYrr), Ymm) 118 .addReg(Ymm, RegState::Undef) 119 .addReg(Ymm, RegState::Undef); 120 addFrameReference(BuildMI(*MBB, MI, DebugLoc(), TII->get(X86::VMOVUPSYmr)), 121 FrameIdx) 122 .addReg(Ymm); 123 addFrameReference(BuildMI(*MBB, MI, DebugLoc(), TII->get(X86::VMOVUPSYmr)), 124 FrameIdx, 32) 125 .addReg(Ymm); 126 } else { 127 assert(ST->hasSSE2() && "AMX should assume SSE2 enabled"); 128 Register Xmm = MRI->createVirtualRegister(&X86::VR128RegClass); 129 BuildMI(*MBB, MI, DebugLoc(), TII->get(X86::PXORrr), Xmm) 130 .addReg(Xmm, RegState::Undef) 131 .addReg(Xmm, RegState::Undef); 132 addFrameReference(BuildMI(*MBB, MI, DebugLoc(), TII->get(X86::MOVUPSmr)), 133 FrameIdx) 134 .addReg(Xmm); 135 addFrameReference(BuildMI(*MBB, MI, DebugLoc(), TII->get(X86::MOVUPSmr)), 136 FrameIdx, 16) 137 .addReg(Xmm); 138 addFrameReference(BuildMI(*MBB, MI, DebugLoc(), TII->get(X86::MOVUPSmr)), 139 FrameIdx, 32) 140 .addReg(Xmm); 141 addFrameReference(BuildMI(*MBB, MI, DebugLoc(), TII->get(X86::MOVUPSmr)), 142 FrameIdx, 48) 143 .addReg(Xmm); 144 } 145 146 // build psuedo ldtilecfg 147 addFrameReference(BuildMI(*MBB, MI, DebugLoc(), TII->get(X86::LDTILECFG)), 148 FrameIdx); 149 } 150 151 static ShapeT getShape(const MachineInstr &MI, MachineRegisterInfo *MRI) { 152 unsigned Opcode = MI.getOpcode(); 153 switch (Opcode) { 154 default: 155 llvm_unreachable("Unexpected machine instruction on tile"); 156 case X86::PTILELOADDV: 157 case X86::PTDPBSSDV: 158 case X86::PTDPBSUDV: 159 case X86::PTDPBUSDV: 160 case X86::PTDPBUUDV: 161 case X86::PTILEZEROV: 162 case X86::PTDPBF16PSV: 163 MachineOperand &MO1 = const_cast<MachineOperand &>(MI.getOperand(1)); 164 MachineOperand &MO2 = const_cast<MachineOperand &>(MI.getOperand(2)); 165 ShapeT Shape(&MO1, &MO2, MRI); 166 return Shape; 167 } 168 } 169 170 MachineInstr *X86PreTileConfig::getTileConfigPoint() { 171 DenseMap<Register, ShapeT> PhysShapeInfo; 172 MachineBasicBlock *MBB = nullptr; 173 DenseSet<const MachineInstr *> MIs; 174 for (unsigned i = 0, e = MRI->getNumVirtRegs(); i != e; ++i) { 175 Register VirtReg = Register::index2VirtReg(i); 176 if (MRI->reg_nodbg_empty(VirtReg)) 177 continue; 178 const TargetRegisterClass &RC = *MRI->getRegClass(VirtReg); 179 if (RC.getID() != X86::TILERegClassID) 180 continue; 181 182 // Find the common dominator for all MI that define tile register. 183 for (const MachineOperand &MO : MRI->def_operands(VirtReg)) { 184 if (MO.isUndef()) 185 continue; 186 const auto *MI = MO.getParent(); 187 // PHI or IMPLICIT_DEF instructiion. 188 // There must be a input tile before PHI instruction. 189 if (MI->isTransient()) 190 continue; 191 if (!MBB) 192 MBB = const_cast<MachineBasicBlock *>(MI->getParent()); 193 MBB = DomTree->findNearestCommonDominator( 194 MBB, const_cast<MachineBasicBlock *>(MI->getParent())); 195 196 // Collect the instructions that define shape. 197 ShapeT Shape = getShape(*MI, MRI); 198 std::array<MachineOperand *, 2> ShapeMOs = {Shape.getRow(), 199 Shape.getCol()}; 200 for (auto *ShapeMO : ShapeMOs) { 201 Register ShapeReg = ShapeMO->getReg(); 202 for (const MachineOperand &MO : MRI->def_operands(ShapeReg)) { 203 const auto *ShapeMI = MO.getParent(); 204 MIs.insert(ShapeMI); 205 } 206 } 207 } 208 } 209 if (!MBB) 210 return nullptr; 211 // This pass is before the pass of eliminating PHI node, so it 212 // is in SSA form. 213 assert(MRI->isSSA() && "Not SSA form in pre-tile config"); 214 // Shape def should dominate tile config MBB. 215 // def s s1 s2 216 // / \ \ / 217 // / \ \ / 218 // conf s3=phi(s1,s2) 219 // | 220 // c 221 // 222 for (const auto *MI : MIs) { 223 const MachineBasicBlock *ShapeMBB = MI->getParent(); 224 if (DomTree->dominates(ShapeMBB, MBB)) 225 continue; 226 if (MI->isMoveImmediate()) 227 continue; 228 report_fatal_error(MF->getName() + ": Failed to config tile register, " 229 "please define the shape earlier"); 230 } 231 232 // ldtilecfg should be inserted after the MI that define the shape. 233 MachineBasicBlock::reverse_instr_iterator I, E; 234 for (I = MBB->instr_rbegin(), E = MBB->instr_rend(); I != E; ++I) { 235 auto *MI = &*I; 236 if (MIs.count(MI) && (!MI->isMoveImmediate())) 237 break; 238 } 239 MachineBasicBlock::iterator MII; 240 if (I == E) 241 MII = MBB->getFirstNonPHI(); 242 else { 243 MII = MachineBasicBlock::iterator(&*I); 244 MII++; 245 } 246 return &*MII; 247 } 248 249 static bool isAMXInstruction(MachineBasicBlock::iterator MII) { 250 switch (MII->getOpcode()) { 251 default: 252 return false; 253 case X86::PTILELOADDV: 254 case X86::PTILESTOREDV: 255 case X86::PTDPBSSDV: 256 case X86::PTDPBSUDV: 257 case X86::PTDPBUSDV: 258 case X86::PTDPBUUDV: 259 case X86::PTILEZEROV: 260 case X86::PTDPBF16PSV: 261 return true; 262 } 263 } 264 265 struct BBInfo { 266 bool HasAMX = false; 267 bool HasCallBeforeAMX = false; 268 bool HasAMXBeforeCallInSuccs = false; 269 MachineInstr *LastCall = nullptr; 270 271 BBInfo() = default; 272 BBInfo(SmallSet<MachineInstr *, 8> &CfgNeedInsert, MachineBasicBlock *MBB, 273 MachineInstr *MI = nullptr) { 274 MachineBasicBlock::iterator MII = MI ? MI->getIterator() : MBB->begin(); 275 for (auto E = MBB->end(); MII != E; ++MII) { 276 if (isAMXInstruction(MII)) { 277 HasAMX = true; 278 if (LastCall) 279 CfgNeedInsert.insert(LastCall); 280 } else if (MII->isCall()) { 281 LastCall = &*MII; 282 if (!HasAMX) 283 HasCallBeforeAMX = true; 284 } 285 } 286 } 287 }; 288 289 static void reloadTileConfig(MachineInstr *MI, int FI, 290 const TargetInstrInfo *TII, 291 const TargetRegisterInfo *TRI) { 292 SmallSet<MachineInstr *, 8> CfgNeedInsert; 293 SmallVector<MachineBasicBlock *, 8> WorkList; 294 DenseMap<MachineBasicBlock *, BBInfo> BBVisitedInfo; 295 296 MachineBasicBlock *MBB = MI->getParent(); 297 BBVisitedInfo[MBB] = BBInfo(CfgNeedInsert, MBB, MI); 298 299 WorkList.push_back(MBB); 300 while (!WorkList.empty()) { 301 MBB = WorkList.pop_back_val(); 302 for (auto I = MBB->succ_begin(), E = MBB->succ_end(); I != E; ++I) { 303 if (!BBVisitedInfo.count(*I)) { 304 BBVisitedInfo[*I] = BBInfo(CfgNeedInsert, *I); 305 WorkList.push_back(*I); 306 } 307 } 308 } 309 310 WorkList.clear(); 311 for (auto I : BBVisitedInfo) { 312 WorkList.push_back(I.first); 313 while (!WorkList.empty()) { 314 MBB = WorkList.pop_back_val(); 315 if (BBVisitedInfo[MBB].HasCallBeforeAMX || 316 (!BBVisitedInfo[MBB].HasAMX && 317 !BBVisitedInfo[MBB].HasAMXBeforeCallInSuccs)) 318 continue; 319 for (auto I = MBB->pred_begin(), E = MBB->pred_end(); I != E; ++I) { 320 if (!BBVisitedInfo.count(*I) || 321 BBVisitedInfo[*I].HasAMXBeforeCallInSuccs) 322 continue; 323 if (BBVisitedInfo[*I].LastCall) 324 CfgNeedInsert.insert(BBVisitedInfo[*I].LastCall); 325 BBVisitedInfo[*I].HasAMXBeforeCallInSuccs = true; 326 WorkList.push_back(*I); 327 } 328 } 329 } 330 331 for (auto *I : CfgNeedInsert) { 332 BitVector UsableRegs(TRI->getNumRegs()); 333 const TargetRegisterClass *RC = TRI->getRegClass(X86::TILERegClassID); 334 for (unsigned J = 0; J < RC->getNumRegs(); J++) 335 UsableRegs.set(X86::TMM0 + J); 336 for (MachineOperand &CallMO : I->operands()) { 337 if (CallMO.isRegMask()) 338 UsableRegs.clearBitsInMask(CallMO.getRegMask()); 339 } 340 if (!UsableRegs.none()) 341 addFrameReference(BuildMI(*I->getParent(), ++I->getIterator(), DebugLoc(), 342 TII->get(X86::LDTILECFG)), 343 FI); 344 } 345 } 346 347 bool X86PreTileConfig::runOnMachineFunction(MachineFunction &mf) { 348 MF = &mf; 349 MRI = &mf.getRegInfo(); 350 ST = &mf.getSubtarget<X86Subtarget>(); 351 TRI = ST->getRegisterInfo(); 352 TII = mf.getSubtarget().getInstrInfo(); 353 DomTree = &getAnalysis<MachineDominatorTree>(); 354 355 MachineInstr *MI = getTileConfigPoint(); 356 if (!MI) 357 return false; 358 unsigned Size = ST->getTileConfigSize(); 359 Align Alignment = ST->getTileConfigAlignment(); 360 int SS = mf.getFrameInfo().CreateStackObject(Size, Alignment, false); 361 buildConfigMI(MI, SS, TII, MRI, ST); 362 reloadTileConfig(MI, SS, TII, TRI); 363 return true; 364 } 365 366 FunctionPass *llvm::createX86PreTileConfigPass() { 367 return new X86PreTileConfig(); 368 } 369