1 //! Implement [`wasi-threads`].
2 //!
3 //! [`wasi-threads`]: https://github.com/WebAssembly/wasi-threads
4 
5 use std::panic::{AssertUnwindSafe, catch_unwind};
6 use std::sync::Arc;
7 use std::sync::atomic::{AtomicI32, Ordering};
8 use std::thread;
9 use wasmtime::{
10     Caller, ExternType, InstancePre, Linker, Module, Result, SharedMemory, Store, format_err,
11 };
12 
13 // This name is a function export designated by the wasi-threads specification:
14 // https://github.com/WebAssembly/wasi-threads/#detailed-design-discussion
15 const WASI_ENTRY_POINT: &str = "wasi_thread_start";
16 
17 pub struct WasiThreadsCtx<T> {
18     instance_pre: Arc<InstancePre<T>>,
19     tid: AtomicI32,
20     use_async: bool,
21 }
22 
23 impl<T: Clone + Send + 'static> WasiThreadsCtx<T> {
new(module: Module, linker: Arc<Linker<T>>, use_async: bool) -> Result<Self>24     pub fn new(module: Module, linker: Arc<Linker<T>>, use_async: bool) -> Result<Self> {
25         let instance_pre = Arc::new(linker.instantiate_pre(&module)?);
26         let tid = AtomicI32::new(0);
27         Ok(Self {
28             instance_pre,
29             tid,
30             use_async,
31         })
32     }
33 
spawn(&self, host: T, thread_start_arg: i32) -> Result<i32>34     pub fn spawn(&self, host: T, thread_start_arg: i32) -> Result<i32> {
35         let instance_pre = self.instance_pre.clone();
36 
37         // Check that the thread entry point is present. Why here? If we check
38         // for this too early, then we cannot accept modules that do not have an
39         // entry point but never spawn a thread. As pointed out in
40         // https://github.com/bytecodealliance/wasmtime/issues/6153, checking
41         // the entry point here allows wasi-threads to be compatible with more
42         // modules.
43         //
44         // As defined in the wasi-threads specification, returning a negative
45         // result here indicates to the guest module that the spawn failed.
46         if !has_entry_point(instance_pre.module()) {
47             log::error!(
48                 "failed to find a wasi-threads entry point function; expected an export with name: {WASI_ENTRY_POINT}"
49             );
50             return Ok(-1);
51         }
52         if !has_correct_signature(instance_pre.module()) {
53             log::error!(
54                 "the exported entry point function has an incorrect signature: expected `(i32, i32) -> ()`"
55             );
56             return Ok(-1);
57         }
58 
59         let wasi_thread_id = self.next_thread_id();
60         if wasi_thread_id.is_none() {
61             log::error!("ran out of valid thread IDs");
62             return Ok(-1);
63         }
64         let wasi_thread_id = wasi_thread_id.unwrap();
65 
66         // Start a Rust thread running a new instance of the current module.
67         let builder = thread::Builder::new().name(format!("wasi-thread-{wasi_thread_id}"));
68         let use_async = self.use_async;
69         builder.spawn(move || {
70             // Catch any panic failures in host code; e.g., if a WASI module
71             // were to crash, we want all threads to exit, not just this one.
72             let result = catch_unwind(AssertUnwindSafe(|| {
73                 // Each new instance is created in its own store.
74                 let mut store = Store::new(&instance_pre.module().engine(), host);
75 
76                 let instance = if use_async {
77                     wasmtime_wasi::runtime::in_tokio(instance_pre.instantiate_async(&mut store))
78                 } else {
79                     instance_pre.instantiate(&mut store)
80                 }
81                 .unwrap();
82 
83                 let thread_entry_point = instance
84                     .get_typed_func::<(i32, i32), ()>(&mut store, WASI_ENTRY_POINT)
85                     .unwrap();
86 
87                 // Start the thread's entry point. Any traps or calls to
88                 // `proc_exit`, by specification, should end execution for all
89                 // threads. This code uses `process::exit` to do so, which is
90                 // what the user expects from the CLI but probably not in a
91                 // Wasmtime embedding.
92                 log::trace!(
93                     "spawned thread id = {wasi_thread_id}; calling start function `{WASI_ENTRY_POINT}` with: {thread_start_arg}"
94                 );
95                 let res = if use_async {
96                     wasmtime_wasi::runtime::in_tokio(
97                         thread_entry_point
98                             .call_async(&mut store, (wasi_thread_id, thread_start_arg)),
99                     )
100                 } else {
101                     thread_entry_point.call(&mut store, (wasi_thread_id, thread_start_arg))
102                 };
103                 match res {
104                     Ok(_) => log::trace!("exiting thread id = {wasi_thread_id} normally"),
105                     Err(e) => {
106                         log::trace!("exiting thread id = {wasi_thread_id} due to error");
107                         let e = wasi_common::maybe_exit_on_error(e);
108                         eprintln!("Error: {e:?}");
109                         std::process::exit(1);
110                     }
111                 }
112             }));
113 
114             if let Err(e) = result {
115                 eprintln!("wasi-thread-{wasi_thread_id} panicked: {e:?}");
116                 std::process::exit(1);
117             }
118         })?;
119 
120         Ok(wasi_thread_id)
121     }
122 
123     /// Helper for generating valid WASI thread IDs (TID).
124     ///
125     /// Callers of `wasi_thread_spawn` expect a TID in range of 0 < TID <= 0x1FFFFFFF
126     /// to indicate a successful spawning of the thread whereas a negative
127     /// return value indicates an failure to spawn.
next_thread_id(&self) -> Option<i32>128     fn next_thread_id(&self) -> Option<i32> {
129         match self
130             .tid
131             .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |v| match v {
132                 ..=0x1ffffffe => Some(v + 1),
133                 _ => None,
134             }) {
135             Ok(v) => Some(v + 1),
136             Err(_) => None,
137         }
138     }
139 }
140 
141 /// Manually add the WASI `thread_spawn` function to the linker.
142 ///
143 /// It is unclear what namespace the `wasi-threads` proposal should live under:
144 /// it is not clear if it should be included in any of the `preview*` releases
145 /// so for the time being its module namespace is simply `"wasi"` (TODO).
add_to_linker<T: Clone + Send + 'static>( linker: &mut wasmtime::Linker<T>, store: &wasmtime::Store<T>, module: &Module, get_cx: impl Fn(&mut T) -> &WasiThreadsCtx<T> + Send + Sync + Copy + 'static, ) -> wasmtime::Result<()>146 pub fn add_to_linker<T: Clone + Send + 'static>(
147     linker: &mut wasmtime::Linker<T>,
148     store: &wasmtime::Store<T>,
149     module: &Module,
150     get_cx: impl Fn(&mut T) -> &WasiThreadsCtx<T> + Send + Sync + Copy + 'static,
151 ) -> wasmtime::Result<()> {
152     linker.func_wrap(
153         "wasi",
154         "thread-spawn",
155         move |mut caller: Caller<'_, T>, start_arg: i32| -> i32 {
156             log::trace!("new thread requested via `wasi::thread_spawn` call");
157             let host = caller.data().clone();
158             let ctx = get_cx(caller.data_mut());
159             match ctx.spawn(host, start_arg) {
160                 Ok(thread_id) => {
161                     assert!(thread_id >= 0, "thread_id = {thread_id}");
162                     thread_id
163                 }
164                 Err(e) => {
165                     log::error!("failed to spawn thread: {e}");
166                     -1
167                 }
168             }
169         },
170     )?;
171 
172     // Find the shared memory import and satisfy it with a newly-created shared
173     // memory import.
174     for import in module.imports() {
175         if let Some(m) = import.ty().memory() {
176             if m.is_shared() {
177                 let mem = SharedMemory::new(module.engine(), m.clone())?;
178                 linker.define(store, import.module(), import.name(), mem.clone())?;
179             } else {
180                 return Err(format_err!(
181                     "memory was not shared; a `wasi-threads` must import \
182                      a shared memory as \"memory\""
183                 ));
184             }
185         }
186     }
187     Ok(())
188 }
189 
190 /// Check if wasi-threads' `wasi_thread_start` export is present.
has_entry_point(module: &Module) -> bool191 fn has_entry_point(module: &Module) -> bool {
192     module.get_export(WASI_ENTRY_POINT).is_some()
193 }
194 
195 /// Check if the entry function has the correct signature `(i32, i32) -> ()`.
has_correct_signature(module: &Module) -> bool196 fn has_correct_signature(module: &Module) -> bool {
197     match module.get_export(WASI_ENTRY_POINT) {
198         Some(ExternType::Func(ty)) => {
199             ty.params().len() == 2
200                 && ty.params().nth(0).unwrap().is_i32()
201                 && ty.params().nth(1).unwrap().is_i32()
202                 && ty.results().len() == 0
203         }
204         _ => false,
205     }
206 }
207