1 //! This module generates test cases for the Wasmtime component model function APIs,
2 //! e.g. `wasmtime::component::func::Func` and `TypedFunc`.
3 //!
4 //! Each case includes a list of arbitrary interface types to use as parameters, plus another one to use as a
5 //! result, and a component which exports a function and imports a function.  The exported function forwards its
6 //! parameters to the imported one and forwards the result back to the caller.  This serves to exercise Wasmtime's
7 //! lifting and lowering code and verify the values remain intact during both processes.
8 
9 use arbitrary::{Arbitrary, Unstructured};
10 use indexmap::IndexSet;
11 use proc_macro2::{Ident, TokenStream};
12 use quote::{ToTokens, format_ident, quote};
13 use std::borrow::Cow;
14 use std::fmt::{self, Debug, Write};
15 use std::hash::{Hash, Hasher};
16 use std::iter;
17 use std::ops::Deref;
18 use wasmtime_component_util::{DiscriminantSize, FlagsSize, REALLOC_AND_FREE};
19 
20 const MAX_FLAT_PARAMS: usize = 16;
21 const MAX_FLAT_ASYNC_PARAMS: usize = 4;
22 const MAX_FLAT_RESULTS: usize = 1;
23 
24 /// The name of the imported host function which the generated component will call
25 pub const IMPORT_FUNCTION: &str = "echo-import";
26 
27 /// The name of the exported guest function which the host should call
28 pub const EXPORT_FUNCTION: &str = "echo-export";
29 
30 /// Wasmtime allows up to 100 type depth so limit this to just under that.
31 pub const MAX_TYPE_DEPTH: u32 = 99;
32 
33 macro_rules! uwriteln {
34     ($($arg:tt)*) => {
35         writeln!($($arg)*).unwrap()
36     };
37 }
38 
39 macro_rules! uwrite {
40     ($($arg:tt)*) => {
41         write!($($arg)*).unwrap()
42     };
43 }
44 
45 #[derive(Debug, Copy, Clone, PartialEq, Eq)]
46 enum CoreType {
47     I32,
48     I64,
49     F32,
50     F64,
51 }
52 
53 impl CoreType {
54     /// This is the `join` operation specified in [the canonical
55     /// ABI](https://github.com/WebAssembly/component-model/blob/main/design/mvp/CanonicalABI.md#flattening) for
56     /// variant types.
join(self, other: Self) -> Self57     fn join(self, other: Self) -> Self {
58         match (self, other) {
59             _ if self == other => self,
60             (Self::I32, Self::F32) | (Self::F32, Self::I32) => Self::I32,
61             _ => Self::I64,
62         }
63     }
64 }
65 
66 impl fmt::Display for CoreType {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result67     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
68         match self {
69             Self::I32 => f.write_str("i32"),
70             Self::I64 => f.write_str("i64"),
71             Self::F32 => f.write_str("f32"),
72             Self::F64 => f.write_str("f64"),
73         }
74     }
75 }
76 
77 /// Wraps a `Box<[T]>` and provides an `Arbitrary` implementation that always generates slices of length less than
78 /// or equal to the longest tuple for which Wasmtime generates a `ComponentType` impl
79 #[derive(Debug, Clone)]
80 pub struct VecInRange<T, const L: u32, const H: u32>(Vec<T>);
81 
82 impl<T, const L: u32, const H: u32> VecInRange<T, L, H> {
new<'a>( input: &mut Unstructured<'a>, fuel: &mut u32, generate: impl Fn(&mut Unstructured<'a>, &mut u32) -> arbitrary::Result<T>, ) -> arbitrary::Result<Self>83     fn new<'a>(
84         input: &mut Unstructured<'a>,
85         fuel: &mut u32,
86         generate: impl Fn(&mut Unstructured<'a>, &mut u32) -> arbitrary::Result<T>,
87     ) -> arbitrary::Result<Self> {
88         let mut ret = Vec::new();
89         input.arbitrary_loop(Some(L), Some(H), |input| {
90             if *fuel > 0 {
91                 *fuel = *fuel - 1;
92                 ret.push(generate(input, fuel)?);
93                 Ok(std::ops::ControlFlow::Continue(()))
94             } else {
95                 Ok(std::ops::ControlFlow::Break(()))
96             }
97         })?;
98         Ok(Self(ret))
99     }
100 }
101 
102 impl<T, const L: u32, const H: u32> Deref for VecInRange<T, L, H> {
103     type Target = [T];
104 
deref(&self) -> &[T]105     fn deref(&self) -> &[T] {
106         self.0.deref()
107     }
108 }
109 
110 /// Represents a component model interface type
111 #[expect(missing_docs, reason = "self-describing")]
112 #[derive(Debug, Clone)]
113 pub enum Type {
114     Bool,
115     S8,
116     U8,
117     S16,
118     U16,
119     S32,
120     U32,
121     S64,
122     U64,
123     Float32,
124     Float64,
125     Char,
126     String,
127     List(Box<Type>),
128     Map(Box<Type>, Box<Type>),
129 
130     // Give records the ability to generate a generous amount of fields but
131     // don't let the fuzzer go too wild since `wasmparser`'s validator currently
132     // has hard limits in the 1000-ish range on the number of fields a record
133     // may contain.
134     Record(VecInRange<Type, 1, 200>),
135 
136     // Tuples can only have up to 16 type parameters in wasmtime right now for
137     // the static API, but the standard library only supports `Debug` up to 11
138     // elements, so compromise at an even 10.
139     Tuple(VecInRange<Type, 1, 10>),
140 
141     // Like records, allow a good number of variants, but variants require at
142     // least one case.
143     Variant(VecInRange<Option<Type>, 1, 200>),
144     Enum(u32),
145 
146     Option(Box<Type>),
147     Result {
148         ok: Option<Box<Type>>,
149         err: Option<Box<Type>>,
150     },
151 
152     Flags(u32),
153 }
154 
155 impl Type {
generate( u: &mut Unstructured<'_>, depth: u32, fuel: &mut u32, ) -> arbitrary::Result<Type>156     pub fn generate(
157         u: &mut Unstructured<'_>,
158         depth: u32,
159         fuel: &mut u32,
160     ) -> arbitrary::Result<Type> {
161         *fuel = fuel.saturating_sub(1);
162         let max = if depth == 0 || *fuel == 0 { 12 } else { 21 };
163         Ok(match u.int_in_range(0..=max)? {
164             0 => Type::Bool,
165             1 => Type::S8,
166             2 => Type::U8,
167             3 => Type::S16,
168             4 => Type::U16,
169             5 => Type::S32,
170             6 => Type::U32,
171             7 => Type::S64,
172             8 => Type::U64,
173             9 => Type::Float32,
174             10 => Type::Float64,
175             11 => Type::Char,
176             12 => Type::String,
177             // ^-- if you add something here update the `depth == 0` case above
178             13 => Type::List(Box::new(Type::generate(u, depth - 1, fuel)?)),
179             14 => Type::Record(Type::generate_list(u, depth - 1, fuel)?),
180             15 => Type::Tuple(Type::generate_list(u, depth - 1, fuel)?),
181             16 => Type::Variant(VecInRange::new(u, fuel, |u, fuel| {
182                 Type::generate_opt(u, depth - 1, fuel)
183             })?),
184             17 => {
185                 let amt = u.int_in_range(1..=(*fuel).max(1).min(257))?;
186                 *fuel -= amt;
187                 Type::Enum(amt)
188             }
189             18 => Type::Option(Box::new(Type::generate(u, depth - 1, fuel)?)),
190             19 => Type::Result {
191                 ok: Type::generate_opt(u, depth - 1, fuel)?.map(Box::new),
192                 err: Type::generate_opt(u, depth - 1, fuel)?.map(Box::new),
193             },
194             20 => {
195                 let amt = u.int_in_range(1..=(*fuel).min(32))?;
196                 *fuel -= amt;
197                 Type::Flags(amt)
198             }
199             21 => Type::Map(
200                 Box::new(Type::generate_hashable_key(u, fuel)?),
201                 Box::new(Type::generate(u, depth - 1, fuel)?),
202             ),
203             // ^-- if you add something here update the `depth != 0` case above
204             _ => unreachable!(),
205         })
206     }
207 
208     /// Generate a type that can be used as a HashMap key (implements Hash + Eq).
209     /// This excludes floats and complex types that might contain floats.
generate_hashable_key(u: &mut Unstructured<'_>, fuel: &mut u32) -> arbitrary::Result<Type>210     fn generate_hashable_key(u: &mut Unstructured<'_>, fuel: &mut u32) -> arbitrary::Result<Type> {
211         *fuel = fuel.saturating_sub(1);
212         // Only generate types that implement Hash and Eq:
213         // - No Float32/Float64 (NaN comparison issues)
214         // - No complex types (Record, Tuple, Variant, etc.) as they might contain floats
215         // - String is allowed as it implements Hash + Eq
216         Ok(match u.int_in_range(0..=11)? {
217             0 => Type::Bool,
218             1 => Type::S8,
219             2 => Type::U8,
220             3 => Type::S16,
221             4 => Type::U16,
222             5 => Type::S32,
223             6 => Type::U32,
224             7 => Type::S64,
225             8 => Type::U64,
226             9 => Type::Char,
227             10 => Type::String,
228             11 => {
229                 let amt = u.int_in_range(1..=(*fuel).max(1).min(257))?;
230                 *fuel = fuel.saturating_sub(amt);
231                 Type::Enum(amt)
232             }
233             _ => unreachable!(),
234         })
235     }
236 
generate_opt( u: &mut Unstructured<'_>, depth: u32, fuel: &mut u32, ) -> arbitrary::Result<Option<Type>>237     fn generate_opt(
238         u: &mut Unstructured<'_>,
239         depth: u32,
240         fuel: &mut u32,
241     ) -> arbitrary::Result<Option<Type>> {
242         Ok(if u.arbitrary()? {
243             Some(Type::generate(u, depth, fuel)?)
244         } else {
245             None
246         })
247     }
248 
generate_list<const L: u32, const H: u32>( u: &mut Unstructured<'_>, depth: u32, fuel: &mut u32, ) -> arbitrary::Result<VecInRange<Type, L, H>>249     fn generate_list<const L: u32, const H: u32>(
250         u: &mut Unstructured<'_>,
251         depth: u32,
252         fuel: &mut u32,
253     ) -> arbitrary::Result<VecInRange<Type, L, H>> {
254         VecInRange::new(u, fuel, |u, fuel| Type::generate(u, depth, fuel))
255     }
256 
257     /// Generates text format wasm into `s` to store a value of this type, in
258     /// its flat representation stored in the `locals` provided, to the local
259     /// named `ptr` at the `offset` provided.
260     ///
261     /// This will register helper functions necessary in `helpers`. The
262     /// `locals` iterator will be advanced for all locals consumed by this
263     /// store operation.
store_flat<'a>( &'a self, s: &mut String, ptr: &str, offset: u32, locals: &mut dyn Iterator<Item = FlatSource>, helpers: &mut IndexSet<Helper<'a>>, )264     fn store_flat<'a>(
265         &'a self,
266         s: &mut String,
267         ptr: &str,
268         offset: u32,
269         locals: &mut dyn Iterator<Item = FlatSource>,
270         helpers: &mut IndexSet<Helper<'a>>,
271     ) {
272         enum Kind {
273             Primitive(&'static str),
274             PointerPair,
275             Helper,
276         }
277         let kind = match self {
278             Type::Bool | Type::S8 | Type::U8 => Kind::Primitive("i32.store8"),
279             Type::S16 | Type::U16 => Kind::Primitive("i32.store16"),
280             Type::S32 | Type::U32 | Type::Char => Kind::Primitive("i32.store"),
281             Type::S64 | Type::U64 => Kind::Primitive("i64.store"),
282             Type::Float32 => Kind::Primitive("f32.store"),
283             Type::Float64 => Kind::Primitive("f64.store"),
284             Type::String | Type::List(_) | Type::Map(_, _) => Kind::PointerPair,
285             Type::Enum(n) if *n <= (1 << 8) => Kind::Primitive("i32.store8"),
286             Type::Enum(n) if *n <= (1 << 16) => Kind::Primitive("i32.store16"),
287             Type::Enum(_) => Kind::Primitive("i32.store"),
288             Type::Flags(n) if *n <= 8 => Kind::Primitive("i32.store8"),
289             Type::Flags(n) if *n <= 16 => Kind::Primitive("i32.store16"),
290             Type::Flags(n) if *n <= 32 => Kind::Primitive("i32.store"),
291             Type::Flags(_) => unreachable!(),
292             Type::Record(_)
293             | Type::Tuple(_)
294             | Type::Variant(_)
295             | Type::Option(_)
296             | Type::Result { .. } => Kind::Helper,
297         };
298 
299         match kind {
300             Kind::Primitive(op) => uwriteln!(
301                 s,
302                 "({op} offset={offset} (local.get {ptr}) {})",
303                 locals.next().unwrap()
304             ),
305             Kind::PointerPair => {
306                 let abi_ptr = locals.next().unwrap();
307                 let abi_len = locals.next().unwrap();
308                 uwriteln!(s, "(i32.store offset={offset} (local.get {ptr}) {abi_ptr})",);
309                 let offset = offset + 4;
310                 uwriteln!(s, "(i32.store offset={offset} (local.get {ptr}) {abi_len})",);
311             }
312             Kind::Helper => {
313                 let (index, _) = helpers.insert_full(Helper(self));
314                 uwriteln!(s, "(i32.add (local.get {ptr}) (i32.const {offset}))");
315                 for _ in 0..self.lowered().len() {
316                     let i = locals.next().unwrap();
317                     uwriteln!(s, "{i}");
318                 }
319                 uwriteln!(s, "call $store_helper_{index}");
320             }
321         }
322     }
323 
324     /// Generates a text-format wasm function which takes a pointer and this
325     /// type's flat representation as arguments and then stores this value in
326     /// the first argument.
327     ///
328     /// This is used to store records/variants to cut down on the size of final
329     /// functions and make codegen here a bit easier.
store_flat_helper<'a>( &'a self, s: &mut String, i: usize, helpers: &mut IndexSet<Helper<'a>>, )330     fn store_flat_helper<'a>(
331         &'a self,
332         s: &mut String,
333         i: usize,
334         helpers: &mut IndexSet<Helper<'a>>,
335     ) {
336         uwrite!(s, "(func $store_helper_{i} (param i32)");
337         let lowered = self.lowered();
338         for ty in &lowered {
339             uwrite!(s, " (param {ty})");
340         }
341         s.push_str("\n");
342         let locals = (0..lowered.len() as u32).map(|i| i + 1).collect::<Vec<_>>();
343         let record = |s: &mut String, helpers: &mut IndexSet<Helper<'a>>, types: &'a [Type]| {
344             let mut locals = locals.iter().cloned().map(FlatSource::Local);
345             for (offset, ty) in record_field_offsets(types) {
346                 ty.store_flat(s, "0", offset, &mut locals, helpers);
347             }
348             assert!(locals.next().is_none());
349         };
350         let variant = |s: &mut String,
351                        helpers: &mut IndexSet<Helper<'a>>,
352                        types: &[Option<&'a Type>]| {
353             let (size, offset) = variant_memory_info(types.iter().cloned());
354             // One extra block for out-of-bounds discriminants.
355             for _ in 0..types.len() + 1 {
356                 s.push_str("block\n");
357             }
358 
359             // Store the discriminant in memory, then branch on it to figure
360             // out which case we're in.
361             let store = match size {
362                 DiscriminantSize::Size1 => "i32.store8",
363                 DiscriminantSize::Size2 => "i32.store16",
364                 DiscriminantSize::Size4 => "i32.store",
365             };
366             uwriteln!(s, "({store} (local.get 0) (local.get 1))");
367             s.push_str("local.get 1\n");
368             s.push_str("br_table");
369             for i in 0..types.len() + 1 {
370                 uwrite!(s, " {i}");
371             }
372             s.push_str("\nend\n");
373 
374             // Store each payload individually while converting locals from
375             // their source types to the precise type necessary for this
376             // variant.
377             for ty in types {
378                 if let Some(ty) = ty {
379                     let ty_lowered = ty.lowered();
380                     let mut locals = locals[1..].iter().zip(&lowered[1..]).zip(&ty_lowered).map(
381                         |((i, from), to)| FlatSource::LocalConvert {
382                             local: *i,
383                             from: *from,
384                             to: *to,
385                         },
386                     );
387                     ty.store_flat(s, "0", offset, &mut locals, helpers);
388                 }
389                 s.push_str("return\n");
390                 s.push_str("end\n");
391             }
392 
393             // Catch-all result which is for out-of-bounds discriminants.
394             s.push_str("unreachable\n");
395         };
396         match self {
397             Type::Bool
398             | Type::S8
399             | Type::U8
400             | Type::S16
401             | Type::U16
402             | Type::S32
403             | Type::U32
404             | Type::Char
405             | Type::S64
406             | Type::U64
407             | Type::Float32
408             | Type::Float64
409             | Type::String
410             | Type::List(_)
411             | Type::Map(_, _)
412             | Type::Flags(_)
413             | Type::Enum(_) => unreachable!(),
414 
415             Type::Record(r) => record(s, helpers, r),
416             Type::Tuple(t) => record(s, helpers, t),
417             Type::Variant(v) => variant(
418                 s,
419                 helpers,
420                 &v.iter().map(|t| t.as_ref()).collect::<Vec<_>>(),
421             ),
422             Type::Option(o) => variant(s, helpers, &[None, Some(&**o)]),
423             Type::Result { ok, err } => variant(s, helpers, &[ok.as_deref(), err.as_deref()]),
424         };
425         s.push_str(")\n");
426     }
427 
428     /// Same as `store_flat`, except loads the flat values from `ptr+offset`.
429     ///
430     /// Results are placed directly on the wasm stack.
load_flat<'a>( &'a self, s: &mut String, ptr: &str, offset: u32, helpers: &mut IndexSet<Helper<'a>>, )431     fn load_flat<'a>(
432         &'a self,
433         s: &mut String,
434         ptr: &str,
435         offset: u32,
436         helpers: &mut IndexSet<Helper<'a>>,
437     ) {
438         enum Kind {
439             Primitive(&'static str),
440             PointerPair,
441             Helper,
442         }
443         let kind = match self {
444             Type::Bool | Type::U8 => Kind::Primitive("i32.load8_u"),
445             Type::S8 => Kind::Primitive("i32.load8_s"),
446             Type::U16 => Kind::Primitive("i32.load16_u"),
447             Type::S16 => Kind::Primitive("i32.load16_s"),
448             Type::U32 | Type::S32 | Type::Char => Kind::Primitive("i32.load"),
449             Type::U64 | Type::S64 => Kind::Primitive("i64.load"),
450             Type::Float32 => Kind::Primitive("f32.load"),
451             Type::Float64 => Kind::Primitive("f64.load"),
452             Type::String | Type::List(_) | Type::Map(_, _) => Kind::PointerPair,
453             Type::Enum(n) if *n <= (1 << 8) => Kind::Primitive("i32.load8_u"),
454             Type::Enum(n) if *n <= (1 << 16) => Kind::Primitive("i32.load16_u"),
455             Type::Enum(_) => Kind::Primitive("i32.load"),
456             Type::Flags(n) if *n <= 8 => Kind::Primitive("i32.load8_u"),
457             Type::Flags(n) if *n <= 16 => Kind::Primitive("i32.load16_u"),
458             Type::Flags(n) if *n <= 32 => Kind::Primitive("i32.load"),
459             Type::Flags(_) => unreachable!(),
460 
461             Type::Record(_)
462             | Type::Tuple(_)
463             | Type::Variant(_)
464             | Type::Option(_)
465             | Type::Result { .. } => Kind::Helper,
466         };
467         match kind {
468             Kind::Primitive(op) => uwriteln!(s, "({op} offset={offset} (local.get {ptr}))"),
469             Kind::PointerPair => {
470                 uwriteln!(s, "(i32.load offset={offset} (local.get {ptr}))",);
471                 let offset = offset + 4;
472                 uwriteln!(s, "(i32.load offset={offset} (local.get {ptr}))",);
473             }
474             Kind::Helper => {
475                 let (index, _) = helpers.insert_full(Helper(self));
476                 uwriteln!(s, "(i32.add (local.get {ptr}) (i32.const {offset}))");
477                 uwriteln!(s, "call $load_helper_{index}");
478             }
479         }
480     }
481 
482     /// Same as `store_flat_helper` but for loading the flat representation.
load_flat_helper<'a>( &'a self, s: &mut String, i: usize, helpers: &mut IndexSet<Helper<'a>>, )483     fn load_flat_helper<'a>(
484         &'a self,
485         s: &mut String,
486         i: usize,
487         helpers: &mut IndexSet<Helper<'a>>,
488     ) {
489         uwrite!(s, "(func $load_helper_{i} (param i32)");
490         let lowered = self.lowered();
491         for ty in &lowered {
492             uwrite!(s, " (result {ty})");
493         }
494         s.push_str("\n");
495         let record = |s: &mut String, helpers: &mut IndexSet<Helper<'a>>, types: &'a [Type]| {
496             for (offset, ty) in record_field_offsets(types) {
497                 ty.load_flat(s, "0", offset, helpers);
498             }
499         };
500         let variant = |s: &mut String,
501                        helpers: &mut IndexSet<Helper<'a>>,
502                        types: &[Option<&'a Type>]| {
503             let (size, offset) = variant_memory_info(types.iter().cloned());
504 
505             // Destination locals where the flat representation will be stored.
506             // These are automatically zero which handles unused fields too.
507             for (i, ty) in lowered.iter().enumerate() {
508                 uwriteln!(s, " (local $r{i} {ty})");
509             }
510 
511             // Return block each case jumps to after setting all locals.
512             s.push_str("block $r\n");
513 
514             // One extra block for "out of bounds discriminant".
515             for _ in 0..types.len() + 1 {
516                 s.push_str("block\n");
517             }
518 
519             // Load the discriminant and branch on it, storing it in
520             // `$r0` as well which is the first flat local representation.
521             let load = match size {
522                 DiscriminantSize::Size1 => "i32.load8_u",
523                 DiscriminantSize::Size2 => "i32.load16",
524                 DiscriminantSize::Size4 => "i32.load",
525             };
526             uwriteln!(s, "({load} (local.get 0))");
527             s.push_str("local.tee $r0\n");
528             s.push_str("br_table");
529             for i in 0..types.len() + 1 {
530                 uwrite!(s, " {i}");
531             }
532             s.push_str("\nend\n");
533 
534             // For each payload, which is in its own block, load payloads from
535             // memory as necessary and convert them into the final locals.
536             for ty in types {
537                 if let Some(ty) = ty {
538                     let ty_lowered = ty.lowered();
539                     ty.load_flat(s, "0", offset, helpers);
540                     for (i, (from, to)) in ty_lowered.iter().zip(&lowered[1..]).enumerate().rev() {
541                         let i = i + 1;
542                         match (from, to) {
543                             (CoreType::F32, CoreType::I32) => {
544                                 s.push_str("i32.reinterpret_f32\n");
545                             }
546                             (CoreType::I32, CoreType::I64) => {
547                                 s.push_str("i64.extend_i32_u\n");
548                             }
549                             (CoreType::F32, CoreType::I64) => {
550                                 s.push_str("i32.reinterpret_f32\n");
551                                 s.push_str("i64.extend_i32_u\n");
552                             }
553                             (CoreType::F64, CoreType::I64) => {
554                                 s.push_str("i64.reinterpret_f64\n");
555                             }
556                             (a, b) if a == b => {}
557                             _ => unimplemented!("convert {from:?} to {to:?}"),
558                         }
559                         uwriteln!(s, "local.set $r{i}");
560                     }
561                 }
562                 s.push_str("br $r\n");
563                 s.push_str("end\n");
564             }
565 
566             // The catch-all block for out-of-bounds discriminants.
567             s.push_str("unreachable\n");
568             s.push_str("end\n");
569             for i in 0..lowered.len() {
570                 uwriteln!(s, " local.get $r{i}");
571             }
572         };
573 
574         match self {
575             Type::Bool
576             | Type::S8
577             | Type::U8
578             | Type::S16
579             | Type::U16
580             | Type::S32
581             | Type::U32
582             | Type::Char
583             | Type::S64
584             | Type::U64
585             | Type::Float32
586             | Type::Float64
587             | Type::String
588             | Type::List(_)
589             | Type::Map(_, _)
590             | Type::Flags(_)
591             | Type::Enum(_) => unreachable!(),
592 
593             Type::Record(r) => record(s, helpers, r),
594             Type::Tuple(t) => record(s, helpers, t),
595             Type::Variant(v) => variant(
596                 s,
597                 helpers,
598                 &v.iter().map(|t| t.as_ref()).collect::<Vec<_>>(),
599             ),
600             Type::Option(o) => variant(s, helpers, &[None, Some(&**o)]),
601             Type::Result { ok, err } => variant(s, helpers, &[ok.as_deref(), err.as_deref()]),
602         };
603         s.push_str(")\n");
604     }
605 }
606 
607 #[derive(Clone)]
608 enum FlatSource {
609     Local(u32),
610     LocalConvert {
611         local: u32,
612         from: CoreType,
613         to: CoreType,
614     },
615 }
616 
617 impl fmt::Display for FlatSource {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result618     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
619         match self {
620             FlatSource::Local(i) => write!(f, "(local.get {i})"),
621             FlatSource::LocalConvert { local, from, to } => {
622                 match (from, to) {
623                     (a, b) if a == b => write!(f, "(local.get {local})"),
624                     (CoreType::I32, CoreType::F32) => {
625                         write!(f, "(f32.reinterpret_i32 (local.get {local}))")
626                     }
627                     (CoreType::I64, CoreType::I32) => {
628                         write!(f, "(i32.wrap_i64 (local.get {local}))")
629                     }
630                     (CoreType::I64, CoreType::F64) => {
631                         write!(f, "(f64.reinterpret_i64 (local.get {local}))")
632                     }
633                     (CoreType::I64, CoreType::F32) => {
634                         write!(
635                             f,
636                             "(f32.reinterpret_i32 (i32.wrap_i64 (local.get {local})))"
637                         )
638                     }
639                     _ => unimplemented!("convert {from:?} to {to:?}"),
640                 }
641                 // ..
642             }
643         }
644     }
645 }
646 
lower_record<'a>(types: impl Iterator<Item = &'a Type>, vec: &mut Vec<CoreType>)647 fn lower_record<'a>(types: impl Iterator<Item = &'a Type>, vec: &mut Vec<CoreType>) {
648     for ty in types {
649         ty.lower(vec);
650     }
651 }
652 
lower_variant<'a>(types: impl Iterator<Item = Option<&'a Type>>, vec: &mut Vec<CoreType>)653 fn lower_variant<'a>(types: impl Iterator<Item = Option<&'a Type>>, vec: &mut Vec<CoreType>) {
654     vec.push(CoreType::I32);
655     let offset = vec.len();
656     for ty in types {
657         let ty = match ty {
658             Some(ty) => ty,
659             None => continue,
660         };
661         for (index, ty) in ty.lowered().iter().enumerate() {
662             let index = offset + index;
663             if index < vec.len() {
664                 vec[index] = vec[index].join(*ty);
665             } else {
666                 vec.push(*ty)
667             }
668         }
669     }
670 }
671 
u32_count_from_flag_count(count: usize) -> usize672 fn u32_count_from_flag_count(count: usize) -> usize {
673     match FlagsSize::from_count(count) {
674         FlagsSize::Size0 => 0,
675         FlagsSize::Size1 | FlagsSize::Size2 => 1,
676         FlagsSize::Size4Plus(n) => n.into(),
677     }
678 }
679 
680 struct SizeAndAlignment {
681     size: usize,
682     alignment: u32,
683 }
684 
685 impl Type {
lowered(&self) -> Vec<CoreType>686     fn lowered(&self) -> Vec<CoreType> {
687         let mut vec = Vec::new();
688         self.lower(&mut vec);
689         vec
690     }
691 
lower(&self, vec: &mut Vec<CoreType>)692     fn lower(&self, vec: &mut Vec<CoreType>) {
693         match self {
694             Type::Bool
695             | Type::U8
696             | Type::S8
697             | Type::S16
698             | Type::U16
699             | Type::S32
700             | Type::U32
701             | Type::Char
702             | Type::Enum(_) => vec.push(CoreType::I32),
703             Type::S64 | Type::U64 => vec.push(CoreType::I64),
704             Type::Float32 => vec.push(CoreType::F32),
705             Type::Float64 => vec.push(CoreType::F64),
706             Type::String | Type::List(_) | Type::Map(_, _) => {
707                 vec.push(CoreType::I32);
708                 vec.push(CoreType::I32);
709             }
710             Type::Record(types) => lower_record(types.iter(), vec),
711             Type::Tuple(types) => lower_record(types.0.iter(), vec),
712             Type::Variant(types) => lower_variant(types.0.iter().map(|t| t.as_ref()), vec),
713             Type::Option(ty) => lower_variant([None, Some(&**ty)].into_iter(), vec),
714             Type::Result { ok, err } => {
715                 lower_variant([ok.as_deref(), err.as_deref()].into_iter(), vec)
716             }
717             Type::Flags(count) => vec.extend(
718                 iter::repeat(CoreType::I32).take(u32_count_from_flag_count(*count as usize)),
719             ),
720         }
721     }
722 
size_and_alignment(&self) -> SizeAndAlignment723     fn size_and_alignment(&self) -> SizeAndAlignment {
724         match self {
725             Type::Bool | Type::S8 | Type::U8 => SizeAndAlignment {
726                 size: 1,
727                 alignment: 1,
728             },
729 
730             Type::S16 | Type::U16 => SizeAndAlignment {
731                 size: 2,
732                 alignment: 2,
733             },
734 
735             Type::S32 | Type::U32 | Type::Char | Type::Float32 => SizeAndAlignment {
736                 size: 4,
737                 alignment: 4,
738             },
739 
740             Type::S64 | Type::U64 | Type::Float64 => SizeAndAlignment {
741                 size: 8,
742                 alignment: 8,
743             },
744 
745             Type::String | Type::List(_) | Type::Map(_, _) => SizeAndAlignment {
746                 size: 8,
747                 alignment: 4,
748             },
749 
750             Type::Record(types) => record_size_and_alignment(types.iter()),
751 
752             Type::Tuple(types) => record_size_and_alignment(types.0.iter()),
753 
754             Type::Variant(types) => variant_size_and_alignment(types.0.iter().map(|t| t.as_ref())),
755 
756             Type::Enum(count) => variant_size_and_alignment((0..*count).map(|_| None)),
757 
758             Type::Option(ty) => variant_size_and_alignment([None, Some(&**ty)].into_iter()),
759 
760             Type::Result { ok, err } => {
761                 variant_size_and_alignment([ok.as_deref(), err.as_deref()].into_iter())
762             }
763 
764             Type::Flags(count) => match FlagsSize::from_count(*count as usize) {
765                 FlagsSize::Size0 => SizeAndAlignment {
766                     size: 0,
767                     alignment: 1,
768                 },
769                 FlagsSize::Size1 => SizeAndAlignment {
770                     size: 1,
771                     alignment: 1,
772                 },
773                 FlagsSize::Size2 => SizeAndAlignment {
774                     size: 2,
775                     alignment: 2,
776                 },
777                 FlagsSize::Size4Plus(n) => SizeAndAlignment {
778                     size: usize::from(n) * 4,
779                     alignment: 4,
780                 },
781             },
782         }
783     }
784 }
785 
align_to(a: usize, align: u32) -> usize786 fn align_to(a: usize, align: u32) -> usize {
787     let align = align as usize;
788     (a + (align - 1)) & !(align - 1)
789 }
790 
record_field_offsets<'a>( types: impl IntoIterator<Item = &'a Type>, ) -> impl Iterator<Item = (u32, &'a Type)>791 fn record_field_offsets<'a>(
792     types: impl IntoIterator<Item = &'a Type>,
793 ) -> impl Iterator<Item = (u32, &'a Type)> {
794     let mut offset = 0;
795     types.into_iter().map(move |ty| {
796         let SizeAndAlignment { size, alignment } = ty.size_and_alignment();
797         let ret = align_to(offset, alignment);
798         offset = ret + size;
799         (ret as u32, ty)
800     })
801 }
802 
record_size_and_alignment<'a>(types: impl IntoIterator<Item = &'a Type>) -> SizeAndAlignment803 fn record_size_and_alignment<'a>(types: impl IntoIterator<Item = &'a Type>) -> SizeAndAlignment {
804     let mut offset = 0;
805     let mut align = 1;
806     for ty in types {
807         let SizeAndAlignment { size, alignment } = ty.size_and_alignment();
808         offset = align_to(offset, alignment) + size;
809         align = align.max(alignment);
810     }
811 
812     SizeAndAlignment {
813         size: align_to(offset, align),
814         alignment: align,
815     }
816 }
817 
variant_size_and_alignment<'a>( types: impl ExactSizeIterator<Item = Option<&'a Type>>, ) -> SizeAndAlignment818 fn variant_size_and_alignment<'a>(
819     types: impl ExactSizeIterator<Item = Option<&'a Type>>,
820 ) -> SizeAndAlignment {
821     let discriminant_size = DiscriminantSize::from_count(types.len()).unwrap();
822     let mut alignment = u32::from(discriminant_size);
823     let mut size = 0;
824     for ty in types {
825         if let Some(ty) = ty {
826             let size_and_alignment = ty.size_and_alignment();
827             alignment = alignment.max(size_and_alignment.alignment);
828             size = size.max(size_and_alignment.size);
829         }
830     }
831 
832     SizeAndAlignment {
833         size: align_to(
834             align_to(usize::from(discriminant_size), alignment) + size,
835             alignment,
836         ),
837         alignment,
838     }
839 }
840 
variant_memory_info<'a>( types: impl ExactSizeIterator<Item = Option<&'a Type>>, ) -> (DiscriminantSize, u32)841 fn variant_memory_info<'a>(
842     types: impl ExactSizeIterator<Item = Option<&'a Type>>,
843 ) -> (DiscriminantSize, u32) {
844     let discriminant_size = DiscriminantSize::from_count(types.len()).unwrap();
845     let mut alignment = u32::from(discriminant_size);
846     for ty in types {
847         if let Some(ty) = ty {
848             let size_and_alignment = ty.size_and_alignment();
849             alignment = alignment.max(size_and_alignment.alignment);
850         }
851     }
852 
853     (
854         discriminant_size,
855         align_to(usize::from(discriminant_size), alignment) as u32,
856     )
857 }
858 
859 /// Generates the internals of a core wasm module which imports a single
860 /// component function `IMPORT_FUNCTION` and exports a single component
861 /// function `EXPORT_FUNCTION`.
862 ///
863 /// The component function takes `params` as arguments and optionally returns
864 /// `result`. The `lift_abi` and `lower_abi` fields indicate the ABI in-use for
865 /// this operation.
make_import_and_export( params: &[&Type], result: Option<&Type>, lift_abi: LiftAbi, lower_abi: LowerAbi, ) -> String866 fn make_import_and_export(
867     params: &[&Type],
868     result: Option<&Type>,
869     lift_abi: LiftAbi,
870     lower_abi: LowerAbi,
871 ) -> String {
872     let params_lowered = params
873         .iter()
874         .flat_map(|ty| ty.lowered())
875         .collect::<Box<[_]>>();
876     let result_lowered = result.map(|t| t.lowered()).unwrap_or(Vec::new());
877 
878     let mut wat = String::new();
879 
880     enum Location {
881         Flat,
882         Indirect(u32),
883     }
884 
885     // Generate the core wasm type corresponding to the imported function being
886     // lowered with `lower_abi`.
887     wat.push_str(&format!("(type $import (func"));
888     let max_import_params = match lower_abi {
889         LowerAbi::Sync => MAX_FLAT_PARAMS,
890         LowerAbi::Async => MAX_FLAT_ASYNC_PARAMS,
891     };
892     let (import_params_loc, nparams) = push_params(&mut wat, &params_lowered, max_import_params);
893     let import_results_loc = match lower_abi {
894         LowerAbi::Sync => {
895             push_result_or_retptr(&mut wat, &result_lowered, nparams, MAX_FLAT_RESULTS)
896         }
897         LowerAbi::Async => {
898             let loc = if result.is_none() {
899                 Location::Flat
900             } else {
901                 wat.push_str(" (param i32)"); // result pointer
902                 Location::Indirect(nparams)
903             };
904             wat.push_str(" (result i32)"); // status code
905             loc
906         }
907     };
908     wat.push_str("))\n");
909 
910     // Generate the import function.
911     wat.push_str(&format!(
912         r#"(import "host" "{IMPORT_FUNCTION}" (func $host (type $import)))"#
913     ));
914 
915     // Do the same as above for the exported function's type which is lifted
916     // with `lift_abi`.
917     //
918     // Note that `export_results_loc` being `None` means that `task.return` is
919     // used to communicate results.
920     wat.push_str(&format!("(type $export (func"));
921     let (export_params_loc, _nparams) = push_params(&mut wat, &params_lowered, MAX_FLAT_PARAMS);
922     let export_results_loc = match lift_abi {
923         LiftAbi::Sync => Some(push_group(&mut wat, "result", &result_lowered, MAX_FLAT_RESULTS).0),
924         LiftAbi::AsyncCallback => {
925             wat.push_str(" (result i32)"); // status code
926             None
927         }
928         LiftAbi::AsyncStackful => None,
929     };
930     wat.push_str("))\n");
931 
932     // If the export is async, generate `task.return` as an import as well
933     // which is necessary to communicate the results.
934     if export_results_loc.is_none() {
935         wat.push_str(&format!("(type $task.return (func"));
936         push_params(&mut wat, &result_lowered, MAX_FLAT_PARAMS);
937         wat.push_str("))\n");
938         wat.push_str(&format!(
939             r#"(import "" "task.return" (func $task.return (type $task.return)))"#
940         ));
941     }
942 
943     wat.push_str(&format!(
944         r#"
945 (func (export "{EXPORT_FUNCTION}") (type $export)
946     (local $retptr i32)
947     (local $argptr i32)
948         "#
949     ));
950     let mut store_helpers = IndexSet::new();
951     let mut load_helpers = IndexSet::new();
952 
953     match (export_params_loc, import_params_loc) {
954         // flat => flat is just moving locals around
955         (Location::Flat, Location::Flat) => {
956             for (index, _) in params_lowered.iter().enumerate() {
957                 uwrite!(wat, "local.get {index}\n");
958             }
959         }
960 
961         // indirect => indirect is just moving locals around
962         (Location::Indirect(i), Location::Indirect(j)) => {
963             assert_eq!(j, 0);
964             uwrite!(wat, "local.get {i}\n");
965         }
966 
967         // flat => indirect means that all parameters are stored in memory as
968         // if it was a record of all the parameters.
969         (Location::Flat, Location::Indirect(_)) => {
970             let SizeAndAlignment { size, alignment } =
971                 record_size_and_alignment(params.iter().cloned());
972             wat.push_str(&format!(
973                 r#"
974                     (local.set $argptr
975                         (call $realloc
976                             (i32.const 0)
977                             (i32.const 0)
978                             (i32.const {alignment})
979                             (i32.const {size})))
980                     local.get $argptr
981                 "#
982             ));
983             let mut locals = (0..params_lowered.len() as u32).map(FlatSource::Local);
984             for (offset, ty) in record_field_offsets(params.iter().cloned()) {
985                 ty.store_flat(&mut wat, "$argptr", offset, &mut locals, &mut store_helpers);
986             }
987             assert!(locals.next().is_none());
988         }
989 
990         (Location::Indirect(_), Location::Flat) => unreachable!(),
991     }
992 
993     // Pass a return-pointer if necessary.
994     match import_results_loc {
995         Location::Flat => {}
996         Location::Indirect(_) => {
997             let SizeAndAlignment { size, alignment } = result.unwrap().size_and_alignment();
998 
999             wat.push_str(&format!(
1000                 r#"
1001                     (local.set $retptr
1002                         (call $realloc
1003                             (i32.const 0)
1004                             (i32.const 0)
1005                             (i32.const {alignment})
1006                             (i32.const {size})))
1007                     local.get $retptr
1008                 "#
1009             ));
1010         }
1011     }
1012 
1013     wat.push_str("call $host\n");
1014 
1015     // Assert the lowered call is ready if an async code was returned.
1016     //
1017     // TODO: handle when the import isn't ready yet
1018     if let LowerAbi::Async = lower_abi {
1019         wat.push_str("i32.const 2\n");
1020         wat.push_str("i32.ne\n");
1021         wat.push_str("if unreachable end\n");
1022     }
1023 
1024     // TODO: conditionally inject a yield here
1025 
1026     match (import_results_loc, export_results_loc) {
1027         // flat => flat results involves nothing, the results are already on
1028         // the stack.
1029         (Location::Flat, Some(Location::Flat)) => {}
1030 
1031         // indirect => indirect results requires returning the `$retptr` the
1032         // host call filled in.
1033         (Location::Indirect(_), Some(Location::Indirect(_))) => {
1034             wat.push_str("local.get $retptr\n");
1035         }
1036 
1037         // indirect => flat requires loading the result from the return pointer
1038         (Location::Indirect(_), Some(Location::Flat)) => {
1039             result
1040                 .unwrap()
1041                 .load_flat(&mut wat, "$retptr", 0, &mut load_helpers);
1042         }
1043 
1044         // flat => task.return is easy, the results are already there so just
1045         // call the function.
1046         (Location::Flat, None) => {
1047             wat.push_str("call $task.return\n");
1048         }
1049 
1050         // indirect => task.return needs to forward `$retptr` if the results
1051         // are indirect, or otherwise it must be loaded from memory to a flat
1052         // representation.
1053         (Location::Indirect(_), None) => {
1054             if result_lowered.len() <= MAX_FLAT_PARAMS {
1055                 result
1056                     .unwrap()
1057                     .load_flat(&mut wat, "$retptr", 0, &mut load_helpers);
1058             } else {
1059                 wat.push_str("local.get $retptr\n");
1060             }
1061             wat.push_str("call $task.return\n");
1062         }
1063 
1064         (Location::Flat, Some(Location::Indirect(_))) => unreachable!(),
1065     }
1066 
1067     if let LiftAbi::AsyncCallback = lift_abi {
1068         wat.push_str("i32.const 0\n"); // completed status code
1069     }
1070 
1071     wat.push_str(")\n");
1072 
1073     // Generate a `callback` function for the callback ABI.
1074     //
1075     // TODO: fill this in
1076     if let LiftAbi::AsyncCallback = lift_abi {
1077         wat.push_str(
1078             r#"
1079 (func (export "callback") (param i32 i32 i32) (result i32) unreachable)
1080             "#,
1081         );
1082     }
1083 
1084     // Fill out all store/load helpers that were needed during generation
1085     // above. This is a fix-point-loop since each helper may end up requiring
1086     // more helpers.
1087     let mut i = 0;
1088     while i < store_helpers.len() {
1089         let ty = store_helpers[i].0;
1090         ty.store_flat_helper(&mut wat, i, &mut store_helpers);
1091         i += 1;
1092     }
1093     i = 0;
1094     while i < load_helpers.len() {
1095         let ty = load_helpers[i].0;
1096         ty.load_flat_helper(&mut wat, i, &mut load_helpers);
1097         i += 1;
1098     }
1099 
1100     return wat;
1101 
1102     fn push_params(wat: &mut String, params: &[CoreType], max_flat: usize) -> (Location, u32) {
1103         push_group(wat, "param", params, max_flat)
1104     }
1105 
1106     fn push_group(
1107         wat: &mut String,
1108         name: &str,
1109         params: &[CoreType],
1110         max_flat: usize,
1111     ) -> (Location, u32) {
1112         let mut nparams = 0;
1113         let loc = if params.is_empty() {
1114             // nothing to emit...
1115             Location::Flat
1116         } else if params.len() <= max_flat {
1117             wat.push_str(&format!(" ({name}"));
1118             for ty in params {
1119                 wat.push_str(&format!(" {ty}"));
1120                 nparams += 1;
1121             }
1122             wat.push_str(")");
1123             Location::Flat
1124         } else {
1125             wat.push_str(&format!(" ({name} i32)"));
1126             nparams += 1;
1127             Location::Indirect(0)
1128         };
1129         (loc, nparams)
1130     }
1131 
1132     fn push_result_or_retptr(
1133         wat: &mut String,
1134         results: &[CoreType],
1135         nparams: u32,
1136         max_flat: usize,
1137     ) -> Location {
1138         if results.is_empty() {
1139             // nothing to emit...
1140             Location::Flat
1141         } else if results.len() <= max_flat {
1142             wat.push_str(" (result");
1143             for ty in results {
1144                 wat.push_str(&format!(" {ty}"));
1145             }
1146             wat.push_str(")");
1147             Location::Flat
1148         } else {
1149             wat.push_str(" (param i32)");
1150             Location::Indirect(nparams)
1151         }
1152     }
1153 }
1154 
1155 struct Helper<'a>(&'a Type);
1156 
1157 impl Hash for Helper<'_> {
hash<H: Hasher>(&self, h: &mut H)1158     fn hash<H: Hasher>(&self, h: &mut H) {
1159         std::ptr::hash(self.0, h);
1160     }
1161 }
1162 
1163 impl PartialEq for Helper<'_> {
eq(&self, other: &Self) -> bool1164     fn eq(&self, other: &Self) -> bool {
1165         std::ptr::eq(self.0, other.0)
1166     }
1167 }
1168 
1169 impl Eq for Helper<'_> {}
1170 
make_rust_name(name_counter: &mut u32) -> Ident1171 fn make_rust_name(name_counter: &mut u32) -> Ident {
1172     let name = format_ident!("Foo{name_counter}");
1173     *name_counter += 1;
1174     name
1175 }
1176 
1177 /// Generate a [`TokenStream`] containing the rust type name for a type.
1178 ///
1179 /// The `name_counter` parameter is used to generate names for each recursively visited type.  The `declarations`
1180 /// parameter is used to accumulate declarations for each recursively visited type.
rust_type(ty: &Type, name_counter: &mut u32, declarations: &mut TokenStream) -> TokenStream1181 pub fn rust_type(ty: &Type, name_counter: &mut u32, declarations: &mut TokenStream) -> TokenStream {
1182     match ty {
1183         Type::Bool => quote!(bool),
1184         Type::S8 => quote!(i8),
1185         Type::U8 => quote!(u8),
1186         Type::S16 => quote!(i16),
1187         Type::U16 => quote!(u16),
1188         Type::S32 => quote!(i32),
1189         Type::U32 => quote!(u32),
1190         Type::S64 => quote!(i64),
1191         Type::U64 => quote!(u64),
1192         Type::Float32 => quote!(Float32),
1193         Type::Float64 => quote!(Float64),
1194         Type::Char => quote!(char),
1195         Type::String => quote!(Box<str>),
1196         Type::List(ty) => {
1197             let ty = rust_type(ty, name_counter, declarations);
1198             quote!(Vec<#ty>)
1199         }
1200         Type::Map(key_ty, value_ty) => {
1201             let key_ty = rust_type(key_ty, name_counter, declarations);
1202             let value_ty = rust_type(value_ty, name_counter, declarations);
1203             quote!(std::collections::HashMap<#key_ty, #value_ty>)
1204         }
1205         Type::Record(types) => {
1206             let fields = types
1207                 .iter()
1208                 .enumerate()
1209                 .map(|(index, ty)| {
1210                     let name = format_ident!("f{index}");
1211                     let ty = rust_type(ty, name_counter, declarations);
1212                     quote!(#name: #ty,)
1213                 })
1214                 .collect::<TokenStream>();
1215 
1216             let name = make_rust_name(name_counter);
1217 
1218             declarations.extend(quote! {
1219                 #[derive(ComponentType, Lift, Lower, PartialEq, Debug, Clone, Arbitrary)]
1220                 #[component(record)]
1221                 struct #name {
1222                     #fields
1223                 }
1224             });
1225 
1226             quote!(#name)
1227         }
1228         Type::Tuple(types) => {
1229             let fields = types
1230                 .0
1231                 .iter()
1232                 .map(|ty| {
1233                     let ty = rust_type(ty, name_counter, declarations);
1234                     quote!(#ty,)
1235                 })
1236                 .collect::<TokenStream>();
1237 
1238             quote!((#fields))
1239         }
1240         Type::Variant(types) => {
1241             let cases = types
1242                 .0
1243                 .iter()
1244                 .enumerate()
1245                 .map(|(index, ty)| {
1246                     let name = format_ident!("C{index}");
1247                     let ty = match ty {
1248                         Some(ty) => {
1249                             let ty = rust_type(ty, name_counter, declarations);
1250                             quote!((#ty))
1251                         }
1252                         None => quote!(),
1253                     };
1254                     quote!(#name #ty,)
1255                 })
1256                 .collect::<TokenStream>();
1257 
1258             let name = make_rust_name(name_counter);
1259             declarations.extend(quote! {
1260                 #[derive(ComponentType, Lift, Lower, PartialEq, Debug, Clone, Arbitrary)]
1261                 #[component(variant)]
1262                 enum #name {
1263                     #cases
1264                 }
1265             });
1266 
1267             quote!(#name)
1268         }
1269         Type::Enum(count) => {
1270             let cases = (0..*count)
1271                 .map(|index| {
1272                     let name = format_ident!("E{index}");
1273                     quote!(#name,)
1274                 })
1275                 .collect::<TokenStream>();
1276 
1277             let name = make_rust_name(name_counter);
1278             let repr = match DiscriminantSize::from_count(*count as usize).unwrap() {
1279                 DiscriminantSize::Size1 => quote!(u8),
1280                 DiscriminantSize::Size2 => quote!(u16),
1281                 DiscriminantSize::Size4 => quote!(u32),
1282             };
1283 
1284             declarations.extend(quote! {
1285                 #[derive(ComponentType, Lift, Lower, PartialEq, Eq, Hash, Debug, Copy, Clone, Arbitrary)]
1286                 #[component(enum)]
1287                 #[repr(#repr)]
1288                 enum #name {
1289                     #cases
1290                 }
1291             });
1292 
1293             quote!(#name)
1294         }
1295         Type::Option(ty) => {
1296             let ty = rust_type(ty, name_counter, declarations);
1297             quote!(Option<#ty>)
1298         }
1299         Type::Result { ok, err } => {
1300             let ok = match ok {
1301                 Some(ok) => rust_type(ok, name_counter, declarations),
1302                 None => quote!(()),
1303             };
1304             let err = match err {
1305                 Some(err) => rust_type(err, name_counter, declarations),
1306                 None => quote!(()),
1307             };
1308             quote!(Result<#ok, #err>)
1309         }
1310         Type::Flags(count) => {
1311             let type_name = make_rust_name(name_counter);
1312 
1313             let mut flags = TokenStream::new();
1314             let mut names = TokenStream::new();
1315 
1316             for index in 0..*count {
1317                 let name = format_ident!("F{index}");
1318                 flags.extend(quote!(const #name;));
1319                 names.extend(quote!(#type_name::#name,))
1320             }
1321 
1322             declarations.extend(quote! {
1323                 wasmtime::component::flags! {
1324                     #type_name {
1325                         #flags
1326                     }
1327                 }
1328 
1329                 impl<'a> arbitrary::Arbitrary<'a> for #type_name {
1330                     fn arbitrary(input: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
1331                         let mut flags = #type_name::default();
1332                         for flag in [#names] {
1333                             if input.arbitrary()? {
1334                                 flags |= flag;
1335                             }
1336                         }
1337                         Ok(flags)
1338                     }
1339                 }
1340             });
1341 
1342             quote!(#type_name)
1343         }
1344     }
1345 }
1346 
1347 #[derive(Default)]
1348 struct TypesBuilder<'a> {
1349     next: u32,
1350     worklist: Vec<(u32, &'a Type)>,
1351 }
1352 
1353 impl<'a> TypesBuilder<'a> {
write_ref(&mut self, ty: &'a Type, dst: &mut String)1354     fn write_ref(&mut self, ty: &'a Type, dst: &mut String) {
1355         match ty {
1356             // Primitive types can be referenced directly
1357             Type::Bool => dst.push_str("bool"),
1358             Type::S8 => dst.push_str("s8"),
1359             Type::U8 => dst.push_str("u8"),
1360             Type::S16 => dst.push_str("s16"),
1361             Type::U16 => dst.push_str("u16"),
1362             Type::S32 => dst.push_str("s32"),
1363             Type::U32 => dst.push_str("u32"),
1364             Type::S64 => dst.push_str("s64"),
1365             Type::U64 => dst.push_str("u64"),
1366             Type::Float32 => dst.push_str("float32"),
1367             Type::Float64 => dst.push_str("float64"),
1368             Type::Char => dst.push_str("char"),
1369             Type::String => dst.push_str("string"),
1370 
1371             // Otherwise emit a reference to the type and remember to generate
1372             // the corresponding type alias later.
1373             Type::List(_)
1374             | Type::Map(_, _)
1375             | Type::Record(_)
1376             | Type::Tuple(_)
1377             | Type::Variant(_)
1378             | Type::Enum(_)
1379             | Type::Option(_)
1380             | Type::Result { .. }
1381             | Type::Flags(_) => {
1382                 let idx = self.next;
1383                 self.next += 1;
1384                 uwrite!(dst, "$t{idx}");
1385                 self.worklist.push((idx, ty));
1386             }
1387         }
1388     }
1389 
write_decl(&mut self, idx: u32, ty: &'a Type) -> String1390     fn write_decl(&mut self, idx: u32, ty: &'a Type) -> String {
1391         let mut decl = format!("(type $t{idx}' ");
1392         match ty {
1393             Type::Bool
1394             | Type::S8
1395             | Type::U8
1396             | Type::S16
1397             | Type::U16
1398             | Type::S32
1399             | Type::U32
1400             | Type::S64
1401             | Type::U64
1402             | Type::Float32
1403             | Type::Float64
1404             | Type::Char
1405             | Type::String => unreachable!(),
1406 
1407             Type::List(ty) => {
1408                 decl.push_str("(list ");
1409                 self.write_ref(ty, &mut decl);
1410                 decl.push_str(")");
1411             }
1412             Type::Map(key_ty, value_ty) => {
1413                 decl.push_str("(map ");
1414                 self.write_ref(key_ty, &mut decl);
1415                 decl.push_str(" ");
1416                 self.write_ref(value_ty, &mut decl);
1417                 decl.push_str(")");
1418             }
1419             Type::Record(types) => {
1420                 decl.push_str("(record");
1421                 for (index, ty) in types.iter().enumerate() {
1422                     uwrite!(decl, r#" (field "f{index}" "#);
1423                     self.write_ref(ty, &mut decl);
1424                     decl.push_str(")");
1425                 }
1426                 decl.push_str(")");
1427             }
1428             Type::Tuple(types) => {
1429                 decl.push_str("(tuple");
1430                 for ty in types.iter() {
1431                     decl.push_str(" ");
1432                     self.write_ref(ty, &mut decl);
1433                 }
1434                 decl.push_str(")");
1435             }
1436             Type::Variant(types) => {
1437                 decl.push_str("(variant");
1438                 for (index, ty) in types.iter().enumerate() {
1439                     uwrite!(decl, r#" (case "C{index}""#);
1440                     if let Some(ty) = ty {
1441                         decl.push_str(" ");
1442                         self.write_ref(ty, &mut decl);
1443                     }
1444                     decl.push_str(")");
1445                 }
1446                 decl.push_str(")");
1447             }
1448             Type::Enum(count) => {
1449                 decl.push_str("(enum");
1450                 for index in 0..*count {
1451                     uwrite!(decl, r#" "E{index}""#);
1452                 }
1453                 decl.push_str(")");
1454             }
1455             Type::Option(ty) => {
1456                 decl.push_str("(option ");
1457                 self.write_ref(ty, &mut decl);
1458                 decl.push_str(")");
1459             }
1460             Type::Result { ok, err } => {
1461                 decl.push_str("(result");
1462                 if let Some(ok) = ok {
1463                     decl.push_str(" ");
1464                     self.write_ref(ok, &mut decl);
1465                 }
1466                 if let Some(err) = err {
1467                     decl.push_str(" (error ");
1468                     self.write_ref(err, &mut decl);
1469                     decl.push_str(")");
1470                 }
1471                 decl.push_str(")");
1472             }
1473             Type::Flags(count) => {
1474                 decl.push_str("(flags");
1475                 for index in 0..*count {
1476                     uwrite!(decl, r#" "F{index}""#);
1477                 }
1478                 decl.push_str(")");
1479             }
1480         }
1481         decl.push_str(")\n");
1482         uwriteln!(decl, "(import \"t{idx}\" (type $t{idx} (eq $t{idx}')))");
1483         decl
1484     }
1485 }
1486 
1487 /// Represents custom fragments of a WAT file which may be used to create a component for exercising [`TestCase`]s
1488 #[derive(Debug)]
1489 pub struct Declarations {
1490     /// Type declarations (if any) referenced by `params` and/or `result`
1491     pub types: Cow<'static, str>,
1492     /// Types to thread through when instantiating sub-components.
1493     pub type_instantiation_args: Cow<'static, str>,
1494     /// Parameter declarations used for the imported and exported functions
1495     pub params: Cow<'static, str>,
1496     /// Result declaration used for the imported and exported functions
1497     pub results: Cow<'static, str>,
1498     /// Implementation of the "caller" component, which invokes the `callee`
1499     /// composed component.
1500     pub caller_module: Cow<'static, str>,
1501     /// Implementation of the "callee" component, which invokes the host.
1502     pub callee_module: Cow<'static, str>,
1503     /// Options used for caller/calle ABI/etc.
1504     pub options: TestCaseOptions,
1505 }
1506 
1507 impl Declarations {
1508     /// Generate a complete WAT file based on the specified fragments.
make_component(&self) -> Box<str>1509     pub fn make_component(&self) -> Box<str> {
1510         let Self {
1511             types,
1512             type_instantiation_args,
1513             params,
1514             results,
1515             caller_module,
1516             callee_module,
1517             options,
1518         } = self;
1519         let mk_component = |name: &str,
1520                             module: &str,
1521                             import_async: bool,
1522                             export_async: bool,
1523                             encoding: StringEncoding,
1524                             lift_abi: LiftAbi,
1525                             lower_abi: LowerAbi| {
1526             let import_async = if import_async { "async" } else { "" };
1527             let export_async = if export_async { "async" } else { "" };
1528             let lower_async_option = match lower_abi {
1529                 LowerAbi::Sync => "",
1530                 LowerAbi::Async => "async",
1531             };
1532             let lift_async_option = match lift_abi {
1533                 LiftAbi::Sync => "",
1534                 LiftAbi::AsyncStackful => "async",
1535                 LiftAbi::AsyncCallback => "async (callback (func $i \"callback\"))",
1536             };
1537 
1538             let mut intrinsic_defs = String::new();
1539             let mut intrinsic_imports = String::new();
1540 
1541             match lift_abi {
1542                 LiftAbi::Sync => {}
1543                 LiftAbi::AsyncCallback | LiftAbi::AsyncStackful => {
1544                     intrinsic_defs.push_str(&format!(
1545                         r#"
1546 (core func $task.return (canon task.return {results}
1547     (memory $libc "memory") string-encoding={encoding}))
1548                         "#,
1549                     ));
1550                     intrinsic_imports.push_str(
1551                         r#"
1552 (with "" (instance (export "task.return" (func $task.return))))
1553                         "#,
1554                     );
1555                 }
1556             }
1557 
1558             format!(
1559                 r#"
1560 (component ${name}
1561     {types}
1562     (type $import_sig (func {import_async} {params} {results}))
1563     (type $export_sig (func {export_async} {params} {results}))
1564     (import "{IMPORT_FUNCTION}" (func $f (type $import_sig)))
1565 
1566     (core instance $libc (instantiate $libc))
1567 
1568     (core func $f_lower (canon lower
1569         (func $f)
1570         (memory $libc "memory")
1571         (realloc (func $libc "realloc"))
1572         string-encoding={encoding}
1573         {lower_async_option}
1574     ))
1575 
1576     {intrinsic_defs}
1577 
1578     (core module $m
1579         (memory (import "libc" "memory") 1)
1580         (func $realloc (import "libc" "realloc") (param i32 i32 i32 i32) (result i32))
1581 
1582         {module}
1583     )
1584 
1585     (core instance $i (instantiate $m
1586         (with "libc" (instance $libc))
1587         (with "host" (instance (export "{IMPORT_FUNCTION}" (func $f_lower))))
1588         {intrinsic_imports}
1589     ))
1590 
1591     (func (export "{EXPORT_FUNCTION}") (type $export_sig)
1592         (canon lift
1593             (core func $i "{EXPORT_FUNCTION}")
1594             (memory $libc "memory")
1595             (realloc (func $libc "realloc"))
1596             string-encoding={encoding}
1597             {lift_async_option}
1598         )
1599     )
1600 )
1601             "#
1602             )
1603         };
1604 
1605         let c1 = mk_component(
1606             "callee",
1607             &callee_module,
1608             options.host_async,
1609             options.guest_callee_async,
1610             options.callee_encoding,
1611             options.callee_lift_abi,
1612             options.callee_lower_abi,
1613         );
1614         let c2 = mk_component(
1615             "caller",
1616             &caller_module,
1617             options.guest_callee_async,
1618             options.guest_caller_async,
1619             options.caller_encoding,
1620             options.caller_lift_abi,
1621             options.caller_lower_abi,
1622         );
1623         let host_async = if options.host_async { "async" } else { "" };
1624 
1625         format!(
1626             r#"
1627             (component
1628                 (core module $libc
1629                     (memory (export "memory") 1)
1630                     {REALLOC_AND_FREE}
1631                 )
1632 
1633 
1634                 {types}
1635 
1636                 (type $host_sig (func {host_async} {params} {results}))
1637                 (import "{IMPORT_FUNCTION}" (func $f (type $host_sig)))
1638 
1639                 {c1}
1640                 {c2}
1641                 (instance $c1 (instantiate $callee
1642                     {type_instantiation_args}
1643                     (with "{IMPORT_FUNCTION}" (func $f))
1644                 ))
1645                 (instance $c2 (instantiate $caller
1646                     {type_instantiation_args}
1647                     (with "{IMPORT_FUNCTION}" (func $c1 "{EXPORT_FUNCTION}"))
1648                 ))
1649                 (export "{EXPORT_FUNCTION}" (func $c2 "{EXPORT_FUNCTION}"))
1650             )"#,
1651         )
1652         .into()
1653     }
1654 }
1655 
1656 /// Represents a test case for calling a component function
1657 #[derive(Debug)]
1658 pub struct TestCase<'a> {
1659     /// The types of parameters to pass to the function
1660     pub params: Vec<&'a Type>,
1661     /// The result types of the function
1662     pub result: Option<&'a Type>,
1663     /// ABI options to use for this test case.
1664     pub options: TestCaseOptions,
1665 }
1666 
1667 /// Collection of options which configure how the caller/callee/etc ABIs are
1668 /// all configured.
1669 #[derive(Debug, Arbitrary, Copy, Clone)]
1670 pub struct TestCaseOptions {
1671     /// Whether or not the guest caller component (the entrypoint) is using an
1672     /// `async` function type.
1673     pub guest_caller_async: bool,
1674     /// Whether or not the guest callee component (what the entrypoint calls)
1675     /// is using an `async` function type.
1676     pub guest_callee_async: bool,
1677     /// Whether or not the host is using an async function type (what the
1678     /// guest callee calls).
1679     pub host_async: bool,
1680     /// The string encoding of the caller component.
1681     pub caller_encoding: StringEncoding,
1682     /// The string encoding of the callee component.
1683     pub callee_encoding: StringEncoding,
1684     /// The ABI that the caller component is using to lift its export (the main
1685     /// entrypoint).
1686     pub caller_lift_abi: LiftAbi,
1687     /// The ABI that the callee component is using to lift its export (called
1688     /// by the caller).
1689     pub callee_lift_abi: LiftAbi,
1690     /// The ABI that the caller component is using to lower its import (the
1691     /// callee's export).
1692     pub caller_lower_abi: LowerAbi,
1693     /// The ABI that the callee component is using to lower its import (the
1694     /// host function).
1695     pub callee_lower_abi: LowerAbi,
1696 }
1697 
1698 #[derive(Debug, Arbitrary, Copy, Clone)]
1699 pub enum LiftAbi {
1700     Sync,
1701     AsyncStackful,
1702     AsyncCallback,
1703 }
1704 
1705 #[derive(Debug, Arbitrary, Copy, Clone)]
1706 pub enum LowerAbi {
1707     Sync,
1708     Async,
1709 }
1710 
1711 impl<'a> TestCase<'a> {
generate(types: &'a [Type], u: &mut Unstructured<'_>) -> arbitrary::Result<Self>1712     pub fn generate(types: &'a [Type], u: &mut Unstructured<'_>) -> arbitrary::Result<Self> {
1713         let max_params = if types.len() > 0 { 5 } else { 0 };
1714         let params = (0..u.int_in_range(0..=max_params)?)
1715             .map(|_| u.choose(&types))
1716             .collect::<arbitrary::Result<Vec<_>>>()?;
1717         let result = if types.len() > 0 && u.arbitrary()? {
1718             Some(u.choose(&types)?)
1719         } else {
1720             None
1721         };
1722 
1723         let mut options = u.arbitrary::<TestCaseOptions>()?;
1724 
1725         // Sync tasks cannot call async functions via a sync lower, nor can they
1726         // block in other ways (e.g. by calling `waitable-set.wait`, returning
1727         // `CALLBACK_CODE_WAIT`, etc.) prior to returning.  Therefore,
1728         // async-ness cascades to the callers:
1729         if options.host_async {
1730             options.guest_callee_async = true;
1731         }
1732         if options.guest_callee_async {
1733             options.guest_caller_async = true;
1734         }
1735 
1736         Ok(Self {
1737             params,
1738             result,
1739             options,
1740         })
1741     }
1742 
1743     /// Generate a `Declarations` for this `TestCase` which may be used to build a component to execute the case.
declarations(&self) -> Declarations1744     pub fn declarations(&self) -> Declarations {
1745         let mut builder = TypesBuilder::default();
1746 
1747         let mut params = String::new();
1748         for (i, ty) in self.params.iter().enumerate() {
1749             params.push_str(&format!(" (param \"p{i}\" "));
1750             builder.write_ref(ty, &mut params);
1751             params.push_str(")");
1752         }
1753 
1754         let mut results = String::new();
1755         if let Some(ty) = self.result {
1756             results.push_str(&format!(" (result "));
1757             builder.write_ref(ty, &mut results);
1758             results.push_str(")");
1759         }
1760 
1761         let caller_module = make_import_and_export(
1762             &self.params,
1763             self.result,
1764             self.options.caller_lift_abi,
1765             self.options.caller_lower_abi,
1766         );
1767         let callee_module = make_import_and_export(
1768             &self.params,
1769             self.result,
1770             self.options.callee_lift_abi,
1771             self.options.callee_lower_abi,
1772         );
1773 
1774         let mut type_decls = Vec::new();
1775         let mut type_instantiation_args = String::new();
1776         while let Some((idx, ty)) = builder.worklist.pop() {
1777             type_decls.push(builder.write_decl(idx, ty));
1778             uwriteln!(type_instantiation_args, "(with \"t{idx}\" (type $t{idx}))");
1779         }
1780 
1781         // Note that types are printed here in reverse order since they were
1782         // pushed onto `type_decls` as they were referenced meaning the last one
1783         // is the "base" one.
1784         let mut types = String::new();
1785         for decl in type_decls.into_iter().rev() {
1786             types.push_str(&decl);
1787             types.push_str("\n");
1788         }
1789 
1790         Declarations {
1791             types: types.into(),
1792             type_instantiation_args: type_instantiation_args.into(),
1793             params: params.into(),
1794             results: results.into(),
1795             caller_module: caller_module.into(),
1796             callee_module: callee_module.into(),
1797             options: self.options,
1798         }
1799     }
1800 }
1801 
1802 #[derive(Copy, Clone, Debug, Arbitrary)]
1803 pub enum StringEncoding {
1804     Utf8,
1805     Utf16,
1806     Latin1OrUtf16,
1807 }
1808 
1809 impl fmt::Display for StringEncoding {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result1810     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1811         match self {
1812             StringEncoding::Utf8 => fmt::Display::fmt(&"utf8", f),
1813             StringEncoding::Utf16 => fmt::Display::fmt(&"utf16", f),
1814             StringEncoding::Latin1OrUtf16 => fmt::Display::fmt(&"latin1+utf16", f),
1815         }
1816     }
1817 }
1818 
1819 impl ToTokens for TestCaseOptions {
to_tokens(&self, tokens: &mut TokenStream)1820     fn to_tokens(&self, tokens: &mut TokenStream) {
1821         let TestCaseOptions {
1822             guest_caller_async,
1823             guest_callee_async,
1824             host_async,
1825             caller_encoding,
1826             callee_encoding,
1827             caller_lift_abi,
1828             callee_lift_abi,
1829             caller_lower_abi,
1830             callee_lower_abi,
1831         } = self;
1832         tokens.extend(quote!(wasmtime_test_util::component_fuzz::TestCaseOptions {
1833             guest_caller_async: #guest_caller_async,
1834             guest_callee_async: #guest_callee_async,
1835             host_async: #host_async,
1836             caller_encoding: #caller_encoding,
1837             callee_encoding: #callee_encoding,
1838             caller_lift_abi: #caller_lift_abi,
1839             callee_lift_abi: #callee_lift_abi,
1840             caller_lower_abi: #caller_lower_abi,
1841             callee_lower_abi: #callee_lower_abi,
1842         }));
1843     }
1844 }
1845 
1846 impl ToTokens for LowerAbi {
to_tokens(&self, tokens: &mut TokenStream)1847     fn to_tokens(&self, tokens: &mut TokenStream) {
1848         let me = match self {
1849             LowerAbi::Sync => quote!(Sync),
1850             LowerAbi::Async => quote!(Async),
1851         };
1852         tokens.extend(quote!(wasmtime_test_util::component_fuzz::LowerAbi::#me));
1853     }
1854 }
1855 
1856 impl ToTokens for LiftAbi {
to_tokens(&self, tokens: &mut TokenStream)1857     fn to_tokens(&self, tokens: &mut TokenStream) {
1858         let me = match self {
1859             LiftAbi::Sync => quote!(Sync),
1860             LiftAbi::AsyncCallback => quote!(AsyncCallback),
1861             LiftAbi::AsyncStackful => quote!(AsyncStackful),
1862         };
1863         tokens.extend(quote!(wasmtime_test_util::component_fuzz::LiftAbi::#me));
1864     }
1865 }
1866 
1867 impl ToTokens for StringEncoding {
to_tokens(&self, tokens: &mut TokenStream)1868     fn to_tokens(&self, tokens: &mut TokenStream) {
1869         let me = match self {
1870             StringEncoding::Utf8 => quote!(Utf8),
1871             StringEncoding::Utf16 => quote!(Utf16),
1872             StringEncoding::Latin1OrUtf16 => quote!(Latin1OrUtf16),
1873         };
1874         tokens.extend(quote!(wasmtime_test_util::component_fuzz::StringEncoding::#me));
1875     }
1876 }
1877 
1878 #[cfg(test)]
1879 mod tests {
1880     use super::*;
1881 
1882     #[test]
arbtest()1883     fn arbtest() {
1884         arbtest::arbtest(|u| {
1885             let mut fuel = 100;
1886             let types = (0..5)
1887                 .map(|_| Type::generate(u, 3, &mut fuel))
1888                 .collect::<arbitrary::Result<Vec<_>>>()?;
1889             let case = TestCase::generate(&types, u)?;
1890             let decls = case.declarations();
1891             let component = decls.make_component();
1892             let wasm = wat::parse_str(&component).unwrap_or_else(|e| {
1893                 panic!("failed to parse generated component as wat: {e}\n\n{component}");
1894             });
1895             wasmparser::Validator::new_with_features(wasmparser::WasmFeatures::all())
1896                 .validate_all(&wasm)
1897                 .unwrap_or_else(|e| {
1898                     let mut wat = String::new();
1899                     let mut dst = wasmprinter::PrintFmtWrite(&mut wat);
1900                     let to_print = if wasmprinter::Config::new()
1901                         .print_offsets(true)
1902                         .print_operand_stack(true)
1903                         .print(&wasm, &mut dst)
1904                         .is_ok()
1905                     {
1906                         &wat[..]
1907                     } else {
1908                         &component[..]
1909                     };
1910                     panic!("generated component is not valid wasm: {e}\n\n{to_print}");
1911                 });
1912             Ok(())
1913         })
1914         .budget_ms(1_000)
1915         // .seed(0x3c9050d4000000e9)
1916         ;
1917     }
1918 }
1919