1 //=== WebAssemblyLateEHPrepare.cpp - WebAssembly Exception Preparation -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 /// \brief Does various transformations for exception handling.
11 ///
12 //===----------------------------------------------------------------------===//
13 
14 #include "MCTargetDesc/WebAssemblyMCTargetDesc.h"
15 #include "WebAssembly.h"
16 #include "WebAssemblySubtarget.h"
17 #include "WebAssemblyUtilities.h"
18 #include "llvm/ADT/SmallSet.h"
19 #include "llvm/CodeGen/MachineInstrBuilder.h"
20 #include "llvm/CodeGen/WasmEHFuncInfo.h"
21 #include "llvm/MC/MCAsmInfo.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Target/TargetMachine.h"
24 using namespace llvm;
25 
26 #define DEBUG_TYPE "wasm-late-eh-prepare"
27 
28 namespace {
29 class WebAssemblyLateEHPrepare final : public MachineFunctionPass {
30   StringRef getPassName() const override {
31     return "WebAssembly Late Prepare Exception";
32   }
33 
34   bool runOnMachineFunction(MachineFunction &MF) override;
35   bool addCatches(MachineFunction &MF);
36   bool replaceFuncletReturns(MachineFunction &MF);
37   bool removeUnnecessaryUnreachables(MachineFunction &MF);
38   bool addExceptionExtraction(MachineFunction &MF);
39   bool restoreStackPointer(MachineFunction &MF);
40 
41 public:
42   static char ID; // Pass identification, replacement for typeid
43   WebAssemblyLateEHPrepare() : MachineFunctionPass(ID) {}
44 };
45 } // end anonymous namespace
46 
47 char WebAssemblyLateEHPrepare::ID = 0;
48 INITIALIZE_PASS(WebAssemblyLateEHPrepare, DEBUG_TYPE,
49                 "WebAssembly Late Exception Preparation", false, false)
50 
51 FunctionPass *llvm::createWebAssemblyLateEHPrepare() {
52   return new WebAssemblyLateEHPrepare();
53 }
54 
55 // Returns the nearest EH pad that dominates this instruction. This does not use
56 // dominator analysis; it just does BFS on its predecessors until arriving at an
57 // EH pad. This assumes valid EH scopes so the first EH pad it arrives in all
58 // possible search paths should be the same.
59 // Returns nullptr in case it does not find any EH pad in the search, or finds
60 // multiple different EH pads.
61 static MachineBasicBlock *getMatchingEHPad(MachineInstr *MI) {
62   MachineFunction *MF = MI->getParent()->getParent();
63   SmallVector<MachineBasicBlock *, 2> WL;
64   SmallPtrSet<MachineBasicBlock *, 2> Visited;
65   WL.push_back(MI->getParent());
66   MachineBasicBlock *EHPad = nullptr;
67   while (!WL.empty()) {
68     MachineBasicBlock *MBB = WL.pop_back_val();
69     if (Visited.count(MBB))
70       continue;
71     Visited.insert(MBB);
72     if (MBB->isEHPad()) {
73       if (EHPad && EHPad != MBB)
74         return nullptr;
75       EHPad = MBB;
76       continue;
77     }
78     if (MBB == &MF->front())
79       return nullptr;
80     WL.append(MBB->pred_begin(), MBB->pred_end());
81   }
82   return EHPad;
83 }
84 
85 // Erase the specified BBs if the BB does not have any remaining predecessors,
86 // and also all its dead children.
87 template <typename Container>
88 static void eraseDeadBBsAndChildren(const Container &MBBs) {
89   SmallVector<MachineBasicBlock *, 8> WL(MBBs.begin(), MBBs.end());
90   while (!WL.empty()) {
91     MachineBasicBlock *MBB = WL.pop_back_val();
92     if (!MBB->pred_empty())
93       continue;
94     SmallVector<MachineBasicBlock *, 4> Succs(MBB->succ_begin(),
95                                               MBB->succ_end());
96     WL.append(MBB->succ_begin(), MBB->succ_end());
97     for (auto *Succ : Succs)
98       MBB->removeSuccessor(Succ);
99     MBB->eraseFromParent();
100   }
101 }
102 
103 bool WebAssemblyLateEHPrepare::runOnMachineFunction(MachineFunction &MF) {
104   LLVM_DEBUG(dbgs() << "********** Late EH Prepare **********\n"
105                        "********** Function: "
106                     << MF.getName() << '\n');
107 
108   if (MF.getTarget().getMCAsmInfo()->getExceptionHandlingType() !=
109       ExceptionHandling::Wasm)
110     return false;
111 
112   bool Changed = false;
113   if (MF.getFunction().hasPersonalityFn()) {
114     Changed |= addCatches(MF);
115     Changed |= replaceFuncletReturns(MF);
116   }
117   Changed |= removeUnnecessaryUnreachables(MF);
118   if (MF.getFunction().hasPersonalityFn()) {
119     Changed |= addExceptionExtraction(MF);
120     Changed |= restoreStackPointer(MF);
121   }
122   return Changed;
123 }
124 
125 // Add catch instruction to beginning of catchpads and cleanuppads.
126 bool WebAssemblyLateEHPrepare::addCatches(MachineFunction &MF) {
127   bool Changed = false;
128   const auto &TII = *MF.getSubtarget<WebAssemblySubtarget>().getInstrInfo();
129   MachineRegisterInfo &MRI = MF.getRegInfo();
130   for (auto &MBB : MF) {
131     if (MBB.isEHPad()) {
132       Changed = true;
133       auto InsertPos = MBB.begin();
134       if (InsertPos->isEHLabel()) // EH pad starts with an EH label
135         ++InsertPos;
136       Register DstReg = MRI.createVirtualRegister(&WebAssembly::EXNREFRegClass);
137       BuildMI(MBB, InsertPos, MBB.begin()->getDebugLoc(),
138               TII.get(WebAssembly::CATCH), DstReg);
139     }
140   }
141   return Changed;
142 }
143 
144 bool WebAssemblyLateEHPrepare::replaceFuncletReturns(MachineFunction &MF) {
145   bool Changed = false;
146   const auto &TII = *MF.getSubtarget<WebAssemblySubtarget>().getInstrInfo();
147 
148   for (auto &MBB : MF) {
149     auto Pos = MBB.getFirstTerminator();
150     if (Pos == MBB.end())
151       continue;
152     MachineInstr *TI = &*Pos;
153 
154     switch (TI->getOpcode()) {
155     case WebAssembly::CATCHRET: {
156       // Replace a catchret with a branch
157       MachineBasicBlock *TBB = TI->getOperand(0).getMBB();
158       if (!MBB.isLayoutSuccessor(TBB))
159         BuildMI(MBB, TI, TI->getDebugLoc(), TII.get(WebAssembly::BR))
160             .addMBB(TBB);
161       TI->eraseFromParent();
162       Changed = true;
163       break;
164     }
165     case WebAssembly::CLEANUPRET:
166     case WebAssembly::RETHROW_IN_CATCH: {
167       // Replace a cleanupret/rethrow_in_catch with a rethrow
168       auto *EHPad = getMatchingEHPad(TI);
169       auto CatchPos = EHPad->begin();
170       if (CatchPos->isEHLabel()) // EH pad starts with an EH label
171         ++CatchPos;
172       MachineInstr *Catch = &*CatchPos;
173       Register ExnReg = Catch->getOperand(0).getReg();
174       BuildMI(MBB, TI, TI->getDebugLoc(), TII.get(WebAssembly::RETHROW))
175           .addReg(ExnReg);
176       TI->eraseFromParent();
177       Changed = true;
178       break;
179     }
180     }
181   }
182   return Changed;
183 }
184 
185 bool WebAssemblyLateEHPrepare::removeUnnecessaryUnreachables(
186     MachineFunction &MF) {
187   bool Changed = false;
188   for (auto &MBB : MF) {
189     for (auto &MI : MBB) {
190       if (MI.getOpcode() != WebAssembly::THROW &&
191           MI.getOpcode() != WebAssembly::RETHROW)
192         continue;
193       Changed = true;
194 
195       // The instruction after the throw should be an unreachable or a branch to
196       // another BB that should eventually lead to an unreachable. Delete it
197       // because throw itself is a terminator, and also delete successors if
198       // any.
199       MBB.erase(std::next(MI.getIterator()), MBB.end());
200       SmallVector<MachineBasicBlock *, 8> Succs(MBB.succ_begin(),
201                                                 MBB.succ_end());
202       for (auto *Succ : Succs)
203         if (!Succ->isEHPad())
204           MBB.removeSuccessor(Succ);
205       eraseDeadBBsAndChildren(Succs);
206     }
207   }
208 
209   return Changed;
210 }
211 
212 // Wasm uses 'br_on_exn' instruction to check the tag of an exception. It takes
213 // exnref type object returned by 'catch', and branches to the destination if it
214 // matches a given tag. We currently use __cpp_exception symbol to represent the
215 // tag for all C++ exceptions.
216 //
217 // block $l (result i32)
218 //   ...
219 //   ;; exnref $e is on the stack at this point
220 //   br_on_exn $l $e ;; branch to $l with $e's arguments
221 //   ...
222 // end
223 // ;; Here we expect the extracted values are on top of the wasm value stack
224 // ... Handle exception using values ...
225 //
226 // br_on_exn takes an exnref object and branches if it matches the given tag.
227 // There can be multiple br_on_exn instructions if we want to match for another
228 // tag, but for now we only test for __cpp_exception tag, and if it does not
229 // match, i.e., it is a foreign exception, we rethrow it.
230 //
231 // In the destination BB that's the target of br_on_exn, extracted exception
232 // values (in C++'s case a single i32, which represents an exception pointer)
233 // are placed on top of the wasm stack. Because we can't model wasm stack in
234 // LLVM instruction, we use 'extract_exception' pseudo instruction to retrieve
235 // it. The pseudo instruction will be deleted later.
236 bool WebAssemblyLateEHPrepare::addExceptionExtraction(MachineFunction &MF) {
237   const auto &TII = *MF.getSubtarget<WebAssemblySubtarget>().getInstrInfo();
238   MachineRegisterInfo &MRI = MF.getRegInfo();
239   auto *EHInfo = MF.getWasmEHFuncInfo();
240   SmallVector<MachineInstr *, 16> ExtractInstrs;
241   SmallVector<MachineInstr *, 8> ToDelete;
242   for (auto &MBB : MF) {
243     for (auto &MI : MBB) {
244       if (MI.getOpcode() == WebAssembly::EXTRACT_EXCEPTION_I32) {
245         if (MI.getOperand(0).isDead())
246           ToDelete.push_back(&MI);
247         else
248           ExtractInstrs.push_back(&MI);
249       }
250     }
251   }
252   bool Changed = !ToDelete.empty() || !ExtractInstrs.empty();
253   for (auto *MI : ToDelete)
254     MI->eraseFromParent();
255   if (ExtractInstrs.empty())
256     return Changed;
257 
258   // Find terminate pads.
259   SmallSet<MachineBasicBlock *, 8> TerminatePads;
260   for (auto &MBB : MF) {
261     for (auto &MI : MBB) {
262       if (MI.isCall()) {
263         const MachineOperand &CalleeOp = MI.getOperand(0);
264         if (CalleeOp.isGlobal() && CalleeOp.getGlobal()->getName() ==
265                                        WebAssembly::ClangCallTerminateFn)
266           TerminatePads.insert(getMatchingEHPad(&MI));
267       }
268     }
269   }
270 
271   for (auto *Extract : ExtractInstrs) {
272     MachineBasicBlock *EHPad = getMatchingEHPad(Extract);
273     assert(EHPad && "No matching EH pad for extract_exception");
274     auto CatchPos = EHPad->begin();
275     if (CatchPos->isEHLabel()) // EH pad starts with an EH label
276       ++CatchPos;
277     MachineInstr *Catch = &*CatchPos;
278 
279     if (Catch->getNextNode() != Extract)
280       EHPad->insert(Catch->getNextNode(), Extract->removeFromParent());
281 
282     // - Before:
283     // ehpad:
284     //   %exnref:exnref = catch
285     //   %exn:i32 = extract_exception
286     //   ... use exn ...
287     //
288     // - After:
289     // ehpad:
290     //   %exnref:exnref = catch
291     //   br_on_exn %thenbb, $__cpp_exception, %exnref
292     //   br %elsebb
293     // elsebb:
294     //   rethrow
295     // thenbb:
296     //   %exn:i32 = extract_exception
297     //   ... use exn ...
298     Register ExnReg = Catch->getOperand(0).getReg();
299     auto *ThenMBB = MF.CreateMachineBasicBlock();
300     auto *ElseMBB = MF.CreateMachineBasicBlock();
301     MF.insert(std::next(MachineFunction::iterator(EHPad)), ElseMBB);
302     MF.insert(std::next(MachineFunction::iterator(ElseMBB)), ThenMBB);
303     ThenMBB->splice(ThenMBB->end(), EHPad, Extract, EHPad->end());
304     ThenMBB->transferSuccessors(EHPad);
305     EHPad->addSuccessor(ThenMBB);
306     EHPad->addSuccessor(ElseMBB);
307 
308     DebugLoc DL = Extract->getDebugLoc();
309     const char *CPPExnSymbol = MF.createExternalSymbolName("__cpp_exception");
310     BuildMI(EHPad, DL, TII.get(WebAssembly::BR_ON_EXN))
311         .addMBB(ThenMBB)
312         .addExternalSymbol(CPPExnSymbol)
313         .addReg(ExnReg);
314     BuildMI(EHPad, DL, TII.get(WebAssembly::BR)).addMBB(ElseMBB);
315 
316     // When this is a terminate pad with __clang_call_terminate() call, we don't
317     // rethrow it anymore and call __clang_call_terminate() with a nullptr
318     // argument, which will call std::terminate().
319     //
320     // - Before:
321     // ehpad:
322     //   %exnref:exnref = catch
323     //   %exn:i32 = extract_exception
324     //   call @__clang_call_terminate(%exn)
325     //   unreachable
326     //
327     // - After:
328     // ehpad:
329     //   %exnref:exnref = catch
330     //   br_on_exn %thenbb, $__cpp_exception, %exnref
331     //   br %elsebb
332     // elsebb:
333     //   call @__clang_call_terminate(0)
334     //   unreachable
335     // thenbb:
336     //   %exn:i32 = extract_exception
337     //   call @__clang_call_terminate(%exn)
338     //   unreachable
339     if (TerminatePads.count(EHPad)) {
340       Function *ClangCallTerminateFn =
341           MF.getFunction().getParent()->getFunction(
342               WebAssembly::ClangCallTerminateFn);
343       assert(ClangCallTerminateFn &&
344              "There is no __clang_call_terminate() function");
345       Register Reg = MRI.createVirtualRegister(&WebAssembly::I32RegClass);
346       BuildMI(ElseMBB, DL, TII.get(WebAssembly::CONST_I32), Reg).addImm(0);
347       BuildMI(ElseMBB, DL, TII.get(WebAssembly::CALL))
348           .addGlobalAddress(ClangCallTerminateFn)
349           .addReg(Reg);
350       BuildMI(ElseMBB, DL, TII.get(WebAssembly::UNREACHABLE));
351 
352     } else {
353       BuildMI(ElseMBB, DL, TII.get(WebAssembly::RETHROW)).addReg(ExnReg);
354       if (EHInfo->hasEHPadUnwindDest(EHPad))
355         ElseMBB->addSuccessor(EHInfo->getEHPadUnwindDest(EHPad));
356     }
357   }
358 
359   return true;
360 }
361 
362 // After the stack is unwound due to a thrown exception, the __stack_pointer
363 // global can point to an invalid address. This inserts instructions that
364 // restore __stack_pointer global.
365 bool WebAssemblyLateEHPrepare::restoreStackPointer(MachineFunction &MF) {
366   const auto *FrameLowering = static_cast<const WebAssemblyFrameLowering *>(
367       MF.getSubtarget().getFrameLowering());
368   if (!FrameLowering->needsPrologForEH(MF))
369     return false;
370   bool Changed = false;
371 
372   for (auto &MBB : MF) {
373     if (!MBB.isEHPad())
374       continue;
375     Changed = true;
376 
377     // Insert __stack_pointer restoring instructions at the beginning of each EH
378     // pad, after the catch instruction. Here it is safe to assume that SP32
379     // holds the latest value of __stack_pointer, because the only exception for
380     // this case is when a function uses the red zone, but that only happens
381     // with leaf functions, and we don't restore __stack_pointer in leaf
382     // functions anyway.
383     auto InsertPos = MBB.begin();
384     if (InsertPos->isEHLabel()) // EH pad starts with an EH label
385       ++InsertPos;
386     if (InsertPos->getOpcode() == WebAssembly::CATCH)
387       ++InsertPos;
388     FrameLowering->writeSPToGlobal(WebAssembly::SP32, MF, MBB, InsertPos,
389                                    MBB.begin()->getDebugLoc());
390   }
391   return Changed;
392 }
393