1 //===-------------- RISCVSExtWRemoval.cpp - MI sext.w Removal -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===---------------------------------------------------------------------===// 8 // 9 // This pass removes unneeded sext.w instructions at the MI level. 10 // 11 //===---------------------------------------------------------------------===// 12 13 #include "RISCV.h" 14 #include "RISCVSubtarget.h" 15 #include "llvm/ADT/Statistic.h" 16 #include "llvm/CodeGen/MachineFunctionPass.h" 17 #include "llvm/CodeGen/TargetInstrInfo.h" 18 19 using namespace llvm; 20 21 #define DEBUG_TYPE "riscv-sextw-removal" 22 23 STATISTIC(NumRemovedSExtW, "Number of removed sign-extensions"); 24 STATISTIC(NumTransformedToWInstrs, 25 "Number of instructions transformed to W-ops"); 26 27 static cl::opt<bool> DisableSExtWRemoval("riscv-disable-sextw-removal", 28 cl::desc("Disable removal of sext.w"), 29 cl::init(false), cl::Hidden); 30 namespace { 31 32 class RISCVSExtWRemoval : public MachineFunctionPass { 33 public: 34 static char ID; 35 36 RISCVSExtWRemoval() : MachineFunctionPass(ID) { 37 initializeRISCVSExtWRemovalPass(*PassRegistry::getPassRegistry()); 38 } 39 40 bool runOnMachineFunction(MachineFunction &MF) override; 41 42 void getAnalysisUsage(AnalysisUsage &AU) const override { 43 AU.setPreservesCFG(); 44 MachineFunctionPass::getAnalysisUsage(AU); 45 } 46 47 StringRef getPassName() const override { return "RISCV sext.w Removal"; } 48 }; 49 50 } // end anonymous namespace 51 52 char RISCVSExtWRemoval::ID = 0; 53 INITIALIZE_PASS(RISCVSExtWRemoval, DEBUG_TYPE, "RISCV sext.w Removal", false, 54 false) 55 56 FunctionPass *llvm::createRISCVSExtWRemovalPass() { 57 return new RISCVSExtWRemoval(); 58 } 59 60 // add uses of MI to the Worklist 61 static void addUses(const MachineInstr &MI, 62 SmallVectorImpl<const MachineInstr *> &Worklist, 63 MachineRegisterInfo &MRI) { 64 for (auto &UserOp : MRI.reg_operands(MI.getOperand(0).getReg())) { 65 const auto *User = UserOp.getParent(); 66 if (User == &MI) // ignore the def, current MI 67 continue; 68 Worklist.push_back(User); 69 } 70 } 71 72 // returns true if all uses of OrigMI only depend on the lower word of its 73 // output, so we can transform OrigMI to the corresponding W-version. 74 // TODO: handle multiple interdependent transformations 75 static bool isAllUsesReadW(const MachineInstr &OrigMI, 76 MachineRegisterInfo &MRI) { 77 78 SmallPtrSet<const MachineInstr *, 4> Visited; 79 SmallVector<const MachineInstr *, 4> Worklist; 80 81 Visited.insert(&OrigMI); 82 addUses(OrigMI, Worklist, MRI); 83 84 while (!Worklist.empty()) { 85 const MachineInstr *MI = Worklist.pop_back_val(); 86 87 if (!Visited.insert(MI).second) { 88 // If we've looped back to OrigMI through a PHI cycle, we can't transform 89 // LD or LWU, because these operations use all 64 bits of input. 90 if (MI == &OrigMI) { 91 unsigned opcode = MI->getOpcode(); 92 if (opcode == RISCV::LD || opcode == RISCV::LWU) 93 return false; 94 } 95 continue; 96 } 97 98 switch (MI->getOpcode()) { 99 case RISCV::ADDIW: 100 case RISCV::ADDW: 101 case RISCV::DIVUW: 102 case RISCV::DIVW: 103 case RISCV::MULW: 104 case RISCV::REMUW: 105 case RISCV::REMW: 106 case RISCV::SLLIW: 107 case RISCV::SLLW: 108 case RISCV::SRAIW: 109 case RISCV::SRAW: 110 case RISCV::SRLIW: 111 case RISCV::SRLW: 112 case RISCV::SUBW: 113 case RISCV::ROLW: 114 case RISCV::RORW: 115 case RISCV::RORIW: 116 case RISCV::CLZW: 117 case RISCV::CTZW: 118 case RISCV::CPOPW: 119 case RISCV::SLLI_UW: 120 case RISCV::FCVT_S_W: 121 case RISCV::FCVT_S_WU: 122 case RISCV::FCVT_D_W: 123 case RISCV::FCVT_D_WU: 124 continue; 125 126 // these overwrite higher input bits, otherwise the lower word of output 127 // depends only on the lower word of input. So check their uses read W. 128 case RISCV::SLLI: 129 if (MI->getOperand(2).getImm() >= 32) 130 continue; 131 addUses(*MI, Worklist, MRI); 132 continue; 133 case RISCV::ANDI: 134 if (isUInt<11>(MI->getOperand(2).getImm())) 135 continue; 136 addUses(*MI, Worklist, MRI); 137 continue; 138 case RISCV::ORI: 139 if (!isUInt<11>(MI->getOperand(2).getImm())) 140 continue; 141 addUses(*MI, Worklist, MRI); 142 continue; 143 144 case RISCV::BEXTI: 145 if (MI->getOperand(2).getImm() >= 32) 146 return false; 147 continue; 148 149 // For these, lower word of output in these operations, depends only on 150 // the lower word of input. So, we check all uses only read lower word. 151 case RISCV::COPY: 152 case RISCV::PHI: 153 154 case RISCV::ADD: 155 case RISCV::ADDI: 156 case RISCV::AND: 157 case RISCV::MUL: 158 case RISCV::OR: 159 case RISCV::SLL: 160 case RISCV::SUB: 161 case RISCV::XOR: 162 case RISCV::XORI: 163 164 case RISCV::ADD_UW: 165 case RISCV::ANDN: 166 case RISCV::CLMUL: 167 case RISCV::ORC_B: 168 case RISCV::ORN: 169 case RISCV::SEXT_B: 170 case RISCV::SEXT_H: 171 case RISCV::SH1ADD: 172 case RISCV::SH1ADD_UW: 173 case RISCV::SH2ADD: 174 case RISCV::SH2ADD_UW: 175 case RISCV::SH3ADD: 176 case RISCV::SH3ADD_UW: 177 case RISCV::XNOR: 178 case RISCV::ZEXT_H_RV64: 179 addUses(*MI, Worklist, MRI); 180 continue; 181 default: 182 return false; 183 } 184 } 185 return true; 186 } 187 188 // This function returns true if the machine instruction always outputs a value 189 // where bits 63:32 match bit 31. 190 // Alternatively, if the instruction can be converted to W variant 191 // (e.g. ADD->ADDW) and all of its uses only use the lower word of its output, 192 // then return true and add the instr to FixableDef to be convereted later 193 // TODO: Allocate a bit in TSFlags for the W instructions? 194 // TODO: Add other W instructions. 195 static bool isSignExtendingOpW(MachineInstr &MI, MachineRegisterInfo &MRI, 196 SmallPtrSetImpl<MachineInstr *> &FixableDef) { 197 switch (MI.getOpcode()) { 198 case RISCV::LUI: 199 case RISCV::LW: 200 case RISCV::ADDW: 201 case RISCV::ADDIW: 202 case RISCV::SUBW: 203 case RISCV::MULW: 204 case RISCV::SLLW: 205 case RISCV::SLLIW: 206 case RISCV::SRAW: 207 case RISCV::SRAIW: 208 case RISCV::SRLW: 209 case RISCV::SRLIW: 210 case RISCV::DIVW: 211 case RISCV::DIVUW: 212 case RISCV::REMW: 213 case RISCV::REMUW: 214 case RISCV::ROLW: 215 case RISCV::RORW: 216 case RISCV::RORIW: 217 case RISCV::CLZW: 218 case RISCV::CTZW: 219 case RISCV::CPOPW: 220 case RISCV::FCVT_W_H: 221 case RISCV::FCVT_WU_H: 222 case RISCV::FCVT_W_S: 223 case RISCV::FCVT_WU_S: 224 case RISCV::FCVT_W_D: 225 case RISCV::FCVT_WU_D: 226 case RISCV::FMV_X_W: 227 // The following aren't W instructions, but are either sign extended from a 228 // smaller size, always outputs a small integer, or put zeros in bits 63:31. 229 case RISCV::LBU: 230 case RISCV::LHU: 231 case RISCV::LB: 232 case RISCV::LH: 233 case RISCV::SLT: 234 case RISCV::SLTI: 235 case RISCV::SLTU: 236 case RISCV::SLTIU: 237 case RISCV::SEXT_B: 238 case RISCV::SEXT_H: 239 case RISCV::ZEXT_H_RV64: 240 case RISCV::FMV_X_H: 241 case RISCV::BEXT: 242 case RISCV::BEXTI: 243 case RISCV::CLZ: 244 case RISCV::CPOP: 245 case RISCV::CTZ: 246 return true; 247 // shifting right sufficiently makes the value 32-bit sign-extended 248 case RISCV::SRAI: 249 return MI.getOperand(2).getImm() >= 32; 250 case RISCV::SRLI: 251 return MI.getOperand(2).getImm() > 32; 252 // The LI pattern ADDI rd, X0, imm is sign extended. 253 case RISCV::ADDI: 254 if (MI.getOperand(1).isReg() && MI.getOperand(1).getReg() == RISCV::X0) 255 return true; 256 if (isAllUsesReadW(MI, MRI)) { 257 // transform to ADDIW 258 FixableDef.insert(&MI); 259 return true; 260 } 261 return false; 262 // An ANDI with an 11 bit immediate will zero bits 63:11. 263 case RISCV::ANDI: 264 return isUInt<11>(MI.getOperand(2).getImm()); 265 // An ORI with an >11 bit immediate (negative 12-bit) will set bits 63:11. 266 case RISCV::ORI: 267 return !isUInt<11>(MI.getOperand(2).getImm()); 268 // Copying from X0 produces zero. 269 case RISCV::COPY: 270 return MI.getOperand(1).getReg() == RISCV::X0; 271 272 // With these opcode, we can "fix" them with the W-version 273 // if we know all users of the result only rely on bits 31:0 274 case RISCV::SLLI: 275 // SLLIW reads the lowest 5 bits, while SLLI reads lowest 6 bits 276 if (MI.getOperand(2).getImm() >= 32) 277 return false; 278 LLVM_FALLTHROUGH; 279 case RISCV::ADD: 280 case RISCV::LD: 281 case RISCV::LWU: 282 case RISCV::MUL: 283 case RISCV::SUB: 284 if (isAllUsesReadW(MI, MRI)) { 285 FixableDef.insert(&MI); 286 return true; 287 } 288 } 289 290 return false; 291 } 292 293 static bool isSignExtendedW(MachineInstr &OrigMI, MachineRegisterInfo &MRI, 294 SmallPtrSetImpl<MachineInstr *> &FixableDef) { 295 296 SmallPtrSet<const MachineInstr *, 4> Visited; 297 SmallVector<MachineInstr *, 4> Worklist; 298 299 Worklist.push_back(&OrigMI); 300 301 while (!Worklist.empty()) { 302 MachineInstr *MI = Worklist.pop_back_val(); 303 304 // If we already visited this instruction, we don't need to check it again. 305 if (!Visited.insert(MI).second) 306 continue; 307 308 // If this is a sign extending operation we don't need to look any further. 309 if (isSignExtendingOpW(*MI, MRI, FixableDef)) 310 continue; 311 312 // Is this an instruction that propagates sign extend. 313 switch (MI->getOpcode()) { 314 default: 315 // Unknown opcode, give up. 316 return false; 317 case RISCV::COPY: { 318 Register SrcReg = MI->getOperand(1).getReg(); 319 320 // TODO: Handle arguments and returns from calls? 321 322 // If this is a copy from another register, check its source instruction. 323 if (!SrcReg.isVirtual()) 324 return false; 325 MachineInstr *SrcMI = MRI.getVRegDef(SrcReg); 326 if (!SrcMI) 327 return false; 328 329 // Add SrcMI to the worklist. 330 Worklist.push_back(SrcMI); 331 break; 332 } 333 334 // For these, we just need to check if the 1st operand is sign extended. 335 case RISCV::BCLRI: 336 case RISCV::BINVI: 337 case RISCV::BSETI: 338 if (MI->getOperand(2).getImm() >= 31) 339 return false; 340 LLVM_FALLTHROUGH; 341 case RISCV::REM: 342 case RISCV::ANDI: 343 case RISCV::ORI: 344 case RISCV::XORI: { 345 // |Remainder| is always <= |Dividend|. If D is 32-bit, then so is R. 346 // DIV doesn't work because of the edge case 0xf..f 8000 0000 / (long)-1 347 // Logical operations use a sign extended 12-bit immediate. 348 Register SrcReg = MI->getOperand(1).getReg(); 349 if (!SrcReg.isVirtual()) 350 return false; 351 MachineInstr *SrcMI = MRI.getVRegDef(SrcReg); 352 if (!SrcMI) 353 return false; 354 355 // Add SrcMI to the worklist. 356 Worklist.push_back(SrcMI); 357 break; 358 } 359 case RISCV::REMU: 360 case RISCV::AND: 361 case RISCV::OR: 362 case RISCV::XOR: 363 case RISCV::ANDN: 364 case RISCV::ORN: 365 case RISCV::XNOR: 366 case RISCV::MAX: 367 case RISCV::MAXU: 368 case RISCV::MIN: 369 case RISCV::MINU: 370 case RISCV::PHI: { 371 // If all incoming values are sign-extended, the output of AND, OR, XOR, 372 // MIN, MAX, or PHI is also sign-extended. 373 374 // The input registers for PHI are operand 1, 3, ... 375 // The input registers for others are operand 1 and 2. 376 unsigned E = 3, D = 1; 377 if (MI->getOpcode() == RISCV::PHI) { 378 E = MI->getNumOperands(); 379 D = 2; 380 } 381 382 for (unsigned I = 1; I != E; I += D) { 383 if (!MI->getOperand(I).isReg()) 384 return false; 385 386 Register SrcReg = MI->getOperand(I).getReg(); 387 if (!SrcReg.isVirtual()) 388 return false; 389 MachineInstr *SrcMI = MRI.getVRegDef(SrcReg); 390 if (!SrcMI) 391 return false; 392 393 // Add SrcMI to the worklist. 394 Worklist.push_back(SrcMI); 395 } 396 397 break; 398 } 399 } 400 } 401 402 // If we get here, then every node we visited produces a sign extended value 403 // or propagated sign extended values. So the result must be sign extended. 404 return true; 405 } 406 407 static unsigned getWOp(unsigned Opcode) { 408 switch (Opcode) { 409 case RISCV::ADDI: 410 return RISCV::ADDIW; 411 case RISCV::ADD: 412 return RISCV::ADDW; 413 case RISCV::LD: 414 case RISCV::LWU: 415 return RISCV::LW; 416 case RISCV::MUL: 417 return RISCV::MULW; 418 case RISCV::SLLI: 419 return RISCV::SLLIW; 420 case RISCV::SUB: 421 return RISCV::SUBW; 422 default: 423 llvm_unreachable("Unexpected opcode for replacement with W variant"); 424 } 425 } 426 427 bool RISCVSExtWRemoval::runOnMachineFunction(MachineFunction &MF) { 428 if (skipFunction(MF.getFunction()) || DisableSExtWRemoval) 429 return false; 430 431 MachineRegisterInfo &MRI = MF.getRegInfo(); 432 const RISCVSubtarget &ST = MF.getSubtarget<RISCVSubtarget>(); 433 434 if (!ST.is64Bit()) 435 return false; 436 437 SmallPtrSet<MachineInstr *, 4> SExtWRemovalCands; 438 439 // Replacing instructions invalidates the MI iterator 440 // we collect the candidates, then iterate over them separately. 441 for (MachineBasicBlock &MBB : MF) { 442 for (auto I = MBB.begin(), IE = MBB.end(); I != IE;) { 443 MachineInstr *MI = &*I++; 444 445 // We're looking for the sext.w pattern ADDIW rd, rs1, 0. 446 if (MI->getOpcode() != RISCV::ADDIW || !MI->getOperand(2).isImm() || 447 MI->getOperand(2).getImm() != 0 || !MI->getOperand(1).isReg()) 448 continue; 449 450 // Input should be a virtual register. 451 Register SrcReg = MI->getOperand(1).getReg(); 452 if (!SrcReg.isVirtual()) 453 continue; 454 455 SExtWRemovalCands.insert(MI); 456 } 457 } 458 459 bool MadeChange = false; 460 for (auto MI : SExtWRemovalCands) { 461 SmallPtrSet<MachineInstr *, 4> FixableDef; 462 Register SrcReg = MI->getOperand(1).getReg(); 463 MachineInstr &SrcMI = *MRI.getVRegDef(SrcReg); 464 465 // If all definitions reaching MI sign-extend their output, 466 // then sext.w is redundant 467 if (!isSignExtendedW(SrcMI, MRI, FixableDef)) 468 continue; 469 470 Register DstReg = MI->getOperand(0).getReg(); 471 if (!MRI.constrainRegClass(SrcReg, MRI.getRegClass(DstReg))) 472 continue; 473 // Replace Fixable instructions with their W versions. 474 for (MachineInstr *Fixable : FixableDef) { 475 MachineBasicBlock &MBB = *Fixable->getParent(); 476 const DebugLoc &DL = Fixable->getDebugLoc(); 477 unsigned Code = getWOp(Fixable->getOpcode()); 478 MachineInstrBuilder Replacement = 479 BuildMI(MBB, Fixable, DL, ST.getInstrInfo()->get(Code)); 480 for (auto Op : Fixable->operands()) 481 Replacement.add(Op); 482 for (auto Op : Fixable->memoperands()) 483 Replacement.addMemOperand(Op); 484 485 LLVM_DEBUG(dbgs() << "Replacing " << *Fixable); 486 LLVM_DEBUG(dbgs() << " with " << *Replacement); 487 488 Fixable->eraseFromParent(); 489 ++NumTransformedToWInstrs; 490 } 491 492 LLVM_DEBUG(dbgs() << "Removing redundant sign-extension\n"); 493 MRI.replaceRegWith(DstReg, SrcReg); 494 MRI.clearKillFlags(SrcReg); 495 MI->eraseFromParent(); 496 ++NumRemovedSExtW; 497 MadeChange = true; 498 } 499 500 return MadeChange; 501 } 502