//===-- X86PreTileConfig.cpp - Tile Register Configure---------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // /// \file Pass to pre-config the shape of AMX register /// AMX register need to be configured before use. The shape of AMX register /// is encoded in the 1st and 2nd machine operand of AMX pseudo instructions. /// The pldtilecfg is to config tile registers. It should dominator all AMX /// instructions. The pldtilecfg produce a virtual cfg register and the cfg /// register is used by all AMX instructions. /// This pass is to find the common dominator of all AMX instructions and /// insert the pldtilecfg instruction. Besides the cfg register that pldtilecfg /// produces is inserted as the last operand of each AMX instruction. We use /// this scheme to model the def-use relationship between AMX config instruction /// and other AMX instructions. Below is an example. /// /// ----B1---- /// / \ /// / \ /// B2 B3 /// %1:tile = PTILELOADDV %2:tile = PTILELOADDV /// /// is transformed to /// /// B1 /// %25:tilecfg = PLDTILECFG /// / \ /// / \ /// %1:tile = PTILELOADDV %25 %2:tile = PTILELOADDV %25 // //===----------------------------------------------------------------------===// #include "X86.h" #include "X86InstrBuilder.h" #include "X86RegisterInfo.h" #include "X86Subtarget.h" #include "llvm/CodeGen/MachineDominators.h" #include "llvm/CodeGen/MachineFunctionPass.h" #include "llvm/CodeGen/MachineInstr.h" #include "llvm/CodeGen/MachineRegisterInfo.h" #include "llvm/CodeGen/Passes.h" #include "llvm/CodeGen/TargetInstrInfo.h" #include "llvm/CodeGen/TargetRegisterInfo.h" #include "llvm/CodeGen/TileShapeInfo.h" #include "llvm/InitializePasses.h" using namespace llvm; #define DEBUG_TYPE "tile-pre-config" namespace { class X86PreTileConfig : public MachineFunctionPass { // context MachineFunction *MF = nullptr; const X86Subtarget *ST = nullptr; const TargetRegisterInfo *TRI; const TargetInstrInfo *TII; MachineDominatorTree *DomTree = nullptr; MachineRegisterInfo *MRI = nullptr; MachineInstr *getTileConfigPoint(); public: X86PreTileConfig() : MachineFunctionPass(ID) {} /// Return the pass name. StringRef getPassName() const override { return "Tile Register Pre-configure"; } /// X86PreTileConfig analysis usage. void getAnalysisUsage(AnalysisUsage &AU) const override; /// Perform register allocation. bool runOnMachineFunction(MachineFunction &mf) override; static char ID; }; } // end anonymous namespace char X86PreTileConfig::ID = 0; INITIALIZE_PASS_BEGIN(X86PreTileConfig, "tilepreconfig", "Tile Register Configure", false, false) INITIALIZE_PASS_DEPENDENCY(MachineDominatorTree) INITIALIZE_PASS_END(X86PreTileConfig, "tilepreconfig", "Tile Register Configure", false, false) void X86PreTileConfig::getAnalysisUsage(AnalysisUsage &AU) const { AU.setPreservesAll(); AU.addRequired(); MachineFunctionPass::getAnalysisUsage(AU); } static void buildConfigMI(MachineBasicBlock::iterator MI, int FrameIdx, const TargetInstrInfo *TII, MachineRegisterInfo *MRI, const X86Subtarget *ST) { auto *MBB = MI->getParent(); // Zero stack slot. if (ST->hasAVX512()) { Register Zmm = MRI->createVirtualRegister(&X86::VR512RegClass); BuildMI(*MBB, MI, DebugLoc(), TII->get(X86::VPXORDZrr), Zmm) .addReg(Zmm, RegState::Undef) .addReg(Zmm, RegState::Undef); addFrameReference(BuildMI(*MBB, MI, DebugLoc(), TII->get(X86::VMOVUPSZmr)), FrameIdx) .addReg(Zmm); } else if (ST->hasAVX2()) { Register Ymm = MRI->createVirtualRegister(&X86::VR256RegClass); BuildMI(*MBB, MI, DebugLoc(), TII->get(X86::VPXORYrr), Ymm) .addReg(Ymm, RegState::Undef) .addReg(Ymm, RegState::Undef); addFrameReference(BuildMI(*MBB, MI, DebugLoc(), TII->get(X86::VMOVUPSYmr)), FrameIdx) .addReg(Ymm); addFrameReference(BuildMI(*MBB, MI, DebugLoc(), TII->get(X86::VMOVUPSYmr)), FrameIdx, 32) .addReg(Ymm); } else { assert(ST->hasSSE2() && "AMX should assume SSE2 enabled"); Register Xmm = MRI->createVirtualRegister(&X86::VR128RegClass); BuildMI(*MBB, MI, DebugLoc(), TII->get(X86::PXORrr), Xmm) .addReg(Xmm, RegState::Undef) .addReg(Xmm, RegState::Undef); addFrameReference(BuildMI(*MBB, MI, DebugLoc(), TII->get(X86::MOVUPSmr)), FrameIdx) .addReg(Xmm); addFrameReference(BuildMI(*MBB, MI, DebugLoc(), TII->get(X86::MOVUPSmr)), FrameIdx, 16) .addReg(Xmm); addFrameReference(BuildMI(*MBB, MI, DebugLoc(), TII->get(X86::MOVUPSmr)), FrameIdx, 32) .addReg(Xmm); addFrameReference(BuildMI(*MBB, MI, DebugLoc(), TII->get(X86::MOVUPSmr)), FrameIdx, 48) .addReg(Xmm); } // build psuedo ldtilecfg addFrameReference(BuildMI(*MBB, MI, DebugLoc(), TII->get(X86::LDTILECFG)), FrameIdx); } static ShapeT getShape(const MachineInstr &MI, MachineRegisterInfo *MRI) { unsigned Opcode = MI.getOpcode(); switch (Opcode) { default: llvm_unreachable("Unexpected machine instruction on tile"); case X86::PTILELOADDV: case X86::PTDPBSSDV: case X86::PTILEZEROV: MachineOperand &MO1 = const_cast(MI.getOperand(1)); MachineOperand &MO2 = const_cast(MI.getOperand(2)); ShapeT Shape(&MO1, &MO2, MRI); return Shape; } } MachineInstr *X86PreTileConfig::getTileConfigPoint() { DenseMap PhysShapeInfo; MachineBasicBlock *MBB = nullptr; DenseSet MIs; for (unsigned i = 0, e = MRI->getNumVirtRegs(); i != e; ++i) { Register VirtReg = Register::index2VirtReg(i); if (MRI->reg_nodbg_empty(VirtReg)) continue; const TargetRegisterClass &RC = *MRI->getRegClass(VirtReg); if (RC.getID() != X86::TILERegClassID) continue; // Find the common dominator for all MI that define tile register. for (const MachineOperand &MO : MRI->def_operands(VirtReg)) { if (MO.isUndef()) continue; const auto *MI = MO.getParent(); // PHI or IMPLICIT_DEF instructiion. // There must be a input tile before PHI instruction. if (MI->isTransient()) continue; if (!MBB) MBB = const_cast(MI->getParent()); MBB = DomTree->findNearestCommonDominator( MBB, const_cast(MI->getParent())); // Collect the instructions that define shape. ShapeT Shape = getShape(*MI, MRI); std::array ShapeMOs = {Shape.getRow(), Shape.getCol()}; for (auto *ShapeMO : ShapeMOs) { Register ShapeReg = ShapeMO->getReg(); for (const MachineOperand &MO : MRI->def_operands(ShapeReg)) { const auto *ShapeMI = MO.getParent(); MIs.insert(ShapeMI); } } } } if (!MBB) return nullptr; // This pass is before the pass of eliminating PHI node, so it // is in SSA form. assert(MRI->isSSA() && "Not SSA form in pre-tile config"); // Shape def should dominate tile config MBB. // def s s1 s2 // / \ \ / // / \ \ / // conf s3=phi(s1,s2) // | // c // for (const auto *MI : MIs) { const MachineBasicBlock *ShapeMBB = MI->getParent(); if (DomTree->dominates(ShapeMBB, MBB)) continue; if (MI->isMoveImmediate()) continue; report_fatal_error(MF->getName() + ": Failed to config tile register, " "please define the shape earlier"); } // ldtilecfg should be inserted after the MI that define the shape. MachineBasicBlock::reverse_instr_iterator I, E; for (I = MBB->instr_rbegin(), E = MBB->instr_rend(); I != E; ++I) { auto *MI = &*I; if (MIs.count(MI) && (!MI->isMoveImmediate())) break; } MachineBasicBlock::iterator MII; if (I == E) MII = MBB->getFirstNonPHI(); else { MII = MachineBasicBlock::iterator(&*I); MII++; } return &*MII; } static bool isAMXInstruction(MachineBasicBlock::iterator MII) { switch (MII->getOpcode()) { default: return false; case X86::PTILELOADDV: case X86::PTILESTOREDV: case X86::PTDPBSSDV: case X86::PTILEZEROV: return true; } } struct BBInfo { bool HasAMX = false; bool HasCallBeforeAMX = false; bool HasAMXBeforeCallInSuccs = false; MachineInstr *LastCall = nullptr; BBInfo() = default; BBInfo(SmallSet &CfgNeedInsert, MachineBasicBlock *MBB, MachineInstr *MI = nullptr) { MachineBasicBlock::iterator MII = MI ? MI->getIterator() : MBB->begin(); for (auto E = MBB->end(); MII != E; ++MII) { if (isAMXInstruction(MII)) { HasAMX = true; if (LastCall) CfgNeedInsert.insert(LastCall); } else if (MII->isCall()) { LastCall = &*MII; if (!HasAMX) HasCallBeforeAMX = true; } } } }; static void reloadTileConfig(MachineInstr *MI, int FI, const TargetInstrInfo *TII, const TargetRegisterInfo *TRI) { SmallSet CfgNeedInsert; SmallVector WorkList; DenseMap BBVisitedInfo; MachineBasicBlock *MBB = MI->getParent(); BBVisitedInfo[MBB] = BBInfo(CfgNeedInsert, MBB, MI); WorkList.push_back(MBB); while (!WorkList.empty()) { MBB = WorkList.pop_back_val(); for (auto I = MBB->succ_begin(), E = MBB->succ_end(); I != E; ++I) { if (!BBVisitedInfo.count(*I)) { BBVisitedInfo[*I] = BBInfo(CfgNeedInsert, *I); WorkList.push_back(*I); } } } WorkList.clear(); for (auto I : BBVisitedInfo) { WorkList.push_back(I.first); while (!WorkList.empty()) { MBB = WorkList.pop_back_val(); if (BBVisitedInfo[MBB].HasCallBeforeAMX || (!BBVisitedInfo[MBB].HasAMX && !BBVisitedInfo[MBB].HasAMXBeforeCallInSuccs)) continue; for (auto I = MBB->pred_begin(), E = MBB->pred_end(); I != E; ++I) { if (!BBVisitedInfo.count(*I) || BBVisitedInfo[*I].HasAMXBeforeCallInSuccs) continue; if (BBVisitedInfo[*I].LastCall) CfgNeedInsert.insert(BBVisitedInfo[*I].LastCall); BBVisitedInfo[*I].HasAMXBeforeCallInSuccs = true; WorkList.push_back(*I); } } } for (auto *I : CfgNeedInsert) { BitVector UsableRegs(TRI->getNumRegs()); const TargetRegisterClass *RC = TRI->getRegClass(X86::TILERegClassID); for (unsigned J = 0; J < RC->getNumRegs(); J++) UsableRegs.set(X86::TMM0 + J); for (MachineOperand &CallMO : I->operands()) { if (CallMO.isRegMask()) UsableRegs.clearBitsInMask(CallMO.getRegMask()); } if (!UsableRegs.none()) addFrameReference(BuildMI(*I->getParent(), ++I->getIterator(), DebugLoc(), TII->get(X86::LDTILECFG)), FI); } } bool X86PreTileConfig::runOnMachineFunction(MachineFunction &mf) { MF = &mf; MRI = &mf.getRegInfo(); ST = &mf.getSubtarget(); TRI = ST->getRegisterInfo(); TII = mf.getSubtarget().getInstrInfo(); DomTree = &getAnalysis(); MachineInstr *MI = getTileConfigPoint(); if (!MI) return false; unsigned Size = ST->getTileConfigSize(); Align Alignment = ST->getTileConfigAlignment(); int SS = mf.getFrameInfo().CreateStackObject(Size, Alignment, false); buildConfigMI(MI, SS, TII, MRI, ST); reloadTileConfig(MI, SS, TII, TRI); return true; } FunctionPass *llvm::createX86PreTileConfigPass() { return new X86PreTileConfig(); }