1 //===-- X86PreTileConfig.cpp - Tile Register Pre-configure-----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file Pass to pre-config the shapes of AMX registers
10 /// AMX register needs to be configured before use. The shapes of AMX register
11 /// are encoded in the 1st and 2nd machine operand of AMX pseudo instructions.
12 ///
13 /// The instruction ldtilecfg is used to config the shapes. It must be reachable
14 /// for all variable shapes. ldtilecfg will be inserted more than once if we
15 /// cannot find a dominating point for all AMX instructions.
16 ///
17 /// The configure register is caller saved according to ABI. We need to insert
18 /// ldtilecfg again after the call instruction if callee clobbers any AMX
19 /// registers.
20 ///
21 /// This pass calculates all points that ldtilecfg need to be inserted to and
22 /// insert them. It reports error if the reachability conditions aren't met.
23 //
24 //===----------------------------------------------------------------------===//
25 
26 #include "X86.h"
27 #include "X86InstrBuilder.h"
28 #include "X86RegisterInfo.h"
29 #include "X86Subtarget.h"
30 #include "llvm/CodeGen/MachineFunctionPass.h"
31 #include "llvm/CodeGen/MachineInstr.h"
32 #include "llvm/CodeGen/MachineLoopInfo.h"
33 #include "llvm/CodeGen/MachineRegisterInfo.h"
34 #include "llvm/CodeGen/Passes.h"
35 #include "llvm/CodeGen/TargetInstrInfo.h"
36 #include "llvm/CodeGen/TargetRegisterInfo.h"
37 #include "llvm/InitializePasses.h"
38 
39 using namespace llvm;
40 
41 #define DEBUG_TYPE "tile-pre-config"
42 #define ASSERT_VALID_COMPARE                                                   \
43   assert((!MBB || !RHS.MBB || MBB == RHS.MBB) &&                               \
44          "Cannot compare between different BBs");
45 #define REPORT_CONFIG_FAIL                                                     \
46   report_fatal_error(                                                          \
47       MF.getName() +                                                           \
48       ": Failed to config tile register, please define the shape earlier");
49 
50 namespace {
51 
52 struct MIRef {
53   MachineInstr *MI = nullptr;
54   MachineBasicBlock *MBB = nullptr;
55   // A virtual position for instruction that will be inserted after MI.
56   size_t Pos = 0;
57   MIRef() = default;
58   MIRef(MachineBasicBlock *MBB) : MBB(MBB) {
59     for (auto I = MBB->begin(), E = MBB->end(); I != E && I->isPHI();
60          ++I, ++Pos)
61       MI = &*I;
62   }
63   MIRef(MachineInstr *MI, MachineBasicBlock *MBB)
64       : MI(MI), MBB(MBB),
65         Pos(std::distance(MBB->instr_begin(), ++MI->getIterator())) {}
66   MIRef(MachineInstr *MI, MachineBasicBlock *MBB, size_t Pos)
67       : MI(MI), MBB(MBB), Pos(Pos) {}
68   operator bool() const { return MBB != nullptr; }
69   bool operator==(const MIRef &RHS) const {
70     return MI == RHS.MI && MBB == RHS.MBB;
71   }
72   bool operator<(const MIRef &RHS) const {
73     ASSERT_VALID_COMPARE;
74     return Pos < RHS.Pos;
75   }
76   bool operator>(const MIRef &RHS) const {
77     ASSERT_VALID_COMPARE;
78     return Pos > RHS.Pos;
79   }
80 };
81 
82 struct BBInfo {
83   MIRef FirstAMX;
84   MIRef LastCall;
85   MIRef LastShape;
86   bool TileCfgForbidden = false;
87   bool NeedTileCfgLiveIn = false;
88 };
89 
90 class X86PreTileConfig : public MachineFunctionPass {
91   MachineRegisterInfo *MRI;
92   const MachineLoopInfo *MLI;
93   SmallSet<MachineInstr *, 8> DefVisited;
94   SmallSet<MachineBasicBlock *, 8> ShapeBBs;
95   DenseMap<MachineBasicBlock *, BBInfo> BBVisitedInfo;
96 
97   /// Check if the callee will clobber AMX registers.
98   bool isDestructiveCall(MachineInstr &MI, BitVector UsableRegs) {
99     auto Iter = llvm::find_if(
100         MI.operands(), [](MachineOperand &MO) { return MO.isRegMask(); });
101     if (Iter == MI.operands_end())
102       return false;
103     UsableRegs.clearBitsInMask(Iter->getRegMask());
104     return !UsableRegs.none();
105   }
106 
107   /// Check if MI is AMX pseudo instruction.
108   bool isAMXInstruction(MachineInstr &MI) {
109     if (MI.isPHI() || MI.isDebugInstr() || MI.getNumOperands() < 3)
110       return false;
111     MachineOperand &MO = MI.getOperand(0);
112     // We can simply check if it is AMX instruction by its def.
113     // But we should exclude old API which uses physical registers.
114     if (MO.isReg() && MO.getReg().isVirtual() &&
115         MRI->getRegClass(MO.getReg())->getID() == X86::TILERegClassID) {
116       collectShapeInfo(MI);
117       return true;
118     }
119     // PTILESTOREDV is the only exception that doesn't def a AMX register.
120     return MI.getOpcode() == X86::PTILESTOREDV;
121   }
122 
123   /// Check if it is an edge from loop bottom to loop head.
124   bool isLoopBackEdge(MachineBasicBlock *Header, MachineBasicBlock *Bottom) {
125     return MLI->isLoopHeader(Header) &&
126            MLI->getLoopFor(Header)->getBottomBlock() == Bottom;
127   }
128 
129   /// Collect the shape def information for later use.
130   void collectShapeInfo(MachineInstr &MI);
131 
132 public:
133   X86PreTileConfig() : MachineFunctionPass(ID) {}
134 
135   /// Return the pass name.
136   StringRef getPassName() const override {
137     return "Tile Register Pre-configure";
138   }
139 
140   /// X86PreTileConfig analysis usage.
141   void getAnalysisUsage(AnalysisUsage &AU) const override {
142     AU.setPreservesAll();
143     AU.addRequired<MachineLoopInfo>();
144     MachineFunctionPass::getAnalysisUsage(AU);
145   }
146 
147   /// Clear MF related structures.
148   void releaseMemory() override {
149     ShapeBBs.clear();
150     DefVisited.clear();
151     BBVisitedInfo.clear();
152   }
153 
154   /// Perform ldtilecfg instructions inserting.
155   bool runOnMachineFunction(MachineFunction &MF) override;
156 
157   static char ID;
158 };
159 
160 } // end anonymous namespace
161 
162 char X86PreTileConfig::ID = 0;
163 
164 INITIALIZE_PASS_BEGIN(X86PreTileConfig, "tilepreconfig",
165                       "Tile Register Pre-configure", false, false)
166 INITIALIZE_PASS_DEPENDENCY(MachineLoopInfo)
167 INITIALIZE_PASS_END(X86PreTileConfig, "tilepreconfig",
168                     "Tile Register Pre-configure", false, false)
169 
170 void X86PreTileConfig::collectShapeInfo(MachineInstr &MI) {
171   auto RecordShape = [&](MachineInstr *MI, MachineBasicBlock *MBB) {
172     MIRef MIR(MI, MBB);
173     if (BBVisitedInfo[MBB].LastShape < MIR)
174       BBVisitedInfo[MBB].LastShape = MIR;
175     ShapeBBs.insert(MBB);
176   };
177 
178   SmallVector<Register, 8> WorkList(
179       {MI.getOperand(1).getReg(), MI.getOperand(2).getReg()});
180   while (!WorkList.empty()) {
181     Register R = WorkList.pop_back_val();
182     MachineInstr *DefMI = MRI->getVRegDef(R);
183     MachineBasicBlock *DefMBB = DefMI->getParent();
184     if (!DefMI || DefMI->isMoveImmediate() || !DefVisited.insert(DefMI).second)
185       continue;
186     if (DefMI->isPHI()) {
187       for (unsigned I = 1; I < DefMI->getNumOperands(); I += 2)
188         if (isLoopBackEdge(DefMBB, DefMI->getOperand(I + 1).getMBB()))
189           RecordShape(DefMI, DefMBB); // In this case, PHI is also a shape def.
190         else
191           WorkList.push_back(DefMI->getOperand(I).getReg());
192     } else {
193       RecordShape(DefMI, DefMBB);
194     }
195   }
196 }
197 
198 bool X86PreTileConfig::runOnMachineFunction(MachineFunction &MF) {
199   const X86Subtarget &ST = MF.getSubtarget<X86Subtarget>();
200   const TargetInstrInfo *TII = ST.getInstrInfo();
201   const TargetRegisterInfo *TRI = ST.getRegisterInfo();
202   const TargetRegisterClass *RC = TRI->getRegClass(X86::TILERegClassID);
203 
204   BitVector AMXRegs(TRI->getNumRegs());
205   for (unsigned I = 0; I < RC->getNumRegs(); I++)
206     AMXRegs.set(X86::TMM0 + I);
207 
208   // Iterate MF to collect information.
209   MRI = &MF.getRegInfo();
210   MLI = &getAnalysis<MachineLoopInfo>();
211   SmallSet<MIRef, 8> CfgNeedInsert;
212   SmallVector<MachineBasicBlock *, 8> CfgLiveInBBs;
213   for (auto &MBB : MF) {
214     size_t Pos = 0;
215     for (auto &MI : MBB) {
216       ++Pos;
217       if (isAMXInstruction(MI)) {
218         // If there's call before the AMX, we need to reload tile config.
219         if (BBVisitedInfo[&MBB].LastCall)
220           CfgNeedInsert.insert(BBVisitedInfo[&MBB].LastCall);
221         else // Otherwise, we need tile config to live in this BB.
222           BBVisitedInfo[&MBB].NeedTileCfgLiveIn = true;
223         // Always record the first AMX in case there's shape def after it.
224         if (!BBVisitedInfo[&MBB].FirstAMX)
225           BBVisitedInfo[&MBB].FirstAMX = MIRef(&MI, &MBB, Pos);
226       } else if (MI.isCall() && isDestructiveCall(MI, AMXRegs)) {
227         // Record the call only if the callee clobbers all AMX registers.
228         BBVisitedInfo[&MBB].LastCall = MIRef(&MI, &MBB, Pos);
229       }
230     }
231     if (BBVisitedInfo[&MBB].NeedTileCfgLiveIn) {
232       if (&MBB == &MF.front())
233         CfgNeedInsert.insert(MIRef(&MBB));
234       else
235         CfgLiveInBBs.push_back(&MBB);
236     }
237   }
238 
239   // Update NeedTileCfgLiveIn for predecessors.
240   while (!CfgLiveInBBs.empty()) {
241     MachineBasicBlock *MBB = CfgLiveInBBs.pop_back_val();
242     for (auto *Pred : MBB->predecessors()) {
243       if (BBVisitedInfo[Pred].LastCall) {
244         CfgNeedInsert.insert(BBVisitedInfo[Pred].LastCall);
245       } else if (!BBVisitedInfo[Pred].NeedTileCfgLiveIn) {
246         BBVisitedInfo[Pred].NeedTileCfgLiveIn = true;
247         if (Pred == &MF.front())
248           CfgNeedInsert.insert(MIRef(Pred));
249         else
250           CfgLiveInBBs.push_back(Pred);
251       }
252     }
253   }
254 
255   // There's no AMX instruction if we didn't find a tile config live in point.
256   if (CfgNeedInsert.empty())
257     return false;
258 
259   // Avoid to insert ldtilecfg before any shape defs.
260   SmallVector<MachineBasicBlock *, 8> WorkList(
261       make_range(ShapeBBs.begin(), ShapeBBs.end()));
262   while (!WorkList.empty()) {
263     MachineBasicBlock *MBB = WorkList.pop_back_val();
264     for (auto *Pred : MBB->predecessors()) {
265       if (!BBVisitedInfo[Pred].TileCfgForbidden && !isLoopBackEdge(MBB, Pred)) {
266         BBVisitedInfo[Pred].TileCfgForbidden = true;
267         WorkList.push_back(Pred);
268       }
269     }
270   }
271 
272   DebugLoc DL;
273   SmallSet<MIRef, 8> VisitedOrInserted;
274   int SS = MF.getFrameInfo().CreateStackObject(
275       ST.getTileConfigSize(), ST.getTileConfigAlignment(), false);
276 
277   // Try to insert for the tile config live in points.
278   for (auto I : CfgNeedInsert) {
279     SmallSet<MIRef, 8> InsertPoints;
280     SmallVector<MIRef, 8> WorkList({I});
281     while (!WorkList.empty()) {
282       MIRef I = WorkList.pop_back_val();
283       if (!VisitedOrInserted.count(I)) {
284         if (!BBVisitedInfo[I.MBB].TileCfgForbidden) {
285           // If the BB is all shapes reachable, stop sink and try to insert.
286           InsertPoints.insert(I);
287         } else {
288           // Avoid the BB to be multi visited.
289           VisitedOrInserted.insert(I);
290           // We cannot sink it across any AMX instruction.
291           if (BBVisitedInfo[I.MBB].FirstAMX)
292             REPORT_CONFIG_FAIL;
293           // Sink the inserting point along the chain with NeedTileCfgLiveIn =
294           // true when MBB isn't all shapes reachable.
295           for (auto *Succ : I.MBB->successors())
296             if (BBVisitedInfo[Succ].NeedTileCfgLiveIn)
297               WorkList.push_back(MIRef(Succ));
298         }
299       }
300     }
301 
302     // A given point might be forked due to shape conditions are not met.
303     for (MIRef I : InsertPoints) {
304       // Even MBB is all shapes reachable, we still need to check if there's
305       // AMX that intersects with shapes in the same MBB.
306       if (BBVisitedInfo[I.MBB].FirstAMX &&
307           BBVisitedInfo[I.MBB].FirstAMX < BBVisitedInfo[I.MBB].LastShape)
308         REPORT_CONFIG_FAIL;
309       // Make sure we insert ldtilecfg after the last shape def in MBB.
310       if (I < BBVisitedInfo[I.MBB].LastShape)
311         I = BBVisitedInfo[I.MBB].LastShape;
312       // There're chances the MBB is sunk more than once. Record it to avoid
313       // multi insert.
314       if (VisitedOrInserted.insert(I).second) {
315         auto II = I.MI ? I.MI->getIterator() : I.MBB->instr_begin();
316         addFrameReference(BuildMI(*I.MBB, ++II, DL, TII->get(X86::LDTILECFG)),
317                           SS);
318       }
319     }
320   }
321 
322   // Zero stack slot.
323   MachineBasicBlock &MBB = MF.front();
324   MachineInstr *MI = &*MBB.begin();
325   if (ST.hasAVX512()) {
326     Register Zmm = MRI->createVirtualRegister(&X86::VR512RegClass);
327     BuildMI(MBB, MI, DL, TII->get(X86::VPXORDZrr), Zmm)
328         .addReg(Zmm, RegState::Undef)
329         .addReg(Zmm, RegState::Undef);
330     addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::VMOVUPSZmr)), SS)
331         .addReg(Zmm);
332   } else if (ST.hasAVX2()) {
333     Register Ymm = MRI->createVirtualRegister(&X86::VR256RegClass);
334     BuildMI(MBB, MI, DL, TII->get(X86::VPXORYrr), Ymm)
335         .addReg(Ymm, RegState::Undef)
336         .addReg(Ymm, RegState::Undef);
337     addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::VMOVUPSYmr)), SS)
338         .addReg(Ymm);
339     addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::VMOVUPSYmr)), SS, 32)
340         .addReg(Ymm);
341   } else {
342     assert(ST.hasSSE2() && "AMX should assume SSE2 enabled");
343     Register Xmm = MRI->createVirtualRegister(&X86::VR128RegClass);
344     BuildMI(MBB, MI, DL, TII->get(X86::PXORrr), Xmm)
345         .addReg(Xmm, RegState::Undef)
346         .addReg(Xmm, RegState::Undef);
347     addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOVUPSmr)), SS)
348         .addReg(Xmm);
349     addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOVUPSmr)), SS, 16)
350         .addReg(Xmm);
351     addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOVUPSmr)), SS, 32)
352         .addReg(Xmm);
353     addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOVUPSmr)), SS, 48)
354         .addReg(Xmm);
355   }
356   // Fill in the palette first.
357   addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOV8mi)), SS).addImm(1);
358 
359   return true;
360 }
361 
362 FunctionPass *llvm::createX86PreTileConfigPass() {
363   return new X86PreTileConfig();
364 }
365