1 //===-- X86TileConfig.cpp - Tile Register Configure----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// \file Pass to config the shape of AMX physical registers 10 /// AMX register need to be configured before use. In X86PreTileConfig pass 11 /// the pldtilecfg instruction is inserted, however at that time we don't 12 /// know the shape of each physical tile registers, because the register 13 /// allocation is not done yet. This pass runs after egister allocation 14 /// pass. It collects the shape information of each physical tile register 15 /// and store the shape in the stack slot that is allocated for load config 16 /// to tile config register. 17 // 18 //===----------------------------------------------------------------------===// 19 20 #include "X86.h" 21 #include "X86InstrBuilder.h" 22 #include "X86MachineFunctionInfo.h" 23 #include "X86RegisterInfo.h" 24 #include "X86Subtarget.h" 25 #include "llvm/ADT/PostOrderIterator.h" 26 #include "llvm/CodeGen/LiveIntervals.h" 27 #include "llvm/CodeGen/MachineDominators.h" 28 #include "llvm/CodeGen/MachineFrameInfo.h" 29 #include "llvm/CodeGen/MachineFunctionPass.h" 30 #include "llvm/CodeGen/MachineInstr.h" 31 #include "llvm/CodeGen/MachineRegisterInfo.h" 32 #include "llvm/CodeGen/Passes.h" 33 #include "llvm/CodeGen/TargetInstrInfo.h" 34 #include "llvm/CodeGen/TargetRegisterInfo.h" 35 #include "llvm/CodeGen/TileShapeInfo.h" 36 #include "llvm/CodeGen/VirtRegMap.h" 37 #include "llvm/InitializePasses.h" 38 39 using namespace llvm; 40 41 #define DEBUG_TYPE "tile-config" 42 43 namespace { 44 45 class X86TileConfig : public MachineFunctionPass { 46 // context 47 MachineFunction *MF = nullptr; 48 const X86Subtarget *ST = nullptr; 49 const TargetRegisterInfo *TRI; 50 const TargetInstrInfo *TII; 51 MachineDominatorTree *DomTree = nullptr; 52 MachineRegisterInfo *MRI = nullptr; 53 VirtRegMap *VRM = nullptr; 54 LiveIntervals *LIS = nullptr; 55 56 MachineInstr *getTileConfigPoint(); 57 void tileConfig(); 58 59 public: 60 X86TileConfig() : MachineFunctionPass(ID) {} 61 62 /// Return the pass name. 63 StringRef getPassName() const override { return "Tile Register Configure"; } 64 65 /// X86TileConfig analysis usage. 66 void getAnalysisUsage(AnalysisUsage &AU) const override; 67 68 /// Perform register allocation. 69 bool runOnMachineFunction(MachineFunction &mf) override; 70 71 MachineFunctionProperties getRequiredProperties() const override { 72 return MachineFunctionProperties().set( 73 MachineFunctionProperties::Property::NoPHIs); 74 } 75 76 static char ID; 77 }; 78 79 } // end anonymous namespace 80 81 char X86TileConfig::ID = 0; 82 83 INITIALIZE_PASS_BEGIN(X86TileConfig, "tileconfig", "Tile Register Configure", 84 false, false) 85 INITIALIZE_PASS_DEPENDENCY(MachineDominatorTree) 86 INITIALIZE_PASS_DEPENDENCY(VirtRegMap) 87 INITIALIZE_PASS_END(X86TileConfig, "tileconfig", "Tile Register Configure", 88 false, false) 89 90 void X86TileConfig::getAnalysisUsage(AnalysisUsage &AU) const { 91 AU.addRequired<MachineDominatorTree>(); 92 AU.addRequired<LiveIntervals>(); 93 AU.addPreserved<SlotIndexes>(); 94 AU.addRequired<VirtRegMap>(); 95 AU.setPreservesAll(); 96 MachineFunctionPass::getAnalysisUsage(AU); 97 } 98 99 static unsigned getTilePhysRegIndex(Register PhysReg) { 100 assert((PhysReg >= X86::TMM0 && X86::TMM0 <= X86::TMM7) && 101 "Tile register number is invalid"); 102 return (PhysReg - X86::TMM0); 103 } 104 105 static MachineInstr * 106 storeRegToStackSlot(MachineBasicBlock &MBB, MachineBasicBlock::iterator MI, 107 Register SrcReg, unsigned BitSize, int FrameIdx, int Offset, 108 const TargetInstrInfo *TII, const TargetRegisterClass *RC, 109 const TargetRegisterInfo *TRI) { 110 111 unsigned SubIdx = (BitSize == 8) ? X86::sub_8bit : X86::sub_16bit; 112 unsigned Opc = (BitSize == 8) ? X86::MOV8mr : X86::MOV16mr; 113 if (BitSize == TRI->getRegSizeInBits(*RC)) 114 SubIdx = 0; 115 MachineInstr *NewMI = 116 addFrameReference(BuildMI(MBB, MI, DebugLoc(), TII->get(Opc)), FrameIdx, 117 Offset) 118 .addReg(SrcReg, 0, SubIdx); 119 return NewMI; 120 } 121 122 static MachineInstr *storeImmToStackSlot(MachineBasicBlock &MBB, 123 MachineBasicBlock::iterator MI, 124 int64_t Imm, unsigned BitSize, 125 int FrameIdx, int Offset, 126 const TargetInstrInfo *TII) { 127 unsigned Opc = (BitSize == 8) ? X86::MOV8mi : X86::MOV16mi; 128 return addFrameReference(BuildMI(MBB, MI, DebugLoc(), TII->get(Opc)), 129 FrameIdx, Offset) 130 .addImm(Imm); 131 } 132 133 MachineInstr *X86TileConfig::getTileConfigPoint() { 134 MachineBasicBlock *Entry = &*MF->begin(); 135 ReversePostOrderTraversal<MachineBasicBlock *> RPOT(Entry); 136 for (MachineBasicBlock *MBB : RPOT) { 137 for (MachineInstr &MI : *MBB) 138 // Refer X86PreTileConfig.cpp. 139 // We only support one tile config for now. The other ldtilecfg 140 // is for spill purpose and is dominated by the first ldtilecfg. 141 if (MI.getOpcode() == X86::LDTILECFG) 142 return &MI; 143 } 144 145 return nullptr; 146 } 147 148 void X86TileConfig::tileConfig() { 149 MachineInstr *MI = getTileConfigPoint(); 150 if (!MI) 151 return; 152 MachineBasicBlock *MBB = MI->getParent(); 153 int SS = MI->getOperand(0).getIndex(); 154 BitVector PhysRegs(TRI->getNumRegs()); 155 156 // Fill in the palette first. 157 auto *NewMI = storeImmToStackSlot(*MBB, *MI, 1, 8, SS, 0, TII); 158 LIS->InsertMachineInstrInMaps(*NewMI); 159 // Fill in the shape of each tile physical register. 160 for (unsigned i = 0, e = MRI->getNumVirtRegs(); i != e; ++i) { 161 Register VirtReg = Register::index2VirtReg(i); 162 if (MRI->reg_nodbg_empty(VirtReg)) 163 continue; 164 const TargetRegisterClass &RC = *MRI->getRegClass(VirtReg); 165 if (RC.getID() != X86::TILERegClassID) 166 continue; 167 Register PhysReg = VRM->getPhys(VirtReg); 168 if (PhysRegs.test(PhysReg)) 169 continue; 170 PhysRegs.set(PhysReg); 171 ShapeT Shape = VRM->getShape(VirtReg); 172 Register RowReg = Shape.getRow()->getReg(); 173 Register ColReg = Shape.getCol()->getReg(); 174 175 // Here is the data format for the tile config. 176 // 0 palette 177 // 1 start_row 178 // 2-15 reserved, must be zero 179 // 16-17 tile0.colsb Tile 0 bytes per row. 180 // 18-19 tile1.colsb Tile 1 bytes per row. 181 // 20-21 tile2.colsb Tile 2 bytes per row. 182 // ... (sequence continues) 183 // 30-31 tile7.colsb Tile 7 bytes per row. 184 // 32-47 reserved, must be zero 185 // 48 tile0.rows Tile 0 rows. 186 // 49 tile1.rows Tile 1 rows. 187 // 50 tile2.rows Tile 2 rows. 188 // ... (sequence continues) 189 // 55 tile7.rows Tile 7 rows. 190 // 56-63 reserved, must be zero 191 unsigned Index = getTilePhysRegIndex(PhysReg); 192 int RowOffset = 48 + Index; 193 int ColOffset = 16 + Index * 2; 194 195 unsigned BitSize = 8; 196 for (const auto &Pair : {std::make_pair(RowReg, RowOffset), 197 std::make_pair(ColReg, ColOffset)}) { 198 int64_t Imm; 199 int ImmCount = 0; 200 // All def must be the same value, otherwise it is invalid MIs. 201 // Immediate is prefered. 202 for (const MachineOperand &MO : MRI->def_operands(Pair.first)) { 203 const auto *Inst = MO.getParent(); 204 if (Inst->isMoveImmediate()) { 205 ImmCount++; 206 Imm = Inst->getOperand(1).getImm(); 207 break; 208 } 209 } 210 auto StoreConfig = [&](int Offset) { 211 MachineInstr *NewMI = nullptr; 212 if (ImmCount) 213 NewMI = storeImmToStackSlot(*MBB, *MI, Imm, BitSize, SS, Offset, TII); 214 else { 215 const TargetRegisterClass *RC = MRI->getRegClass(Pair.first); 216 NewMI = storeRegToStackSlot(*MBB, *MI, Pair.first, BitSize, SS, 217 Offset, TII, RC, TRI); 218 } 219 SlotIndex SIdx = LIS->InsertMachineInstrInMaps(*NewMI); 220 if (!ImmCount) { 221 // Extend the live interval. 222 SmallVector<SlotIndex, 8> EndPoints = {SIdx.getRegSlot()}; 223 LiveInterval &Int = LIS->getInterval(Pair.first); 224 LIS->extendToIndices(Int, EndPoints); 225 } 226 }; 227 StoreConfig(Pair.second); 228 BitSize += 8; 229 } 230 } 231 } 232 233 bool X86TileConfig::runOnMachineFunction(MachineFunction &mf) { 234 MF = &mf; 235 MRI = &mf.getRegInfo(); 236 ST = &mf.getSubtarget<X86Subtarget>(); 237 TRI = ST->getRegisterInfo(); 238 TII = mf.getSubtarget().getInstrInfo(); 239 DomTree = &getAnalysis<MachineDominatorTree>(); 240 VRM = &getAnalysis<VirtRegMap>(); 241 LIS = &getAnalysis<LiveIntervals>(); 242 243 if (VRM->isShapeMapEmpty()) 244 return false; 245 246 tileConfig(); 247 return true; 248 } 249 250 FunctionPass *llvm::createX86TileConfigPass() { return new X86TileConfig(); } 251