1 //===-- X86PreTileConfig.cpp - Tile Register Configure---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// \file Pass to pre-config the shape of AMX register 10 /// AMX register need to be configured before use. The shape of AMX register 11 /// is encoded in the 1st and 2nd machine operand of AMX pseudo instructions. 12 /// The pldtilecfg is to config tile registers. It should dominator all AMX 13 /// instructions. The pldtilecfg produce a virtual cfg register and the cfg 14 /// register is used by all AMX instructions. 15 /// This pass is to find the common dominator of all AMX instructions and 16 /// insert the pldtilecfg instruction. Besides the cfg register that pldtilecfg 17 /// produces is inserted as the last operand of each AMX instruction. We use 18 /// this scheme to model the def-use relationship between AMX config instruction 19 /// and other AMX instructions. Below is an example. 20 /// 21 /// ----B1---- 22 /// / \ 23 /// / \ 24 /// B2 B3 25 /// %1:tile = PTILELOADDV %2:tile = PTILELOADDV 26 /// 27 /// is transformed to 28 /// 29 /// B1 30 /// %25:tilecfg = PLDTILECFG 31 /// / \ 32 /// / \ 33 /// %1:tile = PTILELOADDV %25 %2:tile = PTILELOADDV %25 34 // 35 //===----------------------------------------------------------------------===// 36 37 #include "X86.h" 38 #include "X86InstrBuilder.h" 39 #include "X86RegisterInfo.h" 40 #include "X86Subtarget.h" 41 #include "llvm/CodeGen/MachineDominators.h" 42 #include "llvm/CodeGen/MachineFunctionPass.h" 43 #include "llvm/CodeGen/MachineInstr.h" 44 #include "llvm/CodeGen/MachineRegisterInfo.h" 45 #include "llvm/CodeGen/Passes.h" 46 #include "llvm/CodeGen/TargetInstrInfo.h" 47 #include "llvm/CodeGen/TargetRegisterInfo.h" 48 #include "llvm/CodeGen/TileShapeInfo.h" 49 #include "llvm/InitializePasses.h" 50 51 using namespace llvm; 52 53 #define DEBUG_TYPE "tile-pre-config" 54 55 namespace { 56 57 class X86PreTileConfig : public MachineFunctionPass { 58 // context 59 MachineFunction *MF = nullptr; 60 const X86Subtarget *ST = nullptr; 61 const TargetRegisterInfo *TRI; 62 const TargetInstrInfo *TII; 63 MachineDominatorTree *DomTree = nullptr; 64 MachineRegisterInfo *MRI = nullptr; 65 66 MachineInstr *getTileConfigPoint(); 67 68 public: 69 X86PreTileConfig() : MachineFunctionPass(ID) {} 70 71 /// Return the pass name. 72 StringRef getPassName() const override { 73 return "Tile Register Pre-configure"; 74 } 75 76 /// X86PreTileConfig analysis usage. 77 void getAnalysisUsage(AnalysisUsage &AU) const override; 78 79 /// Perform register allocation. 80 bool runOnMachineFunction(MachineFunction &mf) override; 81 82 static char ID; 83 }; 84 85 } // end anonymous namespace 86 87 char X86PreTileConfig::ID = 0; 88 89 INITIALIZE_PASS_BEGIN(X86PreTileConfig, "tilepreconfig", 90 "Tile Register Configure", false, false) 91 INITIALIZE_PASS_DEPENDENCY(MachineDominatorTree) 92 INITIALIZE_PASS_END(X86PreTileConfig, "tilepreconfig", 93 "Tile Register Configure", false, false) 94 95 void X86PreTileConfig::getAnalysisUsage(AnalysisUsage &AU) const { 96 AU.setPreservesAll(); 97 AU.addRequired<MachineDominatorTree>(); 98 MachineFunctionPass::getAnalysisUsage(AU); 99 } 100 101 static void buildConfigMI(MachineBasicBlock::iterator MI, int FrameIdx, 102 const TargetInstrInfo *TII, MachineRegisterInfo *MRI, 103 const X86Subtarget *ST) { 104 auto *MBB = MI->getParent(); 105 106 // FIXME: AMX should assume AVX512 enabled. 107 if (ST->hasAVX512()) { 108 // Zero stack slot. 109 Register Zmm = MRI->createVirtualRegister(&X86::VR512RegClass); 110 BuildMI(*MBB, MI, DebugLoc(), TII->get(X86::VPXORDZrr), Zmm) 111 .addReg(Zmm, RegState::Undef) 112 .addReg(Zmm, RegState::Undef); 113 addFrameReference(BuildMI(*MBB, MI, DebugLoc(), TII->get(X86::VMOVUPSZmr)), 114 FrameIdx) 115 .addReg(Zmm); 116 } 117 118 // build psuedo ldtilecfg 119 addFrameReference(BuildMI(*MBB, MI, DebugLoc(), TII->get(X86::LDTILECFG)), 120 FrameIdx); 121 } 122 123 static ShapeT getShape(const MachineInstr &MI, MachineRegisterInfo *MRI) { 124 unsigned Opcode = MI.getOpcode(); 125 switch (Opcode) { 126 default: 127 llvm_unreachable("Unexpected machine instruction on tile"); 128 case X86::PTILELOADDV: 129 case X86::PTDPBSSDV: 130 case X86::PTILEZEROV: 131 MachineOperand &MO1 = const_cast<MachineOperand &>(MI.getOperand(1)); 132 MachineOperand &MO2 = const_cast<MachineOperand &>(MI.getOperand(2)); 133 ShapeT Shape(&MO1, &MO2, MRI); 134 return Shape; 135 } 136 } 137 138 MachineInstr *X86PreTileConfig::getTileConfigPoint() { 139 DenseMap<Register, ShapeT> PhysShapeInfo; 140 MachineBasicBlock *MBB = nullptr; 141 DenseSet<const MachineInstr *> MIs; 142 for (unsigned i = 0, e = MRI->getNumVirtRegs(); i != e; ++i) { 143 Register VirtReg = Register::index2VirtReg(i); 144 if (MRI->reg_nodbg_empty(VirtReg)) 145 continue; 146 const TargetRegisterClass &RC = *MRI->getRegClass(VirtReg); 147 if (RC.getID() != X86::TILERegClassID) 148 continue; 149 150 // Find the common dominator for all MI that define tile register. 151 for (const MachineOperand &MO : MRI->def_operands(VirtReg)) { 152 if (MO.isUndef()) 153 continue; 154 const auto *MI = MO.getParent(); 155 // PHI or IMPLICIT_DEF instructiion. 156 // There must be a input tile before PHI instruction. 157 if (MI->isTransient()) 158 continue; 159 if (!MBB) 160 MBB = const_cast<MachineBasicBlock *>(MI->getParent()); 161 MBB = DomTree->findNearestCommonDominator( 162 MBB, const_cast<MachineBasicBlock *>(MI->getParent())); 163 164 // Collect the instructions that define shape. 165 ShapeT Shape = getShape(*MI, MRI); 166 std::array<MachineOperand *, 2> ShapeMOs = {Shape.getRow(), 167 Shape.getCol()}; 168 for (auto *ShapeMO : ShapeMOs) { 169 Register ShapeReg = ShapeMO->getReg(); 170 for (const MachineOperand &MO : MRI->def_operands(ShapeReg)) { 171 const auto *ShapeMI = MO.getParent(); 172 MIs.insert(ShapeMI); 173 } 174 } 175 } 176 } 177 if (!MBB) 178 return nullptr; 179 // This pass is before the pass of eliminating PHI node, so it 180 // is in SSA form. 181 assert(MRI->isSSA() && "Not SSA form in pre-tile config"); 182 // Shape def should dominate tile config MBB. 183 // def s s1 s2 184 // / \ \ / 185 // / \ \ / 186 // conf s3=phi(s1,s2) 187 // | 188 // c 189 // 190 for (const auto *MI : MIs) { 191 const MachineBasicBlock *ShapeMBB = MI->getParent(); 192 if (DomTree->dominates(ShapeMBB, MBB)) 193 continue; 194 if (MI->isMoveImmediate()) 195 continue; 196 report_fatal_error(MF->getName() + ": Failed to config tile register, " 197 "please define the shape earlier"); 198 } 199 200 // ldtilecfg should be inserted after the MI that define the shape. 201 MachineBasicBlock::reverse_instr_iterator I, E; 202 for (I = MBB->instr_rbegin(), E = MBB->instr_rend(); I != E; ++I) { 203 auto *MI = &*I; 204 if (MIs.count(MI) && (!MI->isMoveImmediate())) 205 break; 206 } 207 MachineBasicBlock::iterator MII; 208 if (I == E) 209 MII = MBB->getFirstNonPHI(); 210 else { 211 MII = MachineBasicBlock::iterator(&*I); 212 MII++; 213 } 214 return &*MII; 215 } 216 217 static bool isAMXInstruction(MachineBasicBlock::iterator MII) { 218 switch (MII->getOpcode()) { 219 default: 220 return false; 221 case X86::PTILELOADDV: 222 case X86::PTILESTOREDV: 223 case X86::PTDPBSSDV: 224 case X86::PTILEZEROV: 225 return true; 226 } 227 } 228 229 struct BBInfo { 230 bool HasAMX = false; 231 bool HasCallBeforeAMX = false; 232 bool HasAMXBeforeCallInSuccs = false; 233 MachineInstr *LastCall = nullptr; 234 235 BBInfo() = default; 236 BBInfo(SmallSet<MachineInstr *, 8> &CfgNeedInsert, MachineBasicBlock *MBB, 237 MachineInstr *MI = nullptr) { 238 MachineBasicBlock::iterator MII = MI ? MI->getIterator() : MBB->begin(); 239 for (auto E = MBB->end(); MII != E; ++MII) { 240 if (isAMXInstruction(MII)) { 241 HasAMX = true; 242 if (LastCall) 243 CfgNeedInsert.insert(LastCall); 244 } else if (MII->isCall()) { 245 LastCall = &*MII; 246 if (!HasAMX) 247 HasCallBeforeAMX = true; 248 } 249 } 250 } 251 }; 252 253 static void reloadTileConfig(MachineInstr *MI, int FI, 254 const TargetInstrInfo *TII, 255 const TargetRegisterInfo *TRI) { 256 SmallSet<MachineInstr *, 8> CfgNeedInsert; 257 SmallVector<MachineBasicBlock *, 8> WorkList; 258 DenseMap<MachineBasicBlock *, BBInfo> BBVisitedInfo; 259 260 MachineBasicBlock *MBB = MI->getParent(); 261 BBVisitedInfo[MBB] = BBInfo(CfgNeedInsert, MBB, MI); 262 263 WorkList.push_back(MBB); 264 while (!WorkList.empty()) { 265 MBB = WorkList.pop_back_val(); 266 for (auto I = MBB->succ_begin(), E = MBB->succ_end(); I != E; ++I) { 267 if (!BBVisitedInfo.count(*I)) { 268 BBVisitedInfo[*I] = BBInfo(CfgNeedInsert, *I); 269 WorkList.push_back(*I); 270 } 271 } 272 } 273 274 WorkList.clear(); 275 for (auto I : BBVisitedInfo) { 276 WorkList.push_back(I.first); 277 while (!WorkList.empty()) { 278 MBB = WorkList.pop_back_val(); 279 if (BBVisitedInfo[MBB].HasCallBeforeAMX || 280 (!BBVisitedInfo[MBB].HasAMX && 281 !BBVisitedInfo[MBB].HasAMXBeforeCallInSuccs)) 282 continue; 283 for (auto I = MBB->pred_begin(), E = MBB->pred_end(); I != E; ++I) { 284 if (!BBVisitedInfo.count(*I) || 285 BBVisitedInfo[*I].HasAMXBeforeCallInSuccs) 286 continue; 287 if (BBVisitedInfo[*I].LastCall) 288 CfgNeedInsert.insert(BBVisitedInfo[*I].LastCall); 289 BBVisitedInfo[*I].HasAMXBeforeCallInSuccs = true; 290 WorkList.push_back(*I); 291 } 292 } 293 } 294 295 for (auto *I : CfgNeedInsert) { 296 BitVector UsableRegs(TRI->getNumRegs()); 297 const TargetRegisterClass *RC = TRI->getRegClass(X86::TILERegClassID); 298 for (unsigned J = 0; J < RC->getNumRegs(); J++) 299 UsableRegs.set(X86::TMM0 + J); 300 for (MachineOperand &CallMO : I->operands()) { 301 if (CallMO.isRegMask()) 302 UsableRegs.clearBitsInMask(CallMO.getRegMask()); 303 } 304 if (!UsableRegs.none()) 305 addFrameReference(BuildMI(*I->getParent(), ++I->getIterator(), DebugLoc(), 306 TII->get(X86::LDTILECFG)), 307 FI); 308 } 309 } 310 311 bool X86PreTileConfig::runOnMachineFunction(MachineFunction &mf) { 312 MF = &mf; 313 MRI = &mf.getRegInfo(); 314 ST = &mf.getSubtarget<X86Subtarget>(); 315 TRI = ST->getRegisterInfo(); 316 TII = mf.getSubtarget().getInstrInfo(); 317 DomTree = &getAnalysis<MachineDominatorTree>(); 318 319 MachineInstr *MI = getTileConfigPoint(); 320 if (!MI) 321 return false; 322 unsigned Size = ST->getTileConfigSize(); 323 Align Alignment = ST->getTileConfigAlignment(); 324 int SS = mf.getFrameInfo().CreateStackObject(Size, Alignment, false); 325 buildConfigMI(MI, SS, TII, MRI, ST); 326 reloadTileConfig(MI, SS, TII, TRI); 327 return true; 328 } 329 330 FunctionPass *llvm::createX86PreTileConfigPass() { 331 return new X86PreTileConfig(); 332 } 333