1 //! This module generates test cases for the Wasmtime component model function APIs,
2 //! e.g. `wasmtime::component::func::Func` and `TypedFunc`.
3 //!
4 //! Each case includes a list of arbitrary interface types to use as parameters, plus another one to use as a
5 //! result, and a component which exports a function and imports a function.  The exported function forwards its
6 //! parameters to the imported one and forwards the result back to the caller.  This serves to exercise Wasmtime's
7 //! lifting and lowering code and verify the values remain intact during both processes.
8 
9 use arbitrary::{Arbitrary, Unstructured};
10 use indexmap::IndexSet;
11 use proc_macro2::{Ident, TokenStream};
12 use quote::{ToTokens, format_ident, quote};
13 use std::borrow::Cow;
14 use std::fmt::{self, Debug, Write};
15 use std::hash::{Hash, Hasher};
16 use std::iter;
17 use std::ops::Deref;
18 use wasmtime_component_util::{DiscriminantSize, FlagsSize, REALLOC_AND_FREE};
19 
20 const MAX_FLAT_PARAMS: usize = 16;
21 const MAX_FLAT_ASYNC_PARAMS: usize = 4;
22 const MAX_FLAT_RESULTS: usize = 1;
23 
24 /// The name of the imported host function which the generated component will call
25 pub const IMPORT_FUNCTION: &str = "echo-import";
26 
27 /// The name of the exported guest function which the host should call
28 pub const EXPORT_FUNCTION: &str = "echo-export";
29 
30 /// Wasmtime allows up to 100 type depth so limit this to just under that.
31 pub const MAX_TYPE_DEPTH: u32 = 99;
32 
33 macro_rules! uwriteln {
34     ($($arg:tt)*) => {
35         writeln!($($arg)*).unwrap()
36     };
37 }
38 
39 macro_rules! uwrite {
40     ($($arg:tt)*) => {
41         write!($($arg)*).unwrap()
42     };
43 }
44 
45 #[derive(Debug, Copy, Clone, PartialEq, Eq)]
46 enum CoreType {
47     I32,
48     I64,
49     F32,
50     F64,
51 }
52 
53 impl CoreType {
54     /// This is the `join` operation specified in [the canonical
55     /// ABI](https://github.com/WebAssembly/component-model/blob/main/design/mvp/CanonicalABI.md#flattening) for
56     /// variant types.
57     fn join(self, other: Self) -> Self {
58         match (self, other) {
59             _ if self == other => self,
60             (Self::I32, Self::F32) | (Self::F32, Self::I32) => Self::I32,
61             _ => Self::I64,
62         }
63     }
64 }
65 
66 impl fmt::Display for CoreType {
67     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
68         match self {
69             Self::I32 => f.write_str("i32"),
70             Self::I64 => f.write_str("i64"),
71             Self::F32 => f.write_str("f32"),
72             Self::F64 => f.write_str("f64"),
73         }
74     }
75 }
76 
77 /// Wraps a `Box<[T]>` and provides an `Arbitrary` implementation that always generates slices of length less than
78 /// or equal to the longest tuple for which Wasmtime generates a `ComponentType` impl
79 #[derive(Debug, Clone)]
80 pub struct VecInRange<T, const L: u32, const H: u32>(Vec<T>);
81 
82 impl<T, const L: u32, const H: u32> VecInRange<T, L, H> {
83     fn new<'a>(
84         input: &mut Unstructured<'a>,
85         fuel: &mut u32,
86         generate: impl Fn(&mut Unstructured<'a>, &mut u32) -> arbitrary::Result<T>,
87     ) -> arbitrary::Result<Self> {
88         let mut ret = Vec::new();
89         input.arbitrary_loop(Some(L), Some(H), |input| {
90             if *fuel > 0 {
91                 *fuel = *fuel - 1;
92                 ret.push(generate(input, fuel)?);
93                 Ok(std::ops::ControlFlow::Continue(()))
94             } else {
95                 Ok(std::ops::ControlFlow::Break(()))
96             }
97         })?;
98         Ok(Self(ret))
99     }
100 }
101 
102 impl<T, const L: u32, const H: u32> Deref for VecInRange<T, L, H> {
103     type Target = [T];
104 
105     fn deref(&self) -> &[T] {
106         self.0.deref()
107     }
108 }
109 
110 /// Represents a component model interface type
111 #[expect(missing_docs, reason = "self-describing")]
112 #[derive(Debug, Clone)]
113 pub enum Type {
114     Bool,
115     S8,
116     U8,
117     S16,
118     U16,
119     S32,
120     U32,
121     S64,
122     U64,
123     Float32,
124     Float64,
125     Char,
126     String,
127     List(Box<Type>),
128 
129     // Give records the ability to generate a generous amount of fields but
130     // don't let the fuzzer go too wild since `wasmparser`'s validator currently
131     // has hard limits in the 1000-ish range on the number of fields a record
132     // may contain.
133     Record(VecInRange<Type, 1, 200>),
134 
135     // Tuples can only have up to 16 type parameters in wasmtime right now for
136     // the static API, but the standard library only supports `Debug` up to 11
137     // elements, so compromise at an even 10.
138     Tuple(VecInRange<Type, 1, 10>),
139 
140     // Like records, allow a good number of variants, but variants require at
141     // least one case.
142     Variant(VecInRange<Option<Type>, 1, 200>),
143     Enum(u32),
144 
145     Option(Box<Type>),
146     Result {
147         ok: Option<Box<Type>>,
148         err: Option<Box<Type>>,
149     },
150 
151     Flags(u32),
152 }
153 
154 impl Type {
155     pub fn generate(
156         u: &mut Unstructured<'_>,
157         depth: u32,
158         fuel: &mut u32,
159     ) -> arbitrary::Result<Type> {
160         *fuel = fuel.saturating_sub(1);
161         let max = if depth == 0 || *fuel == 0 { 12 } else { 20 };
162         Ok(match u.int_in_range(0..=max)? {
163             0 => Type::Bool,
164             1 => Type::S8,
165             2 => Type::U8,
166             3 => Type::S16,
167             4 => Type::U16,
168             5 => Type::S32,
169             6 => Type::U32,
170             7 => Type::S64,
171             8 => Type::U64,
172             9 => Type::Float32,
173             10 => Type::Float64,
174             11 => Type::Char,
175             12 => Type::String,
176             // ^-- if you add something here update the `depth == 0` case above
177             13 => Type::List(Box::new(Type::generate(u, depth - 1, fuel)?)),
178             14 => Type::Record(Type::generate_list(u, depth - 1, fuel)?),
179             15 => Type::Tuple(Type::generate_list(u, depth - 1, fuel)?),
180             16 => Type::Variant(VecInRange::new(u, fuel, |u, fuel| {
181                 Type::generate_opt(u, depth - 1, fuel)
182             })?),
183             17 => {
184                 let amt = u.int_in_range(1..=(*fuel).max(1).min(257))?;
185                 *fuel -= amt;
186                 Type::Enum(amt)
187             }
188             18 => Type::Option(Box::new(Type::generate(u, depth - 1, fuel)?)),
189             19 => Type::Result {
190                 ok: Type::generate_opt(u, depth - 1, fuel)?.map(Box::new),
191                 err: Type::generate_opt(u, depth - 1, fuel)?.map(Box::new),
192             },
193             20 => {
194                 let amt = u.int_in_range(1..=(*fuel).min(32))?;
195                 *fuel -= amt;
196                 Type::Flags(amt)
197             }
198             // ^-- if you add something here update the `depth != 0` case above
199             _ => unreachable!(),
200         })
201     }
202 
203     fn generate_opt(
204         u: &mut Unstructured<'_>,
205         depth: u32,
206         fuel: &mut u32,
207     ) -> arbitrary::Result<Option<Type>> {
208         Ok(if u.arbitrary()? {
209             Some(Type::generate(u, depth, fuel)?)
210         } else {
211             None
212         })
213     }
214 
215     fn generate_list<const L: u32, const H: u32>(
216         u: &mut Unstructured<'_>,
217         depth: u32,
218         fuel: &mut u32,
219     ) -> arbitrary::Result<VecInRange<Type, L, H>> {
220         VecInRange::new(u, fuel, |u, fuel| Type::generate(u, depth, fuel))
221     }
222 
223     /// Generates text format wasm into `s` to store a value of this type, in
224     /// its flat representation stored in the `locals` provided, to the local
225     /// named `ptr` at the `offset` provided.
226     ///
227     /// This will register helper functions necessary in `helpers`. The
228     /// `locals` iterator will be advanced for all locals consumed by this
229     /// store operation.
230     fn store_flat<'a>(
231         &'a self,
232         s: &mut String,
233         ptr: &str,
234         offset: u32,
235         locals: &mut dyn Iterator<Item = FlatSource>,
236         helpers: &mut IndexSet<Helper<'a>>,
237     ) {
238         enum Kind {
239             Primitive(&'static str),
240             PointerPair,
241             Helper,
242         }
243         let kind = match self {
244             Type::Bool | Type::S8 | Type::U8 => Kind::Primitive("i32.store8"),
245             Type::S16 | Type::U16 => Kind::Primitive("i32.store16"),
246             Type::S32 | Type::U32 | Type::Char => Kind::Primitive("i32.store"),
247             Type::S64 | Type::U64 => Kind::Primitive("i64.store"),
248             Type::Float32 => Kind::Primitive("f32.store"),
249             Type::Float64 => Kind::Primitive("f64.store"),
250             Type::String | Type::List(_) => Kind::PointerPair,
251             Type::Enum(n) if *n <= (1 << 8) => Kind::Primitive("i32.store8"),
252             Type::Enum(n) if *n <= (1 << 16) => Kind::Primitive("i32.store16"),
253             Type::Enum(_) => Kind::Primitive("i32.store"),
254             Type::Flags(n) if *n <= 8 => Kind::Primitive("i32.store8"),
255             Type::Flags(n) if *n <= 16 => Kind::Primitive("i32.store16"),
256             Type::Flags(n) if *n <= 32 => Kind::Primitive("i32.store"),
257             Type::Flags(_) => unreachable!(),
258             Type::Record(_)
259             | Type::Tuple(_)
260             | Type::Variant(_)
261             | Type::Option(_)
262             | Type::Result { .. } => Kind::Helper,
263         };
264 
265         match kind {
266             Kind::Primitive(op) => uwriteln!(
267                 s,
268                 "({op} offset={offset} (local.get {ptr}) {})",
269                 locals.next().unwrap()
270             ),
271             Kind::PointerPair => {
272                 let abi_ptr = locals.next().unwrap();
273                 let abi_len = locals.next().unwrap();
274                 uwriteln!(s, "(i32.store offset={offset} (local.get {ptr}) {abi_ptr})",);
275                 let offset = offset + 4;
276                 uwriteln!(s, "(i32.store offset={offset} (local.get {ptr}) {abi_len})",);
277             }
278             Kind::Helper => {
279                 let (index, _) = helpers.insert_full(Helper(self));
280                 uwriteln!(s, "(i32.add (local.get {ptr}) (i32.const {offset}))");
281                 for _ in 0..self.lowered().len() {
282                     let i = locals.next().unwrap();
283                     uwriteln!(s, "{i}");
284                 }
285                 uwriteln!(s, "call $store_helper_{index}");
286             }
287         }
288     }
289 
290     /// Generates a text-format wasm function which takes a pointer and this
291     /// type's flat representation as arguments and then stores this value in
292     /// the first argument.
293     ///
294     /// This is used to store records/variants to cut down on the size of final
295     /// functions and make codegen here a bit easier.
296     fn store_flat_helper<'a>(
297         &'a self,
298         s: &mut String,
299         i: usize,
300         helpers: &mut IndexSet<Helper<'a>>,
301     ) {
302         uwrite!(s, "(func $store_helper_{i} (param i32)");
303         let lowered = self.lowered();
304         for ty in &lowered {
305             uwrite!(s, " (param {ty})");
306         }
307         s.push_str("\n");
308         let locals = (0..lowered.len() as u32).map(|i| i + 1).collect::<Vec<_>>();
309         let record = |s: &mut String, helpers: &mut IndexSet<Helper<'a>>, types: &'a [Type]| {
310             let mut locals = locals.iter().cloned().map(FlatSource::Local);
311             for (offset, ty) in record_field_offsets(types) {
312                 ty.store_flat(s, "0", offset, &mut locals, helpers);
313             }
314             assert!(locals.next().is_none());
315         };
316         let variant = |s: &mut String,
317                        helpers: &mut IndexSet<Helper<'a>>,
318                        types: &[Option<&'a Type>]| {
319             let (size, offset) = variant_memory_info(types.iter().cloned());
320             // One extra block for out-of-bounds discriminants.
321             for _ in 0..types.len() + 1 {
322                 s.push_str("block\n");
323             }
324 
325             // Store the discriminant in memory, then branch on it to figure
326             // out which case we're in.
327             let store = match size {
328                 DiscriminantSize::Size1 => "i32.store8",
329                 DiscriminantSize::Size2 => "i32.store16",
330                 DiscriminantSize::Size4 => "i32.store",
331             };
332             uwriteln!(s, "({store} (local.get 0) (local.get 1))");
333             s.push_str("local.get 1\n");
334             s.push_str("br_table");
335             for i in 0..types.len() + 1 {
336                 uwrite!(s, " {i}");
337             }
338             s.push_str("\nend\n");
339 
340             // Store each payload individually while converting locals from
341             // their source types to the precise type necessary for this
342             // variant.
343             for ty in types {
344                 if let Some(ty) = ty {
345                     let ty_lowered = ty.lowered();
346                     let mut locals = locals[1..].iter().zip(&lowered[1..]).zip(&ty_lowered).map(
347                         |((i, from), to)| FlatSource::LocalConvert {
348                             local: *i,
349                             from: *from,
350                             to: *to,
351                         },
352                     );
353                     ty.store_flat(s, "0", offset, &mut locals, helpers);
354                 }
355                 s.push_str("return\n");
356                 s.push_str("end\n");
357             }
358 
359             // Catch-all result which is for out-of-bounds discriminants.
360             s.push_str("unreachable\n");
361         };
362         match self {
363             Type::Bool
364             | Type::S8
365             | Type::U8
366             | Type::S16
367             | Type::U16
368             | Type::S32
369             | Type::U32
370             | Type::Char
371             | Type::S64
372             | Type::U64
373             | Type::Float32
374             | Type::Float64
375             | Type::String
376             | Type::List(_)
377             | Type::Flags(_)
378             | Type::Enum(_) => unreachable!(),
379 
380             Type::Record(r) => record(s, helpers, r),
381             Type::Tuple(t) => record(s, helpers, t),
382             Type::Variant(v) => variant(
383                 s,
384                 helpers,
385                 &v.iter().map(|t| t.as_ref()).collect::<Vec<_>>(),
386             ),
387             Type::Option(o) => variant(s, helpers, &[None, Some(&**o)]),
388             Type::Result { ok, err } => variant(s, helpers, &[ok.as_deref(), err.as_deref()]),
389         };
390         s.push_str(")\n");
391     }
392 
393     /// Same as `store_flat`, except loads the flat values from `ptr+offset`.
394     ///
395     /// Results are placed directly on the wasm stack.
396     fn load_flat<'a>(
397         &'a self,
398         s: &mut String,
399         ptr: &str,
400         offset: u32,
401         helpers: &mut IndexSet<Helper<'a>>,
402     ) {
403         enum Kind {
404             Primitive(&'static str),
405             PointerPair,
406             Helper,
407         }
408         let kind = match self {
409             Type::Bool | Type::U8 => Kind::Primitive("i32.load8_u"),
410             Type::S8 => Kind::Primitive("i32.load8_s"),
411             Type::U16 => Kind::Primitive("i32.load16_u"),
412             Type::S16 => Kind::Primitive("i32.load16_s"),
413             Type::U32 | Type::S32 | Type::Char => Kind::Primitive("i32.load"),
414             Type::U64 | Type::S64 => Kind::Primitive("i64.load"),
415             Type::Float32 => Kind::Primitive("f32.load"),
416             Type::Float64 => Kind::Primitive("f64.load"),
417             Type::String | Type::List(_) => Kind::PointerPair,
418             Type::Enum(n) if *n <= (1 << 8) => Kind::Primitive("i32.load8_u"),
419             Type::Enum(n) if *n <= (1 << 16) => Kind::Primitive("i32.load16_u"),
420             Type::Enum(_) => Kind::Primitive("i32.load"),
421             Type::Flags(n) if *n <= 8 => Kind::Primitive("i32.load8_u"),
422             Type::Flags(n) if *n <= 16 => Kind::Primitive("i32.load16_u"),
423             Type::Flags(n) if *n <= 32 => Kind::Primitive("i32.load"),
424             Type::Flags(_) => unreachable!(),
425 
426             Type::Record(_)
427             | Type::Tuple(_)
428             | Type::Variant(_)
429             | Type::Option(_)
430             | Type::Result { .. } => Kind::Helper,
431         };
432         match kind {
433             Kind::Primitive(op) => uwriteln!(s, "({op} offset={offset} (local.get {ptr}))"),
434             Kind::PointerPair => {
435                 uwriteln!(s, "(i32.load offset={offset} (local.get {ptr}))",);
436                 let offset = offset + 4;
437                 uwriteln!(s, "(i32.load offset={offset} (local.get {ptr}))",);
438             }
439             Kind::Helper => {
440                 let (index, _) = helpers.insert_full(Helper(self));
441                 uwriteln!(s, "(i32.add (local.get {ptr}) (i32.const {offset}))");
442                 uwriteln!(s, "call $load_helper_{index}");
443             }
444         }
445     }
446 
447     /// Same as `store_flat_helper` but for loading the flat representation.
448     fn load_flat_helper<'a>(
449         &'a self,
450         s: &mut String,
451         i: usize,
452         helpers: &mut IndexSet<Helper<'a>>,
453     ) {
454         uwrite!(s, "(func $load_helper_{i} (param i32)");
455         let lowered = self.lowered();
456         for ty in &lowered {
457             uwrite!(s, " (result {ty})");
458         }
459         s.push_str("\n");
460         let record = |s: &mut String, helpers: &mut IndexSet<Helper<'a>>, types: &'a [Type]| {
461             for (offset, ty) in record_field_offsets(types) {
462                 ty.load_flat(s, "0", offset, helpers);
463             }
464         };
465         let variant = |s: &mut String,
466                        helpers: &mut IndexSet<Helper<'a>>,
467                        types: &[Option<&'a Type>]| {
468             let (size, offset) = variant_memory_info(types.iter().cloned());
469 
470             // Destination locals where the flat representation will be stored.
471             // These are automatically zero which handles unused fields too.
472             for (i, ty) in lowered.iter().enumerate() {
473                 uwriteln!(s, " (local $r{i} {ty})");
474             }
475 
476             // Return block each case jumps to after setting all locals.
477             s.push_str("block $r\n");
478 
479             // One extra block for "out of bounds discriminant".
480             for _ in 0..types.len() + 1 {
481                 s.push_str("block\n");
482             }
483 
484             // Load the discriminant and branch on it, storing it in
485             // `$r0` as well which is the first flat local representation.
486             let load = match size {
487                 DiscriminantSize::Size1 => "i32.load8_u",
488                 DiscriminantSize::Size2 => "i32.load16",
489                 DiscriminantSize::Size4 => "i32.load",
490             };
491             uwriteln!(s, "({load} (local.get 0))");
492             s.push_str("local.tee $r0\n");
493             s.push_str("br_table");
494             for i in 0..types.len() + 1 {
495                 uwrite!(s, " {i}");
496             }
497             s.push_str("\nend\n");
498 
499             // For each payload, which is in its own block, load payloads from
500             // memory as necessary and convert them into the final locals.
501             for ty in types {
502                 if let Some(ty) = ty {
503                     let ty_lowered = ty.lowered();
504                     ty.load_flat(s, "0", offset, helpers);
505                     for (i, (from, to)) in ty_lowered.iter().zip(&lowered[1..]).enumerate().rev() {
506                         let i = i + 1;
507                         match (from, to) {
508                             (CoreType::F32, CoreType::I32) => {
509                                 s.push_str("i32.reinterpret_f32\n");
510                             }
511                             (CoreType::I32, CoreType::I64) => {
512                                 s.push_str("i64.extend_i32_u\n");
513                             }
514                             (CoreType::F32, CoreType::I64) => {
515                                 s.push_str("i32.reinterpret_f32\n");
516                                 s.push_str("i64.extend_i32_u\n");
517                             }
518                             (CoreType::F64, CoreType::I64) => {
519                                 s.push_str("i64.reinterpret_f64\n");
520                             }
521                             (a, b) if a == b => {}
522                             _ => unimplemented!("convert {from:?} to {to:?}"),
523                         }
524                         uwriteln!(s, "local.set $r{i}");
525                     }
526                 }
527                 s.push_str("br $r\n");
528                 s.push_str("end\n");
529             }
530 
531             // The catch-all block for out-of-bounds discriminants.
532             s.push_str("unreachable\n");
533             s.push_str("end\n");
534             for i in 0..lowered.len() {
535                 uwriteln!(s, " local.get $r{i}");
536             }
537         };
538 
539         match self {
540             Type::Bool
541             | Type::S8
542             | Type::U8
543             | Type::S16
544             | Type::U16
545             | Type::S32
546             | Type::U32
547             | Type::Char
548             | Type::S64
549             | Type::U64
550             | Type::Float32
551             | Type::Float64
552             | Type::String
553             | Type::List(_)
554             | Type::Flags(_)
555             | Type::Enum(_) => unreachable!(),
556 
557             Type::Record(r) => record(s, helpers, r),
558             Type::Tuple(t) => record(s, helpers, t),
559             Type::Variant(v) => variant(
560                 s,
561                 helpers,
562                 &v.iter().map(|t| t.as_ref()).collect::<Vec<_>>(),
563             ),
564             Type::Option(o) => variant(s, helpers, &[None, Some(&**o)]),
565             Type::Result { ok, err } => variant(s, helpers, &[ok.as_deref(), err.as_deref()]),
566         };
567         s.push_str(")\n");
568     }
569 }
570 
571 #[derive(Clone)]
572 enum FlatSource {
573     Local(u32),
574     LocalConvert {
575         local: u32,
576         from: CoreType,
577         to: CoreType,
578     },
579 }
580 
581 impl fmt::Display for FlatSource {
582     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
583         match self {
584             FlatSource::Local(i) => write!(f, "(local.get {i})"),
585             FlatSource::LocalConvert { local, from, to } => {
586                 match (from, to) {
587                     (a, b) if a == b => write!(f, "(local.get {local})"),
588                     (CoreType::I32, CoreType::F32) => {
589                         write!(f, "(f32.reinterpret_i32 (local.get {local}))")
590                     }
591                     (CoreType::I64, CoreType::I32) => {
592                         write!(f, "(i32.wrap_i64 (local.get {local}))")
593                     }
594                     (CoreType::I64, CoreType::F64) => {
595                         write!(f, "(f64.reinterpret_i64 (local.get {local}))")
596                     }
597                     (CoreType::I64, CoreType::F32) => {
598                         write!(
599                             f,
600                             "(f32.reinterpret_i32 (i32.wrap_i64 (local.get {local})))"
601                         )
602                     }
603                     _ => unimplemented!("convert {from:?} to {to:?}"),
604                 }
605                 // ..
606             }
607         }
608     }
609 }
610 
611 fn lower_record<'a>(types: impl Iterator<Item = &'a Type>, vec: &mut Vec<CoreType>) {
612     for ty in types {
613         ty.lower(vec);
614     }
615 }
616 
617 fn lower_variant<'a>(types: impl Iterator<Item = Option<&'a Type>>, vec: &mut Vec<CoreType>) {
618     vec.push(CoreType::I32);
619     let offset = vec.len();
620     for ty in types {
621         let ty = match ty {
622             Some(ty) => ty,
623             None => continue,
624         };
625         for (index, ty) in ty.lowered().iter().enumerate() {
626             let index = offset + index;
627             if index < vec.len() {
628                 vec[index] = vec[index].join(*ty);
629             } else {
630                 vec.push(*ty)
631             }
632         }
633     }
634 }
635 
636 fn u32_count_from_flag_count(count: usize) -> usize {
637     match FlagsSize::from_count(count) {
638         FlagsSize::Size0 => 0,
639         FlagsSize::Size1 | FlagsSize::Size2 => 1,
640         FlagsSize::Size4Plus(n) => n.into(),
641     }
642 }
643 
644 struct SizeAndAlignment {
645     size: usize,
646     alignment: u32,
647 }
648 
649 impl Type {
650     fn lowered(&self) -> Vec<CoreType> {
651         let mut vec = Vec::new();
652         self.lower(&mut vec);
653         vec
654     }
655 
656     fn lower(&self, vec: &mut Vec<CoreType>) {
657         match self {
658             Type::Bool
659             | Type::U8
660             | Type::S8
661             | Type::S16
662             | Type::U16
663             | Type::S32
664             | Type::U32
665             | Type::Char
666             | Type::Enum(_) => vec.push(CoreType::I32),
667             Type::S64 | Type::U64 => vec.push(CoreType::I64),
668             Type::Float32 => vec.push(CoreType::F32),
669             Type::Float64 => vec.push(CoreType::F64),
670             Type::String | Type::List(_) => {
671                 vec.push(CoreType::I32);
672                 vec.push(CoreType::I32);
673             }
674             Type::Record(types) => lower_record(types.iter(), vec),
675             Type::Tuple(types) => lower_record(types.0.iter(), vec),
676             Type::Variant(types) => lower_variant(types.0.iter().map(|t| t.as_ref()), vec),
677             Type::Option(ty) => lower_variant([None, Some(&**ty)].into_iter(), vec),
678             Type::Result { ok, err } => {
679                 lower_variant([ok.as_deref(), err.as_deref()].into_iter(), vec)
680             }
681             Type::Flags(count) => vec.extend(
682                 iter::repeat(CoreType::I32).take(u32_count_from_flag_count(*count as usize)),
683             ),
684         }
685     }
686 
687     fn size_and_alignment(&self) -> SizeAndAlignment {
688         match self {
689             Type::Bool | Type::S8 | Type::U8 => SizeAndAlignment {
690                 size: 1,
691                 alignment: 1,
692             },
693 
694             Type::S16 | Type::U16 => SizeAndAlignment {
695                 size: 2,
696                 alignment: 2,
697             },
698 
699             Type::S32 | Type::U32 | Type::Char | Type::Float32 => SizeAndAlignment {
700                 size: 4,
701                 alignment: 4,
702             },
703 
704             Type::S64 | Type::U64 | Type::Float64 => SizeAndAlignment {
705                 size: 8,
706                 alignment: 8,
707             },
708 
709             Type::String | Type::List(_) => SizeAndAlignment {
710                 size: 8,
711                 alignment: 4,
712             },
713 
714             Type::Record(types) => record_size_and_alignment(types.iter()),
715 
716             Type::Tuple(types) => record_size_and_alignment(types.0.iter()),
717 
718             Type::Variant(types) => variant_size_and_alignment(types.0.iter().map(|t| t.as_ref())),
719 
720             Type::Enum(count) => variant_size_and_alignment((0..*count).map(|_| None)),
721 
722             Type::Option(ty) => variant_size_and_alignment([None, Some(&**ty)].into_iter()),
723 
724             Type::Result { ok, err } => {
725                 variant_size_and_alignment([ok.as_deref(), err.as_deref()].into_iter())
726             }
727 
728             Type::Flags(count) => match FlagsSize::from_count(*count as usize) {
729                 FlagsSize::Size0 => SizeAndAlignment {
730                     size: 0,
731                     alignment: 1,
732                 },
733                 FlagsSize::Size1 => SizeAndAlignment {
734                     size: 1,
735                     alignment: 1,
736                 },
737                 FlagsSize::Size2 => SizeAndAlignment {
738                     size: 2,
739                     alignment: 2,
740                 },
741                 FlagsSize::Size4Plus(n) => SizeAndAlignment {
742                     size: usize::from(n) * 4,
743                     alignment: 4,
744                 },
745             },
746         }
747     }
748 }
749 
750 fn align_to(a: usize, align: u32) -> usize {
751     let align = align as usize;
752     (a + (align - 1)) & !(align - 1)
753 }
754 
755 fn record_field_offsets<'a>(
756     types: impl IntoIterator<Item = &'a Type>,
757 ) -> impl Iterator<Item = (u32, &'a Type)> {
758     let mut offset = 0;
759     types.into_iter().map(move |ty| {
760         let SizeAndAlignment { size, alignment } = ty.size_and_alignment();
761         let ret = align_to(offset, alignment);
762         offset = ret + size;
763         (ret as u32, ty)
764     })
765 }
766 
767 fn record_size_and_alignment<'a>(types: impl IntoIterator<Item = &'a Type>) -> SizeAndAlignment {
768     let mut offset = 0;
769     let mut align = 1;
770     for ty in types {
771         let SizeAndAlignment { size, alignment } = ty.size_and_alignment();
772         offset = align_to(offset, alignment) + size;
773         align = align.max(alignment);
774     }
775 
776     SizeAndAlignment {
777         size: align_to(offset, align),
778         alignment: align,
779     }
780 }
781 
782 fn variant_size_and_alignment<'a>(
783     types: impl ExactSizeIterator<Item = Option<&'a Type>>,
784 ) -> SizeAndAlignment {
785     let discriminant_size = DiscriminantSize::from_count(types.len()).unwrap();
786     let mut alignment = u32::from(discriminant_size);
787     let mut size = 0;
788     for ty in types {
789         if let Some(ty) = ty {
790             let size_and_alignment = ty.size_and_alignment();
791             alignment = alignment.max(size_and_alignment.alignment);
792             size = size.max(size_and_alignment.size);
793         }
794     }
795 
796     SizeAndAlignment {
797         size: align_to(
798             align_to(usize::from(discriminant_size), alignment) + size,
799             alignment,
800         ),
801         alignment,
802     }
803 }
804 
805 fn variant_memory_info<'a>(
806     types: impl ExactSizeIterator<Item = Option<&'a Type>>,
807 ) -> (DiscriminantSize, u32) {
808     let discriminant_size = DiscriminantSize::from_count(types.len()).unwrap();
809     let mut alignment = u32::from(discriminant_size);
810     for ty in types {
811         if let Some(ty) = ty {
812             let size_and_alignment = ty.size_and_alignment();
813             alignment = alignment.max(size_and_alignment.alignment);
814         }
815     }
816 
817     (
818         discriminant_size,
819         align_to(usize::from(discriminant_size), alignment) as u32,
820     )
821 }
822 
823 /// Generates the internals of a core wasm module which imports a single
824 /// component function `IMPORT_FUNCTION` and exports a single component
825 /// function `EXPORT_FUNCTION`.
826 ///
827 /// The component function takes `params` as arguments and optionally returns
828 /// `result`. The `lift_abi` and `lower_abi` fields indicate the ABI in-use for
829 /// this operation.
830 fn make_import_and_export(
831     params: &[&Type],
832     result: Option<&Type>,
833     lift_abi: LiftAbi,
834     lower_abi: LowerAbi,
835 ) -> String {
836     let params_lowered = params
837         .iter()
838         .flat_map(|ty| ty.lowered())
839         .collect::<Box<[_]>>();
840     let result_lowered = result.map(|t| t.lowered()).unwrap_or(Vec::new());
841 
842     let mut wat = String::new();
843 
844     enum Location {
845         Flat,
846         Indirect(u32),
847     }
848 
849     // Generate the core wasm type corresponding to the imported function being
850     // lowered with `lower_abi`.
851     wat.push_str(&format!("(type $import (func"));
852     let max_import_params = match lower_abi {
853         LowerAbi::Sync => MAX_FLAT_PARAMS,
854         LowerAbi::Async => MAX_FLAT_ASYNC_PARAMS,
855     };
856     let (import_params_loc, nparams) = push_params(&mut wat, &params_lowered, max_import_params);
857     let import_results_loc = match lower_abi {
858         LowerAbi::Sync => {
859             push_result_or_retptr(&mut wat, &result_lowered, nparams, MAX_FLAT_RESULTS)
860         }
861         LowerAbi::Async => {
862             let loc = if result.is_none() {
863                 Location::Flat
864             } else {
865                 wat.push_str(" (param i32)"); // result pointer
866                 Location::Indirect(nparams)
867             };
868             wat.push_str(" (result i32)"); // status code
869             loc
870         }
871     };
872     wat.push_str("))\n");
873 
874     // Generate the import function.
875     wat.push_str(&format!(
876         r#"(import "host" "{IMPORT_FUNCTION}" (func $host (type $import)))"#
877     ));
878 
879     // Do the same as above for the exported function's type which is lifted
880     // with `lift_abi`.
881     //
882     // Note that `export_results_loc` being `None` means that `task.return` is
883     // used to communicate results.
884     wat.push_str(&format!("(type $export (func"));
885     let (export_params_loc, _nparams) = push_params(&mut wat, &params_lowered, MAX_FLAT_PARAMS);
886     let export_results_loc = match lift_abi {
887         LiftAbi::Sync => Some(push_group(&mut wat, "result", &result_lowered, MAX_FLAT_RESULTS).0),
888         LiftAbi::AsyncCallback => {
889             wat.push_str(" (result i32)"); // status code
890             None
891         }
892         LiftAbi::AsyncStackful => None,
893     };
894     wat.push_str("))\n");
895 
896     // If the export is async, generate `task.return` as an import as well
897     // which is necesary to communicate the results.
898     if export_results_loc.is_none() {
899         wat.push_str(&format!("(type $task.return (func"));
900         push_params(&mut wat, &result_lowered, MAX_FLAT_PARAMS);
901         wat.push_str("))\n");
902         wat.push_str(&format!(
903             r#"(import "" "task.return" (func $task.return (type $task.return)))"#
904         ));
905     }
906 
907     wat.push_str(&format!(
908         r#"
909 (func (export "{EXPORT_FUNCTION}") (type $export)
910     (local $retptr i32)
911     (local $argptr i32)
912         "#
913     ));
914     let mut store_helpers = IndexSet::new();
915     let mut load_helpers = IndexSet::new();
916 
917     match (export_params_loc, import_params_loc) {
918         // flat => flat is just moving locals around
919         (Location::Flat, Location::Flat) => {
920             for (index, _) in params_lowered.iter().enumerate() {
921                 uwrite!(wat, "local.get {index}\n");
922             }
923         }
924 
925         // indirect => indirect is just moving locals around
926         (Location::Indirect(i), Location::Indirect(j)) => {
927             assert_eq!(j, 0);
928             uwrite!(wat, "local.get {i}\n");
929         }
930 
931         // flat => indirect means that all parameters are stored in memory as
932         // if it was a record of all the parameters.
933         (Location::Flat, Location::Indirect(_)) => {
934             let SizeAndAlignment { size, alignment } =
935                 record_size_and_alignment(params.iter().cloned());
936             wat.push_str(&format!(
937                 r#"
938                     (local.set $argptr
939                         (call $realloc
940                             (i32.const 0)
941                             (i32.const 0)
942                             (i32.const {alignment})
943                             (i32.const {size})))
944                     local.get $argptr
945                 "#
946             ));
947             let mut locals = (0..params_lowered.len() as u32).map(FlatSource::Local);
948             for (offset, ty) in record_field_offsets(params.iter().cloned()) {
949                 ty.store_flat(&mut wat, "$argptr", offset, &mut locals, &mut store_helpers);
950             }
951             assert!(locals.next().is_none());
952         }
953 
954         (Location::Indirect(_), Location::Flat) => unreachable!(),
955     }
956 
957     // Pass a return-pointer if necessary.
958     match import_results_loc {
959         Location::Flat => {}
960         Location::Indirect(_) => {
961             let SizeAndAlignment { size, alignment } = result.unwrap().size_and_alignment();
962 
963             wat.push_str(&format!(
964                 r#"
965                     (local.set $retptr
966                         (call $realloc
967                             (i32.const 0)
968                             (i32.const 0)
969                             (i32.const {alignment})
970                             (i32.const {size})))
971                     local.get $retptr
972                 "#
973             ));
974         }
975     }
976 
977     wat.push_str("call $host\n");
978 
979     // Assert the lowered call is ready if an async code was returned.
980     //
981     // TODO: handle when the import isn't ready yet
982     if let LowerAbi::Async = lower_abi {
983         wat.push_str("i32.const 2\n");
984         wat.push_str("i32.ne\n");
985         wat.push_str("if unreachable end\n");
986     }
987 
988     // TODO: conditionally inject a yield here
989 
990     match (import_results_loc, export_results_loc) {
991         // flat => flat results involves nothing, the results are already on
992         // the stack.
993         (Location::Flat, Some(Location::Flat)) => {}
994 
995         // indirect => indirect results requires returning the `$retptr` the
996         // host call filled in.
997         (Location::Indirect(_), Some(Location::Indirect(_))) => {
998             wat.push_str("local.get $retptr\n");
999         }
1000 
1001         // indirect => flat requires loading the result from the return pointer
1002         (Location::Indirect(_), Some(Location::Flat)) => {
1003             result
1004                 .unwrap()
1005                 .load_flat(&mut wat, "$retptr", 0, &mut load_helpers);
1006         }
1007 
1008         // flat => task.return is easy, the results are already there so just
1009         // call the function.
1010         (Location::Flat, None) => {
1011             wat.push_str("call $task.return\n");
1012         }
1013 
1014         // indirect => task.return needs to forward `$retptr` if the results
1015         // are indirect, or otherwise it must be loaded from memory to a flat
1016         // representation.
1017         (Location::Indirect(_), None) => {
1018             if result_lowered.len() <= MAX_FLAT_PARAMS {
1019                 result
1020                     .unwrap()
1021                     .load_flat(&mut wat, "$retptr", 0, &mut load_helpers);
1022             } else {
1023                 wat.push_str("local.get $retptr\n");
1024             }
1025             wat.push_str("call $task.return\n");
1026         }
1027 
1028         (Location::Flat, Some(Location::Indirect(_))) => unreachable!(),
1029     }
1030 
1031     if let LiftAbi::AsyncCallback = lift_abi {
1032         wat.push_str("i32.const 0\n"); // completed status code
1033     }
1034 
1035     wat.push_str(")\n");
1036 
1037     // Generate a `callback` function for the callback ABI.
1038     //
1039     // TODO: fill this in
1040     if let LiftAbi::AsyncCallback = lift_abi {
1041         wat.push_str(
1042             r#"
1043 (func (export "callback") (param i32 i32 i32) (result i32) unreachable)
1044             "#,
1045         );
1046     }
1047 
1048     // Fill out all store/load helpers that were needed during generation
1049     // above. This is a fix-point-loop since each helper may end up requiring
1050     // more helpers.
1051     let mut i = 0;
1052     while i < store_helpers.len() {
1053         let ty = store_helpers[i].0;
1054         ty.store_flat_helper(&mut wat, i, &mut store_helpers);
1055         i += 1;
1056     }
1057     i = 0;
1058     while i < load_helpers.len() {
1059         let ty = load_helpers[i].0;
1060         ty.load_flat_helper(&mut wat, i, &mut load_helpers);
1061         i += 1;
1062     }
1063 
1064     return wat;
1065 
1066     fn push_params(wat: &mut String, params: &[CoreType], max_flat: usize) -> (Location, u32) {
1067         push_group(wat, "param", params, max_flat)
1068     }
1069 
1070     fn push_group(
1071         wat: &mut String,
1072         name: &str,
1073         params: &[CoreType],
1074         max_flat: usize,
1075     ) -> (Location, u32) {
1076         let mut nparams = 0;
1077         let loc = if params.is_empty() {
1078             // nothing to emit...
1079             Location::Flat
1080         } else if params.len() <= max_flat {
1081             wat.push_str(&format!(" ({name}"));
1082             for ty in params {
1083                 wat.push_str(&format!(" {ty}"));
1084                 nparams += 1;
1085             }
1086             wat.push_str(")");
1087             Location::Flat
1088         } else {
1089             wat.push_str(&format!(" ({name} i32)"));
1090             nparams += 1;
1091             Location::Indirect(0)
1092         };
1093         (loc, nparams)
1094     }
1095 
1096     fn push_result_or_retptr(
1097         wat: &mut String,
1098         results: &[CoreType],
1099         nparams: u32,
1100         max_flat: usize,
1101     ) -> Location {
1102         if results.is_empty() {
1103             // nothing to emit...
1104             Location::Flat
1105         } else if results.len() <= max_flat {
1106             wat.push_str(" (result");
1107             for ty in results {
1108                 wat.push_str(&format!(" {ty}"));
1109             }
1110             wat.push_str(")");
1111             Location::Flat
1112         } else {
1113             wat.push_str(" (param i32)");
1114             Location::Indirect(nparams)
1115         }
1116     }
1117 }
1118 
1119 struct Helper<'a>(&'a Type);
1120 
1121 impl Hash for Helper<'_> {
1122     fn hash<H: Hasher>(&self, h: &mut H) {
1123         std::ptr::hash(self.0, h);
1124     }
1125 }
1126 
1127 impl PartialEq for Helper<'_> {
1128     fn eq(&self, other: &Self) -> bool {
1129         std::ptr::eq(self.0, other.0)
1130     }
1131 }
1132 
1133 impl Eq for Helper<'_> {}
1134 
1135 fn make_rust_name(name_counter: &mut u32) -> Ident {
1136     let name = format_ident!("Foo{name_counter}");
1137     *name_counter += 1;
1138     name
1139 }
1140 
1141 /// Generate a [`TokenStream`] containing the rust type name for a type.
1142 ///
1143 /// The `name_counter` parameter is used to generate names for each recursively visited type.  The `declarations`
1144 /// parameter is used to accumulate declarations for each recursively visited type.
1145 pub fn rust_type(ty: &Type, name_counter: &mut u32, declarations: &mut TokenStream) -> TokenStream {
1146     match ty {
1147         Type::Bool => quote!(bool),
1148         Type::S8 => quote!(i8),
1149         Type::U8 => quote!(u8),
1150         Type::S16 => quote!(i16),
1151         Type::U16 => quote!(u16),
1152         Type::S32 => quote!(i32),
1153         Type::U32 => quote!(u32),
1154         Type::S64 => quote!(i64),
1155         Type::U64 => quote!(u64),
1156         Type::Float32 => quote!(Float32),
1157         Type::Float64 => quote!(Float64),
1158         Type::Char => quote!(char),
1159         Type::String => quote!(Box<str>),
1160         Type::List(ty) => {
1161             let ty = rust_type(ty, name_counter, declarations);
1162             quote!(Vec<#ty>)
1163         }
1164         Type::Record(types) => {
1165             let fields = types
1166                 .iter()
1167                 .enumerate()
1168                 .map(|(index, ty)| {
1169                     let name = format_ident!("f{index}");
1170                     let ty = rust_type(ty, name_counter, declarations);
1171                     quote!(#name: #ty,)
1172                 })
1173                 .collect::<TokenStream>();
1174 
1175             let name = make_rust_name(name_counter);
1176 
1177             declarations.extend(quote! {
1178                 #[derive(ComponentType, Lift, Lower, PartialEq, Debug, Clone, Arbitrary)]
1179                 #[component(record)]
1180                 struct #name {
1181                     #fields
1182                 }
1183             });
1184 
1185             quote!(#name)
1186         }
1187         Type::Tuple(types) => {
1188             let fields = types
1189                 .0
1190                 .iter()
1191                 .map(|ty| {
1192                     let ty = rust_type(ty, name_counter, declarations);
1193                     quote!(#ty,)
1194                 })
1195                 .collect::<TokenStream>();
1196 
1197             quote!((#fields))
1198         }
1199         Type::Variant(types) => {
1200             let cases = types
1201                 .0
1202                 .iter()
1203                 .enumerate()
1204                 .map(|(index, ty)| {
1205                     let name = format_ident!("C{index}");
1206                     let ty = match ty {
1207                         Some(ty) => {
1208                             let ty = rust_type(ty, name_counter, declarations);
1209                             quote!((#ty))
1210                         }
1211                         None => quote!(),
1212                     };
1213                     quote!(#name #ty,)
1214                 })
1215                 .collect::<TokenStream>();
1216 
1217             let name = make_rust_name(name_counter);
1218             declarations.extend(quote! {
1219                 #[derive(ComponentType, Lift, Lower, PartialEq, Debug, Clone, Arbitrary)]
1220                 #[component(variant)]
1221                 enum #name {
1222                     #cases
1223                 }
1224             });
1225 
1226             quote!(#name)
1227         }
1228         Type::Enum(count) => {
1229             let cases = (0..*count)
1230                 .map(|index| {
1231                     let name = format_ident!("E{index}");
1232                     quote!(#name,)
1233                 })
1234                 .collect::<TokenStream>();
1235 
1236             let name = make_rust_name(name_counter);
1237             let repr = match DiscriminantSize::from_count(*count as usize).unwrap() {
1238                 DiscriminantSize::Size1 => quote!(u8),
1239                 DiscriminantSize::Size2 => quote!(u16),
1240                 DiscriminantSize::Size4 => quote!(u32),
1241             };
1242 
1243             declarations.extend(quote! {
1244                 #[derive(ComponentType, Lift, Lower, PartialEq, Debug, Copy, Clone, Arbitrary)]
1245                 #[component(enum)]
1246                 #[repr(#repr)]
1247                 enum #name {
1248                     #cases
1249                 }
1250             });
1251 
1252             quote!(#name)
1253         }
1254         Type::Option(ty) => {
1255             let ty = rust_type(ty, name_counter, declarations);
1256             quote!(Option<#ty>)
1257         }
1258         Type::Result { ok, err } => {
1259             let ok = match ok {
1260                 Some(ok) => rust_type(ok, name_counter, declarations),
1261                 None => quote!(()),
1262             };
1263             let err = match err {
1264                 Some(err) => rust_type(err, name_counter, declarations),
1265                 None => quote!(()),
1266             };
1267             quote!(Result<#ok, #err>)
1268         }
1269         Type::Flags(count) => {
1270             let type_name = make_rust_name(name_counter);
1271 
1272             let mut flags = TokenStream::new();
1273             let mut names = TokenStream::new();
1274 
1275             for index in 0..*count {
1276                 let name = format_ident!("F{index}");
1277                 flags.extend(quote!(const #name;));
1278                 names.extend(quote!(#type_name::#name,))
1279             }
1280 
1281             declarations.extend(quote! {
1282                 wasmtime::component::flags! {
1283                     #type_name {
1284                         #flags
1285                     }
1286                 }
1287 
1288                 impl<'a> arbitrary::Arbitrary<'a> for #type_name {
1289                     fn arbitrary(input: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
1290                         let mut flags = #type_name::default();
1291                         for flag in [#names] {
1292                             if input.arbitrary()? {
1293                                 flags |= flag;
1294                             }
1295                         }
1296                         Ok(flags)
1297                     }
1298                 }
1299             });
1300 
1301             quote!(#type_name)
1302         }
1303     }
1304 }
1305 
1306 #[derive(Default)]
1307 struct TypesBuilder<'a> {
1308     next: u32,
1309     worklist: Vec<(u32, &'a Type)>,
1310 }
1311 
1312 impl<'a> TypesBuilder<'a> {
1313     fn write_ref(&mut self, ty: &'a Type, dst: &mut String) {
1314         match ty {
1315             // Primitive types can be referenced directly
1316             Type::Bool => dst.push_str("bool"),
1317             Type::S8 => dst.push_str("s8"),
1318             Type::U8 => dst.push_str("u8"),
1319             Type::S16 => dst.push_str("s16"),
1320             Type::U16 => dst.push_str("u16"),
1321             Type::S32 => dst.push_str("s32"),
1322             Type::U32 => dst.push_str("u32"),
1323             Type::S64 => dst.push_str("s64"),
1324             Type::U64 => dst.push_str("u64"),
1325             Type::Float32 => dst.push_str("float32"),
1326             Type::Float64 => dst.push_str("float64"),
1327             Type::Char => dst.push_str("char"),
1328             Type::String => dst.push_str("string"),
1329 
1330             // Otherwise emit a reference to the type and remember to generate
1331             // the corresponding type alias later.
1332             Type::List(_)
1333             | Type::Record(_)
1334             | Type::Tuple(_)
1335             | Type::Variant(_)
1336             | Type::Enum(_)
1337             | Type::Option(_)
1338             | Type::Result { .. }
1339             | Type::Flags(_) => {
1340                 let idx = self.next;
1341                 self.next += 1;
1342                 uwrite!(dst, "$t{idx}");
1343                 self.worklist.push((idx, ty));
1344             }
1345         }
1346     }
1347 
1348     fn write_decl(&mut self, idx: u32, ty: &'a Type) -> String {
1349         let mut decl = format!("(type $t{idx}' ");
1350         match ty {
1351             Type::Bool
1352             | Type::S8
1353             | Type::U8
1354             | Type::S16
1355             | Type::U16
1356             | Type::S32
1357             | Type::U32
1358             | Type::S64
1359             | Type::U64
1360             | Type::Float32
1361             | Type::Float64
1362             | Type::Char
1363             | Type::String => unreachable!(),
1364 
1365             Type::List(ty) => {
1366                 decl.push_str("(list ");
1367                 self.write_ref(ty, &mut decl);
1368                 decl.push_str(")");
1369             }
1370             Type::Record(types) => {
1371                 decl.push_str("(record");
1372                 for (index, ty) in types.iter().enumerate() {
1373                     uwrite!(decl, r#" (field "f{index}" "#);
1374                     self.write_ref(ty, &mut decl);
1375                     decl.push_str(")");
1376                 }
1377                 decl.push_str(")");
1378             }
1379             Type::Tuple(types) => {
1380                 decl.push_str("(tuple");
1381                 for ty in types.iter() {
1382                     decl.push_str(" ");
1383                     self.write_ref(ty, &mut decl);
1384                 }
1385                 decl.push_str(")");
1386             }
1387             Type::Variant(types) => {
1388                 decl.push_str("(variant");
1389                 for (index, ty) in types.iter().enumerate() {
1390                     uwrite!(decl, r#" (case "C{index}""#);
1391                     if let Some(ty) = ty {
1392                         decl.push_str(" ");
1393                         self.write_ref(ty, &mut decl);
1394                     }
1395                     decl.push_str(")");
1396                 }
1397                 decl.push_str(")");
1398             }
1399             Type::Enum(count) => {
1400                 decl.push_str("(enum");
1401                 for index in 0..*count {
1402                     uwrite!(decl, r#" "E{index}""#);
1403                 }
1404                 decl.push_str(")");
1405             }
1406             Type::Option(ty) => {
1407                 decl.push_str("(option ");
1408                 self.write_ref(ty, &mut decl);
1409                 decl.push_str(")");
1410             }
1411             Type::Result { ok, err } => {
1412                 decl.push_str("(result");
1413                 if let Some(ok) = ok {
1414                     decl.push_str(" ");
1415                     self.write_ref(ok, &mut decl);
1416                 }
1417                 if let Some(err) = err {
1418                     decl.push_str(" (error ");
1419                     self.write_ref(err, &mut decl);
1420                     decl.push_str(")");
1421                 }
1422                 decl.push_str(")");
1423             }
1424             Type::Flags(count) => {
1425                 decl.push_str("(flags");
1426                 for index in 0..*count {
1427                     uwrite!(decl, r#" "F{index}""#);
1428                 }
1429                 decl.push_str(")");
1430             }
1431         }
1432         decl.push_str(")\n");
1433         uwriteln!(decl, "(import \"t{idx}\" (type $t{idx} (eq $t{idx}')))");
1434         decl
1435     }
1436 }
1437 
1438 /// Represents custom fragments of a WAT file which may be used to create a component for exercising [`TestCase`]s
1439 #[derive(Debug)]
1440 pub struct Declarations {
1441     /// Type declarations (if any) referenced by `params` and/or `result`
1442     pub types: Cow<'static, str>,
1443     /// Types to thread through when instantiating sub-components.
1444     pub type_instantiation_args: Cow<'static, str>,
1445     /// Parameter declarations used for the imported and exported functions
1446     pub params: Cow<'static, str>,
1447     /// Result declaration used for the imported and exported functions
1448     pub results: Cow<'static, str>,
1449     /// Implementation of the "caller" component, which invokes the `callee`
1450     /// composed component.
1451     pub caller_module: Cow<'static, str>,
1452     /// Implementation of the "callee" component, which invokes the host.
1453     pub callee_module: Cow<'static, str>,
1454     /// Options used for caller/calle ABI/etc.
1455     pub options: TestCaseOptions,
1456 }
1457 
1458 impl Declarations {
1459     /// Generate a complete WAT file based on the specified fragments.
1460     pub fn make_component(&self) -> Box<str> {
1461         let Self {
1462             types,
1463             type_instantiation_args,
1464             params,
1465             results,
1466             caller_module,
1467             callee_module,
1468             options,
1469         } = self;
1470         let mk_component = |name: &str,
1471                             module: &str,
1472                             import_async: bool,
1473                             export_async: bool,
1474                             encoding: StringEncoding,
1475                             lift_abi: LiftAbi,
1476                             lower_abi: LowerAbi| {
1477             let import_async = if import_async { "async" } else { "" };
1478             let export_async = if export_async { "async" } else { "" };
1479             let lower_async_option = match lower_abi {
1480                 LowerAbi::Sync => "",
1481                 LowerAbi::Async => "async",
1482             };
1483             let lift_async_option = match lift_abi {
1484                 LiftAbi::Sync => "",
1485                 LiftAbi::AsyncStackful => "async",
1486                 LiftAbi::AsyncCallback => "async (callback (func $i \"callback\"))",
1487             };
1488 
1489             let mut intrinsic_defs = String::new();
1490             let mut intrinsic_imports = String::new();
1491 
1492             match lift_abi {
1493                 LiftAbi::Sync => {}
1494                 LiftAbi::AsyncCallback | LiftAbi::AsyncStackful => {
1495                     intrinsic_defs.push_str(&format!(
1496                         r#"
1497 (core func $task.return (canon task.return {results}
1498     (memory $libc "memory") string-encoding={encoding}))
1499                         "#,
1500                     ));
1501                     intrinsic_imports.push_str(
1502                         r#"
1503 (with "" (instance (export "task.return" (func $task.return))))
1504                         "#,
1505                     );
1506                 }
1507             }
1508 
1509             format!(
1510                 r#"
1511 (component ${name}
1512     {types}
1513     (type $import_sig (func {import_async} {params} {results}))
1514     (type $export_sig (func {export_async} {params} {results}))
1515     (import "{IMPORT_FUNCTION}" (func $f (type $import_sig)))
1516 
1517     (core instance $libc (instantiate $libc))
1518 
1519     (core func $f_lower (canon lower
1520         (func $f)
1521         (memory $libc "memory")
1522         (realloc (func $libc "realloc"))
1523         string-encoding={encoding}
1524         {lower_async_option}
1525     ))
1526 
1527     {intrinsic_defs}
1528 
1529     (core module $m
1530         (memory (import "libc" "memory") 1)
1531         (func $realloc (import "libc" "realloc") (param i32 i32 i32 i32) (result i32))
1532 
1533         {module}
1534     )
1535 
1536     (core instance $i (instantiate $m
1537         (with "libc" (instance $libc))
1538         (with "host" (instance (export "{IMPORT_FUNCTION}" (func $f_lower))))
1539         {intrinsic_imports}
1540     ))
1541 
1542     (func (export "{EXPORT_FUNCTION}") (type $export_sig)
1543         (canon lift
1544             (core func $i "{EXPORT_FUNCTION}")
1545             (memory $libc "memory")
1546             (realloc (func $libc "realloc"))
1547             string-encoding={encoding}
1548             {lift_async_option}
1549         )
1550     )
1551 )
1552             "#
1553             )
1554         };
1555 
1556         let c1 = mk_component(
1557             "callee",
1558             &callee_module,
1559             options.host_async,
1560             options.guest_callee_async,
1561             options.callee_encoding,
1562             options.callee_lift_abi,
1563             options.callee_lower_abi,
1564         );
1565         let c2 = mk_component(
1566             "caller",
1567             &caller_module,
1568             options.guest_callee_async,
1569             options.guest_caller_async,
1570             options.caller_encoding,
1571             options.caller_lift_abi,
1572             options.caller_lower_abi,
1573         );
1574         let host_async = if options.host_async { "async" } else { "" };
1575 
1576         format!(
1577             r#"
1578             (component
1579                 (core module $libc
1580                     (memory (export "memory") 1)
1581                     {REALLOC_AND_FREE}
1582                 )
1583 
1584 
1585                 {types}
1586 
1587                 (type $host_sig (func {host_async} {params} {results}))
1588                 (import "{IMPORT_FUNCTION}" (func $f (type $host_sig)))
1589 
1590                 {c1}
1591                 {c2}
1592                 (instance $c1 (instantiate $callee
1593                     {type_instantiation_args}
1594                     (with "{IMPORT_FUNCTION}" (func $f))
1595                 ))
1596                 (instance $c2 (instantiate $caller
1597                     {type_instantiation_args}
1598                     (with "{IMPORT_FUNCTION}" (func $c1 "{EXPORT_FUNCTION}"))
1599                 ))
1600                 (export "{EXPORT_FUNCTION}" (func $c2 "{EXPORT_FUNCTION}"))
1601             )"#,
1602         )
1603         .into()
1604     }
1605 }
1606 
1607 /// Represents a test case for calling a component function
1608 #[derive(Debug)]
1609 pub struct TestCase<'a> {
1610     /// The types of parameters to pass to the function
1611     pub params: Vec<&'a Type>,
1612     /// The result types of the function
1613     pub result: Option<&'a Type>,
1614     /// ABI options to use for this test case.
1615     pub options: TestCaseOptions,
1616 }
1617 
1618 /// Collection of options which configure how the caller/callee/etc ABIs are
1619 /// all configured.
1620 #[derive(Debug, Arbitrary, Copy, Clone)]
1621 pub struct TestCaseOptions {
1622     /// Whether or not the guest caller component (the entrypoint) is using an
1623     /// `async` function type.
1624     pub guest_caller_async: bool,
1625     /// Whether or not the guest callee component (what the entrypoint calls)
1626     /// is using an `async` function type.
1627     pub guest_callee_async: bool,
1628     /// Whether or not the host is using an async function type (what the
1629     /// guest callee calls).
1630     pub host_async: bool,
1631     /// The string encoding of the caller component.
1632     pub caller_encoding: StringEncoding,
1633     /// The string encoding of the callee component.
1634     pub callee_encoding: StringEncoding,
1635     /// The ABI that the caller component is using to lift its export (the main
1636     /// entrypoint).
1637     pub caller_lift_abi: LiftAbi,
1638     /// The ABI that the callee component is using to lift its export (called
1639     /// by the caller).
1640     pub callee_lift_abi: LiftAbi,
1641     /// The ABI that the caller component is using to lower its import (the
1642     /// callee's export).
1643     pub caller_lower_abi: LowerAbi,
1644     /// The ABI that the callee component is using to lower its import (the
1645     /// host function).
1646     pub callee_lower_abi: LowerAbi,
1647 }
1648 
1649 #[derive(Debug, Arbitrary, Copy, Clone)]
1650 pub enum LiftAbi {
1651     Sync,
1652     AsyncStackful,
1653     AsyncCallback,
1654 }
1655 
1656 #[derive(Debug, Arbitrary, Copy, Clone)]
1657 pub enum LowerAbi {
1658     Sync,
1659     Async,
1660 }
1661 
1662 impl<'a> TestCase<'a> {
1663     pub fn generate(types: &'a [Type], u: &mut Unstructured<'_>) -> arbitrary::Result<Self> {
1664         let max_params = if types.len() > 0 { 5 } else { 0 };
1665         let params = (0..u.int_in_range(0..=max_params)?)
1666             .map(|_| u.choose(&types))
1667             .collect::<arbitrary::Result<Vec<_>>>()?;
1668         let result = if types.len() > 0 && u.arbitrary()? {
1669             Some(u.choose(&types)?)
1670         } else {
1671             None
1672         };
1673 
1674         let mut options = u.arbitrary::<TestCaseOptions>()?;
1675 
1676         // Sync tasks cannot call async functions via a sync lower, nor can they
1677         // block in other ways (e.g. by calling `waitable-set.wait`, returning
1678         // `CALLBACK_CODE_WAIT`, etc.) prior to returning.  Therefore,
1679         // async-ness cascades to the callers:
1680         if options.host_async {
1681             options.guest_callee_async = true;
1682         }
1683         if options.guest_callee_async {
1684             options.guest_caller_async = true;
1685         }
1686 
1687         Ok(Self {
1688             params,
1689             result,
1690             options,
1691         })
1692     }
1693 
1694     /// Generate a `Declarations` for this `TestCase` which may be used to build a component to execute the case.
1695     pub fn declarations(&self) -> Declarations {
1696         let mut builder = TypesBuilder::default();
1697 
1698         let mut params = String::new();
1699         for (i, ty) in self.params.iter().enumerate() {
1700             params.push_str(&format!(" (param \"p{i}\" "));
1701             builder.write_ref(ty, &mut params);
1702             params.push_str(")");
1703         }
1704 
1705         let mut results = String::new();
1706         if let Some(ty) = self.result {
1707             results.push_str(&format!(" (result "));
1708             builder.write_ref(ty, &mut results);
1709             results.push_str(")");
1710         }
1711 
1712         let caller_module = make_import_and_export(
1713             &self.params,
1714             self.result,
1715             self.options.caller_lift_abi,
1716             self.options.caller_lower_abi,
1717         );
1718         let callee_module = make_import_and_export(
1719             &self.params,
1720             self.result,
1721             self.options.callee_lift_abi,
1722             self.options.callee_lower_abi,
1723         );
1724 
1725         let mut type_decls = Vec::new();
1726         let mut type_instantiation_args = String::new();
1727         while let Some((idx, ty)) = builder.worklist.pop() {
1728             type_decls.push(builder.write_decl(idx, ty));
1729             uwriteln!(type_instantiation_args, "(with \"t{idx}\" (type $t{idx}))");
1730         }
1731 
1732         // Note that types are printed here in reverse order since they were
1733         // pushed onto `type_decls` as they were referenced meaning the last one
1734         // is the "base" one.
1735         let mut types = String::new();
1736         for decl in type_decls.into_iter().rev() {
1737             types.push_str(&decl);
1738             types.push_str("\n");
1739         }
1740 
1741         Declarations {
1742             types: types.into(),
1743             type_instantiation_args: type_instantiation_args.into(),
1744             params: params.into(),
1745             results: results.into(),
1746             caller_module: caller_module.into(),
1747             callee_module: callee_module.into(),
1748             options: self.options,
1749         }
1750     }
1751 }
1752 
1753 #[derive(Copy, Clone, Debug, Arbitrary)]
1754 pub enum StringEncoding {
1755     Utf8,
1756     Utf16,
1757     Latin1OrUtf16,
1758 }
1759 
1760 impl fmt::Display for StringEncoding {
1761     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1762         match self {
1763             StringEncoding::Utf8 => fmt::Display::fmt(&"utf8", f),
1764             StringEncoding::Utf16 => fmt::Display::fmt(&"utf16", f),
1765             StringEncoding::Latin1OrUtf16 => fmt::Display::fmt(&"latin1+utf16", f),
1766         }
1767     }
1768 }
1769 
1770 impl ToTokens for TestCaseOptions {
1771     fn to_tokens(&self, tokens: &mut TokenStream) {
1772         let TestCaseOptions {
1773             guest_caller_async,
1774             guest_callee_async,
1775             host_async,
1776             caller_encoding,
1777             callee_encoding,
1778             caller_lift_abi,
1779             callee_lift_abi,
1780             caller_lower_abi,
1781             callee_lower_abi,
1782         } = self;
1783         tokens.extend(quote!(wasmtime_test_util::component_fuzz::TestCaseOptions {
1784             guest_caller_async: #guest_caller_async,
1785             guest_callee_async: #guest_callee_async,
1786             host_async: #host_async,
1787             caller_encoding: #caller_encoding,
1788             callee_encoding: #callee_encoding,
1789             caller_lift_abi: #caller_lift_abi,
1790             callee_lift_abi: #callee_lift_abi,
1791             caller_lower_abi: #caller_lower_abi,
1792             callee_lower_abi: #callee_lower_abi,
1793         }));
1794     }
1795 }
1796 
1797 impl ToTokens for LowerAbi {
1798     fn to_tokens(&self, tokens: &mut TokenStream) {
1799         let me = match self {
1800             LowerAbi::Sync => quote!(Sync),
1801             LowerAbi::Async => quote!(Async),
1802         };
1803         tokens.extend(quote!(wasmtime_test_util::component_fuzz::LowerAbi::#me));
1804     }
1805 }
1806 
1807 impl ToTokens for LiftAbi {
1808     fn to_tokens(&self, tokens: &mut TokenStream) {
1809         let me = match self {
1810             LiftAbi::Sync => quote!(Sync),
1811             LiftAbi::AsyncCallback => quote!(AsyncCallback),
1812             LiftAbi::AsyncStackful => quote!(AsyncStackful),
1813         };
1814         tokens.extend(quote!(wasmtime_test_util::component_fuzz::LiftAbi::#me));
1815     }
1816 }
1817 
1818 impl ToTokens for StringEncoding {
1819     fn to_tokens(&self, tokens: &mut TokenStream) {
1820         let me = match self {
1821             StringEncoding::Utf8 => quote!(Utf8),
1822             StringEncoding::Utf16 => quote!(Utf16),
1823             StringEncoding::Latin1OrUtf16 => quote!(Latin1OrUtf16),
1824         };
1825         tokens.extend(quote!(wasmtime_test_util::component_fuzz::StringEncoding::#me));
1826     }
1827 }
1828 
1829 #[cfg(test)]
1830 mod tests {
1831     use super::*;
1832 
1833     #[test]
1834     fn arbtest() {
1835         arbtest::arbtest(|u| {
1836             let mut fuel = 100;
1837             let types = (0..5)
1838                 .map(|_| Type::generate(u, 3, &mut fuel))
1839                 .collect::<arbitrary::Result<Vec<_>>>()?;
1840             let case = TestCase::generate(&types, u)?;
1841             let decls = case.declarations();
1842             let component = decls.make_component();
1843             let wasm = wat::parse_str(&component).unwrap_or_else(|e| {
1844                 panic!("failed to parse generated component as wat: {e}\n\n{component}");
1845             });
1846             wasmparser::Validator::new_with_features(wasmparser::WasmFeatures::all())
1847                 .validate_all(&wasm)
1848                 .unwrap_or_else(|e| {
1849                     let mut wat = String::new();
1850                     let mut dst = wasmprinter::PrintFmtWrite(&mut wat);
1851                     let to_print = if wasmprinter::Config::new()
1852                         .print_offsets(true)
1853                         .print_operand_stack(true)
1854                         .print(&wasm, &mut dst)
1855                         .is_ok()
1856                     {
1857                         &wat[..]
1858                     } else {
1859                         &component[..]
1860                     };
1861                     panic!("generated component is not valid wasm: {e}\n\n{to_print}");
1862                 });
1863             Ok(())
1864         })
1865         .budget_ms(1_000)
1866         // .seed(0x3c9050d4000000e9)
1867         ;
1868     }
1869 }
1870