1 //===-------------- RISCVSExtWRemoval.cpp - MI sext.w Removal -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===---------------------------------------------------------------------===// 8 // 9 // This pass removes unneeded sext.w instructions at the MI level. 10 // 11 //===---------------------------------------------------------------------===// 12 13 #include "RISCV.h" 14 #include "RISCVSubtarget.h" 15 #include "llvm/ADT/Statistic.h" 16 #include "llvm/CodeGen/MachineFunctionPass.h" 17 #include "llvm/CodeGen/TargetInstrInfo.h" 18 19 using namespace llvm; 20 21 #define DEBUG_TYPE "riscv-sextw-removal" 22 23 STATISTIC(NumRemovedSExtW, "Number of removed sign-extensions"); 24 25 static cl::opt<bool> DisableSExtWRemoval("riscv-disable-sextw-removal", 26 cl::desc("Disable removal of sext.w"), 27 cl::init(false), cl::Hidden); 28 namespace { 29 30 class RISCVSExtWRemoval : public MachineFunctionPass { 31 public: 32 static char ID; 33 34 RISCVSExtWRemoval() : MachineFunctionPass(ID) { 35 initializeRISCVSExtWRemovalPass(*PassRegistry::getPassRegistry()); 36 } 37 38 bool runOnMachineFunction(MachineFunction &MF) override; 39 40 void getAnalysisUsage(AnalysisUsage &AU) const override { 41 AU.setPreservesCFG(); 42 MachineFunctionPass::getAnalysisUsage(AU); 43 } 44 45 StringRef getPassName() const override { return "RISCV sext.w Removal"; } 46 }; 47 48 } // end anonymous namespace 49 50 char RISCVSExtWRemoval::ID = 0; 51 INITIALIZE_PASS(RISCVSExtWRemoval, DEBUG_TYPE, "RISCV sext.w Removal", false, 52 false) 53 54 FunctionPass *llvm::createRISCVSExtWRemovalPass() { 55 return new RISCVSExtWRemoval(); 56 } 57 58 // This function returns true if the machine instruction always outputs a value 59 // where bits 63:32 match bit 31. 60 // TODO: Allocate a bit in TSFlags for the W instructions? 61 // TODO: Add other W instructions. 62 static bool isSignExtendingOpW(const MachineInstr &MI) { 63 switch (MI.getOpcode()) { 64 case RISCV::LUI: 65 case RISCV::LW: 66 case RISCV::ADDW: 67 case RISCV::ADDIW: 68 case RISCV::SUBW: 69 case RISCV::MULW: 70 case RISCV::SLLW: 71 case RISCV::SLLIW: 72 case RISCV::SRAW: 73 case RISCV::SRAIW: 74 case RISCV::SRLW: 75 case RISCV::SRLIW: 76 case RISCV::DIVW: 77 case RISCV::DIVUW: 78 case RISCV::REMW: 79 case RISCV::REMUW: 80 case RISCV::ROLW: 81 case RISCV::RORW: 82 case RISCV::RORIW: 83 case RISCV::CLZW: 84 case RISCV::CTZW: 85 case RISCV::CPOPW: 86 case RISCV::FCVT_W_H: 87 case RISCV::FCVT_WU_H: 88 case RISCV::FCVT_W_S: 89 case RISCV::FCVT_WU_S: 90 case RISCV::FCVT_W_D: 91 case RISCV::FCVT_WU_D: 92 // The following aren't W instructions, but are either sign extended from a 93 // smaller size or put zeros in bits 63:31. 94 case RISCV::LBU: 95 case RISCV::LHU: 96 case RISCV::LB: 97 case RISCV::LH: 98 case RISCV::SEXTB: 99 case RISCV::SEXTH: 100 case RISCV::ZEXTH_RV64: 101 return true; 102 } 103 104 // The LI pattern ADDI rd, X0, imm is sign extended. 105 if (MI.getOpcode() == RISCV::ADDI && MI.getOperand(1).isReg() && 106 MI.getOperand(1).getReg() == RISCV::X0) 107 return true; 108 109 // An ANDI with an 11 bit immediate will zero bits 63:11. 110 if (MI.getOpcode() == RISCV::ANDI && isUInt<11>(MI.getOperand(2).getImm())) 111 return true; 112 113 // Copying from X0 produces zero. 114 if (MI.getOpcode() == RISCV::COPY && MI.getOperand(1).getReg() == RISCV::X0) 115 return true; 116 117 return false; 118 } 119 120 static bool isSignExtendedW(const MachineInstr &OrigMI, 121 MachineRegisterInfo &MRI) { 122 123 SmallPtrSet<const MachineInstr *, 4> Visited; 124 SmallVector<const MachineInstr *, 4> Worklist; 125 126 Worklist.push_back(&OrigMI); 127 128 while (!Worklist.empty()) { 129 const MachineInstr *MI = Worklist.pop_back_val(); 130 131 // If we already visited this instruction, we don't need to check it again. 132 if (!Visited.insert(MI).second) 133 continue; 134 135 // If this is a sign extending operation we don't need to look any further. 136 if (isSignExtendingOpW(*MI)) 137 continue; 138 139 // Is this an instruction that propagates sign extend. 140 switch (MI->getOpcode()) { 141 default: 142 // Unknown opcode, give up. 143 return false; 144 case RISCV::COPY: { 145 Register SrcReg = MI->getOperand(1).getReg(); 146 147 // TODO: Handle arguments and returns from calls? 148 149 // If this is a copy from another register, check its source instruction. 150 if (!SrcReg.isVirtual()) 151 return false; 152 const MachineInstr *SrcMI = MRI.getVRegDef(SrcReg); 153 if (!SrcMI) 154 return false; 155 156 // Add SrcMI to the worklist. 157 Worklist.push_back(SrcMI); 158 break; 159 } 160 case RISCV::ANDI: 161 case RISCV::ORI: 162 case RISCV::XORI: { 163 // Logical operations use a sign extended 12-bit immediate. We just need 164 // to check if the other operand is sign extended. 165 Register SrcReg = MI->getOperand(1).getReg(); 166 if (!SrcReg.isVirtual()) 167 return false; 168 const MachineInstr *SrcMI = MRI.getVRegDef(SrcReg); 169 if (!SrcMI) 170 return false; 171 172 // Add SrcMI to the worklist. 173 Worklist.push_back(SrcMI); 174 break; 175 } 176 case RISCV::AND: 177 case RISCV::OR: 178 case RISCV::XOR: 179 case RISCV::ANDN: 180 case RISCV::ORN: 181 case RISCV::XNOR: 182 case RISCV::MAX: 183 case RISCV::MAXU: 184 case RISCV::MIN: 185 case RISCV::MINU: 186 case RISCV::PHI: { 187 // If all incoming values are sign-extended, the output of AND, OR, XOR, 188 // MIN, MAX, or PHI is also sign-extended. 189 190 // The input registers for PHI are operand 1, 3, ... 191 // The input registers for others are operand 1 and 2. 192 unsigned E = 3, D = 1; 193 if (MI->getOpcode() == RISCV::PHI) { 194 E = MI->getNumOperands(); 195 D = 2; 196 } 197 198 for (unsigned I = 1; I != E; I += D) { 199 if (!MI->getOperand(I).isReg()) 200 return false; 201 202 Register SrcReg = MI->getOperand(I).getReg(); 203 if (!SrcReg.isVirtual()) 204 return false; 205 const MachineInstr *SrcMI = MRI.getVRegDef(SrcReg); 206 if (!SrcMI) 207 return false; 208 209 // Add SrcMI to the worklist. 210 Worklist.push_back(SrcMI); 211 } 212 213 break; 214 } 215 } 216 } 217 218 // If we get here, then every node we visited produces a sign extended value 219 // or propagated sign extended values. So the result must be sign extended. 220 return true; 221 } 222 223 bool RISCVSExtWRemoval::runOnMachineFunction(MachineFunction &MF) { 224 if (skipFunction(MF.getFunction()) || DisableSExtWRemoval) 225 return false; 226 227 MachineRegisterInfo &MRI = MF.getRegInfo(); 228 const RISCVSubtarget &ST = MF.getSubtarget<RISCVSubtarget>(); 229 230 if (!ST.is64Bit()) 231 return false; 232 233 bool MadeChange = false; 234 for (MachineBasicBlock &MBB : MF) { 235 for (auto I = MBB.begin(), IE = MBB.end(); I != IE;) { 236 MachineInstr *MI = &*I++; 237 238 // We're looking for the sext.w pattern ADDIW rd, rs1, 0. 239 if (MI->getOpcode() != RISCV::ADDIW || !MI->getOperand(2).isImm() || 240 MI->getOperand(2).getImm() != 0 || !MI->getOperand(1).isReg()) 241 continue; 242 243 // Input should be a virtual register. 244 Register SrcReg = MI->getOperand(1).getReg(); 245 if (!SrcReg.isVirtual()) 246 continue; 247 248 const MachineInstr &SrcMI = *MRI.getVRegDef(SrcReg); 249 if (!isSignExtendedW(SrcMI, MRI)) 250 continue; 251 252 Register DstReg = MI->getOperand(0).getReg(); 253 if (!MRI.constrainRegClass(SrcReg, MRI.getRegClass(DstReg))) 254 continue; 255 256 LLVM_DEBUG(dbgs() << "Removing redundant sign-extension\n"); 257 MRI.replaceRegWith(DstReg, SrcReg); 258 MRI.clearKillFlags(SrcReg); 259 MI->eraseFromParent(); 260 ++NumRemovedSExtW; 261 MadeChange = true; 262 } 263 } 264 265 return MadeChange; 266 } 267