1 //===-------------- RISCVSExtWRemoval.cpp - MI sext.w Removal -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===---------------------------------------------------------------------===// 8 // 9 // This pass removes unneeded sext.w instructions at the MI level. 10 // 11 //===---------------------------------------------------------------------===// 12 13 #include "RISCV.h" 14 #include "RISCVSubtarget.h" 15 #include "llvm/ADT/Statistic.h" 16 #include "llvm/CodeGen/MachineFunctionPass.h" 17 #include "llvm/CodeGen/TargetInstrInfo.h" 18 19 using namespace llvm; 20 21 #define DEBUG_TYPE "riscv-sextw-removal" 22 23 STATISTIC(NumRemovedSExtW, "Number of removed sign-extensions"); 24 25 static cl::opt<bool> DisableSExtWRemoval("riscv-disable-sextw-removal", 26 cl::desc("Disable removal of sext.w"), 27 cl::init(false), cl::Hidden); 28 namespace { 29 30 class RISCVSExtWRemoval : public MachineFunctionPass { 31 public: 32 static char ID; 33 34 RISCVSExtWRemoval() : MachineFunctionPass(ID) { 35 initializeRISCVSExtWRemovalPass(*PassRegistry::getPassRegistry()); 36 } 37 38 bool runOnMachineFunction(MachineFunction &MF) override; 39 40 void getAnalysisUsage(AnalysisUsage &AU) const override { 41 AU.setPreservesCFG(); 42 MachineFunctionPass::getAnalysisUsage(AU); 43 } 44 45 StringRef getPassName() const override { return "RISCV sext.w Removal"; } 46 }; 47 48 } // end anonymous namespace 49 50 char RISCVSExtWRemoval::ID = 0; 51 INITIALIZE_PASS(RISCVSExtWRemoval, DEBUG_TYPE, "RISCV sext.w Removal", false, 52 false) 53 54 FunctionPass *llvm::createRISCVSExtWRemovalPass() { 55 return new RISCVSExtWRemoval(); 56 } 57 58 // This function returns true if the machine instruction always outputs a value 59 // where bits 63:32 match bit 31. 60 // TODO: Allocate a bit in TSFlags for the W instructions? 61 // TODO: Add other W instructions. 62 static bool isSignExtendingOpW(const MachineInstr &MI) { 63 switch (MI.getOpcode()) { 64 case RISCV::LUI: 65 case RISCV::LW: 66 case RISCV::ADDW: 67 case RISCV::ADDIW: 68 case RISCV::SUBW: 69 case RISCV::MULW: 70 case RISCV::SLLW: 71 case RISCV::SLLIW: 72 case RISCV::SRAW: 73 case RISCV::SRAIW: 74 case RISCV::SRLW: 75 case RISCV::SRLIW: 76 case RISCV::DIVW: 77 case RISCV::DIVUW: 78 case RISCV::REMW: 79 case RISCV::REMUW: 80 case RISCV::ROLW: 81 case RISCV::RORW: 82 case RISCV::RORIW: 83 case RISCV::CLZW: 84 case RISCV::CTZW: 85 case RISCV::CPOPW: 86 case RISCV::FCVT_W_H: 87 case RISCV::FCVT_WU_H: 88 case RISCV::FCVT_W_S: 89 case RISCV::FCVT_WU_S: 90 case RISCV::FCVT_W_D: 91 case RISCV::FCVT_WU_D: 92 case RISCV::FMV_X_W: 93 // The following aren't W instructions, but are either sign extended from a 94 // smaller size or put zeros in bits 63:31. 95 case RISCV::LBU: 96 case RISCV::LHU: 97 case RISCV::LB: 98 case RISCV::LH: 99 case RISCV::SLT: 100 case RISCV::SLTI: 101 case RISCV::SLTU: 102 case RISCV::SLTIU: 103 case RISCV::SEXT_B: 104 case RISCV::SEXT_H: 105 case RISCV::ZEXT_H_RV64: 106 case RISCV::FMV_X_H: 107 return true; 108 // shifting right sufficiently makes the value 32-bit sign-extended 109 case RISCV::SRAI: 110 return MI.getOperand(2).getImm() >= 32; 111 case RISCV::SRLI: 112 return MI.getOperand(2).getImm() > 32; 113 // The LI pattern ADDI rd, X0, imm is sign extended. 114 case RISCV::ADDI: 115 return MI.getOperand(1).isReg() && MI.getOperand(1).getReg() == RISCV::X0; 116 // An ANDI with an 11 bit immediate will zero bits 63:11. 117 case RISCV::ANDI: 118 return isUInt<11>(MI.getOperand(2).getImm()); 119 // An ORI with an >11 bit immediate (negative 12-bit) will set bits 63:11. 120 case RISCV::ORI: 121 return !isUInt<11>(MI.getOperand(2).getImm()); 122 // Copying from X0 produces zero. 123 case RISCV::COPY: 124 return MI.getOperand(1).getReg() == RISCV::X0; 125 } 126 127 return false; 128 } 129 130 static bool isSignExtendedW(const MachineInstr &OrigMI, 131 MachineRegisterInfo &MRI) { 132 133 SmallPtrSet<const MachineInstr *, 4> Visited; 134 SmallVector<const MachineInstr *, 4> Worklist; 135 136 Worklist.push_back(&OrigMI); 137 138 while (!Worklist.empty()) { 139 const MachineInstr *MI = Worklist.pop_back_val(); 140 141 // If we already visited this instruction, we don't need to check it again. 142 if (!Visited.insert(MI).second) 143 continue; 144 145 // If this is a sign extending operation we don't need to look any further. 146 if (isSignExtendingOpW(*MI)) 147 continue; 148 149 // Is this an instruction that propagates sign extend. 150 switch (MI->getOpcode()) { 151 default: 152 // Unknown opcode, give up. 153 return false; 154 case RISCV::COPY: { 155 Register SrcReg = MI->getOperand(1).getReg(); 156 157 // TODO: Handle arguments and returns from calls? 158 159 // If this is a copy from another register, check its source instruction. 160 if (!SrcReg.isVirtual()) 161 return false; 162 const MachineInstr *SrcMI = MRI.getVRegDef(SrcReg); 163 if (!SrcMI) 164 return false; 165 166 // Add SrcMI to the worklist. 167 Worklist.push_back(SrcMI); 168 break; 169 } 170 case RISCV::REM: 171 case RISCV::ANDI: 172 case RISCV::ORI: 173 case RISCV::XORI: { 174 // |Remainder| is always <= |Dividend|. If D is 32-bit, then so is R. 175 // DIV doesn't work because of the edge case 0xf..f 8000 0000 / (long)-1 176 // Logical operations use a sign extended 12-bit immediate. We just need 177 // to check if the other operand is sign extended. 178 Register SrcReg = MI->getOperand(1).getReg(); 179 if (!SrcReg.isVirtual()) 180 return false; 181 const MachineInstr *SrcMI = MRI.getVRegDef(SrcReg); 182 if (!SrcMI) 183 return false; 184 185 // Add SrcMI to the worklist. 186 Worklist.push_back(SrcMI); 187 break; 188 } 189 case RISCV::REMU: 190 case RISCV::AND: 191 case RISCV::OR: 192 case RISCV::XOR: 193 case RISCV::ANDN: 194 case RISCV::ORN: 195 case RISCV::XNOR: 196 case RISCV::MAX: 197 case RISCV::MAXU: 198 case RISCV::MIN: 199 case RISCV::MINU: 200 case RISCV::PHI: { 201 // If all incoming values are sign-extended, the output of AND, OR, XOR, 202 // MIN, MAX, or PHI is also sign-extended. 203 204 // The input registers for PHI are operand 1, 3, ... 205 // The input registers for others are operand 1 and 2. 206 unsigned E = 3, D = 1; 207 if (MI->getOpcode() == RISCV::PHI) { 208 E = MI->getNumOperands(); 209 D = 2; 210 } 211 212 for (unsigned I = 1; I != E; I += D) { 213 if (!MI->getOperand(I).isReg()) 214 return false; 215 216 Register SrcReg = MI->getOperand(I).getReg(); 217 if (!SrcReg.isVirtual()) 218 return false; 219 const MachineInstr *SrcMI = MRI.getVRegDef(SrcReg); 220 if (!SrcMI) 221 return false; 222 223 // Add SrcMI to the worklist. 224 Worklist.push_back(SrcMI); 225 } 226 227 break; 228 } 229 } 230 } 231 232 // If we get here, then every node we visited produces a sign extended value 233 // or propagated sign extended values. So the result must be sign extended. 234 return true; 235 } 236 237 bool RISCVSExtWRemoval::runOnMachineFunction(MachineFunction &MF) { 238 if (skipFunction(MF.getFunction()) || DisableSExtWRemoval) 239 return false; 240 241 MachineRegisterInfo &MRI = MF.getRegInfo(); 242 const RISCVSubtarget &ST = MF.getSubtarget<RISCVSubtarget>(); 243 244 if (!ST.is64Bit()) 245 return false; 246 247 bool MadeChange = false; 248 for (MachineBasicBlock &MBB : MF) { 249 for (auto I = MBB.begin(), IE = MBB.end(); I != IE;) { 250 MachineInstr *MI = &*I++; 251 252 // We're looking for the sext.w pattern ADDIW rd, rs1, 0. 253 if (MI->getOpcode() != RISCV::ADDIW || !MI->getOperand(2).isImm() || 254 MI->getOperand(2).getImm() != 0 || !MI->getOperand(1).isReg()) 255 continue; 256 257 // Input should be a virtual register. 258 Register SrcReg = MI->getOperand(1).getReg(); 259 if (!SrcReg.isVirtual()) 260 continue; 261 262 const MachineInstr &SrcMI = *MRI.getVRegDef(SrcReg); 263 if (!isSignExtendedW(SrcMI, MRI)) 264 continue; 265 266 Register DstReg = MI->getOperand(0).getReg(); 267 if (!MRI.constrainRegClass(SrcReg, MRI.getRegClass(DstReg))) 268 continue; 269 270 LLVM_DEBUG(dbgs() << "Removing redundant sign-extension\n"); 271 MRI.replaceRegWith(DstReg, SrcReg); 272 MRI.clearKillFlags(SrcReg); 273 MI->eraseFromParent(); 274 ++NumRemovedSExtW; 275 MadeChange = true; 276 } 277 } 278 279 return MadeChange; 280 } 281