1 //===-- X86PreTileConfig.cpp - Tile Register Configure---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// \file Pass to pre-config the shape of AMX register 10 /// AMX register need to be configured before use. The shape of AMX register 11 /// is encoded in the 1st and 2nd machine operand of AMX pseudo instructions. 12 /// The pldtilecfg is to config tile registers. It should dominator all AMX 13 /// instructions. The pldtilecfg produce a virtual cfg register and the cfg 14 /// register is used by all AMX instructions. 15 /// This pass is to find the common dominator of all AMX instructions and 16 /// insert the pldtilecfg instruction. Besides the cfg register that pldtilecfg 17 /// produces is inserted as the last operand of each AMX instruction. We use 18 /// this scheme to model the def-use relationship between AMX config instruction 19 /// and other AMX instructions. Below is an example. 20 /// 21 /// ----B1---- 22 /// / \ 23 /// / \ 24 /// B2 B3 25 /// %1:tile = PTILELOADDV %2:tile = PTILELOADDV 26 /// 27 /// is transformed to 28 /// 29 /// B1 30 /// %25:tilecfg = PLDTILECFG 31 /// / \ 32 /// / \ 33 /// %1:tile = PTILELOADDV %25 %2:tile = PTILELOADDV %25 34 // 35 //===----------------------------------------------------------------------===// 36 37 #include "X86.h" 38 #include "X86InstrBuilder.h" 39 #include "X86RegisterInfo.h" 40 #include "X86Subtarget.h" 41 #include "llvm/CodeGen/MachineDominators.h" 42 #include "llvm/CodeGen/MachineFunctionPass.h" 43 #include "llvm/CodeGen/MachineInstr.h" 44 #include "llvm/CodeGen/MachineRegisterInfo.h" 45 #include "llvm/CodeGen/Passes.h" 46 #include "llvm/CodeGen/TargetInstrInfo.h" 47 #include "llvm/CodeGen/TargetRegisterInfo.h" 48 #include "llvm/CodeGen/TileShapeInfo.h" 49 #include "llvm/InitializePasses.h" 50 51 using namespace llvm; 52 53 #define DEBUG_TYPE "tile-pre-config" 54 55 namespace { 56 57 class X86PreTileConfig : public MachineFunctionPass { 58 // context 59 MachineFunction *MF = nullptr; 60 const X86Subtarget *ST = nullptr; 61 const TargetRegisterInfo *TRI; 62 const TargetInstrInfo *TII; 63 MachineDominatorTree *DomTree = nullptr; 64 MachineRegisterInfo *MRI = nullptr; 65 66 MachineInstr *getTileConfigPoint(); 67 68 public: 69 X86PreTileConfig() : MachineFunctionPass(ID) {} 70 71 /// Return the pass name. 72 StringRef getPassName() const override { 73 return "Tile Register Pre-configure"; 74 } 75 76 /// X86PreTileConfig analysis usage. 77 void getAnalysisUsage(AnalysisUsage &AU) const override; 78 79 /// Perform register allocation. 80 bool runOnMachineFunction(MachineFunction &mf) override; 81 82 static char ID; 83 }; 84 85 } // end anonymous namespace 86 87 char X86PreTileConfig::ID = 0; 88 89 INITIALIZE_PASS_BEGIN(X86PreTileConfig, "tilepreconfig", 90 "Tile Register Configure", false, false) 91 INITIALIZE_PASS_DEPENDENCY(MachineDominatorTree) 92 INITIALIZE_PASS_END(X86PreTileConfig, "tilepreconfig", 93 "Tile Register Configure", false, false) 94 95 void X86PreTileConfig::getAnalysisUsage(AnalysisUsage &AU) const { 96 AU.setPreservesAll(); 97 AU.addRequired<MachineDominatorTree>(); 98 MachineFunctionPass::getAnalysisUsage(AU); 99 } 100 101 static void buildConfigMI(MachineBasicBlock::iterator MI, int FrameIdx, 102 const TargetInstrInfo *TII, MachineRegisterInfo *MRI, 103 const X86Subtarget *ST) { 104 auto *MBB = MI->getParent(); 105 106 // Zero stack slot. 107 if (ST->hasAVX512()) { 108 Register Zmm = MRI->createVirtualRegister(&X86::VR512RegClass); 109 BuildMI(*MBB, MI, DebugLoc(), TII->get(X86::VPXORDZrr), Zmm) 110 .addReg(Zmm, RegState::Undef) 111 .addReg(Zmm, RegState::Undef); 112 addFrameReference(BuildMI(*MBB, MI, DebugLoc(), TII->get(X86::VMOVUPSZmr)), 113 FrameIdx) 114 .addReg(Zmm); 115 } else if (ST->hasAVX2()) { 116 Register Ymm = MRI->createVirtualRegister(&X86::VR256RegClass); 117 BuildMI(*MBB, MI, DebugLoc(), TII->get(X86::VPXORYrr), Ymm) 118 .addReg(Ymm, RegState::Undef) 119 .addReg(Ymm, RegState::Undef); 120 addFrameReference(BuildMI(*MBB, MI, DebugLoc(), TII->get(X86::VMOVUPSYmr)), 121 FrameIdx) 122 .addReg(Ymm); 123 addFrameReference(BuildMI(*MBB, MI, DebugLoc(), TII->get(X86::VMOVUPSYmr)), 124 FrameIdx, 32) 125 .addReg(Ymm); 126 } else { 127 assert(ST->hasSSE2() && "AMX should assume SSE2 enabled"); 128 Register Xmm = MRI->createVirtualRegister(&X86::VR128RegClass); 129 BuildMI(*MBB, MI, DebugLoc(), TII->get(X86::PXORrr), Xmm) 130 .addReg(Xmm, RegState::Undef) 131 .addReg(Xmm, RegState::Undef); 132 addFrameReference(BuildMI(*MBB, MI, DebugLoc(), TII->get(X86::MOVUPSmr)), 133 FrameIdx) 134 .addReg(Xmm); 135 addFrameReference(BuildMI(*MBB, MI, DebugLoc(), TII->get(X86::MOVUPSmr)), 136 FrameIdx, 16) 137 .addReg(Xmm); 138 addFrameReference(BuildMI(*MBB, MI, DebugLoc(), TII->get(X86::MOVUPSmr)), 139 FrameIdx, 32) 140 .addReg(Xmm); 141 addFrameReference(BuildMI(*MBB, MI, DebugLoc(), TII->get(X86::MOVUPSmr)), 142 FrameIdx, 48) 143 .addReg(Xmm); 144 } 145 146 // build psuedo ldtilecfg 147 addFrameReference(BuildMI(*MBB, MI, DebugLoc(), TII->get(X86::LDTILECFG)), 148 FrameIdx); 149 } 150 151 static ShapeT getShape(const MachineInstr &MI, MachineRegisterInfo *MRI) { 152 unsigned Opcode = MI.getOpcode(); 153 switch (Opcode) { 154 default: 155 llvm_unreachable("Unexpected machine instruction on tile"); 156 case X86::PTILELOADDV: 157 case X86::PTDPBSSDV: 158 case X86::PTILEZEROV: 159 MachineOperand &MO1 = const_cast<MachineOperand &>(MI.getOperand(1)); 160 MachineOperand &MO2 = const_cast<MachineOperand &>(MI.getOperand(2)); 161 ShapeT Shape(&MO1, &MO2, MRI); 162 return Shape; 163 } 164 } 165 166 MachineInstr *X86PreTileConfig::getTileConfigPoint() { 167 DenseMap<Register, ShapeT> PhysShapeInfo; 168 MachineBasicBlock *MBB = nullptr; 169 DenseSet<const MachineInstr *> MIs; 170 for (unsigned i = 0, e = MRI->getNumVirtRegs(); i != e; ++i) { 171 Register VirtReg = Register::index2VirtReg(i); 172 if (MRI->reg_nodbg_empty(VirtReg)) 173 continue; 174 const TargetRegisterClass &RC = *MRI->getRegClass(VirtReg); 175 if (RC.getID() != X86::TILERegClassID) 176 continue; 177 178 // Find the common dominator for all MI that define tile register. 179 for (const MachineOperand &MO : MRI->def_operands(VirtReg)) { 180 if (MO.isUndef()) 181 continue; 182 const auto *MI = MO.getParent(); 183 // PHI or IMPLICIT_DEF instructiion. 184 // There must be a input tile before PHI instruction. 185 if (MI->isTransient()) 186 continue; 187 if (!MBB) 188 MBB = const_cast<MachineBasicBlock *>(MI->getParent()); 189 MBB = DomTree->findNearestCommonDominator( 190 MBB, const_cast<MachineBasicBlock *>(MI->getParent())); 191 192 // Collect the instructions that define shape. 193 ShapeT Shape = getShape(*MI, MRI); 194 std::array<MachineOperand *, 2> ShapeMOs = {Shape.getRow(), 195 Shape.getCol()}; 196 for (auto *ShapeMO : ShapeMOs) { 197 Register ShapeReg = ShapeMO->getReg(); 198 for (const MachineOperand &MO : MRI->def_operands(ShapeReg)) { 199 const auto *ShapeMI = MO.getParent(); 200 MIs.insert(ShapeMI); 201 } 202 } 203 } 204 } 205 if (!MBB) 206 return nullptr; 207 // This pass is before the pass of eliminating PHI node, so it 208 // is in SSA form. 209 assert(MRI->isSSA() && "Not SSA form in pre-tile config"); 210 // Shape def should dominate tile config MBB. 211 // def s s1 s2 212 // / \ \ / 213 // / \ \ / 214 // conf s3=phi(s1,s2) 215 // | 216 // c 217 // 218 for (const auto *MI : MIs) { 219 const MachineBasicBlock *ShapeMBB = MI->getParent(); 220 if (DomTree->dominates(ShapeMBB, MBB)) 221 continue; 222 if (MI->isMoveImmediate()) 223 continue; 224 report_fatal_error(MF->getName() + ": Failed to config tile register, " 225 "please define the shape earlier"); 226 } 227 228 // ldtilecfg should be inserted after the MI that define the shape. 229 MachineBasicBlock::reverse_instr_iterator I, E; 230 for (I = MBB->instr_rbegin(), E = MBB->instr_rend(); I != E; ++I) { 231 auto *MI = &*I; 232 if (MIs.count(MI) && (!MI->isMoveImmediate())) 233 break; 234 } 235 MachineBasicBlock::iterator MII; 236 if (I == E) 237 MII = MBB->getFirstNonPHI(); 238 else { 239 MII = MachineBasicBlock::iterator(&*I); 240 MII++; 241 } 242 return &*MII; 243 } 244 245 static bool isAMXInstruction(MachineBasicBlock::iterator MII) { 246 switch (MII->getOpcode()) { 247 default: 248 return false; 249 case X86::PTILELOADDV: 250 case X86::PTILESTOREDV: 251 case X86::PTDPBSSDV: 252 case X86::PTILEZEROV: 253 return true; 254 } 255 } 256 257 struct BBInfo { 258 bool HasAMX = false; 259 bool HasCallBeforeAMX = false; 260 bool HasAMXBeforeCallInSuccs = false; 261 MachineInstr *LastCall = nullptr; 262 263 BBInfo() = default; 264 BBInfo(SmallSet<MachineInstr *, 8> &CfgNeedInsert, MachineBasicBlock *MBB, 265 MachineInstr *MI = nullptr) { 266 MachineBasicBlock::iterator MII = MI ? MI->getIterator() : MBB->begin(); 267 for (auto E = MBB->end(); MII != E; ++MII) { 268 if (isAMXInstruction(MII)) { 269 HasAMX = true; 270 if (LastCall) 271 CfgNeedInsert.insert(LastCall); 272 } else if (MII->isCall()) { 273 LastCall = &*MII; 274 if (!HasAMX) 275 HasCallBeforeAMX = true; 276 } 277 } 278 } 279 }; 280 281 static void reloadTileConfig(MachineInstr *MI, int FI, 282 const TargetInstrInfo *TII, 283 const TargetRegisterInfo *TRI) { 284 SmallSet<MachineInstr *, 8> CfgNeedInsert; 285 SmallVector<MachineBasicBlock *, 8> WorkList; 286 DenseMap<MachineBasicBlock *, BBInfo> BBVisitedInfo; 287 288 MachineBasicBlock *MBB = MI->getParent(); 289 BBVisitedInfo[MBB] = BBInfo(CfgNeedInsert, MBB, MI); 290 291 WorkList.push_back(MBB); 292 while (!WorkList.empty()) { 293 MBB = WorkList.pop_back_val(); 294 for (auto I = MBB->succ_begin(), E = MBB->succ_end(); I != E; ++I) { 295 if (!BBVisitedInfo.count(*I)) { 296 BBVisitedInfo[*I] = BBInfo(CfgNeedInsert, *I); 297 WorkList.push_back(*I); 298 } 299 } 300 } 301 302 WorkList.clear(); 303 for (auto I : BBVisitedInfo) { 304 WorkList.push_back(I.first); 305 while (!WorkList.empty()) { 306 MBB = WorkList.pop_back_val(); 307 if (BBVisitedInfo[MBB].HasCallBeforeAMX || 308 (!BBVisitedInfo[MBB].HasAMX && 309 !BBVisitedInfo[MBB].HasAMXBeforeCallInSuccs)) 310 continue; 311 for (auto I = MBB->pred_begin(), E = MBB->pred_end(); I != E; ++I) { 312 if (!BBVisitedInfo.count(*I) || 313 BBVisitedInfo[*I].HasAMXBeforeCallInSuccs) 314 continue; 315 if (BBVisitedInfo[*I].LastCall) 316 CfgNeedInsert.insert(BBVisitedInfo[*I].LastCall); 317 BBVisitedInfo[*I].HasAMXBeforeCallInSuccs = true; 318 WorkList.push_back(*I); 319 } 320 } 321 } 322 323 for (auto *I : CfgNeedInsert) { 324 BitVector UsableRegs(TRI->getNumRegs()); 325 const TargetRegisterClass *RC = TRI->getRegClass(X86::TILERegClassID); 326 for (unsigned J = 0; J < RC->getNumRegs(); J++) 327 UsableRegs.set(X86::TMM0 + J); 328 for (MachineOperand &CallMO : I->operands()) { 329 if (CallMO.isRegMask()) 330 UsableRegs.clearBitsInMask(CallMO.getRegMask()); 331 } 332 if (!UsableRegs.none()) 333 addFrameReference(BuildMI(*I->getParent(), ++I->getIterator(), DebugLoc(), 334 TII->get(X86::LDTILECFG)), 335 FI); 336 } 337 } 338 339 bool X86PreTileConfig::runOnMachineFunction(MachineFunction &mf) { 340 MF = &mf; 341 MRI = &mf.getRegInfo(); 342 ST = &mf.getSubtarget<X86Subtarget>(); 343 TRI = ST->getRegisterInfo(); 344 TII = mf.getSubtarget().getInstrInfo(); 345 DomTree = &getAnalysis<MachineDominatorTree>(); 346 347 MachineInstr *MI = getTileConfigPoint(); 348 if (!MI) 349 return false; 350 unsigned Size = ST->getTileConfigSize(); 351 Align Alignment = ST->getTileConfigAlignment(); 352 int SS = mf.getFrameInfo().CreateStackObject(Size, Alignment, false); 353 buildConfigMI(MI, SS, TII, MRI, ST); 354 reloadTileConfig(MI, SS, TII, TRI); 355 return true; 356 } 357 358 FunctionPass *llvm::createX86PreTileConfigPass() { 359 return new X86PreTileConfig(); 360 } 361