xref: /wasmtime-44.0.1/crates/fiber/src/windows.rs (revision f3156fe0)
1 use crate::{RunResult, RuntimeFiberStack};
2 use alloc::boxed::Box;
3 use std::cell::Cell;
4 use std::ffi::c_void;
5 use std::io;
6 use std::mem::needs_drop;
7 use std::ops::Range;
8 use std::ptr;
9 use windows_sys::Win32::Foundation::*;
10 use windows_sys::Win32::System::Threading::*;
11 
12 pub type Error = io::Error;
13 
14 #[derive(Debug)]
15 pub struct FiberStack(usize);
16 
17 impl FiberStack {
new(size: usize, zeroed: bool) -> io::Result<Self>18     pub fn new(size: usize, zeroed: bool) -> io::Result<Self> {
19         // We don't support fiber stack zeroing on windows.
20         let _ = zeroed;
21 
22         Ok(Self(size))
23     }
24 
from_raw_parts( _base: *mut u8, _guard_size: usize, _len: usize, ) -> io::Result<Self>25     pub unsafe fn from_raw_parts(
26         _base: *mut u8,
27         _guard_size: usize,
28         _len: usize,
29     ) -> io::Result<Self> {
30         Err(io::Error::from_raw_os_error(ERROR_NOT_SUPPORTED as i32))
31     }
32 
is_from_raw_parts(&self) -> bool33     pub fn is_from_raw_parts(&self) -> bool {
34         false
35     }
36 
from_custom(_custom: Box<dyn RuntimeFiberStack>) -> io::Result<Self>37     pub fn from_custom(_custom: Box<dyn RuntimeFiberStack>) -> io::Result<Self> {
38         Err(io::Error::from_raw_os_error(ERROR_NOT_SUPPORTED as i32))
39     }
40 
top(&self) -> Option<*mut u8>41     pub fn top(&self) -> Option<*mut u8> {
42         None
43     }
44 
range(&self) -> Option<Range<usize>>45     pub fn range(&self) -> Option<Range<usize>> {
46         None
47     }
48 
guard_range(&self) -> Option<Range<*mut u8>>49     pub fn guard_range(&self) -> Option<Range<*mut u8>> {
50         None
51     }
52 }
53 
54 pub struct Fiber {
55     fiber: *mut c_void,
56     state: Box<StartState>,
57 }
58 
59 pub struct Suspend {
60     state: *const StartState,
61 }
62 
63 struct StartState {
64     parent: Cell<*mut c_void>,
65     initial_closure: Cell<*mut u8>,
66     result_location: Cell<*const u8>,
67 }
68 
69 const FIBER_FLAG_FLOAT_SWITCH: u32 = 1;
70 
71 unsafe extern "C" {
72     #[wasmtime_versioned_export_macros::versioned_link]
wasmtime_fiber_get_current() -> *mut c_void73     fn wasmtime_fiber_get_current() -> *mut c_void;
74 }
75 
fiber_start<F, A, B, C>(data: *mut c_void) where F: FnOnce(A, &mut super::Suspend<A, B, C>) -> C,76 unsafe extern "system" fn fiber_start<F, A, B, C>(data: *mut c_void)
77 where
78     F: FnOnce(A, &mut super::Suspend<A, B, C>) -> C,
79 {
80     unsafe {
81         // Set the stack guarantee to be consistent with what Rust expects for threads
82         // This value is taken from:
83         // https://github.com/rust-lang/rust/blob/0d97f7a96877a96015d70ece41ad08bb7af12377/library/std/src/sys/windows/stack_overflow.rs
84         if SetThreadStackGuarantee(&mut 0x5000) == 0 {
85             panic!("failed to set fiber stack guarantee");
86         }
87 
88         let state = data.cast::<StartState>();
89         let func = Box::from_raw((*state).initial_closure.get().cast::<F>());
90         (*state).initial_closure.set(ptr::null_mut());
91         let suspend = Suspend { state };
92         let initial = suspend.take_resume::<A, B, C>();
93         let suspend = super::Suspend::<A, B, C>::execute(suspend, initial, *func);
94 
95         let parent = (*suspend.state).parent.get();
96         debug_assert!(!parent.is_null());
97 
98         // Technically this means that this function `fiber_start` never
99         // returns, but that's how the fibers API works on Windows it seems.
100         SwitchToFiber(parent);
101     }
102 }
103 
104 impl Fiber {
new<F, A, B, C>(stack: &FiberStack, func: F) -> io::Result<Self> where F: FnOnce(A, &mut super::Suspend<A, B, C>) -> C,105     pub fn new<F, A, B, C>(stack: &FiberStack, func: F) -> io::Result<Self>
106     where
107         F: FnOnce(A, &mut super::Suspend<A, B, C>) -> C,
108     {
109         unsafe {
110             let state = Box::new(StartState {
111                 initial_closure: Cell::new(Box::into_raw(Box::new(func)).cast()),
112                 parent: Cell::new(ptr::null_mut()),
113                 result_location: Cell::new(ptr::null()),
114             });
115 
116             let fiber = CreateFiberEx(
117                 0,
118                 stack.0,
119                 FIBER_FLAG_FLOAT_SWITCH,
120                 Some(fiber_start::<F, A, B, C>),
121                 &*state as *const StartState as *mut _,
122             );
123 
124             if fiber.is_null() {
125                 drop(Box::from_raw(state.initial_closure.get().cast::<F>()));
126                 return Err(io::Error::last_os_error());
127             }
128 
129             Ok(Self { fiber, state })
130         }
131     }
132 
resume<A, B, C>(&self, _stack: &FiberStack, result: &Cell<RunResult<A, B, C>>)133     pub(crate) fn resume<A, B, C>(&self, _stack: &FiberStack, result: &Cell<RunResult<A, B, C>>) {
134         unsafe {
135             let is_fiber = IsThreadAFiber() != 0;
136             let parent_fiber = if is_fiber {
137                 wasmtime_fiber_get_current()
138             } else {
139                 // Newer Rust versions use fiber local storage to register an internal hook that
140                 // calls thread locals' destructors on thread exit.
141                 // This has a limitation: the hook only runs in a regular thread (not in a fiber).
142                 // We convert back into a thread once execution returns to this function,
143                 // but we must also ensure that the hook is registered before converting into a fiber.
144                 // Otherwise, a different fiber could be the first to register the hook,
145                 // causing the hook to be called (and skipped) prematurely when that fiber is deleted.
146                 struct Guard;
147 
148                 impl Drop for Guard {
149                     fn drop(&mut self) {}
150                 }
151                 assert!(needs_drop::<Guard>());
152                 thread_local!(static GUARD: Guard = Guard);
153                 GUARD.with(|_g| {});
154                 ConvertThreadToFiber(ptr::null_mut())
155             };
156             assert!(
157                 !parent_fiber.is_null(),
158                 "failed to make current thread a fiber"
159             );
160             self.state
161                 .result_location
162                 .set(result as *const _ as *const _);
163             self.state.parent.set(parent_fiber);
164             SwitchToFiber(self.fiber);
165             self.state.parent.set(ptr::null_mut());
166             self.state.result_location.set(ptr::null());
167             if !is_fiber {
168                 let res = ConvertFiberToThread();
169                 assert!(res != 0, "failed to convert main thread back");
170             }
171         }
172     }
173 
drop<A, B, C>(&mut self)174     pub(crate) unsafe fn drop<A, B, C>(&mut self) {}
175 }
176 
177 impl Drop for Fiber {
drop(&mut self)178     fn drop(&mut self) {
179         unsafe {
180             DeleteFiber(self.fiber);
181         }
182     }
183 }
184 
185 impl Suspend {
switch<A, B, C>(&self, result: RunResult<A, B, C>) -> A186     pub(crate) fn switch<A, B, C>(&self, result: RunResult<A, B, C>) -> A {
187         unsafe {
188             (*self.result_location::<A, B, C>()).set(result);
189             debug_assert!(IsThreadAFiber() != 0);
190             let parent = (*self.state).parent.get();
191             debug_assert!(!parent.is_null());
192             SwitchToFiber(parent);
193             self.take_resume::<A, B, C>()
194         }
195     }
196 
start_exit<A, B, C>(&mut self, result: RunResult<A, B, C>)197     pub(crate) fn start_exit<A, B, C>(&mut self, result: RunResult<A, B, C>) {
198         unsafe {
199             (*self.result_location::<A, B, C>()).set(result);
200         }
201     }
202 
take_resume<A, B, C>(&self) -> A203     unsafe fn take_resume<A, B, C>(&self) -> A {
204         unsafe {
205             match (*self.result_location::<A, B, C>()).replace(RunResult::Executing) {
206                 RunResult::Resuming(val) => val,
207                 _ => panic!("not in resuming state"),
208             }
209         }
210     }
211 
result_location<A, B, C>(&self) -> *const Cell<RunResult<A, B, C>>212     unsafe fn result_location<A, B, C>(&self) -> *const Cell<RunResult<A, B, C>> {
213         unsafe {
214             let ret = (*self.state)
215                 .result_location
216                 .get()
217                 .cast::<Cell<RunResult<A, B, C>>>();
218             assert!(!ret.is_null());
219             return ret;
220         }
221     }
222 }
223