1f80b2987SLuo, Yuanke //===-- X86TileConfig.cpp - Tile Register Configure----------------------===//
2f80b2987SLuo, Yuanke //
3f80b2987SLuo, Yuanke // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4f80b2987SLuo, Yuanke // See https://llvm.org/LICENSE.txt for license information.
5f80b2987SLuo, Yuanke // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6f80b2987SLuo, Yuanke //
7f80b2987SLuo, Yuanke //===----------------------------------------------------------------------===//
8f80b2987SLuo, Yuanke //
9f80b2987SLuo, Yuanke /// \file Pass to config the shape of AMX physical registers
10f80b2987SLuo, Yuanke /// AMX register need to be configured before use. In X86PreTileConfig pass
11f80b2987SLuo, Yuanke /// the pldtilecfg instruction is inserted, however at that time we don't
12f80b2987SLuo, Yuanke /// know the shape of each physical tile registers, because the register
13f80b2987SLuo, Yuanke /// allocation is not done yet. This pass runs after egister allocation
14f80b2987SLuo, Yuanke /// pass. It collects the shape information of each physical tile register
15f80b2987SLuo, Yuanke /// and store the shape in the stack slot that is allocated for load config
16f80b2987SLuo, Yuanke /// to tile config register.
17f80b2987SLuo, Yuanke //
18f80b2987SLuo, Yuanke //===----------------------------------------------------------------------===//
19f80b2987SLuo, Yuanke 
20f80b2987SLuo, Yuanke #include "X86.h"
21f80b2987SLuo, Yuanke #include "X86InstrBuilder.h"
22f80b2987SLuo, Yuanke #include "X86MachineFunctionInfo.h"
23f80b2987SLuo, Yuanke #include "X86RegisterInfo.h"
24f80b2987SLuo, Yuanke #include "X86Subtarget.h"
25f80b2987SLuo, Yuanke #include "llvm/CodeGen/LiveIntervals.h"
26f80b2987SLuo, Yuanke #include "llvm/CodeGen/MachineFrameInfo.h"
27f80b2987SLuo, Yuanke #include "llvm/CodeGen/MachineFunctionPass.h"
28f80b2987SLuo, Yuanke #include "llvm/CodeGen/MachineInstr.h"
29f80b2987SLuo, Yuanke #include "llvm/CodeGen/MachineRegisterInfo.h"
30f80b2987SLuo, Yuanke #include "llvm/CodeGen/Passes.h"
31f80b2987SLuo, Yuanke #include "llvm/CodeGen/TargetInstrInfo.h"
32f80b2987SLuo, Yuanke #include "llvm/CodeGen/TargetRegisterInfo.h"
33f80b2987SLuo, Yuanke #include "llvm/CodeGen/TileShapeInfo.h"
34f80b2987SLuo, Yuanke #include "llvm/CodeGen/VirtRegMap.h"
35f80b2987SLuo, Yuanke #include "llvm/InitializePasses.h"
36f80b2987SLuo, Yuanke 
37f80b2987SLuo, Yuanke using namespace llvm;
38f80b2987SLuo, Yuanke 
39*5cb09798SLuo, Yuanke #define DEBUG_TYPE "tileconfig"
40f80b2987SLuo, Yuanke 
41f80b2987SLuo, Yuanke namespace {
42f80b2987SLuo, Yuanke 
43a3b52a9dSWang, Pengfei struct X86TileConfig : public MachineFunctionPass {
44f80b2987SLuo, Yuanke 
X86TileConfig__anon4ca03bf50111::X86TileConfig45f80b2987SLuo, Yuanke   X86TileConfig() : MachineFunctionPass(ID) {}
46f80b2987SLuo, Yuanke 
47f80b2987SLuo, Yuanke   /// Return the pass name.
getPassName__anon4ca03bf50111::X86TileConfig48f80b2987SLuo, Yuanke   StringRef getPassName() const override { return "Tile Register Configure"; }
49f80b2987SLuo, Yuanke 
50f80b2987SLuo, Yuanke   /// X86TileConfig analysis usage.
getAnalysisUsage__anon4ca03bf50111::X86TileConfig51a3b52a9dSWang, Pengfei   void getAnalysisUsage(AnalysisUsage &AU) const override {
52a3b52a9dSWang, Pengfei     AU.setPreservesAll();
53a3b52a9dSWang, Pengfei     AU.addRequired<VirtRegMap>();
54a3b52a9dSWang, Pengfei     AU.addRequired<LiveIntervals>();
55a3b52a9dSWang, Pengfei     MachineFunctionPass::getAnalysisUsage(AU);
56a3b52a9dSWang, Pengfei   }
57f80b2987SLuo, Yuanke 
58f80b2987SLuo, Yuanke   /// Perform register allocation.
59f80b2987SLuo, Yuanke   bool runOnMachineFunction(MachineFunction &mf) override;
60f80b2987SLuo, Yuanke 
getRequiredProperties__anon4ca03bf50111::X86TileConfig61f80b2987SLuo, Yuanke   MachineFunctionProperties getRequiredProperties() const override {
62f80b2987SLuo, Yuanke     return MachineFunctionProperties().set(
63f80b2987SLuo, Yuanke         MachineFunctionProperties::Property::NoPHIs);
64f80b2987SLuo, Yuanke   }
65f80b2987SLuo, Yuanke 
66f80b2987SLuo, Yuanke   static char ID;
67f80b2987SLuo, Yuanke };
68f80b2987SLuo, Yuanke 
69f80b2987SLuo, Yuanke } // end anonymous namespace
70f80b2987SLuo, Yuanke 
71f80b2987SLuo, Yuanke char X86TileConfig::ID = 0;
72f80b2987SLuo, Yuanke 
73*5cb09798SLuo, Yuanke INITIALIZE_PASS_BEGIN(X86TileConfig, DEBUG_TYPE, "Tile Register Configure",
74f80b2987SLuo, Yuanke                       false, false)
INITIALIZE_PASS_DEPENDENCY(VirtRegMap)75f80b2987SLuo, Yuanke INITIALIZE_PASS_DEPENDENCY(VirtRegMap)
76*5cb09798SLuo, Yuanke INITIALIZE_PASS_END(X86TileConfig, DEBUG_TYPE, "Tile Register Configure", false,
77*5cb09798SLuo, Yuanke                     false)
78f80b2987SLuo, Yuanke 
79a3b52a9dSWang, Pengfei bool X86TileConfig::runOnMachineFunction(MachineFunction &MF) {
80a3b52a9dSWang, Pengfei   const X86Subtarget &ST = MF.getSubtarget<X86Subtarget>();
81a3b52a9dSWang, Pengfei   const TargetRegisterInfo *TRI = ST.getRegisterInfo();
82a3b52a9dSWang, Pengfei   const TargetInstrInfo *TII = ST.getInstrInfo();
83a3b52a9dSWang, Pengfei   MachineRegisterInfo &MRI = MF.getRegInfo();
84a3b52a9dSWang, Pengfei   LiveIntervals &LIS = getAnalysis<LiveIntervals>();
85a3b52a9dSWang, Pengfei   VirtRegMap &VRM = getAnalysis<VirtRegMap>();
86a3b52a9dSWang, Pengfei 
87a3b52a9dSWang, Pengfei   if (VRM.isShapeMapEmpty())
88a3b52a9dSWang, Pengfei     return false;
89a3b52a9dSWang, Pengfei 
90a3b52a9dSWang, Pengfei   int SS = INT_MAX;
91a3b52a9dSWang, Pengfei   for (MachineBasicBlock &MBB : MF) {
92a3b52a9dSWang, Pengfei     for (MachineInstr &MI : MBB) {
93aaaf9cedSLuo, Yuanke       if (MI.getOpcode() == X86::PLDTILECFGV) {
94a3b52a9dSWang, Pengfei         SS = MI.getOperand(0).getIndex();
95a3b52a9dSWang, Pengfei         break;
96a3b52a9dSWang, Pengfei       }
97a3b52a9dSWang, Pengfei     }
98a3b52a9dSWang, Pengfei     if (SS != INT_MAX)
99a3b52a9dSWang, Pengfei       break;
100f80b2987SLuo, Yuanke   }
101aaaf9cedSLuo, Yuanke   // Didn't find PLDTILECFGV, just return false;
102f3ad7ea0SLuo, Yuanke   if (SS == INT_MAX)
103f3ad7ea0SLuo, Yuanke     return false;
104f80b2987SLuo, Yuanke 
105a3b52a9dSWang, Pengfei   // Try to find a point to insert MIs for constant shapes.
106a3b52a9dSWang, Pengfei   // Here we are leveraging the palette id inserted in PreRA pass.
107a3b52a9dSWang, Pengfei   unsigned ConstPos = 0;
108a3b52a9dSWang, Pengfei   MachineInstr *ConstMI = nullptr;
109a3b52a9dSWang, Pengfei   for (MachineInstr &MI : MF.front()) {
110a3b52a9dSWang, Pengfei     if (MI.getOpcode() == X86::MOV8mi && SS == MI.getOperand(0).getIndex()) {
111a3b52a9dSWang, Pengfei       ConstMI = &MI;
112a3b52a9dSWang, Pengfei       break;
113a3b52a9dSWang, Pengfei     }
114a3b52a9dSWang, Pengfei     ++ConstPos;
115a3b52a9dSWang, Pengfei   }
116a3b52a9dSWang, Pengfei   assert(ConstMI && "Cannot find an insertion point");
117a3b52a9dSWang, Pengfei 
118a3b52a9dSWang, Pengfei   unsigned AMXRegNum = TRI->getRegClass(X86::TILERegClassID)->getNumRegs();
119a3b52a9dSWang, Pengfei   SmallVector<Register, 8> Phys2Virt(AMXRegNum, 0);
120a3b52a9dSWang, Pengfei   for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
121a3b52a9dSWang, Pengfei     Register VirtReg = Register::index2VirtReg(I);
122a3b52a9dSWang, Pengfei     if (MRI.reg_nodbg_empty(VirtReg))
123a3b52a9dSWang, Pengfei       continue;
124a3b52a9dSWang, Pengfei     if (MRI.getRegClass(VirtReg)->getID() != X86::TILERegClassID)
125a3b52a9dSWang, Pengfei       continue;
126*5cb09798SLuo, Yuanke     if (VRM.getPhys(VirtReg) == VirtRegMap::NO_PHYS_REG)
127*5cb09798SLuo, Yuanke       continue;
128a3b52a9dSWang, Pengfei     unsigned Index = VRM.getPhys(VirtReg) - X86::TMM0;
129a3b52a9dSWang, Pengfei     if (!Phys2Virt[Index])
130a3b52a9dSWang, Pengfei       Phys2Virt[Index] = VirtReg;
131f80b2987SLuo, Yuanke   }
132f80b2987SLuo, Yuanke 
133f80b2987SLuo, Yuanke   // Fill in the shape of each tile physical register.
134a3b52a9dSWang, Pengfei   for (unsigned I = 0; I < AMXRegNum; ++I) {
135a3b52a9dSWang, Pengfei     if (!Phys2Virt[I])
136f80b2987SLuo, Yuanke       continue;
137a3b52a9dSWang, Pengfei     DebugLoc DL;
138a3b52a9dSWang, Pengfei     bool IsRow = true;
139a3b52a9dSWang, Pengfei     MachineInstr *NewMI = nullptr;
140a3b52a9dSWang, Pengfei     ShapeT Shape = VRM.getShape(Phys2Virt[I]);
141a3b52a9dSWang, Pengfei     for (auto &R : {Shape.getRow()->getReg(), Shape.getCol()->getReg()}) {
142f80b2987SLuo, Yuanke       // Here is the data format for the tile config.
143f80b2987SLuo, Yuanke       // 0      palette
144f80b2987SLuo, Yuanke       // 1      start_row
145f80b2987SLuo, Yuanke       // 2-15   reserved, must be zero
146f80b2987SLuo, Yuanke       // 16-17  tile0.colsb Tile 0 bytes per row.
147f80b2987SLuo, Yuanke       // 18-19  tile1.colsb Tile 1 bytes per row.
148f80b2987SLuo, Yuanke       // 20-21  tile2.colsb Tile 2 bytes per row.
149f80b2987SLuo, Yuanke       // ... (sequence continues)
150f80b2987SLuo, Yuanke       // 30-31  tile7.colsb Tile 7 bytes per row.
151f80b2987SLuo, Yuanke       // 32-47  reserved, must be zero
152f80b2987SLuo, Yuanke       // 48     tile0.rows Tile 0 rows.
153f80b2987SLuo, Yuanke       // 49     tile1.rows Tile 1 rows.
154f80b2987SLuo, Yuanke       // 50     tile2.rows Tile 2 rows.
155f80b2987SLuo, Yuanke       // ... (sequence continues)
156f80b2987SLuo, Yuanke       // 55     tile7.rows Tile 7 rows.
157f80b2987SLuo, Yuanke       // 56-63  reserved, must be zero
15853673fd1SWang, Pengfei       int64_t Imm = INT64_MAX;
159a3b52a9dSWang, Pengfei       int Offset = IsRow ? 48 + I : 16 + I * 2;
160a3b52a9dSWang, Pengfei       for (auto &DefMI : MRI.def_instructions(R)) {
161a3b52a9dSWang, Pengfei         MachineBasicBlock &MBB = *DefMI.getParent();
162a3b52a9dSWang, Pengfei         if (DefMI.isMoveImmediate()) {
16353673fd1SWang, Pengfei           if (Imm != INT64_MAX) {
164a3b52a9dSWang, Pengfei             // FIXME: We should handle this case in future.
16553673fd1SWang, Pengfei             assert(Imm == DefMI.getOperand(1).getImm() &&
16653673fd1SWang, Pengfei                    "Cannot initialize with different shapes");
16753673fd1SWang, Pengfei             continue;
16853673fd1SWang, Pengfei           }
16953673fd1SWang, Pengfei           Imm = DefMI.getOperand(1).getImm();
170a3b52a9dSWang, Pengfei           NewMI = addFrameReference(
171a3b52a9dSWang, Pengfei                       BuildMI(MF.front(), ++ConstMI->getIterator(), DL,
172a3b52a9dSWang, Pengfei                               TII->get(IsRow ? X86::MOV8mi : X86::MOV16mi)),
173a3b52a9dSWang, Pengfei                       SS, Offset)
17453673fd1SWang, Pengfei                       .addImm(Imm);
175a3b52a9dSWang, Pengfei           ConstMI = NewMI;
176a3b52a9dSWang, Pengfei           LIS.InsertMachineInstrInMaps(*NewMI);
177a3b52a9dSWang, Pengfei         } else {
178a3b52a9dSWang, Pengfei           unsigned SubIdx = IsRow ? X86::sub_8bit : X86::sub_16bit;
179a3b52a9dSWang, Pengfei           unsigned RegSize = TRI->getRegSizeInBits(*MRI.getRegClass(R));
180a3b52a9dSWang, Pengfei           if ((IsRow && RegSize == 8) || (!IsRow && RegSize == 16))
181a3b52a9dSWang, Pengfei             SubIdx = 0;
182a3b52a9dSWang, Pengfei           auto Iter = DefMI.getIterator();
183a3b52a9dSWang, Pengfei           if (&MBB == &MF.front() &&
184e2815398SLuke Benes               (unsigned)std::distance(MBB.instr_begin(), Iter) < ConstPos)
185a3b52a9dSWang, Pengfei             Iter = ConstMI->getIterator();
186a3b52a9dSWang, Pengfei           NewMI = addFrameReference(
187a3b52a9dSWang, Pengfei                       BuildMI(MBB, ++Iter, DL,
188a3b52a9dSWang, Pengfei                               TII->get(IsRow ? X86::MOV8mr : X86::MOV16mr)),
189a3b52a9dSWang, Pengfei                       SS, Offset)
190a3b52a9dSWang, Pengfei                       .addReg(R, 0, SubIdx);
191a3b52a9dSWang, Pengfei           SlotIndex SIdx = LIS.InsertMachineInstrInMaps(*NewMI);
192a3b52a9dSWang, Pengfei           LIS.extendToIndices(LIS.getInterval(R), {SIdx.getRegSlot()});
193f80b2987SLuo, Yuanke         }
194f80b2987SLuo, Yuanke       }
195a3b52a9dSWang, Pengfei       IsRow = false;
196f80b2987SLuo, Yuanke     }
197f80b2987SLuo, Yuanke   }
198f80b2987SLuo, Yuanke   return true;
199f80b2987SLuo, Yuanke }
200f80b2987SLuo, Yuanke 
createX86TileConfigPass()201f80b2987SLuo, Yuanke FunctionPass *llvm::createX86TileConfigPass() { return new X86TileConfig(); }
202