1 //===-- X86TileConfig.cpp - Tile Register Configure----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file Pass to config the shape of AMX physical registers
10 /// AMX register need to be configured before use. In X86PreTileConfig pass
11 /// the pldtilecfg instruction is inserted, however at that time we don't
12 /// know the shape of each physical tile registers, because the register
13 /// allocation is not done yet. This pass runs after egister allocation
14 /// pass. It collects the shape information of each physical tile register
15 /// and store the shape in the stack slot that is allocated for load config
16 /// to tile config register.
17 //
18 //===----------------------------------------------------------------------===//
19 
20 #include "X86.h"
21 #include "X86InstrBuilder.h"
22 #include "X86MachineFunctionInfo.h"
23 #include "X86RegisterInfo.h"
24 #include "X86Subtarget.h"
25 #include "llvm/ADT/PostOrderIterator.h"
26 #include "llvm/CodeGen/LiveIntervals.h"
27 #include "llvm/CodeGen/MachineDominators.h"
28 #include "llvm/CodeGen/MachineFrameInfo.h"
29 #include "llvm/CodeGen/MachineFunctionPass.h"
30 #include "llvm/CodeGen/MachineInstr.h"
31 #include "llvm/CodeGen/MachineRegisterInfo.h"
32 #include "llvm/CodeGen/Passes.h"
33 #include "llvm/CodeGen/TargetInstrInfo.h"
34 #include "llvm/CodeGen/TargetRegisterInfo.h"
35 #include "llvm/CodeGen/TileShapeInfo.h"
36 #include "llvm/CodeGen/VirtRegMap.h"
37 #include "llvm/InitializePasses.h"
38 
39 using namespace llvm;
40 
41 #define DEBUG_TYPE "tile-config"
42 
43 namespace {
44 
45 class X86TileConfig : public MachineFunctionPass {
46   // context
47   MachineFunction *MF = nullptr;
48   const X86Subtarget *ST = nullptr;
49   const TargetRegisterInfo *TRI;
50   const TargetInstrInfo *TII;
51   MachineDominatorTree *DomTree = nullptr;
52   MachineRegisterInfo *MRI = nullptr;
53   VirtRegMap *VRM = nullptr;
54   LiveIntervals *LIS = nullptr;
55 
56   MachineInstr *getTileConfigPoint();
57   void tileConfig();
58 
59 public:
60   X86TileConfig() : MachineFunctionPass(ID) {}
61 
62   /// Return the pass name.
63   StringRef getPassName() const override { return "Tile Register Configure"; }
64 
65   /// X86TileConfig analysis usage.
66   void getAnalysisUsage(AnalysisUsage &AU) const override;
67 
68   /// Perform register allocation.
69   bool runOnMachineFunction(MachineFunction &mf) override;
70 
71   MachineFunctionProperties getRequiredProperties() const override {
72     return MachineFunctionProperties().set(
73         MachineFunctionProperties::Property::NoPHIs);
74   }
75 
76   static char ID;
77 };
78 
79 } // end anonymous namespace
80 
81 char X86TileConfig::ID = 0;
82 
83 INITIALIZE_PASS_BEGIN(X86TileConfig, "tileconfig", "Tile Register Configure",
84                       false, false)
85 INITIALIZE_PASS_DEPENDENCY(MachineDominatorTree)
86 INITIALIZE_PASS_DEPENDENCY(VirtRegMap)
87 INITIALIZE_PASS_END(X86TileConfig, "tileconfig", "Tile Register Configure",
88                     false, false)
89 
90 void X86TileConfig::getAnalysisUsage(AnalysisUsage &AU) const {
91   AU.addRequired<MachineDominatorTree>();
92   AU.addRequired<LiveIntervals>();
93   AU.addPreserved<SlotIndexes>();
94   AU.addRequired<VirtRegMap>();
95   AU.setPreservesAll();
96   MachineFunctionPass::getAnalysisUsage(AU);
97 }
98 
99 static unsigned getTilePhysRegIndex(Register PhysReg) {
100   assert((PhysReg >= X86::TMM0 && X86::TMM0 <= X86::TMM7) &&
101          "Tile register number is invalid");
102   return (PhysReg - X86::TMM0);
103 }
104 
105 static MachineInstr *
106 storeRegToStackSlot(MachineBasicBlock &MBB, MachineBasicBlock::iterator MI,
107                     Register SrcReg, unsigned BitSize, int FrameIdx, int Offset,
108                     const TargetInstrInfo *TII, const TargetRegisterClass *RC,
109                     const TargetRegisterInfo *TRI) {
110 
111   unsigned SubIdx = (BitSize == 8) ? X86::sub_8bit : X86::sub_16bit;
112   unsigned Opc = (BitSize == 8) ? X86::MOV8mr : X86::MOV16mr;
113   if (BitSize == TRI->getRegSizeInBits(*RC))
114     SubIdx = 0;
115   MachineInstr *NewMI =
116       addFrameReference(BuildMI(MBB, MI, DebugLoc(), TII->get(Opc)), FrameIdx,
117                         Offset)
118           .addReg(SrcReg, 0, SubIdx);
119   return NewMI;
120 }
121 
122 static MachineInstr *storeImmToStackSlot(MachineBasicBlock &MBB,
123                                          MachineBasicBlock::iterator MI,
124                                          int64_t Imm, unsigned BitSize,
125                                          int FrameIdx, int Offset,
126                                          const TargetInstrInfo *TII) {
127   unsigned Opc = (BitSize == 8) ? X86::MOV8mi : X86::MOV16mi;
128   return addFrameReference(BuildMI(MBB, MI, DebugLoc(), TII->get(Opc)),
129                            FrameIdx, Offset)
130       .addImm(Imm);
131 }
132 
133 MachineInstr *X86TileConfig::getTileConfigPoint() {
134   MachineBasicBlock *Entry = &*MF->begin();
135   ReversePostOrderTraversal<MachineBasicBlock *> RPOT(Entry);
136   for (MachineBasicBlock *MBB : RPOT) {
137     for (MachineInstr &MI : *MBB)
138       // Refer X86PreTileConfig.cpp.
139       // We only support one tile config for now. The other ldtilecfg
140       // is for spill purpose and is dominated by the first ldtilecfg.
141       if (MI.getOpcode() == X86::LDTILECFG)
142         return &MI;
143   }
144 
145   return nullptr;
146 }
147 
148 void X86TileConfig::tileConfig() {
149   MachineInstr *MI = getTileConfigPoint();
150   if (!MI)
151     return;
152   MachineBasicBlock *MBB = MI->getParent();
153   int SS = MI->getOperand(0).getIndex();
154   BitVector PhysRegs(TRI->getNumRegs());
155 
156   // Fill in the palette first.
157   auto *NewMI = storeImmToStackSlot(*MBB, *MI, 1, 8, SS, 0, TII);
158   LIS->InsertMachineInstrInMaps(*NewMI);
159   // Fill in the shape of each tile physical register.
160   for (unsigned i = 0, e = MRI->getNumVirtRegs(); i != e; ++i) {
161     Register VirtReg = Register::index2VirtReg(i);
162     if (MRI->reg_nodbg_empty(VirtReg))
163       continue;
164     const TargetRegisterClass &RC = *MRI->getRegClass(VirtReg);
165     if (RC.getID() != X86::TILERegClassID)
166       continue;
167     Register PhysReg = VRM->getPhys(VirtReg);
168     if (PhysRegs.test(PhysReg))
169       continue;
170     PhysRegs.set(PhysReg);
171     ShapeT Shape = VRM->getShape(VirtReg);
172     Register RowReg = Shape.getRow()->getReg();
173     Register ColReg = Shape.getCol()->getReg();
174 
175     // Here is the data format for the tile config.
176     // 0      palette
177     // 1      start_row
178     // 2-15   reserved, must be zero
179     // 16-17  tile0.colsb Tile 0 bytes per row.
180     // 18-19  tile1.colsb Tile 1 bytes per row.
181     // 20-21  tile2.colsb Tile 2 bytes per row.
182     // ... (sequence continues)
183     // 30-31  tile7.colsb Tile 7 bytes per row.
184     // 32-47  reserved, must be zero
185     // 48     tile0.rows Tile 0 rows.
186     // 49     tile1.rows Tile 1 rows.
187     // 50     tile2.rows Tile 2 rows.
188     // ... (sequence continues)
189     // 55     tile7.rows Tile 7 rows.
190     // 56-63  reserved, must be zero
191     unsigned Index = getTilePhysRegIndex(PhysReg);
192     int RowOffset = 48 + Index;
193     int ColOffset = 16 + Index * 2;
194 
195     unsigned BitSize = 8;
196     for (const auto &Pair : {std::make_pair(RowReg, RowOffset),
197                              std::make_pair(ColReg, ColOffset)}) {
198       int64_t Imm;
199       int ImmCount = 0;
200       // All def must be the same value, otherwise it is invalid MIs.
201       // Immediate is prefered.
202       for (const MachineOperand &MO : MRI->def_operands(Pair.first)) {
203         const auto *Inst = MO.getParent();
204         if (Inst->isMoveImmediate()) {
205           ImmCount++;
206           Imm = Inst->getOperand(1).getImm();
207           break;
208         }
209       }
210       auto StoreConfig = [&](int Offset) {
211         MachineInstr *NewMI = nullptr;
212         if (ImmCount)
213           NewMI = storeImmToStackSlot(*MBB, *MI, Imm, BitSize, SS, Offset, TII);
214         else {
215           const TargetRegisterClass *RC = MRI->getRegClass(Pair.first);
216           NewMI = storeRegToStackSlot(*MBB, *MI, Pair.first, BitSize, SS,
217                                       Offset, TII, RC, TRI);
218         }
219         SlotIndex SIdx = LIS->InsertMachineInstrInMaps(*NewMI);
220         if (!ImmCount) {
221           // Extend the live interval.
222           SmallVector<SlotIndex, 8> EndPoints = {SIdx.getRegSlot()};
223           LiveInterval &Int = LIS->getInterval(Pair.first);
224           LIS->extendToIndices(Int, EndPoints);
225         }
226       };
227       StoreConfig(Pair.second);
228       BitSize += 8;
229     }
230   }
231 }
232 
233 bool X86TileConfig::runOnMachineFunction(MachineFunction &mf) {
234   MF = &mf;
235   MRI = &mf.getRegInfo();
236   ST = &mf.getSubtarget<X86Subtarget>();
237   TRI = ST->getRegisterInfo();
238   TII = mf.getSubtarget().getInstrInfo();
239   DomTree = &getAnalysis<MachineDominatorTree>();
240   VRM = &getAnalysis<VirtRegMap>();
241   LIS = &getAnalysis<LiveIntervals>();
242 
243   if (VRM->isShapeMapEmpty())
244     return false;
245 
246   tileConfig();
247   return true;
248 }
249 
250 FunctionPass *llvm::createX86TileConfigPass() { return new X86TileConfig(); }
251