xref: /wasmtime-44.0.1/tests/all/call_hook.rs (revision cc8d04f4)
1 #![cfg(not(miri))]
2 
3 use std::future::Future;
4 use std::pin::Pin;
5 use std::task::{self, Poll};
6 use wasmtime::bail;
7 use wasmtime::*;
8 
9 // Crate a synchronous Func, call it directly:
10 #[test]
call_wrapped_func() -> Result<(), Error>11 fn call_wrapped_func() -> Result<(), Error> {
12     let mut store = Store::<State>::default();
13     store.call_hook(sync_call_hook);
14 
15     fn verify(state: &State) {
16         // Calling this func will switch context into wasm, then back to host:
17         assert_eq!(state.context, vec![Context::Wasm, Context::Host]);
18 
19         assert_eq!(state.calls_into_host, state.returns_from_host + 1);
20         assert_eq!(state.calls_into_wasm, state.returns_from_wasm + 1);
21     }
22 
23     let mut funcs = Vec::new();
24     funcs.push(Func::wrap(
25         &mut store,
26         |caller: Caller<State>, a: i32, b: i64, c: f32, d: f64| {
27             verify(caller.data());
28 
29             assert_eq!(a, 1);
30             assert_eq!(b, 2);
31             assert_eq!(c, 3.0);
32             assert_eq!(d, 4.0);
33         },
34     ));
35     let func_ty = FuncType::new(
36         store.engine(),
37         [ValType::I32, ValType::I64, ValType::F32, ValType::F64],
38         [],
39     );
40     funcs.push(Func::new(
41         &mut store,
42         func_ty,
43         |caller: Caller<State>, params, results| {
44             verify(caller.data());
45 
46             assert_eq!(params.len(), 4);
47             assert_eq!(params[0].i32().unwrap(), 1);
48             assert_eq!(params[1].i64().unwrap(), 2);
49             assert_eq!(params[2].f32().unwrap(), 3.0);
50             assert_eq!(params[3].f64().unwrap(), 4.0);
51             assert_eq!(results.len(), 0);
52             Ok(())
53         },
54     ));
55     let func_ty = FuncType::new(
56         store.engine(),
57         [ValType::I32, ValType::I64, ValType::F32, ValType::F64],
58         [],
59     );
60     funcs.push(unsafe {
61         Func::new_unchecked(&mut store, func_ty, |caller: Caller<State>, space| {
62             verify(caller.data());
63 
64             assert_eq!(space[0].assume_init_ref().get_i32(), 1i32);
65             assert_eq!(space[1].assume_init_ref().get_i64(), 2i64);
66             assert_eq!(space[2].assume_init_ref().get_f32(), 3.0f32.to_bits());
67             assert_eq!(space[3].assume_init_ref().get_f64(), 4.0f64.to_bits());
68             Ok(())
69         })
70     });
71 
72     let mut n = 0;
73     for f in funcs.iter() {
74         f.call(
75             &mut store,
76             &[Val::I32(1), Val::I64(2), 3.0f32.into(), 4.0f64.into()],
77             &mut [],
78         )?;
79         n += 1;
80 
81         // One switch from vm to host to call f, another in return from f.
82         assert_eq!(store.data().calls_into_host, n);
83         assert_eq!(store.data().returns_from_host, n);
84         assert_eq!(store.data().calls_into_wasm, n);
85         assert_eq!(store.data().returns_from_wasm, n);
86 
87         f.typed::<(i32, i64, f32, f64), ()>(&store)?
88             .call(&mut store, (1, 2, 3.0, 4.0))?;
89         n += 1;
90 
91         assert_eq!(store.data().calls_into_host, n);
92         assert_eq!(store.data().returns_from_host, n);
93         assert_eq!(store.data().calls_into_wasm, n);
94         assert_eq!(store.data().returns_from_wasm, n);
95 
96         unsafe {
97             let mut args = [
98                 Val::I32(1).to_raw(&mut store)?,
99                 Val::I64(2).to_raw(&mut store)?,
100                 Val::F32(3.0f32.to_bits()).to_raw(&mut store)?,
101                 Val::F64(4.0f64.to_bits()).to_raw(&mut store)?,
102             ];
103             f.call_unchecked(&mut store, &mut args)?;
104         }
105         n += 1;
106 
107         assert_eq!(store.data().calls_into_host, n);
108         assert_eq!(store.data().returns_from_host, n);
109         assert_eq!(store.data().calls_into_wasm, n);
110         assert_eq!(store.data().returns_from_wasm, n);
111     }
112 
113     Ok(())
114 }
115 
116 // Create an async Func, call it directly:
117 #[tokio::test]
call_wrapped_async_func() -> Result<(), Error>118 async fn call_wrapped_async_func() -> Result<(), Error> {
119     let engine = Engine::default();
120     let mut store = Store::new(&engine, State::default());
121     store.call_hook(sync_call_hook);
122     let f = Func::wrap_async(
123         &mut store,
124         |caller: Caller<State>, (a, b, c, d): (i32, i64, f32, f64)| {
125             Box::new(async move {
126                 // Calling this func will switch context into wasm, then back to host:
127                 assert_eq!(caller.data().context, vec![Context::Wasm, Context::Host]);
128 
129                 assert_eq!(
130                     caller.data().calls_into_host,
131                     caller.data().returns_from_host + 1
132                 );
133                 assert_eq!(
134                     caller.data().calls_into_wasm,
135                     caller.data().returns_from_wasm + 1
136                 );
137 
138                 assert_eq!(a, 1);
139                 assert_eq!(b, 2);
140                 assert_eq!(c, 3.0);
141                 assert_eq!(d, 4.0);
142             })
143         },
144     );
145 
146     f.call_async(
147         &mut store,
148         &[Val::I32(1), Val::I64(2), 3.0f32.into(), 4.0f64.into()],
149         &mut [],
150     )
151     .await?;
152 
153     // One switch from vm to host to call f, another in return from f.
154     assert_eq!(store.data().calls_into_host, 1);
155     assert_eq!(store.data().returns_from_host, 1);
156     assert_eq!(store.data().calls_into_wasm, 1);
157     assert_eq!(store.data().returns_from_wasm, 1);
158 
159     f.typed::<(i32, i64, f32, f64), ()>(&store)?
160         .call_async(&mut store, (1, 2, 3.0, 4.0))
161         .await?;
162 
163     assert_eq!(store.data().calls_into_host, 2);
164     assert_eq!(store.data().returns_from_host, 2);
165     assert_eq!(store.data().calls_into_wasm, 2);
166     assert_eq!(store.data().returns_from_wasm, 2);
167 
168     Ok(())
169 }
170 
171 // Use the Linker to define a sync func, call it through WebAssembly:
172 #[test]
call_linked_func() -> Result<(), Error>173 fn call_linked_func() -> Result<(), Error> {
174     let engine = Engine::default();
175     let mut store = Store::new(&engine, State::default());
176     store.call_hook(sync_call_hook);
177     let mut linker = Linker::new(&engine);
178 
179     linker.func_wrap(
180         "host",
181         "f",
182         |caller: Caller<State>, a: i32, b: i64, c: f32, d: f64| {
183             // Calling this func will switch context into wasm, then back to host:
184             assert_eq!(caller.data().context, vec![Context::Wasm, Context::Host]);
185 
186             assert_eq!(
187                 caller.data().calls_into_host,
188                 caller.data().returns_from_host + 1
189             );
190             assert_eq!(
191                 caller.data().calls_into_wasm,
192                 caller.data().returns_from_wasm + 1
193             );
194 
195             assert_eq!(a, 1);
196             assert_eq!(b, 2);
197             assert_eq!(c, 3.0);
198             assert_eq!(d, 4.0);
199         },
200     )?;
201 
202     let wat = r#"
203         (module
204             (import "host" "f"
205                 (func $f (param i32) (param i64) (param f32) (param f64)))
206             (func (export "export")
207                 (call $f (i32.const 1) (i64.const 2) (f32.const 3.0) (f64.const 4.0)))
208         )
209     "#;
210     let module = Module::new(&engine, wat)?;
211 
212     let inst = linker.instantiate(&mut store, &module)?;
213     let export = inst
214         .get_export(&mut store, "export")
215         .expect("get export")
216         .into_func()
217         .expect("export is func");
218 
219     export.call(&mut store, &[], &mut [])?;
220 
221     // One switch from vm to host to call f, another in return from f.
222     assert_eq!(store.data().calls_into_host, 1);
223     assert_eq!(store.data().returns_from_host, 1);
224     assert_eq!(store.data().calls_into_wasm, 1);
225     assert_eq!(store.data().returns_from_wasm, 1);
226 
227     export.typed::<(), ()>(&store)?.call(&mut store, ())?;
228 
229     assert_eq!(store.data().calls_into_host, 2);
230     assert_eq!(store.data().returns_from_host, 2);
231     assert_eq!(store.data().calls_into_wasm, 2);
232     assert_eq!(store.data().returns_from_wasm, 2);
233 
234     Ok(())
235 }
236 
237 // Use the Linker to define an async func, call it through WebAssembly:
238 #[tokio::test]
call_linked_func_async() -> Result<(), Error>239 async fn call_linked_func_async() -> Result<(), Error> {
240     let engine = Engine::default();
241     let mut store = Store::new(&engine, State::default());
242     store.call_hook(sync_call_hook);
243 
244     let f = Func::wrap_async(
245         &mut store,
246         |caller: Caller<State>, (a, b, c, d): (i32, i64, f32, f64)| {
247             Box::new(async move {
248                 // Calling this func will switch context into wasm, then back to host:
249                 assert_eq!(caller.data().context, vec![Context::Wasm, Context::Host]);
250 
251                 assert_eq!(
252                     caller.data().calls_into_host,
253                     caller.data().returns_from_host + 1
254                 );
255                 assert_eq!(
256                     caller.data().calls_into_wasm,
257                     caller.data().returns_from_wasm + 1
258                 );
259                 assert_eq!(a, 1);
260                 assert_eq!(b, 2);
261                 assert_eq!(c, 3.0);
262                 assert_eq!(d, 4.0);
263             })
264         },
265     );
266 
267     let mut linker = Linker::new(&engine);
268 
269     linker.define(&mut store, "host", "f", f)?;
270 
271     let wat = r#"
272         (module
273             (import "host" "f"
274                 (func $f (param i32) (param i64) (param f32) (param f64)))
275             (func (export "export")
276                 (call $f (i32.const 1) (i64.const 2) (f32.const 3.0) (f64.const 4.0)))
277         )
278     "#;
279     let module = Module::new(&engine, wat)?;
280 
281     let inst = linker.instantiate_async(&mut store, &module).await?;
282     let export = inst
283         .get_export(&mut store, "export")
284         .expect("get export")
285         .into_func()
286         .expect("export is func");
287 
288     export.call_async(&mut store, &[], &mut []).await?;
289 
290     // One switch from vm to host to call f, another in return from f.
291     assert_eq!(store.data().calls_into_host, 1);
292     assert_eq!(store.data().returns_from_host, 1);
293     assert_eq!(store.data().calls_into_wasm, 1);
294     assert_eq!(store.data().returns_from_wasm, 1);
295 
296     export
297         .typed::<(), ()>(&store)?
298         .call_async(&mut store, ())
299         .await?;
300 
301     assert_eq!(store.data().calls_into_host, 2);
302     assert_eq!(store.data().returns_from_host, 2);
303     assert_eq!(store.data().calls_into_wasm, 2);
304     assert_eq!(store.data().returns_from_wasm, 2);
305 
306     Ok(())
307 }
308 
309 #[test]
instantiate() -> Result<(), Error>310 fn instantiate() -> Result<(), Error> {
311     let mut store = Store::<State>::default();
312     store.call_hook(sync_call_hook);
313 
314     let m = Module::new(store.engine(), "(module)")?;
315     Instance::new(&mut store, &m, &[])?;
316     assert_eq!(store.data().calls_into_wasm, 0);
317     assert_eq!(store.data().calls_into_host, 0);
318 
319     let m = Module::new(store.engine(), "(module (func) (start 0))")?;
320     Instance::new(&mut store, &m, &[])?;
321     assert_eq!(store.data().calls_into_wasm, 1);
322     assert_eq!(store.data().calls_into_host, 0);
323 
324     Ok(())
325 }
326 
327 #[tokio::test]
instantiate_async() -> Result<(), Error>328 async fn instantiate_async() -> Result<(), Error> {
329     let engine = Engine::default();
330     let mut store = Store::new(&engine, State::default());
331     store.call_hook(sync_call_hook);
332 
333     let m = Module::new(store.engine(), "(module)")?;
334     Instance::new_async(&mut store, &m, &[]).await?;
335     assert_eq!(store.data().calls_into_wasm, 0);
336     assert_eq!(store.data().calls_into_host, 0);
337 
338     let m = Module::new(store.engine(), "(module (func) (start 0))")?;
339     Instance::new_async(&mut store, &m, &[]).await?;
340     assert_eq!(store.data().calls_into_wasm, 1);
341     assert_eq!(store.data().calls_into_host, 0);
342 
343     Ok(())
344 }
345 
346 #[test]
recursion() -> Result<(), Error>347 fn recursion() -> Result<(), Error> {
348     // Make sure call hook behaves reasonably when called recursively
349 
350     let engine = Engine::default();
351     let mut store = Store::new(&engine, State::default());
352     store.call_hook(sync_call_hook);
353     let mut linker = Linker::new(&engine);
354 
355     linker.func_wrap("host", "f", |mut caller: Caller<State>, n: i32| {
356         assert_eq!(caller.data().context.last(), Some(&Context::Host));
357 
358         assert_eq!(caller.data().calls_into_host, caller.data().calls_into_wasm);
359 
360         // Recurse
361         if n > 0 {
362             caller
363                 .get_export("export")
364                 .expect("caller exports \"export\"")
365                 .into_func()
366                 .expect("export is a func")
367                 .typed::<i32, ()>(&caller)
368                 .expect("export typing")
369                 .call(&mut caller, n - 1)
370                 .unwrap()
371         }
372     })?;
373 
374     let wat = r#"
375         (module
376             (import "host" "f"
377                 (func $f (param i32)))
378             (func (export "export") (param i32)
379                 (call $f (local.get 0)))
380         )
381     "#;
382     let module = Module::new(&engine, wat)?;
383 
384     let inst = linker.instantiate(&mut store, &module)?;
385     let export = inst
386         .get_export(&mut store, "export")
387         .expect("get export")
388         .into_func()
389         .expect("export is func");
390 
391     // Recursion depth:
392     let n: usize = 10;
393 
394     export.call(&mut store, &[Val::I32(n as i32)], &mut [])?;
395 
396     // Recurse down to 0: n+1 calls
397     assert_eq!(store.data().calls_into_host, n + 1);
398     assert_eq!(store.data().returns_from_host, n + 1);
399     assert_eq!(store.data().calls_into_wasm, n + 1);
400     assert_eq!(store.data().returns_from_wasm, n + 1);
401 
402     export
403         .typed::<i32, ()>(&store)?
404         .call(&mut store, n as i32)?;
405 
406     assert_eq!(store.data().calls_into_host, 2 * (n + 1));
407     assert_eq!(store.data().returns_from_host, 2 * (n + 1));
408     assert_eq!(store.data().calls_into_wasm, 2 * (n + 1));
409     assert_eq!(store.data().returns_from_wasm, 2 * (n + 1));
410 
411     Ok(())
412 }
413 
414 #[test]
trapping() -> Result<(), Error>415 fn trapping() -> Result<(), Error> {
416     const TRAP_IN_F: i32 = 0;
417     const TRAP_NEXT_CALL_HOST: i32 = 1;
418     const TRAP_NEXT_RETURN_HOST: i32 = 2;
419     const TRAP_NEXT_CALL_WASM: i32 = 3;
420     const TRAP_NEXT_RETURN_WASM: i32 = 4;
421 
422     let engine = Engine::default();
423 
424     let mut linker = Linker::new(&engine);
425 
426     linker.func_wrap(
427         "host",
428         "f",
429         |mut caller: Caller<State>, action: i32, recur: i32| -> Result<()> {
430             assert_eq!(caller.data().context.last(), Some(&Context::Host));
431             assert_eq!(caller.data().calls_into_host, caller.data().calls_into_wasm);
432 
433             match action {
434                 TRAP_IN_F => bail!("trapping in f"),
435                 TRAP_NEXT_CALL_HOST => caller.data_mut().trap_next_call_host = true,
436                 TRAP_NEXT_RETURN_HOST => caller.data_mut().trap_next_return_host = true,
437                 TRAP_NEXT_CALL_WASM => caller.data_mut().trap_next_call_wasm = true,
438                 TRAP_NEXT_RETURN_WASM => caller.data_mut().trap_next_return_wasm = true,
439                 _ => {} // Do nothing
440             }
441 
442             // recur so that we can trigger a next call.
443             // propagate its trap, if it traps!
444             if recur > 0 {
445                 let _ = caller
446                     .get_export("export")
447                     .expect("caller exports \"export\"")
448                     .into_func()
449                     .expect("export is a func")
450                     .typed::<(i32, i32), ()>(&caller)
451                     .expect("export typing")
452                     .call(&mut caller, (action, 0))?;
453             }
454 
455             Ok(())
456         },
457     )?;
458 
459     let wat = r#"
460         (module
461             (import "host" "f"
462                 (func $f (param i32) (param i32)))
463             (func (export "export") (param i32) (param i32)
464                 (call $f (local.get 0) (local.get 1)))
465         )
466     "#;
467     let module = Module::new(&engine, wat)?;
468 
469     let run = |action: i32, recur: bool| -> (State, Option<Error>) {
470         let mut store = Store::new(&engine, State::default());
471         store.call_hook(sync_call_hook);
472         let inst = linker
473             .instantiate(&mut store, &module)
474             .expect("instantiate");
475         let export = inst
476             .get_export(&mut store, "export")
477             .expect("get export")
478             .into_func()
479             .expect("export is func");
480 
481         let r = export.call(
482             &mut store,
483             &[Val::I32(action), Val::I32(if recur { 1 } else { 0 })],
484             &mut [],
485         );
486         (store.into_data(), r.err())
487     };
488 
489     let (s, e) = run(TRAP_IN_F, false);
490     assert!(format!("{:?}", e.unwrap()).contains("trapping in f"));
491     assert_eq!(s.calls_into_host, 1);
492     assert_eq!(s.returns_from_host, 1);
493     assert_eq!(s.calls_into_wasm, 1);
494     assert_eq!(s.returns_from_wasm, 1);
495 
496     // trap in next call to host. No calls after the bit is set, so this trap shouldn't happen
497     let (s, e) = run(TRAP_NEXT_CALL_HOST, false);
498     assert!(e.is_none());
499     assert_eq!(s.calls_into_host, 1);
500     assert_eq!(s.returns_from_host, 1);
501     assert_eq!(s.calls_into_wasm, 1);
502     assert_eq!(s.returns_from_wasm, 1);
503 
504     // trap in next call to host. recur, so the second call into host traps:
505     let (s, e) = run(TRAP_NEXT_CALL_HOST, true);
506     assert!(format!("{:?}", e.unwrap()).contains("call_hook: trapping on CallingHost"));
507     assert_eq!(s.calls_into_host, 2);
508     assert_eq!(s.returns_from_host, 1);
509     assert_eq!(s.calls_into_wasm, 2);
510     assert_eq!(s.returns_from_wasm, 2);
511 
512     // trap in the return from host. should trap right away, without recursion
513     let (s, e) = run(TRAP_NEXT_RETURN_HOST, false);
514     assert!(format!("{:?}", e.unwrap()).contains("call_hook: trapping on ReturningFromHost"));
515     assert_eq!(s.calls_into_host, 1);
516     assert_eq!(s.returns_from_host, 1);
517     assert_eq!(s.calls_into_wasm, 1);
518     assert_eq!(s.returns_from_wasm, 1);
519 
520     // trap in next call to wasm. No calls after the bit is set, so this trap shouldn't happen:
521     let (s, e) = run(TRAP_NEXT_CALL_WASM, false);
522     assert!(e.is_none());
523     assert_eq!(s.calls_into_host, 1);
524     assert_eq!(s.returns_from_host, 1);
525     assert_eq!(s.calls_into_wasm, 1);
526     assert_eq!(s.returns_from_wasm, 1);
527 
528     // trap in next call to wasm. recur, so the second call into wasm traps:
529     let (s, e) = run(TRAP_NEXT_CALL_WASM, true);
530     assert!(format!("{:?}", e.unwrap()).contains("call_hook: trapping on CallingWasm"));
531     assert_eq!(s.calls_into_host, 1);
532     assert_eq!(s.returns_from_host, 1);
533     assert_eq!(s.calls_into_wasm, 2);
534     assert_eq!(s.returns_from_wasm, 1);
535 
536     // trap in the return from wasm. should trap right away, without recursion
537     let (s, e) = run(TRAP_NEXT_RETURN_WASM, false);
538     assert!(format!("{:?}", e.unwrap()).contains("call_hook: trapping on ReturningFromWasm"));
539     assert_eq!(s.calls_into_host, 1);
540     assert_eq!(s.returns_from_host, 1);
541     assert_eq!(s.calls_into_wasm, 1);
542     assert_eq!(s.returns_from_wasm, 1);
543 
544     Ok(())
545 }
546 
547 #[tokio::test]
basic_async_hook() -> Result<(), Error>548 async fn basic_async_hook() -> Result<(), Error> {
549     struct HandlerR;
550 
551     #[async_trait::async_trait]
552     impl CallHookHandler<State> for HandlerR {
553         async fn handle_call_event(
554             &self,
555             ctx: StoreContextMut<'_, State>,
556             ch: CallHook,
557         ) -> Result<()> {
558             sync_call_hook(ctx, ch)
559         }
560     }
561     let engine = Engine::default();
562     let mut store = Store::new(&engine, State::default());
563     store.call_hook_async(HandlerR {});
564 
565     assert_eq!(store.data().calls_into_host, 0);
566     assert_eq!(store.data().returns_from_host, 0);
567     assert_eq!(store.data().calls_into_wasm, 0);
568     assert_eq!(store.data().returns_from_wasm, 0);
569 
570     let mut linker = Linker::new(&engine);
571 
572     linker.func_wrap(
573         "host",
574         "f",
575         |caller: Caller<State>, a: i32, b: i64, c: f32, d: f64| {
576             // Calling this func will switch context into wasm, then back to host:
577             assert_eq!(caller.data().context, vec![Context::Wasm, Context::Host]);
578 
579             assert_eq!(
580                 caller.data().calls_into_host,
581                 caller.data().returns_from_host + 1
582             );
583             assert_eq!(
584                 caller.data().calls_into_wasm,
585                 caller.data().returns_from_wasm + 1
586             );
587 
588             assert_eq!(a, 1);
589             assert_eq!(b, 2);
590             assert_eq!(c, 3.0);
591             assert_eq!(d, 4.0);
592         },
593     )?;
594 
595     let wat = r#"
596         (module
597             (import "host" "f"
598                 (func $f (param i32) (param i64) (param f32) (param f64)))
599             (func (export "export")
600                 (call $f (i32.const 1) (i64.const 2) (f32.const 3.0) (f64.const 4.0)))
601         )
602     "#;
603     let module = Module::new(&engine, wat)?;
604 
605     let inst = linker.instantiate_async(&mut store, &module).await?;
606     let export = inst
607         .get_export(&mut store, "export")
608         .expect("get export")
609         .into_func()
610         .expect("export is func");
611 
612     export.call_async(&mut store, &[], &mut []).await?;
613 
614     // One switch from vm to host to call f, another in return from f.
615     assert_eq!(store.data().calls_into_host, 1);
616     assert_eq!(store.data().returns_from_host, 1);
617     assert_eq!(store.data().calls_into_wasm, 1);
618     assert_eq!(store.data().returns_from_wasm, 1);
619 
620     Ok(())
621 }
622 
623 #[tokio::test]
timeout_async_hook() -> Result<(), Error>624 async fn timeout_async_hook() -> Result<(), Error> {
625     struct HandlerR;
626 
627     #[async_trait::async_trait]
628     impl CallHookHandler<State> for HandlerR {
629         async fn handle_call_event(
630             &self,
631             mut ctx: StoreContextMut<'_, State>,
632             ch: CallHook,
633         ) -> Result<()> {
634             let obj = ctx.data_mut();
635             if obj.calls_into_host > 200 {
636                 bail!("timeout");
637             }
638 
639             match ch {
640                 CallHook::CallingHost => obj.calls_into_host += 1,
641                 CallHook::CallingWasm => obj.calls_into_wasm += 1,
642                 CallHook::ReturningFromHost => obj.returns_from_host += 1,
643                 CallHook::ReturningFromWasm => obj.returns_from_wasm += 1,
644             }
645 
646             Ok(())
647         }
648     }
649 
650     let engine = Engine::default();
651     let mut store = Store::new(&engine, State::default());
652     store.call_hook_async(HandlerR {});
653 
654     assert_eq!(store.data().calls_into_host, 0);
655     assert_eq!(store.data().returns_from_host, 0);
656     assert_eq!(store.data().calls_into_wasm, 0);
657     assert_eq!(store.data().returns_from_wasm, 0);
658 
659     let mut linker = Linker::new(&engine);
660 
661     linker.func_wrap(
662         "host",
663         "f",
664         |_caller: Caller<State>, a: i32, b: i64, c: f32, d: f64| {
665             assert_eq!(a, 1);
666             assert_eq!(b, 2);
667             assert_eq!(c, 3.0);
668             assert_eq!(d, 4.0);
669         },
670     )?;
671 
672     let wat = r#"
673         (module
674             (import "host" "f"
675                 (func $f (param i32) (param i64) (param f32) (param f64)))
676             (func (export "export")
677                 (loop $start
678                     (call $f (i32.const 1) (i64.const 2) (f32.const 3.0) (f64.const 4.0))
679                     (br $start)))
680         )
681     "#;
682     let module = Module::new(&engine, wat)?;
683 
684     let inst = linker.instantiate_async(&mut store, &module).await?;
685     let export = inst
686         .get_typed_func::<(), ()>(&mut store, "export")
687         .expect("export is func");
688 
689     store.set_epoch_deadline(1);
690     store.epoch_deadline_async_yield_and_update(1);
691     assert!(export.call_async(&mut store, ()).await.is_err());
692 
693     // One switch from vm to host to call f, another in return from f.
694     assert!(store.data().calls_into_host > 1);
695     assert!(store.data().returns_from_host > 1);
696     assert_eq!(store.data().calls_into_wasm, 1);
697     assert_eq!(store.data().returns_from_wasm, 0);
698 
699     Ok(())
700 }
701 
702 #[tokio::test]
drop_suspended_async_hook() -> Result<(), Error>703 async fn drop_suspended_async_hook() -> Result<(), Error> {
704     struct Handler;
705 
706     #[async_trait::async_trait]
707     impl CallHookHandler<u32> for Handler {
708         async fn handle_call_event(
709             &self,
710             mut ctx: StoreContextMut<'_, u32>,
711             _ch: CallHook,
712         ) -> Result<()> {
713             let state = ctx.data_mut();
714             assert_eq!(*state, 0);
715             *state += 1;
716             let _dec = Decrement(state);
717 
718             // Simulate some sort of event which takes a number of yields
719             for _ in 0..500 {
720                 tokio::task::yield_now().await;
721             }
722             Ok(())
723         }
724     }
725 
726     let engine = Engine::default();
727     let mut store = Store::new(&engine, 0);
728     store.call_hook_async(Handler);
729 
730     let mut linker = Linker::new(&engine);
731 
732     // Simulate a host function that has lots of yields with an infinite loop.
733     linker.func_wrap_async("host", "f", |mut cx, _: ()| {
734         Box::new(async move {
735             let state = cx.data_mut();
736             assert_eq!(*state, 0);
737             *state += 1;
738             let _dec = Decrement(state);
739             for _ in 0.. {
740                 tokio::task::yield_now().await;
741             }
742         })
743     })?;
744 
745     let wat = r#"
746         (module
747             (import "host" "f" (func $f))
748             (func (export "") call $f)
749         )
750     "#;
751     let module = Module::new(&engine, wat)?;
752 
753     let inst = linker.instantiate_async(&mut store, &module).await?;
754     assert_eq!(*store.data(), 0);
755     let export = inst
756         .get_typed_func::<(), ()>(&mut store, "")
757         .expect("export is func");
758 
759     // First test that if we drop in the middle of an async hook that everything
760     // is alright.
761     PollNTimes {
762         future: Box::pin(export.call_async(&mut store, ())),
763         times: 200,
764     }
765     .await;
766     assert_eq!(*store.data(), 0); // double-check user dtors ran
767 
768     // Next test that if we drop while in a host async function that everything
769     // is also alright.
770     PollNTimes {
771         future: Box::pin(export.call_async(&mut store, ())),
772         times: 1_000,
773     }
774     .await;
775     assert_eq!(*store.data(), 0); // double-check user dtors ran
776 
777     return Ok(());
778 
779     // A helper struct to poll an inner `future` N `times` and then resolve.
780     // This is used above to test that when futures are dropped while they're
781     // pending everything works and is cleaned up on the Wasmtime side of
782     // things.
783     struct PollNTimes<F> {
784         future: F,
785         times: u32,
786     }
787 
788     impl<F: Future + Unpin> Future for PollNTimes<F> {
789         type Output = ();
790         fn poll(mut self: Pin<&mut Self>, task: &mut task::Context<'_>) -> Poll<()> {
791             for _ in 0..self.times {
792                 match Pin::new(&mut self.future).poll(task) {
793                     Poll::Ready(_) => panic!("future should not be ready"),
794                     Poll::Pending => {}
795                 }
796             }
797 
798             Poll::Ready(())
799         }
800     }
801 
802     // helper struct to decrement a counter on drop
803     struct Decrement<'a>(&'a mut u32);
804 
805     impl Drop for Decrement<'_> {
806         fn drop(&mut self) {
807             *self.0 -= 1;
808         }
809     }
810 }
811 
812 #[derive(Debug, PartialEq, Eq)]
813 pub enum Context {
814     Host,
815     Wasm,
816 }
817 
818 pub struct State {
819     pub context: Vec<Context>,
820 
821     pub calls_into_host: usize,
822     pub returns_from_host: usize,
823     pub calls_into_wasm: usize,
824     pub returns_from_wasm: usize,
825 
826     pub trap_next_call_host: bool,
827     pub trap_next_return_host: bool,
828     pub trap_next_call_wasm: bool,
829     pub trap_next_return_wasm: bool,
830 }
831 
832 impl Default for State {
default() -> Self833     fn default() -> Self {
834         State {
835             context: Vec::new(),
836             calls_into_host: 0,
837             returns_from_host: 0,
838             calls_into_wasm: 0,
839             returns_from_wasm: 0,
840             trap_next_call_host: false,
841             trap_next_return_host: false,
842             trap_next_call_wasm: false,
843             trap_next_return_wasm: false,
844         }
845     }
846 }
847 
848 impl State {
849     // This implementation asserts that hooks are always called in a stack-like manner.
call_hook(&mut self, s: CallHook) -> Result<()>850     fn call_hook(&mut self, s: CallHook) -> Result<()> {
851         match s {
852             CallHook::CallingHost => {
853                 self.calls_into_host += 1;
854                 if self.trap_next_call_host {
855                     bail!("call_hook: trapping on CallingHost");
856                 } else {
857                     self.context.push(Context::Host);
858                 }
859             }
860             CallHook::ReturningFromHost => match self.context.pop() {
861                 Some(Context::Host) => {
862                     self.returns_from_host += 1;
863                     if self.trap_next_return_host {
864                         bail!("call_hook: trapping on ReturningFromHost");
865                     }
866                 }
867                 c => panic!(
868                     "illegal context: expected Some(Host), got {:?}. remaining: {:?}",
869                     c, self.context
870                 ),
871             },
872             CallHook::CallingWasm => {
873                 self.calls_into_wasm += 1;
874                 if self.trap_next_call_wasm {
875                     bail!("call_hook: trapping on CallingWasm");
876                 } else {
877                     self.context.push(Context::Wasm);
878                 }
879             }
880             CallHook::ReturningFromWasm => match self.context.pop() {
881                 Some(Context::Wasm) => {
882                     self.returns_from_wasm += 1;
883                     if self.trap_next_return_wasm {
884                         bail!("call_hook: trapping on ReturningFromWasm");
885                     }
886                 }
887                 c => panic!(
888                     "illegal context: expected Some(Wasm), got {:?}. remaining: {:?}",
889                     c, self.context
890                 ),
891             },
892         }
893         Ok(())
894     }
895 }
896 
sync_call_hook(mut ctx: StoreContextMut<'_, State>, transition: CallHook) -> Result<()>897 pub fn sync_call_hook(mut ctx: StoreContextMut<'_, State>, transition: CallHook) -> Result<()> {
898     ctx.data_mut().call_hook(transition)
899 }
900