1 //===-- X86PreTileConfig.cpp - Tile Register Pre-configure-----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file Pass to pre-config the shape of AMX register
10 /// AMX register need to be configured before use. The shape of AMX register
11 /// is encoded in the 1st and 2nd machine operand of AMX pseudo instructions.
12 /// The pldtilecfg is to config tile registers. It should dominator all AMX
13 /// instructions. The pldtilecfg produce a virtual cfg register and the cfg
14 /// register is used by all AMX instructions.
15 /// This pass is to find the common dominator of all AMX instructions and
16 /// insert the pldtilecfg instruction. Besides the cfg register that pldtilecfg
17 /// produces is inserted as the last operand of each AMX instruction. We use
18 /// this scheme to model the def-use relationship between AMX config instruction
19 /// and other AMX instructions. Below is an example.
20 ///
21 ///                        ----B1----
22 ///                       /           \
23 ///                      /             \
24 ///                    B2               B3
25 ///    %1:tile = PTILELOADDV        %2:tile = PTILELOADDV
26 ///
27 ///  is transformed to
28 ///
29 ///                            B1
30 ///                 %25:tilecfg = PLDTILECFG
31 ///                       /           \
32 ///                      /             \
33 ///  %1:tile = PTILELOADDV %25    %2:tile = PTILELOADDV %25
34 //
35 //===----------------------------------------------------------------------===//
36 
37 #include "X86.h"
38 #include "X86InstrBuilder.h"
39 #include "X86RegisterInfo.h"
40 #include "X86Subtarget.h"
41 #include "llvm/CodeGen/MachineDominators.h"
42 #include "llvm/CodeGen/MachineFunctionPass.h"
43 #include "llvm/CodeGen/MachineInstr.h"
44 #include "llvm/CodeGen/MachineRegisterInfo.h"
45 #include "llvm/CodeGen/Passes.h"
46 #include "llvm/CodeGen/TargetInstrInfo.h"
47 #include "llvm/CodeGen/TargetRegisterInfo.h"
48 #include "llvm/CodeGen/TileShapeInfo.h"
49 #include "llvm/InitializePasses.h"
50 
51 using namespace llvm;
52 
53 #define DEBUG_TYPE "tile-pre-config"
54 
55 namespace {
56 
57 class X86PreTileConfig : public MachineFunctionPass {
58   // context
59   MachineFunction *MF = nullptr;
60   const X86Subtarget *ST = nullptr;
61   const TargetRegisterInfo *TRI;
62   const TargetInstrInfo *TII;
63   MachineDominatorTree *DomTree = nullptr;
64   MachineRegisterInfo *MRI = nullptr;
65 
66   MachineInstr *getTileConfigPoint();
67 
68 public:
69   X86PreTileConfig() : MachineFunctionPass(ID) {}
70 
71   /// Return the pass name.
72   StringRef getPassName() const override {
73     return "Tile Register Pre-configure";
74   }
75 
76   /// X86PreTileConfig analysis usage.
77   void getAnalysisUsage(AnalysisUsage &AU) const override;
78 
79   /// Perform register allocation.
80   bool runOnMachineFunction(MachineFunction &mf) override;
81 
82   static char ID;
83 };
84 
85 } // end anonymous namespace
86 
87 char X86PreTileConfig::ID = 0;
88 
89 INITIALIZE_PASS_BEGIN(X86PreTileConfig, "tilepreconfig",
90                       "Tile Register Pre-configure", false, false)
91 INITIALIZE_PASS_DEPENDENCY(MachineDominatorTree)
92 INITIALIZE_PASS_END(X86PreTileConfig, "tilepreconfig",
93                     "Tile Register Pre-configure", false, false)
94 
95 void X86PreTileConfig::getAnalysisUsage(AnalysisUsage &AU) const {
96   AU.setPreservesAll();
97   AU.addRequired<MachineDominatorTree>();
98   MachineFunctionPass::getAnalysisUsage(AU);
99 }
100 
101 static void buildConfigMI(MachineBasicBlock::iterator MI, int FrameIdx,
102                           const TargetInstrInfo *TII, MachineRegisterInfo *MRI,
103                           const X86Subtarget *ST) {
104   auto *MBB = MI->getParent();
105 
106   // Zero stack slot.
107   if (ST->hasAVX512()) {
108     Register Zmm = MRI->createVirtualRegister(&X86::VR512RegClass);
109     BuildMI(*MBB, MI, DebugLoc(), TII->get(X86::VPXORDZrr), Zmm)
110         .addReg(Zmm, RegState::Undef)
111         .addReg(Zmm, RegState::Undef);
112     addFrameReference(BuildMI(*MBB, MI, DebugLoc(), TII->get(X86::VMOVUPSZmr)),
113                       FrameIdx)
114         .addReg(Zmm);
115   } else if (ST->hasAVX2()) {
116     Register Ymm = MRI->createVirtualRegister(&X86::VR256RegClass);
117     BuildMI(*MBB, MI, DebugLoc(), TII->get(X86::VPXORYrr), Ymm)
118         .addReg(Ymm, RegState::Undef)
119         .addReg(Ymm, RegState::Undef);
120     addFrameReference(BuildMI(*MBB, MI, DebugLoc(), TII->get(X86::VMOVUPSYmr)),
121                       FrameIdx)
122         .addReg(Ymm);
123     addFrameReference(BuildMI(*MBB, MI, DebugLoc(), TII->get(X86::VMOVUPSYmr)),
124                       FrameIdx, 32)
125         .addReg(Ymm);
126   } else {
127     assert(ST->hasSSE2() && "AMX should assume SSE2 enabled");
128     Register Xmm = MRI->createVirtualRegister(&X86::VR128RegClass);
129     BuildMI(*MBB, MI, DebugLoc(), TII->get(X86::PXORrr), Xmm)
130         .addReg(Xmm, RegState::Undef)
131         .addReg(Xmm, RegState::Undef);
132     addFrameReference(BuildMI(*MBB, MI, DebugLoc(), TII->get(X86::MOVUPSmr)),
133                       FrameIdx)
134         .addReg(Xmm);
135     addFrameReference(BuildMI(*MBB, MI, DebugLoc(), TII->get(X86::MOVUPSmr)),
136                       FrameIdx, 16)
137         .addReg(Xmm);
138     addFrameReference(BuildMI(*MBB, MI, DebugLoc(), TII->get(X86::MOVUPSmr)),
139                       FrameIdx, 32)
140         .addReg(Xmm);
141     addFrameReference(BuildMI(*MBB, MI, DebugLoc(), TII->get(X86::MOVUPSmr)),
142                       FrameIdx, 48)
143         .addReg(Xmm);
144   }
145 
146   // build psuedo ldtilecfg
147   addFrameReference(BuildMI(*MBB, MI, DebugLoc(), TII->get(X86::LDTILECFG)),
148                     FrameIdx);
149 }
150 
151 static ShapeT getShape(const MachineInstr &MI, MachineRegisterInfo *MRI) {
152   unsigned Opcode = MI.getOpcode();
153   switch (Opcode) {
154   default:
155     llvm_unreachable("Unexpected machine instruction on tile");
156   case X86::PTILELOADDV:
157   case X86::PTDPBSSDV:
158   case X86::PTDPBSUDV:
159   case X86::PTDPBUSDV:
160   case X86::PTDPBUUDV:
161   case X86::PTILEZEROV:
162   case X86::PTDPBF16PSV:
163     MachineOperand &MO1 = const_cast<MachineOperand &>(MI.getOperand(1));
164     MachineOperand &MO2 = const_cast<MachineOperand &>(MI.getOperand(2));
165     ShapeT Shape(&MO1, &MO2, MRI);
166     return Shape;
167   }
168 }
169 
170 MachineInstr *X86PreTileConfig::getTileConfigPoint() {
171   DenseMap<Register, ShapeT> PhysShapeInfo;
172   MachineBasicBlock *MBB = nullptr;
173   DenseSet<const MachineInstr *> MIs;
174   for (unsigned i = 0, e = MRI->getNumVirtRegs(); i != e; ++i) {
175     Register VirtReg = Register::index2VirtReg(i);
176     if (MRI->reg_nodbg_empty(VirtReg))
177       continue;
178     const TargetRegisterClass &RC = *MRI->getRegClass(VirtReg);
179     if (RC.getID() != X86::TILERegClassID)
180       continue;
181 
182     // Find the common dominator for all MI that define tile register.
183     for (const MachineOperand &MO : MRI->def_operands(VirtReg)) {
184       if (MO.isUndef())
185         continue;
186       const auto *MI = MO.getParent();
187       // PHI or IMPLICIT_DEF instructiion.
188       // There must be a input tile before PHI instruction.
189       if (MI->isTransient())
190         continue;
191       if (!MBB)
192         MBB = const_cast<MachineBasicBlock *>(MI->getParent());
193       MBB = DomTree->findNearestCommonDominator(
194           MBB, const_cast<MachineBasicBlock *>(MI->getParent()));
195 
196       // Collect the instructions that define shape.
197       ShapeT Shape = getShape(*MI, MRI);
198       std::array<MachineOperand *, 2> ShapeMOs = {Shape.getRow(),
199                                                   Shape.getCol()};
200       for (auto *ShapeMO : ShapeMOs) {
201         Register ShapeReg = ShapeMO->getReg();
202         for (const MachineOperand &MO : MRI->def_operands(ShapeReg)) {
203           const auto *ShapeMI = MO.getParent();
204           MIs.insert(ShapeMI);
205         }
206       }
207     }
208   }
209   if (!MBB)
210     return nullptr;
211   // This pass is before the pass of eliminating PHI node, so it
212   // is in SSA form.
213   assert(MRI->isSSA() && "Not SSA form in pre-tile config");
214   // Shape def should dominate tile config MBB.
215   //    def s           s1    s2
216   //     / \             \   /
217   //    /   \             \ /
218   //  conf               s3=phi(s1,s2)
219   //                       |
220   //                       c
221   //
222   for (const auto *MI : MIs) {
223     const MachineBasicBlock *ShapeMBB = MI->getParent();
224     if (DomTree->dominates(ShapeMBB, MBB))
225       continue;
226     if (MI->isMoveImmediate())
227       continue;
228     report_fatal_error(MF->getName() + ": Failed to config tile register, "
229                                        "please define the shape earlier");
230   }
231 
232   // ldtilecfg should be inserted after the MI that define the shape.
233   MachineBasicBlock::reverse_instr_iterator I, E;
234   for (I = MBB->instr_rbegin(), E = MBB->instr_rend(); I != E; ++I) {
235     auto *MI = &*I;
236     if (MIs.count(MI) && (!MI->isMoveImmediate()))
237       break;
238   }
239   MachineBasicBlock::iterator MII;
240   if (I == E)
241     MII = MBB->getFirstNonPHI();
242   else {
243     MII = MachineBasicBlock::iterator(&*I);
244     MII++;
245   }
246   return &*MII;
247 }
248 
249 static bool isAMXInstruction(MachineBasicBlock::iterator MII) {
250   switch (MII->getOpcode()) {
251   default:
252     return false;
253   case X86::PTILELOADDV:
254   case X86::PTILESTOREDV:
255   case X86::PTDPBSSDV:
256   case X86::PTDPBSUDV:
257   case X86::PTDPBUSDV:
258   case X86::PTDPBUUDV:
259   case X86::PTILEZEROV:
260   case X86::PTDPBF16PSV:
261     return true;
262   }
263 }
264 
265 struct BBInfo {
266   bool HasAMX = false;
267   bool HasCallBeforeAMX = false;
268   bool HasAMXBeforeCallInSuccs = false;
269   MachineInstr *LastCall = nullptr;
270 
271   BBInfo() = default;
272   BBInfo(SmallSet<MachineInstr *, 8> &CfgNeedInsert, MachineBasicBlock *MBB,
273          MachineInstr *MI = nullptr) {
274     MachineBasicBlock::iterator MII = MI ? MI->getIterator() : MBB->begin();
275     for (auto E = MBB->end(); MII != E; ++MII) {
276       if (isAMXInstruction(MII)) {
277         HasAMX = true;
278         if (LastCall)
279           CfgNeedInsert.insert(LastCall);
280       } else if (MII->isCall()) {
281         LastCall = &*MII;
282         if (!HasAMX)
283           HasCallBeforeAMX = true;
284       }
285     }
286   }
287 };
288 
289 static void reloadTileConfig(MachineInstr *MI, int FI,
290                              const TargetInstrInfo *TII,
291                              const TargetRegisterInfo *TRI) {
292   SmallSet<MachineInstr *, 8> CfgNeedInsert;
293   SmallVector<MachineBasicBlock *, 8> WorkList;
294   DenseMap<MachineBasicBlock *, BBInfo> BBVisitedInfo;
295 
296   MachineBasicBlock *MBB = MI->getParent();
297   BBVisitedInfo[MBB] = BBInfo(CfgNeedInsert, MBB, MI);
298 
299   WorkList.push_back(MBB);
300   while (!WorkList.empty()) {
301     MBB = WorkList.pop_back_val();
302     for (auto I = MBB->succ_begin(), E = MBB->succ_end(); I != E; ++I) {
303       if (!BBVisitedInfo.count(*I)) {
304         BBVisitedInfo[*I] = BBInfo(CfgNeedInsert, *I);
305         WorkList.push_back(*I);
306       }
307     }
308   }
309 
310   WorkList.clear();
311   for (auto I : BBVisitedInfo) {
312     WorkList.push_back(I.first);
313     while (!WorkList.empty()) {
314       MBB = WorkList.pop_back_val();
315       if (BBVisitedInfo[MBB].HasCallBeforeAMX ||
316           (!BBVisitedInfo[MBB].HasAMX &&
317            !BBVisitedInfo[MBB].HasAMXBeforeCallInSuccs))
318         continue;
319       for (auto I = MBB->pred_begin(), E = MBB->pred_end(); I != E; ++I) {
320         if (!BBVisitedInfo.count(*I) ||
321             BBVisitedInfo[*I].HasAMXBeforeCallInSuccs)
322           continue;
323         if (BBVisitedInfo[*I].LastCall)
324           CfgNeedInsert.insert(BBVisitedInfo[*I].LastCall);
325         BBVisitedInfo[*I].HasAMXBeforeCallInSuccs = true;
326         WorkList.push_back(*I);
327       }
328     }
329   }
330 
331   for (auto *I : CfgNeedInsert) {
332     BitVector UsableRegs(TRI->getNumRegs());
333     const TargetRegisterClass *RC = TRI->getRegClass(X86::TILERegClassID);
334     for (unsigned J = 0; J < RC->getNumRegs(); J++)
335       UsableRegs.set(X86::TMM0 + J);
336     for (MachineOperand &CallMO : I->operands()) {
337       if (CallMO.isRegMask())
338         UsableRegs.clearBitsInMask(CallMO.getRegMask());
339     }
340     if (!UsableRegs.none())
341       addFrameReference(BuildMI(*I->getParent(), ++I->getIterator(), DebugLoc(),
342                                 TII->get(X86::LDTILECFG)),
343                         FI);
344   }
345 }
346 
347 bool X86PreTileConfig::runOnMachineFunction(MachineFunction &mf) {
348   MF = &mf;
349   MRI = &mf.getRegInfo();
350   ST = &mf.getSubtarget<X86Subtarget>();
351   TRI = ST->getRegisterInfo();
352   TII = mf.getSubtarget().getInstrInfo();
353   DomTree = &getAnalysis<MachineDominatorTree>();
354 
355   MachineInstr *MI = getTileConfigPoint();
356   if (!MI)
357     return false;
358   unsigned Size = ST->getTileConfigSize();
359   Align Alignment = ST->getTileConfigAlignment();
360   int SS = mf.getFrameInfo().CreateStackObject(Size, Alignment, false);
361   buildConfigMI(MI, SS, TII, MRI, ST);
362   reloadTileConfig(MI, SS, TII, TRI);
363   return true;
364 }
365 
366 FunctionPass *llvm::createX86PreTileConfigPass() {
367   return new X86PreTileConfig();
368 }
369