1 //===-------------- RISCVSExtWRemoval.cpp - MI sext.w Removal -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===---------------------------------------------------------------------===//
8 //
9 // This pass removes unneeded sext.w instructions at the MI level.
10 //
11 //===---------------------------------------------------------------------===//
12
13 #include "RISCV.h"
14 #include "RISCVSubtarget.h"
15 #include "llvm/ADT/Statistic.h"
16 #include "llvm/CodeGen/MachineFunctionPass.h"
17 #include "llvm/CodeGen/TargetInstrInfo.h"
18
19 using namespace llvm;
20
21 #define DEBUG_TYPE "riscv-sextw-removal"
22
23 STATISTIC(NumRemovedSExtW, "Number of removed sign-extensions");
24 STATISTIC(NumTransformedToWInstrs,
25 "Number of instructions transformed to W-ops");
26
27 static cl::opt<bool> DisableSExtWRemoval("riscv-disable-sextw-removal",
28 cl::desc("Disable removal of sext.w"),
29 cl::init(false), cl::Hidden);
30 namespace {
31
32 class RISCVSExtWRemoval : public MachineFunctionPass {
33 public:
34 static char ID;
35
RISCVSExtWRemoval()36 RISCVSExtWRemoval() : MachineFunctionPass(ID) {
37 initializeRISCVSExtWRemovalPass(*PassRegistry::getPassRegistry());
38 }
39
40 bool runOnMachineFunction(MachineFunction &MF) override;
41
getAnalysisUsage(AnalysisUsage & AU) const42 void getAnalysisUsage(AnalysisUsage &AU) const override {
43 AU.setPreservesCFG();
44 MachineFunctionPass::getAnalysisUsage(AU);
45 }
46
getPassName() const47 StringRef getPassName() const override { return "RISCV sext.w Removal"; }
48 };
49
50 } // end anonymous namespace
51
52 char RISCVSExtWRemoval::ID = 0;
53 INITIALIZE_PASS(RISCVSExtWRemoval, DEBUG_TYPE, "RISCV sext.w Removal", false,
54 false)
55
createRISCVSExtWRemovalPass()56 FunctionPass *llvm::createRISCVSExtWRemovalPass() {
57 return new RISCVSExtWRemoval();
58 }
59
60 // add uses of MI to the Worklist
addUses(const MachineInstr & MI,SmallVectorImpl<const MachineInstr * > & Worklist,MachineRegisterInfo & MRI)61 static void addUses(const MachineInstr &MI,
62 SmallVectorImpl<const MachineInstr *> &Worklist,
63 MachineRegisterInfo &MRI) {
64 for (auto &UserOp : MRI.reg_operands(MI.getOperand(0).getReg())) {
65 const auto *User = UserOp.getParent();
66 if (User == &MI) // ignore the def, current MI
67 continue;
68 Worklist.push_back(User);
69 }
70 }
71
72 // returns true if all uses of OrigMI only depend on the lower word of its
73 // output, so we can transform OrigMI to the corresponding W-version.
74 // TODO: handle multiple interdependent transformations
isAllUsesReadW(const MachineInstr & OrigMI,MachineRegisterInfo & MRI)75 static bool isAllUsesReadW(const MachineInstr &OrigMI,
76 MachineRegisterInfo &MRI) {
77
78 SmallPtrSet<const MachineInstr *, 4> Visited;
79 SmallVector<const MachineInstr *, 4> Worklist;
80
81 Visited.insert(&OrigMI);
82 addUses(OrigMI, Worklist, MRI);
83
84 while (!Worklist.empty()) {
85 const MachineInstr *MI = Worklist.pop_back_val();
86
87 if (!Visited.insert(MI).second) {
88 // If we've looped back to OrigMI through a PHI cycle, we can't transform
89 // LD or LWU, because these operations use all 64 bits of input.
90 if (MI == &OrigMI) {
91 unsigned opcode = MI->getOpcode();
92 if (opcode == RISCV::LD || opcode == RISCV::LWU)
93 return false;
94 }
95 continue;
96 }
97
98 switch (MI->getOpcode()) {
99 case RISCV::ADDIW:
100 case RISCV::ADDW:
101 case RISCV::DIVUW:
102 case RISCV::DIVW:
103 case RISCV::MULW:
104 case RISCV::REMUW:
105 case RISCV::REMW:
106 case RISCV::SLLIW:
107 case RISCV::SLLW:
108 case RISCV::SRAIW:
109 case RISCV::SRAW:
110 case RISCV::SRLIW:
111 case RISCV::SRLW:
112 case RISCV::SUBW:
113 case RISCV::ROLW:
114 case RISCV::RORW:
115 case RISCV::RORIW:
116 case RISCV::CLZW:
117 case RISCV::CTZW:
118 case RISCV::CPOPW:
119 case RISCV::SLLI_UW:
120 case RISCV::FCVT_S_W:
121 case RISCV::FCVT_S_WU:
122 case RISCV::FCVT_D_W:
123 case RISCV::FCVT_D_WU:
124 continue;
125
126 // these overwrite higher input bits, otherwise the lower word of output
127 // depends only on the lower word of input. So check their uses read W.
128 case RISCV::SLLI:
129 if (MI->getOperand(2).getImm() >= 32)
130 continue;
131 addUses(*MI, Worklist, MRI);
132 continue;
133 case RISCV::ANDI:
134 if (isUInt<11>(MI->getOperand(2).getImm()))
135 continue;
136 addUses(*MI, Worklist, MRI);
137 continue;
138 case RISCV::ORI:
139 if (!isUInt<11>(MI->getOperand(2).getImm()))
140 continue;
141 addUses(*MI, Worklist, MRI);
142 continue;
143
144 case RISCV::BEXTI:
145 if (MI->getOperand(2).getImm() >= 32)
146 return false;
147 continue;
148
149 // For these, lower word of output in these operations, depends only on
150 // the lower word of input. So, we check all uses only read lower word.
151 case RISCV::COPY:
152 case RISCV::PHI:
153
154 case RISCV::ADD:
155 case RISCV::ADDI:
156 case RISCV::AND:
157 case RISCV::MUL:
158 case RISCV::OR:
159 case RISCV::SLL:
160 case RISCV::SUB:
161 case RISCV::XOR:
162 case RISCV::XORI:
163
164 case RISCV::ADD_UW:
165 case RISCV::ANDN:
166 case RISCV::CLMUL:
167 case RISCV::ORC_B:
168 case RISCV::ORN:
169 case RISCV::SEXT_B:
170 case RISCV::SEXT_H:
171 case RISCV::SH1ADD:
172 case RISCV::SH1ADD_UW:
173 case RISCV::SH2ADD:
174 case RISCV::SH2ADD_UW:
175 case RISCV::SH3ADD:
176 case RISCV::SH3ADD_UW:
177 case RISCV::XNOR:
178 case RISCV::ZEXT_H_RV64:
179 addUses(*MI, Worklist, MRI);
180 continue;
181 default:
182 return false;
183 }
184 }
185 return true;
186 }
187
188 // This function returns true if the machine instruction always outputs a value
189 // where bits 63:32 match bit 31.
190 // Alternatively, if the instruction can be converted to W variant
191 // (e.g. ADD->ADDW) and all of its uses only use the lower word of its output,
192 // then return true and add the instr to FixableDef to be convereted later
193 // TODO: Allocate a bit in TSFlags for the W instructions?
194 // TODO: Add other W instructions.
isSignExtendingOpW(MachineInstr & MI,MachineRegisterInfo & MRI,SmallPtrSetImpl<MachineInstr * > & FixableDef)195 static bool isSignExtendingOpW(MachineInstr &MI, MachineRegisterInfo &MRI,
196 SmallPtrSetImpl<MachineInstr *> &FixableDef) {
197 switch (MI.getOpcode()) {
198 case RISCV::LUI:
199 case RISCV::LW:
200 case RISCV::ADDW:
201 case RISCV::ADDIW:
202 case RISCV::SUBW:
203 case RISCV::MULW:
204 case RISCV::SLLW:
205 case RISCV::SLLIW:
206 case RISCV::SRAW:
207 case RISCV::SRAIW:
208 case RISCV::SRLW:
209 case RISCV::SRLIW:
210 case RISCV::DIVW:
211 case RISCV::DIVUW:
212 case RISCV::REMW:
213 case RISCV::REMUW:
214 case RISCV::ROLW:
215 case RISCV::RORW:
216 case RISCV::RORIW:
217 case RISCV::CLZW:
218 case RISCV::CTZW:
219 case RISCV::CPOPW:
220 case RISCV::FCVT_W_H:
221 case RISCV::FCVT_WU_H:
222 case RISCV::FCVT_W_S:
223 case RISCV::FCVT_WU_S:
224 case RISCV::FCVT_W_D:
225 case RISCV::FCVT_WU_D:
226 case RISCV::FMV_X_W:
227 // The following aren't W instructions, but are either sign extended from a
228 // smaller size, always outputs a small integer, or put zeros in bits 63:31.
229 case RISCV::LBU:
230 case RISCV::LHU:
231 case RISCV::LB:
232 case RISCV::LH:
233 case RISCV::SLT:
234 case RISCV::SLTI:
235 case RISCV::SLTU:
236 case RISCV::SLTIU:
237 case RISCV::SEXT_B:
238 case RISCV::SEXT_H:
239 case RISCV::ZEXT_H_RV64:
240 case RISCV::FMV_X_H:
241 case RISCV::BEXT:
242 case RISCV::BEXTI:
243 case RISCV::CLZ:
244 case RISCV::CPOP:
245 case RISCV::CTZ:
246 return true;
247 // shifting right sufficiently makes the value 32-bit sign-extended
248 case RISCV::SRAI:
249 return MI.getOperand(2).getImm() >= 32;
250 case RISCV::SRLI:
251 return MI.getOperand(2).getImm() > 32;
252 // The LI pattern ADDI rd, X0, imm is sign extended.
253 case RISCV::ADDI:
254 if (MI.getOperand(1).isReg() && MI.getOperand(1).getReg() == RISCV::X0)
255 return true;
256 if (isAllUsesReadW(MI, MRI)) {
257 // transform to ADDIW
258 FixableDef.insert(&MI);
259 return true;
260 }
261 return false;
262 // An ANDI with an 11 bit immediate will zero bits 63:11.
263 case RISCV::ANDI:
264 return isUInt<11>(MI.getOperand(2).getImm());
265 // An ORI with an >11 bit immediate (negative 12-bit) will set bits 63:11.
266 case RISCV::ORI:
267 return !isUInt<11>(MI.getOperand(2).getImm());
268 // Copying from X0 produces zero.
269 case RISCV::COPY:
270 return MI.getOperand(1).getReg() == RISCV::X0;
271
272 // With these opcode, we can "fix" them with the W-version
273 // if we know all users of the result only rely on bits 31:0
274 case RISCV::SLLI:
275 // SLLIW reads the lowest 5 bits, while SLLI reads lowest 6 bits
276 if (MI.getOperand(2).getImm() >= 32)
277 return false;
278 LLVM_FALLTHROUGH;
279 case RISCV::ADD:
280 case RISCV::LD:
281 case RISCV::LWU:
282 case RISCV::MUL:
283 case RISCV::SUB:
284 if (isAllUsesReadW(MI, MRI)) {
285 FixableDef.insert(&MI);
286 return true;
287 }
288 }
289
290 return false;
291 }
292
isSignExtendedW(MachineInstr & OrigMI,MachineRegisterInfo & MRI,SmallPtrSetImpl<MachineInstr * > & FixableDef)293 static bool isSignExtendedW(MachineInstr &OrigMI, MachineRegisterInfo &MRI,
294 SmallPtrSetImpl<MachineInstr *> &FixableDef) {
295
296 SmallPtrSet<const MachineInstr *, 4> Visited;
297 SmallVector<MachineInstr *, 4> Worklist;
298
299 Worklist.push_back(&OrigMI);
300
301 while (!Worklist.empty()) {
302 MachineInstr *MI = Worklist.pop_back_val();
303
304 // If we already visited this instruction, we don't need to check it again.
305 if (!Visited.insert(MI).second)
306 continue;
307
308 // If this is a sign extending operation we don't need to look any further.
309 if (isSignExtendingOpW(*MI, MRI, FixableDef))
310 continue;
311
312 // Is this an instruction that propagates sign extend.
313 switch (MI->getOpcode()) {
314 default:
315 // Unknown opcode, give up.
316 return false;
317 case RISCV::COPY: {
318 Register SrcReg = MI->getOperand(1).getReg();
319
320 // TODO: Handle arguments and returns from calls?
321
322 // If this is a copy from another register, check its source instruction.
323 if (!SrcReg.isVirtual())
324 return false;
325 MachineInstr *SrcMI = MRI.getVRegDef(SrcReg);
326 if (!SrcMI)
327 return false;
328
329 // Add SrcMI to the worklist.
330 Worklist.push_back(SrcMI);
331 break;
332 }
333
334 // For these, we just need to check if the 1st operand is sign extended.
335 case RISCV::BCLRI:
336 case RISCV::BINVI:
337 case RISCV::BSETI:
338 if (MI->getOperand(2).getImm() >= 31)
339 return false;
340 LLVM_FALLTHROUGH;
341 case RISCV::REM:
342 case RISCV::ANDI:
343 case RISCV::ORI:
344 case RISCV::XORI: {
345 // |Remainder| is always <= |Dividend|. If D is 32-bit, then so is R.
346 // DIV doesn't work because of the edge case 0xf..f 8000 0000 / (long)-1
347 // Logical operations use a sign extended 12-bit immediate.
348 Register SrcReg = MI->getOperand(1).getReg();
349 if (!SrcReg.isVirtual())
350 return false;
351 MachineInstr *SrcMI = MRI.getVRegDef(SrcReg);
352 if (!SrcMI)
353 return false;
354
355 // Add SrcMI to the worklist.
356 Worklist.push_back(SrcMI);
357 break;
358 }
359 case RISCV::REMU:
360 case RISCV::AND:
361 case RISCV::OR:
362 case RISCV::XOR:
363 case RISCV::ANDN:
364 case RISCV::ORN:
365 case RISCV::XNOR:
366 case RISCV::MAX:
367 case RISCV::MAXU:
368 case RISCV::MIN:
369 case RISCV::MINU:
370 case RISCV::PHI: {
371 // If all incoming values are sign-extended, the output of AND, OR, XOR,
372 // MIN, MAX, or PHI is also sign-extended.
373
374 // The input registers for PHI are operand 1, 3, ...
375 // The input registers for others are operand 1 and 2.
376 unsigned E = 3, D = 1;
377 if (MI->getOpcode() == RISCV::PHI) {
378 E = MI->getNumOperands();
379 D = 2;
380 }
381
382 for (unsigned I = 1; I != E; I += D) {
383 if (!MI->getOperand(I).isReg())
384 return false;
385
386 Register SrcReg = MI->getOperand(I).getReg();
387 if (!SrcReg.isVirtual())
388 return false;
389 MachineInstr *SrcMI = MRI.getVRegDef(SrcReg);
390 if (!SrcMI)
391 return false;
392
393 // Add SrcMI to the worklist.
394 Worklist.push_back(SrcMI);
395 }
396
397 break;
398 }
399 }
400 }
401
402 // If we get here, then every node we visited produces a sign extended value
403 // or propagated sign extended values. So the result must be sign extended.
404 return true;
405 }
406
getWOp(unsigned Opcode)407 static unsigned getWOp(unsigned Opcode) {
408 switch (Opcode) {
409 case RISCV::ADDI:
410 return RISCV::ADDIW;
411 case RISCV::ADD:
412 return RISCV::ADDW;
413 case RISCV::LD:
414 case RISCV::LWU:
415 return RISCV::LW;
416 case RISCV::MUL:
417 return RISCV::MULW;
418 case RISCV::SLLI:
419 return RISCV::SLLIW;
420 case RISCV::SUB:
421 return RISCV::SUBW;
422 default:
423 llvm_unreachable("Unexpected opcode for replacement with W variant");
424 }
425 }
426
runOnMachineFunction(MachineFunction & MF)427 bool RISCVSExtWRemoval::runOnMachineFunction(MachineFunction &MF) {
428 if (skipFunction(MF.getFunction()) || DisableSExtWRemoval)
429 return false;
430
431 MachineRegisterInfo &MRI = MF.getRegInfo();
432 const RISCVSubtarget &ST = MF.getSubtarget<RISCVSubtarget>();
433
434 if (!ST.is64Bit())
435 return false;
436
437 SmallPtrSet<MachineInstr *, 4> SExtWRemovalCands;
438
439 // Replacing instructions invalidates the MI iterator
440 // we collect the candidates, then iterate over them separately.
441 for (MachineBasicBlock &MBB : MF) {
442 for (auto I = MBB.begin(), IE = MBB.end(); I != IE;) {
443 MachineInstr *MI = &*I++;
444
445 // We're looking for the sext.w pattern ADDIW rd, rs1, 0.
446 if (!RISCV::isSEXT_W(*MI))
447 continue;
448
449 // Input should be a virtual register.
450 Register SrcReg = MI->getOperand(1).getReg();
451 if (!SrcReg.isVirtual())
452 continue;
453
454 SExtWRemovalCands.insert(MI);
455 }
456 }
457
458 bool MadeChange = false;
459 for (auto MI : SExtWRemovalCands) {
460 SmallPtrSet<MachineInstr *, 4> FixableDef;
461 Register SrcReg = MI->getOperand(1).getReg();
462 MachineInstr &SrcMI = *MRI.getVRegDef(SrcReg);
463
464 // If all definitions reaching MI sign-extend their output,
465 // then sext.w is redundant
466 if (!isSignExtendedW(SrcMI, MRI, FixableDef))
467 continue;
468
469 Register DstReg = MI->getOperand(0).getReg();
470 if (!MRI.constrainRegClass(SrcReg, MRI.getRegClass(DstReg)))
471 continue;
472 // Replace Fixable instructions with their W versions.
473 for (MachineInstr *Fixable : FixableDef) {
474 MachineBasicBlock &MBB = *Fixable->getParent();
475 const DebugLoc &DL = Fixable->getDebugLoc();
476 unsigned Code = getWOp(Fixable->getOpcode());
477 MachineInstrBuilder Replacement =
478 BuildMI(MBB, Fixable, DL, ST.getInstrInfo()->get(Code));
479 for (auto Op : Fixable->operands())
480 Replacement.add(Op);
481 for (auto Op : Fixable->memoperands())
482 Replacement.addMemOperand(Op);
483
484 LLVM_DEBUG(dbgs() << "Replacing " << *Fixable);
485 LLVM_DEBUG(dbgs() << " with " << *Replacement);
486
487 Fixable->eraseFromParent();
488 ++NumTransformedToWInstrs;
489 }
490
491 LLVM_DEBUG(dbgs() << "Removing redundant sign-extension\n");
492 MRI.replaceRegWith(DstReg, SrcReg);
493 MRI.clearKillFlags(SrcReg);
494 MI->eraseFromParent();
495 ++NumRemovedSExtW;
496 MadeChange = true;
497 }
498
499 return MadeChange;
500 }
501