1 //=== WebAssemblyLateEHPrepare.cpp - WebAssembly Exception Preparation -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 /// \brief Does various transformations for exception handling.
11 ///
12 //===----------------------------------------------------------------------===//
13 
14 #include "MCTargetDesc/WebAssemblyMCTargetDesc.h"
15 #include "WebAssembly.h"
16 #include "WebAssemblySubtarget.h"
17 #include "WebAssemblyUtilities.h"
18 #include "llvm/ADT/SmallSet.h"
19 #include "llvm/CodeGen/MachineInstrBuilder.h"
20 #include "llvm/CodeGen/WasmEHFuncInfo.h"
21 #include "llvm/MC/MCAsmInfo.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Target/TargetMachine.h"
24 using namespace llvm;
25 
26 #define DEBUG_TYPE "wasm-late-eh-prepare"
27 
28 namespace {
29 class WebAssemblyLateEHPrepare final : public MachineFunctionPass {
30   StringRef getPassName() const override {
31     return "WebAssembly Late Prepare Exception";
32   }
33 
34   bool runOnMachineFunction(MachineFunction &MF) override;
35   void recordCatchRetBBs(MachineFunction &MF);
36   bool addCatches(MachineFunction &MF);
37   bool replaceFuncletReturns(MachineFunction &MF);
38   bool removeUnnecessaryUnreachables(MachineFunction &MF);
39   bool addExceptionExtraction(MachineFunction &MF);
40   bool restoreStackPointer(MachineFunction &MF);
41 
42   MachineBasicBlock *getMatchingEHPad(MachineInstr *MI);
43   SmallSet<MachineBasicBlock *, 8> CatchRetBBs;
44 
45 public:
46   static char ID; // Pass identification, replacement for typeid
47   WebAssemblyLateEHPrepare() : MachineFunctionPass(ID) {}
48 };
49 } // end anonymous namespace
50 
51 char WebAssemblyLateEHPrepare::ID = 0;
52 INITIALIZE_PASS(WebAssemblyLateEHPrepare, DEBUG_TYPE,
53                 "WebAssembly Late Exception Preparation", false, false)
54 
55 FunctionPass *llvm::createWebAssemblyLateEHPrepare() {
56   return new WebAssemblyLateEHPrepare();
57 }
58 
59 // Returns the nearest EH pad that dominates this instruction. This does not use
60 // dominator analysis; it just does BFS on its predecessors until arriving at an
61 // EH pad. This assumes valid EH scopes so the first EH pad it arrives in all
62 // possible search paths should be the same.
63 // Returns nullptr in case it does not find any EH pad in the search, or finds
64 // multiple different EH pads.
65 MachineBasicBlock *
66 WebAssemblyLateEHPrepare::getMatchingEHPad(MachineInstr *MI) {
67   MachineFunction *MF = MI->getParent()->getParent();
68   SmallVector<MachineBasicBlock *, 2> WL;
69   SmallPtrSet<MachineBasicBlock *, 2> Visited;
70   WL.push_back(MI->getParent());
71   MachineBasicBlock *EHPad = nullptr;
72   while (!WL.empty()) {
73     MachineBasicBlock *MBB = WL.pop_back_val();
74     if (Visited.count(MBB))
75       continue;
76     Visited.insert(MBB);
77     if (MBB->isEHPad()) {
78       if (EHPad && EHPad != MBB)
79         return nullptr;
80       EHPad = MBB;
81       continue;
82     }
83     if (MBB == &MF->front())
84       return nullptr;
85     for (auto *Pred : MBB->predecessors())
86       if (!CatchRetBBs.count(Pred)) // We don't go into child scopes
87         WL.push_back(Pred);
88   }
89   return EHPad;
90 }
91 
92 // Erase the specified BBs if the BB does not have any remaining predecessors,
93 // and also all its dead children.
94 template <typename Container>
95 static void eraseDeadBBsAndChildren(const Container &MBBs) {
96   SmallVector<MachineBasicBlock *, 8> WL(MBBs.begin(), MBBs.end());
97   while (!WL.empty()) {
98     MachineBasicBlock *MBB = WL.pop_back_val();
99     if (!MBB->pred_empty())
100       continue;
101     SmallVector<MachineBasicBlock *, 4> Succs(MBB->successors());
102     WL.append(MBB->succ_begin(), MBB->succ_end());
103     for (auto *Succ : Succs)
104       MBB->removeSuccessor(Succ);
105     MBB->eraseFromParent();
106   }
107 }
108 
109 bool WebAssemblyLateEHPrepare::runOnMachineFunction(MachineFunction &MF) {
110   LLVM_DEBUG(dbgs() << "********** Late EH Prepare **********\n"
111                        "********** Function: "
112                     << MF.getName() << '\n');
113 
114   if (MF.getTarget().getMCAsmInfo()->getExceptionHandlingType() !=
115       ExceptionHandling::Wasm)
116     return false;
117 
118   bool Changed = false;
119   if (MF.getFunction().hasPersonalityFn()) {
120     recordCatchRetBBs(MF);
121     Changed |= addCatches(MF);
122     Changed |= replaceFuncletReturns(MF);
123   }
124   Changed |= removeUnnecessaryUnreachables(MF);
125   if (MF.getFunction().hasPersonalityFn()) {
126     Changed |= addExceptionExtraction(MF);
127     Changed |= restoreStackPointer(MF);
128   }
129   return Changed;
130 }
131 
132 // Record which BB ends with 'CATCHRET' instruction, because this will be
133 // replaced with BRs later. This set of 'CATCHRET' BBs is necessary in
134 // 'getMatchingEHPad' function.
135 void WebAssemblyLateEHPrepare::recordCatchRetBBs(MachineFunction &MF) {
136   CatchRetBBs.clear();
137   for (auto &MBB : MF) {
138     auto Pos = MBB.getFirstTerminator();
139     if (Pos == MBB.end())
140       continue;
141     MachineInstr *TI = &*Pos;
142     if (TI->getOpcode() == WebAssembly::CATCHRET)
143       CatchRetBBs.insert(&MBB);
144   }
145 }
146 
147 // Add catch instruction to beginning of catchpads and cleanuppads.
148 bool WebAssemblyLateEHPrepare::addCatches(MachineFunction &MF) {
149   bool Changed = false;
150   const auto &TII = *MF.getSubtarget<WebAssemblySubtarget>().getInstrInfo();
151   MachineRegisterInfo &MRI = MF.getRegInfo();
152   for (auto &MBB : MF) {
153     if (MBB.isEHPad()) {
154       Changed = true;
155       auto InsertPos = MBB.begin();
156       if (InsertPos->isEHLabel()) // EH pad starts with an EH label
157         ++InsertPos;
158       Register DstReg = MRI.createVirtualRegister(&WebAssembly::EXNREFRegClass);
159       BuildMI(MBB, InsertPos, MBB.begin()->getDebugLoc(),
160               TII.get(WebAssembly::CATCH), DstReg);
161     }
162   }
163   return Changed;
164 }
165 
166 bool WebAssemblyLateEHPrepare::replaceFuncletReturns(MachineFunction &MF) {
167   bool Changed = false;
168   const auto &TII = *MF.getSubtarget<WebAssemblySubtarget>().getInstrInfo();
169 
170   for (auto &MBB : MF) {
171     auto Pos = MBB.getFirstTerminator();
172     if (Pos == MBB.end())
173       continue;
174     MachineInstr *TI = &*Pos;
175 
176     switch (TI->getOpcode()) {
177     case WebAssembly::CATCHRET: {
178       // Replace a catchret with a branch
179       MachineBasicBlock *TBB = TI->getOperand(0).getMBB();
180       if (!MBB.isLayoutSuccessor(TBB))
181         BuildMI(MBB, TI, TI->getDebugLoc(), TII.get(WebAssembly::BR))
182             .addMBB(TBB);
183       TI->eraseFromParent();
184       Changed = true;
185       break;
186     }
187     case WebAssembly::CLEANUPRET:
188     case WebAssembly::RETHROW_IN_CATCH: {
189       // Replace a cleanupret/rethrow_in_catch with a rethrow
190       auto *EHPad = getMatchingEHPad(TI);
191       auto CatchPos = EHPad->begin();
192       if (CatchPos->isEHLabel()) // EH pad starts with an EH label
193         ++CatchPos;
194       MachineInstr *Catch = &*CatchPos;
195       Register ExnReg = Catch->getOperand(0).getReg();
196       BuildMI(MBB, TI, TI->getDebugLoc(), TII.get(WebAssembly::RETHROW))
197           .addReg(ExnReg);
198       TI->eraseFromParent();
199       Changed = true;
200       break;
201     }
202     }
203   }
204   return Changed;
205 }
206 
207 bool WebAssemblyLateEHPrepare::removeUnnecessaryUnreachables(
208     MachineFunction &MF) {
209   bool Changed = false;
210   for (auto &MBB : MF) {
211     for (auto &MI : MBB) {
212       if (MI.getOpcode() != WebAssembly::THROW &&
213           MI.getOpcode() != WebAssembly::RETHROW)
214         continue;
215       Changed = true;
216 
217       // The instruction after the throw should be an unreachable or a branch to
218       // another BB that should eventually lead to an unreachable. Delete it
219       // because throw itself is a terminator, and also delete successors if
220       // any.
221       MBB.erase(std::next(MI.getIterator()), MBB.end());
222       SmallVector<MachineBasicBlock *, 8> Succs(MBB.successors());
223       for (auto *Succ : Succs)
224         if (!Succ->isEHPad())
225           MBB.removeSuccessor(Succ);
226       eraseDeadBBsAndChildren(Succs);
227     }
228   }
229 
230   return Changed;
231 }
232 
233 // Wasm uses 'br_on_exn' instruction to check the tag of an exception. It takes
234 // exnref type object returned by 'catch', and branches to the destination if it
235 // matches a given tag. We currently use __cpp_exception symbol to represent the
236 // tag for all C++ exceptions.
237 //
238 // block $l (result i32)
239 //   ...
240 //   ;; exnref $e is on the stack at this point
241 //   br_on_exn $l $e ;; branch to $l with $e's arguments
242 //   ...
243 // end
244 // ;; Here we expect the extracted values are on top of the wasm value stack
245 // ... Handle exception using values ...
246 //
247 // br_on_exn takes an exnref object and branches if it matches the given tag.
248 // There can be multiple br_on_exn instructions if we want to match for another
249 // tag, but for now we only test for __cpp_exception tag, and if it does not
250 // match, i.e., it is a foreign exception, we rethrow it.
251 //
252 // In the destination BB that's the target of br_on_exn, extracted exception
253 // values (in C++'s case a single i32, which represents an exception pointer)
254 // are placed on top of the wasm stack. Because we can't model wasm stack in
255 // LLVM instruction, we use 'extract_exception' pseudo instruction to retrieve
256 // it. The pseudo instruction will be deleted later.
257 bool WebAssemblyLateEHPrepare::addExceptionExtraction(MachineFunction &MF) {
258   const auto &TII = *MF.getSubtarget<WebAssemblySubtarget>().getInstrInfo();
259   MachineRegisterInfo &MRI = MF.getRegInfo();
260   auto *EHInfo = MF.getWasmEHFuncInfo();
261   SmallVector<MachineInstr *, 16> ExtractInstrs;
262   SmallVector<MachineInstr *, 8> ToDelete;
263   for (auto &MBB : MF) {
264     for (auto &MI : MBB) {
265       if (MI.getOpcode() == WebAssembly::EXTRACT_EXCEPTION_I32) {
266         if (MI.getOperand(0).isDead())
267           ToDelete.push_back(&MI);
268         else
269           ExtractInstrs.push_back(&MI);
270       }
271     }
272   }
273   bool Changed = !ToDelete.empty() || !ExtractInstrs.empty();
274   for (auto *MI : ToDelete)
275     MI->eraseFromParent();
276   if (ExtractInstrs.empty())
277     return Changed;
278 
279   // Find terminate pads.
280   SmallSet<MachineBasicBlock *, 8> TerminatePads;
281   for (auto &MBB : MF) {
282     for (auto &MI : MBB) {
283       if (MI.isCall()) {
284         const MachineOperand &CalleeOp = MI.getOperand(0);
285         if (CalleeOp.isGlobal() && CalleeOp.getGlobal()->getName() ==
286                                        WebAssembly::ClangCallTerminateFn)
287           TerminatePads.insert(getMatchingEHPad(&MI));
288       }
289     }
290   }
291 
292   for (auto *Extract : ExtractInstrs) {
293     MachineBasicBlock *EHPad = getMatchingEHPad(Extract);
294     assert(EHPad && "No matching EH pad for extract_exception");
295     auto CatchPos = EHPad->begin();
296     if (CatchPos->isEHLabel()) // EH pad starts with an EH label
297       ++CatchPos;
298     MachineInstr *Catch = &*CatchPos;
299 
300     if (Catch->getNextNode() != Extract)
301       EHPad->insert(Catch->getNextNode(), Extract->removeFromParent());
302 
303     // - Before:
304     // ehpad:
305     //   %exnref:exnref = catch
306     //   %exn:i32 = extract_exception
307     //   ... use exn ...
308     //
309     // - After:
310     // ehpad:
311     //   %exnref:exnref = catch
312     //   br_on_exn %thenbb, $__cpp_exception, %exnref
313     //   br %elsebb
314     // elsebb:
315     //   rethrow
316     // thenbb:
317     //   %exn:i32 = extract_exception
318     //   ... use exn ...
319     Register ExnReg = Catch->getOperand(0).getReg();
320     auto *ThenMBB = MF.CreateMachineBasicBlock();
321     auto *ElseMBB = MF.CreateMachineBasicBlock();
322     MF.insert(std::next(MachineFunction::iterator(EHPad)), ElseMBB);
323     MF.insert(std::next(MachineFunction::iterator(ElseMBB)), ThenMBB);
324     ThenMBB->splice(ThenMBB->end(), EHPad, Extract, EHPad->end());
325     ThenMBB->transferSuccessors(EHPad);
326     EHPad->addSuccessor(ThenMBB);
327     EHPad->addSuccessor(ElseMBB);
328 
329     DebugLoc DL = Extract->getDebugLoc();
330     const char *CPPExnSymbol = MF.createExternalSymbolName("__cpp_exception");
331     BuildMI(EHPad, DL, TII.get(WebAssembly::BR_ON_EXN))
332         .addMBB(ThenMBB)
333         .addExternalSymbol(CPPExnSymbol)
334         .addReg(ExnReg);
335     BuildMI(EHPad, DL, TII.get(WebAssembly::BR)).addMBB(ElseMBB);
336 
337     // When this is a terminate pad with __clang_call_terminate() call, we don't
338     // rethrow it anymore and call __clang_call_terminate() with a nullptr
339     // argument, which will call std::terminate().
340     //
341     // - Before:
342     // ehpad:
343     //   %exnref:exnref = catch
344     //   %exn:i32 = extract_exception
345     //   call @__clang_call_terminate(%exn)
346     //   unreachable
347     //
348     // - After:
349     // ehpad:
350     //   %exnref:exnref = catch
351     //   br_on_exn %thenbb, $__cpp_exception, %exnref
352     //   br %elsebb
353     // elsebb:
354     //   call @__clang_call_terminate(0)
355     //   unreachable
356     // thenbb:
357     //   %exn:i32 = extract_exception
358     //   call @__clang_call_terminate(%exn)
359     //   unreachable
360     if (TerminatePads.count(EHPad)) {
361       Function *ClangCallTerminateFn =
362           MF.getFunction().getParent()->getFunction(
363               WebAssembly::ClangCallTerminateFn);
364       assert(ClangCallTerminateFn &&
365              "There is no __clang_call_terminate() function");
366       Register Reg = MRI.createVirtualRegister(&WebAssembly::I32RegClass);
367       BuildMI(ElseMBB, DL, TII.get(WebAssembly::CONST_I32), Reg).addImm(0);
368       BuildMI(ElseMBB, DL, TII.get(WebAssembly::CALL))
369           .addGlobalAddress(ClangCallTerminateFn)
370           .addReg(Reg);
371       BuildMI(ElseMBB, DL, TII.get(WebAssembly::UNREACHABLE));
372 
373     } else {
374       BuildMI(ElseMBB, DL, TII.get(WebAssembly::RETHROW)).addReg(ExnReg);
375       if (EHInfo->hasEHPadUnwindDest(EHPad))
376         ElseMBB->addSuccessor(EHInfo->getEHPadUnwindDest(EHPad));
377     }
378   }
379 
380   return true;
381 }
382 
383 // After the stack is unwound due to a thrown exception, the __stack_pointer
384 // global can point to an invalid address. This inserts instructions that
385 // restore __stack_pointer global.
386 bool WebAssemblyLateEHPrepare::restoreStackPointer(MachineFunction &MF) {
387   const auto *FrameLowering = static_cast<const WebAssemblyFrameLowering *>(
388       MF.getSubtarget().getFrameLowering());
389   if (!FrameLowering->needsPrologForEH(MF))
390     return false;
391   bool Changed = false;
392 
393   for (auto &MBB : MF) {
394     if (!MBB.isEHPad())
395       continue;
396     Changed = true;
397 
398     // Insert __stack_pointer restoring instructions at the beginning of each EH
399     // pad, after the catch instruction. Here it is safe to assume that SP32
400     // holds the latest value of __stack_pointer, because the only exception for
401     // this case is when a function uses the red zone, but that only happens
402     // with leaf functions, and we don't restore __stack_pointer in leaf
403     // functions anyway.
404     auto InsertPos = MBB.begin();
405     if (InsertPos->isEHLabel()) // EH pad starts with an EH label
406       ++InsertPos;
407     if (InsertPos->getOpcode() == WebAssembly::CATCH)
408       ++InsertPos;
409     FrameLowering->writeSPToGlobal(FrameLowering->getSPReg(MF), MF, MBB,
410                                    InsertPos, MBB.begin()->getDebugLoc());
411   }
412   return Changed;
413 }
414