1 //! This module generates test cases for the Wasmtime component model function APIs, 2 //! e.g. `wasmtime::component::func::Func` and `TypedFunc`. 3 //! 4 //! Each case includes a list of arbitrary interface types to use as parameters, plus another one to use as a 5 //! result, and a component which exports a function and imports a function. The exported function forwards its 6 //! parameters to the imported one and forwards the result back to the caller. This serves to exercise Wasmtime's 7 //! lifting and lowering code and verify the values remain intact during both processes. 8 9 use arbitrary::{Arbitrary, Unstructured}; 10 use indexmap::IndexSet; 11 use proc_macro2::{Ident, TokenStream}; 12 use quote::{ToTokens, format_ident, quote}; 13 use std::borrow::Cow; 14 use std::fmt::{self, Debug, Write}; 15 use std::hash::{Hash, Hasher}; 16 use std::iter; 17 use std::ops::Deref; 18 use wasmtime_component_util::{DiscriminantSize, FlagsSize, REALLOC_AND_FREE}; 19 20 const MAX_FLAT_PARAMS: usize = 16; 21 const MAX_FLAT_ASYNC_PARAMS: usize = 4; 22 const MAX_FLAT_RESULTS: usize = 1; 23 24 /// The name of the imported host function which the generated component will call 25 pub const IMPORT_FUNCTION: &str = "echo-import"; 26 27 /// The name of the exported guest function which the host should call 28 pub const EXPORT_FUNCTION: &str = "echo-export"; 29 30 /// Wasmtime allows up to 100 type depth so limit this to just under that. 31 pub const MAX_TYPE_DEPTH: u32 = 99; 32 33 macro_rules! uwriteln { 34 ($($arg:tt)*) => { 35 writeln!($($arg)*).unwrap() 36 }; 37 } 38 39 macro_rules! uwrite { 40 ($($arg:tt)*) => { 41 write!($($arg)*).unwrap() 42 }; 43 } 44 45 #[derive(Debug, Copy, Clone, PartialEq, Eq)] 46 enum CoreType { 47 I32, 48 I64, 49 F32, 50 F64, 51 } 52 53 impl CoreType { 54 /// This is the `join` operation specified in [the canonical 55 /// ABI](https://github.com/WebAssembly/component-model/blob/main/design/mvp/CanonicalABI.md#flattening) for 56 /// variant types. 57 fn join(self, other: Self) -> Self { 58 match (self, other) { 59 _ if self == other => self, 60 (Self::I32, Self::F32) | (Self::F32, Self::I32) => Self::I32, 61 _ => Self::I64, 62 } 63 } 64 } 65 66 impl fmt::Display for CoreType { 67 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 68 match self { 69 Self::I32 => f.write_str("i32"), 70 Self::I64 => f.write_str("i64"), 71 Self::F32 => f.write_str("f32"), 72 Self::F64 => f.write_str("f64"), 73 } 74 } 75 } 76 77 /// Wraps a `Box<[T]>` and provides an `Arbitrary` implementation that always generates slices of length less than 78 /// or equal to the longest tuple for which Wasmtime generates a `ComponentType` impl 79 #[derive(Debug, Clone)] 80 pub struct VecInRange<T, const L: u32, const H: u32>(Vec<T>); 81 82 impl<T, const L: u32, const H: u32> VecInRange<T, L, H> { 83 fn new<'a>( 84 input: &mut Unstructured<'a>, 85 fuel: &mut u32, 86 generate: impl Fn(&mut Unstructured<'a>, &mut u32) -> arbitrary::Result<T>, 87 ) -> arbitrary::Result<Self> { 88 let mut ret = Vec::new(); 89 input.arbitrary_loop(Some(L), Some(H), |input| { 90 if *fuel > 0 { 91 *fuel = *fuel - 1; 92 ret.push(generate(input, fuel)?); 93 Ok(std::ops::ControlFlow::Continue(())) 94 } else { 95 Ok(std::ops::ControlFlow::Break(())) 96 } 97 })?; 98 Ok(Self(ret)) 99 } 100 } 101 102 impl<T, const L: u32, const H: u32> Deref for VecInRange<T, L, H> { 103 type Target = [T]; 104 105 fn deref(&self) -> &[T] { 106 self.0.deref() 107 } 108 } 109 110 /// Represents a component model interface type 111 #[expect(missing_docs, reason = "self-describing")] 112 #[derive(Debug, Clone)] 113 pub enum Type { 114 Bool, 115 S8, 116 U8, 117 S16, 118 U16, 119 S32, 120 U32, 121 S64, 122 U64, 123 Float32, 124 Float64, 125 Char, 126 String, 127 List(Box<Type>), 128 129 // Give records the ability to generate a generous amount of fields but 130 // don't let the fuzzer go too wild since `wasmparser`'s validator currently 131 // has hard limits in the 1000-ish range on the number of fields a record 132 // may contain. 133 Record(VecInRange<Type, 1, 200>), 134 135 // Tuples can only have up to 16 type parameters in wasmtime right now for 136 // the static API, but the standard library only supports `Debug` up to 11 137 // elements, so compromise at an even 10. 138 Tuple(VecInRange<Type, 1, 10>), 139 140 // Like records, allow a good number of variants, but variants require at 141 // least one case. 142 Variant(VecInRange<Option<Type>, 1, 200>), 143 Enum(u32), 144 145 Option(Box<Type>), 146 Result { 147 ok: Option<Box<Type>>, 148 err: Option<Box<Type>>, 149 }, 150 151 Flags(u32), 152 } 153 154 impl Type { 155 pub fn generate( 156 u: &mut Unstructured<'_>, 157 depth: u32, 158 fuel: &mut u32, 159 ) -> arbitrary::Result<Type> { 160 *fuel = fuel.saturating_sub(1); 161 let max = if depth == 0 || *fuel == 0 { 12 } else { 20 }; 162 Ok(match u.int_in_range(0..=max)? { 163 0 => Type::Bool, 164 1 => Type::S8, 165 2 => Type::U8, 166 3 => Type::S16, 167 4 => Type::U16, 168 5 => Type::S32, 169 6 => Type::U32, 170 7 => Type::S64, 171 8 => Type::U64, 172 9 => Type::Float32, 173 10 => Type::Float64, 174 11 => Type::Char, 175 12 => Type::String, 176 // ^-- if you add something here update the `depth == 0` case above 177 13 => Type::List(Box::new(Type::generate(u, depth - 1, fuel)?)), 178 14 => Type::Record(Type::generate_list(u, depth - 1, fuel)?), 179 15 => Type::Tuple(Type::generate_list(u, depth - 1, fuel)?), 180 16 => Type::Variant(VecInRange::new(u, fuel, |u, fuel| { 181 Type::generate_opt(u, depth - 1, fuel) 182 })?), 183 17 => { 184 let amt = u.int_in_range(1..=(*fuel).max(1).min(257))?; 185 *fuel -= amt; 186 Type::Enum(amt) 187 } 188 18 => Type::Option(Box::new(Type::generate(u, depth - 1, fuel)?)), 189 19 => Type::Result { 190 ok: Type::generate_opt(u, depth - 1, fuel)?.map(Box::new), 191 err: Type::generate_opt(u, depth - 1, fuel)?.map(Box::new), 192 }, 193 20 => { 194 let amt = u.int_in_range(1..=(*fuel).min(32))?; 195 *fuel -= amt; 196 Type::Flags(amt) 197 } 198 // ^-- if you add something here update the `depth != 0` case above 199 _ => unreachable!(), 200 }) 201 } 202 203 fn generate_opt( 204 u: &mut Unstructured<'_>, 205 depth: u32, 206 fuel: &mut u32, 207 ) -> arbitrary::Result<Option<Type>> { 208 Ok(if u.arbitrary()? { 209 Some(Type::generate(u, depth, fuel)?) 210 } else { 211 None 212 }) 213 } 214 215 fn generate_list<const L: u32, const H: u32>( 216 u: &mut Unstructured<'_>, 217 depth: u32, 218 fuel: &mut u32, 219 ) -> arbitrary::Result<VecInRange<Type, L, H>> { 220 VecInRange::new(u, fuel, |u, fuel| Type::generate(u, depth, fuel)) 221 } 222 223 /// Generates text format wasm into `s` to store a value of this type, in 224 /// its flat representation stored in the `locals` provided, to the local 225 /// named `ptr` at the `offset` provided. 226 /// 227 /// This will register helper functions necessary in `helpers`. The 228 /// `locals` iterator will be advanced for all locals consumed by this 229 /// store operation. 230 fn store_flat<'a>( 231 &'a self, 232 s: &mut String, 233 ptr: &str, 234 offset: u32, 235 locals: &mut dyn Iterator<Item = FlatSource>, 236 helpers: &mut IndexSet<Helper<'a>>, 237 ) { 238 enum Kind { 239 Primitive(&'static str), 240 PointerPair, 241 Helper, 242 } 243 let kind = match self { 244 Type::Bool | Type::S8 | Type::U8 => Kind::Primitive("i32.store8"), 245 Type::S16 | Type::U16 => Kind::Primitive("i32.store16"), 246 Type::S32 | Type::U32 | Type::Char => Kind::Primitive("i32.store"), 247 Type::S64 | Type::U64 => Kind::Primitive("i64.store"), 248 Type::Float32 => Kind::Primitive("f32.store"), 249 Type::Float64 => Kind::Primitive("f64.store"), 250 Type::String | Type::List(_) => Kind::PointerPair, 251 Type::Enum(n) if *n <= (1 << 8) => Kind::Primitive("i32.store8"), 252 Type::Enum(n) if *n <= (1 << 16) => Kind::Primitive("i32.store16"), 253 Type::Enum(_) => Kind::Primitive("i32.store"), 254 Type::Flags(n) if *n <= 8 => Kind::Primitive("i32.store8"), 255 Type::Flags(n) if *n <= 16 => Kind::Primitive("i32.store16"), 256 Type::Flags(n) if *n <= 32 => Kind::Primitive("i32.store"), 257 Type::Flags(_) => unreachable!(), 258 Type::Record(_) 259 | Type::Tuple(_) 260 | Type::Variant(_) 261 | Type::Option(_) 262 | Type::Result { .. } => Kind::Helper, 263 }; 264 265 match kind { 266 Kind::Primitive(op) => uwriteln!( 267 s, 268 "({op} offset={offset} (local.get {ptr}) {})", 269 locals.next().unwrap() 270 ), 271 Kind::PointerPair => { 272 let abi_ptr = locals.next().unwrap(); 273 let abi_len = locals.next().unwrap(); 274 uwriteln!(s, "(i32.store offset={offset} (local.get {ptr}) {abi_ptr})",); 275 let offset = offset + 4; 276 uwriteln!(s, "(i32.store offset={offset} (local.get {ptr}) {abi_len})",); 277 } 278 Kind::Helper => { 279 let (index, _) = helpers.insert_full(Helper(self)); 280 uwriteln!(s, "(i32.add (local.get {ptr}) (i32.const {offset}))"); 281 for _ in 0..self.lowered().len() { 282 let i = locals.next().unwrap(); 283 uwriteln!(s, "{i}"); 284 } 285 uwriteln!(s, "call $store_helper_{index}"); 286 } 287 } 288 } 289 290 /// Generates a text-format wasm function which takes a pointer and this 291 /// type's flat representation as arguments and then stores this value in 292 /// the first argument. 293 /// 294 /// This is used to store records/variants to cut down on the size of final 295 /// functions and make codegen here a bit easier. 296 fn store_flat_helper<'a>( 297 &'a self, 298 s: &mut String, 299 i: usize, 300 helpers: &mut IndexSet<Helper<'a>>, 301 ) { 302 uwrite!(s, "(func $store_helper_{i} (param i32)"); 303 let lowered = self.lowered(); 304 for ty in &lowered { 305 uwrite!(s, " (param {ty})"); 306 } 307 s.push_str("\n"); 308 let locals = (0..lowered.len() as u32).map(|i| i + 1).collect::<Vec<_>>(); 309 let record = |s: &mut String, helpers: &mut IndexSet<Helper<'a>>, types: &'a [Type]| { 310 let mut locals = locals.iter().cloned().map(FlatSource::Local); 311 for (offset, ty) in record_field_offsets(types) { 312 ty.store_flat(s, "0", offset, &mut locals, helpers); 313 } 314 assert!(locals.next().is_none()); 315 }; 316 let variant = |s: &mut String, 317 helpers: &mut IndexSet<Helper<'a>>, 318 types: &[Option<&'a Type>]| { 319 let (size, offset) = variant_memory_info(types.iter().cloned()); 320 // One extra block for out-of-bounds discriminants. 321 for _ in 0..types.len() + 1 { 322 s.push_str("block\n"); 323 } 324 325 // Store the discriminant in memory, then branch on it to figure 326 // out which case we're in. 327 let store = match size { 328 DiscriminantSize::Size1 => "i32.store8", 329 DiscriminantSize::Size2 => "i32.store16", 330 DiscriminantSize::Size4 => "i32.store", 331 }; 332 uwriteln!(s, "({store} (local.get 0) (local.get 1))"); 333 s.push_str("local.get 1\n"); 334 s.push_str("br_table"); 335 for i in 0..types.len() + 1 { 336 uwrite!(s, " {i}"); 337 } 338 s.push_str("\nend\n"); 339 340 // Store each payload individually while converting locals from 341 // their source types to the precise type necessary for this 342 // variant. 343 for ty in types { 344 if let Some(ty) = ty { 345 let ty_lowered = ty.lowered(); 346 let mut locals = locals[1..].iter().zip(&lowered[1..]).zip(&ty_lowered).map( 347 |((i, from), to)| FlatSource::LocalConvert { 348 local: *i, 349 from: *from, 350 to: *to, 351 }, 352 ); 353 ty.store_flat(s, "0", offset, &mut locals, helpers); 354 } 355 s.push_str("return\n"); 356 s.push_str("end\n"); 357 } 358 359 // Catch-all result which is for out-of-bounds discriminants. 360 s.push_str("unreachable\n"); 361 }; 362 match self { 363 Type::Bool 364 | Type::S8 365 | Type::U8 366 | Type::S16 367 | Type::U16 368 | Type::S32 369 | Type::U32 370 | Type::Char 371 | Type::S64 372 | Type::U64 373 | Type::Float32 374 | Type::Float64 375 | Type::String 376 | Type::List(_) 377 | Type::Flags(_) 378 | Type::Enum(_) => unreachable!(), 379 380 Type::Record(r) => record(s, helpers, r), 381 Type::Tuple(t) => record(s, helpers, t), 382 Type::Variant(v) => variant( 383 s, 384 helpers, 385 &v.iter().map(|t| t.as_ref()).collect::<Vec<_>>(), 386 ), 387 Type::Option(o) => variant(s, helpers, &[None, Some(&**o)]), 388 Type::Result { ok, err } => variant(s, helpers, &[ok.as_deref(), err.as_deref()]), 389 }; 390 s.push_str(")\n"); 391 } 392 393 /// Same as `store_flat`, except loads the flat values from `ptr+offset`. 394 /// 395 /// Results are placed directly on the wasm stack. 396 fn load_flat<'a>( 397 &'a self, 398 s: &mut String, 399 ptr: &str, 400 offset: u32, 401 helpers: &mut IndexSet<Helper<'a>>, 402 ) { 403 enum Kind { 404 Primitive(&'static str), 405 PointerPair, 406 Helper, 407 } 408 let kind = match self { 409 Type::Bool | Type::U8 => Kind::Primitive("i32.load8_u"), 410 Type::S8 => Kind::Primitive("i32.load8_s"), 411 Type::U16 => Kind::Primitive("i32.load16_u"), 412 Type::S16 => Kind::Primitive("i32.load16_s"), 413 Type::U32 | Type::S32 | Type::Char => Kind::Primitive("i32.load"), 414 Type::U64 | Type::S64 => Kind::Primitive("i64.load"), 415 Type::Float32 => Kind::Primitive("f32.load"), 416 Type::Float64 => Kind::Primitive("f64.load"), 417 Type::String | Type::List(_) => Kind::PointerPair, 418 Type::Enum(n) if *n <= (1 << 8) => Kind::Primitive("i32.load8_u"), 419 Type::Enum(n) if *n <= (1 << 16) => Kind::Primitive("i32.load16_u"), 420 Type::Enum(_) => Kind::Primitive("i32.load"), 421 Type::Flags(n) if *n <= 8 => Kind::Primitive("i32.load8_u"), 422 Type::Flags(n) if *n <= 16 => Kind::Primitive("i32.load16_u"), 423 Type::Flags(n) if *n <= 32 => Kind::Primitive("i32.load"), 424 Type::Flags(_) => unreachable!(), 425 426 Type::Record(_) 427 | Type::Tuple(_) 428 | Type::Variant(_) 429 | Type::Option(_) 430 | Type::Result { .. } => Kind::Helper, 431 }; 432 match kind { 433 Kind::Primitive(op) => uwriteln!(s, "({op} offset={offset} (local.get {ptr}))"), 434 Kind::PointerPair => { 435 uwriteln!(s, "(i32.load offset={offset} (local.get {ptr}))",); 436 let offset = offset + 4; 437 uwriteln!(s, "(i32.load offset={offset} (local.get {ptr}))",); 438 } 439 Kind::Helper => { 440 let (index, _) = helpers.insert_full(Helper(self)); 441 uwriteln!(s, "(i32.add (local.get {ptr}) (i32.const {offset}))"); 442 uwriteln!(s, "call $load_helper_{index}"); 443 } 444 } 445 } 446 447 /// Same as `store_flat_helper` but for loading the flat representation. 448 fn load_flat_helper<'a>( 449 &'a self, 450 s: &mut String, 451 i: usize, 452 helpers: &mut IndexSet<Helper<'a>>, 453 ) { 454 uwrite!(s, "(func $load_helper_{i} (param i32)"); 455 let lowered = self.lowered(); 456 for ty in &lowered { 457 uwrite!(s, " (result {ty})"); 458 } 459 s.push_str("\n"); 460 let record = |s: &mut String, helpers: &mut IndexSet<Helper<'a>>, types: &'a [Type]| { 461 for (offset, ty) in record_field_offsets(types) { 462 ty.load_flat(s, "0", offset, helpers); 463 } 464 }; 465 let variant = |s: &mut String, 466 helpers: &mut IndexSet<Helper<'a>>, 467 types: &[Option<&'a Type>]| { 468 let (size, offset) = variant_memory_info(types.iter().cloned()); 469 470 // Destination locals where the flat representation will be stored. 471 // These are automatically zero which handles unused fields too. 472 for (i, ty) in lowered.iter().enumerate() { 473 uwriteln!(s, " (local $r{i} {ty})"); 474 } 475 476 // Return block each case jumps to after setting all locals. 477 s.push_str("block $r\n"); 478 479 // One extra block for "out of bounds discriminant". 480 for _ in 0..types.len() + 1 { 481 s.push_str("block\n"); 482 } 483 484 // Load the discriminant and branch on it, storing it in 485 // `$r0` as well which is the first flat local representation. 486 let load = match size { 487 DiscriminantSize::Size1 => "i32.load8_u", 488 DiscriminantSize::Size2 => "i32.load16", 489 DiscriminantSize::Size4 => "i32.load", 490 }; 491 uwriteln!(s, "({load} (local.get 0))"); 492 s.push_str("local.tee $r0\n"); 493 s.push_str("br_table"); 494 for i in 0..types.len() + 1 { 495 uwrite!(s, " {i}"); 496 } 497 s.push_str("\nend\n"); 498 499 // For each payload, which is in its own block, load payloads from 500 // memory as necessary and convert them into the final locals. 501 for ty in types { 502 if let Some(ty) = ty { 503 let ty_lowered = ty.lowered(); 504 ty.load_flat(s, "0", offset, helpers); 505 for (i, (from, to)) in ty_lowered.iter().zip(&lowered[1..]).enumerate().rev() { 506 let i = i + 1; 507 match (from, to) { 508 (CoreType::F32, CoreType::I32) => { 509 s.push_str("i32.reinterpret_f32\n"); 510 } 511 (CoreType::I32, CoreType::I64) => { 512 s.push_str("i64.extend_i32_u\n"); 513 } 514 (CoreType::F32, CoreType::I64) => { 515 s.push_str("i32.reinterpret_f32\n"); 516 s.push_str("i64.extend_i32_u\n"); 517 } 518 (CoreType::F64, CoreType::I64) => { 519 s.push_str("i64.reinterpret_f64\n"); 520 } 521 (a, b) if a == b => {} 522 _ => unimplemented!("convert {from:?} to {to:?}"), 523 } 524 uwriteln!(s, "local.set $r{i}"); 525 } 526 } 527 s.push_str("br $r\n"); 528 s.push_str("end\n"); 529 } 530 531 // The catch-all block for out-of-bounds discriminants. 532 s.push_str("unreachable\n"); 533 s.push_str("end\n"); 534 for i in 0..lowered.len() { 535 uwriteln!(s, " local.get $r{i}"); 536 } 537 }; 538 539 match self { 540 Type::Bool 541 | Type::S8 542 | Type::U8 543 | Type::S16 544 | Type::U16 545 | Type::S32 546 | Type::U32 547 | Type::Char 548 | Type::S64 549 | Type::U64 550 | Type::Float32 551 | Type::Float64 552 | Type::String 553 | Type::List(_) 554 | Type::Flags(_) 555 | Type::Enum(_) => unreachable!(), 556 557 Type::Record(r) => record(s, helpers, r), 558 Type::Tuple(t) => record(s, helpers, t), 559 Type::Variant(v) => variant( 560 s, 561 helpers, 562 &v.iter().map(|t| t.as_ref()).collect::<Vec<_>>(), 563 ), 564 Type::Option(o) => variant(s, helpers, &[None, Some(&**o)]), 565 Type::Result { ok, err } => variant(s, helpers, &[ok.as_deref(), err.as_deref()]), 566 }; 567 s.push_str(")\n"); 568 } 569 } 570 571 #[derive(Clone)] 572 enum FlatSource { 573 Local(u32), 574 LocalConvert { 575 local: u32, 576 from: CoreType, 577 to: CoreType, 578 }, 579 } 580 581 impl fmt::Display for FlatSource { 582 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 583 match self { 584 FlatSource::Local(i) => write!(f, "(local.get {i})"), 585 FlatSource::LocalConvert { local, from, to } => { 586 match (from, to) { 587 (a, b) if a == b => write!(f, "(local.get {local})"), 588 (CoreType::I32, CoreType::F32) => { 589 write!(f, "(f32.reinterpret_i32 (local.get {local}))") 590 } 591 (CoreType::I64, CoreType::I32) => { 592 write!(f, "(i32.wrap_i64 (local.get {local}))") 593 } 594 (CoreType::I64, CoreType::F64) => { 595 write!(f, "(f64.reinterpret_i64 (local.get {local}))") 596 } 597 (CoreType::I64, CoreType::F32) => { 598 write!( 599 f, 600 "(f32.reinterpret_i32 (i32.wrap_i64 (local.get {local})))" 601 ) 602 } 603 _ => unimplemented!("convert {from:?} to {to:?}"), 604 } 605 // .. 606 } 607 } 608 } 609 } 610 611 fn lower_record<'a>(types: impl Iterator<Item = &'a Type>, vec: &mut Vec<CoreType>) { 612 for ty in types { 613 ty.lower(vec); 614 } 615 } 616 617 fn lower_variant<'a>(types: impl Iterator<Item = Option<&'a Type>>, vec: &mut Vec<CoreType>) { 618 vec.push(CoreType::I32); 619 let offset = vec.len(); 620 for ty in types { 621 let ty = match ty { 622 Some(ty) => ty, 623 None => continue, 624 }; 625 for (index, ty) in ty.lowered().iter().enumerate() { 626 let index = offset + index; 627 if index < vec.len() { 628 vec[index] = vec[index].join(*ty); 629 } else { 630 vec.push(*ty) 631 } 632 } 633 } 634 } 635 636 fn u32_count_from_flag_count(count: usize) -> usize { 637 match FlagsSize::from_count(count) { 638 FlagsSize::Size0 => 0, 639 FlagsSize::Size1 | FlagsSize::Size2 => 1, 640 FlagsSize::Size4Plus(n) => n.into(), 641 } 642 } 643 644 struct SizeAndAlignment { 645 size: usize, 646 alignment: u32, 647 } 648 649 impl Type { 650 fn lowered(&self) -> Vec<CoreType> { 651 let mut vec = Vec::new(); 652 self.lower(&mut vec); 653 vec 654 } 655 656 fn lower(&self, vec: &mut Vec<CoreType>) { 657 match self { 658 Type::Bool 659 | Type::U8 660 | Type::S8 661 | Type::S16 662 | Type::U16 663 | Type::S32 664 | Type::U32 665 | Type::Char 666 | Type::Enum(_) => vec.push(CoreType::I32), 667 Type::S64 | Type::U64 => vec.push(CoreType::I64), 668 Type::Float32 => vec.push(CoreType::F32), 669 Type::Float64 => vec.push(CoreType::F64), 670 Type::String | Type::List(_) => { 671 vec.push(CoreType::I32); 672 vec.push(CoreType::I32); 673 } 674 Type::Record(types) => lower_record(types.iter(), vec), 675 Type::Tuple(types) => lower_record(types.0.iter(), vec), 676 Type::Variant(types) => lower_variant(types.0.iter().map(|t| t.as_ref()), vec), 677 Type::Option(ty) => lower_variant([None, Some(&**ty)].into_iter(), vec), 678 Type::Result { ok, err } => { 679 lower_variant([ok.as_deref(), err.as_deref()].into_iter(), vec) 680 } 681 Type::Flags(count) => vec.extend( 682 iter::repeat(CoreType::I32).take(u32_count_from_flag_count(*count as usize)), 683 ), 684 } 685 } 686 687 fn size_and_alignment(&self) -> SizeAndAlignment { 688 match self { 689 Type::Bool | Type::S8 | Type::U8 => SizeAndAlignment { 690 size: 1, 691 alignment: 1, 692 }, 693 694 Type::S16 | Type::U16 => SizeAndAlignment { 695 size: 2, 696 alignment: 2, 697 }, 698 699 Type::S32 | Type::U32 | Type::Char | Type::Float32 => SizeAndAlignment { 700 size: 4, 701 alignment: 4, 702 }, 703 704 Type::S64 | Type::U64 | Type::Float64 => SizeAndAlignment { 705 size: 8, 706 alignment: 8, 707 }, 708 709 Type::String | Type::List(_) => SizeAndAlignment { 710 size: 8, 711 alignment: 4, 712 }, 713 714 Type::Record(types) => record_size_and_alignment(types.iter()), 715 716 Type::Tuple(types) => record_size_and_alignment(types.0.iter()), 717 718 Type::Variant(types) => variant_size_and_alignment(types.0.iter().map(|t| t.as_ref())), 719 720 Type::Enum(count) => variant_size_and_alignment((0..*count).map(|_| None)), 721 722 Type::Option(ty) => variant_size_and_alignment([None, Some(&**ty)].into_iter()), 723 724 Type::Result { ok, err } => { 725 variant_size_and_alignment([ok.as_deref(), err.as_deref()].into_iter()) 726 } 727 728 Type::Flags(count) => match FlagsSize::from_count(*count as usize) { 729 FlagsSize::Size0 => SizeAndAlignment { 730 size: 0, 731 alignment: 1, 732 }, 733 FlagsSize::Size1 => SizeAndAlignment { 734 size: 1, 735 alignment: 1, 736 }, 737 FlagsSize::Size2 => SizeAndAlignment { 738 size: 2, 739 alignment: 2, 740 }, 741 FlagsSize::Size4Plus(n) => SizeAndAlignment { 742 size: usize::from(n) * 4, 743 alignment: 4, 744 }, 745 }, 746 } 747 } 748 } 749 750 fn align_to(a: usize, align: u32) -> usize { 751 let align = align as usize; 752 (a + (align - 1)) & !(align - 1) 753 } 754 755 fn record_field_offsets<'a>( 756 types: impl IntoIterator<Item = &'a Type>, 757 ) -> impl Iterator<Item = (u32, &'a Type)> { 758 let mut offset = 0; 759 types.into_iter().map(move |ty| { 760 let SizeAndAlignment { size, alignment } = ty.size_and_alignment(); 761 let ret = align_to(offset, alignment); 762 offset = ret + size; 763 (ret as u32, ty) 764 }) 765 } 766 767 fn record_size_and_alignment<'a>(types: impl IntoIterator<Item = &'a Type>) -> SizeAndAlignment { 768 let mut offset = 0; 769 let mut align = 1; 770 for ty in types { 771 let SizeAndAlignment { size, alignment } = ty.size_and_alignment(); 772 offset = align_to(offset, alignment) + size; 773 align = align.max(alignment); 774 } 775 776 SizeAndAlignment { 777 size: align_to(offset, align), 778 alignment: align, 779 } 780 } 781 782 fn variant_size_and_alignment<'a>( 783 types: impl ExactSizeIterator<Item = Option<&'a Type>>, 784 ) -> SizeAndAlignment { 785 let discriminant_size = DiscriminantSize::from_count(types.len()).unwrap(); 786 let mut alignment = u32::from(discriminant_size); 787 let mut size = 0; 788 for ty in types { 789 if let Some(ty) = ty { 790 let size_and_alignment = ty.size_and_alignment(); 791 alignment = alignment.max(size_and_alignment.alignment); 792 size = size.max(size_and_alignment.size); 793 } 794 } 795 796 SizeAndAlignment { 797 size: align_to( 798 align_to(usize::from(discriminant_size), alignment) + size, 799 alignment, 800 ), 801 alignment, 802 } 803 } 804 805 fn variant_memory_info<'a>( 806 types: impl ExactSizeIterator<Item = Option<&'a Type>>, 807 ) -> (DiscriminantSize, u32) { 808 let discriminant_size = DiscriminantSize::from_count(types.len()).unwrap(); 809 let mut alignment = u32::from(discriminant_size); 810 for ty in types { 811 if let Some(ty) = ty { 812 let size_and_alignment = ty.size_and_alignment(); 813 alignment = alignment.max(size_and_alignment.alignment); 814 } 815 } 816 817 ( 818 discriminant_size, 819 align_to(usize::from(discriminant_size), alignment) as u32, 820 ) 821 } 822 823 /// Generates the internals of a core wasm module which imports a single 824 /// component function `IMPORT_FUNCTION` and exports a single component 825 /// function `EXPORT_FUNCTION`. 826 /// 827 /// The component function takes `params` as arguments and optionally returns 828 /// `result`. The `lift_abi` and `lower_abi` fields indicate the ABI in-use for 829 /// this operation. 830 fn make_import_and_export( 831 params: &[&Type], 832 result: Option<&Type>, 833 lift_abi: LiftAbi, 834 lower_abi: LowerAbi, 835 ) -> String { 836 let params_lowered = params 837 .iter() 838 .flat_map(|ty| ty.lowered()) 839 .collect::<Box<[_]>>(); 840 let result_lowered = result.map(|t| t.lowered()).unwrap_or(Vec::new()); 841 842 let mut wat = String::new(); 843 844 enum Location { 845 Flat, 846 Indirect(u32), 847 } 848 849 // Generate the core wasm type corresponding to the imported function being 850 // lowered with `lower_abi`. 851 wat.push_str(&format!("(type $import (func")); 852 let max_import_params = match lower_abi { 853 LowerAbi::Sync => MAX_FLAT_PARAMS, 854 LowerAbi::Async => MAX_FLAT_ASYNC_PARAMS, 855 }; 856 let (import_params_loc, nparams) = push_params(&mut wat, ¶ms_lowered, max_import_params); 857 let import_results_loc = match lower_abi { 858 LowerAbi::Sync => { 859 push_result_or_retptr(&mut wat, &result_lowered, nparams, MAX_FLAT_RESULTS) 860 } 861 LowerAbi::Async => { 862 let loc = if result.is_none() { 863 Location::Flat 864 } else { 865 wat.push_str(" (param i32)"); // result pointer 866 Location::Indirect(nparams) 867 }; 868 wat.push_str(" (result i32)"); // status code 869 loc 870 } 871 }; 872 wat.push_str("))\n"); 873 874 // Generate the import function. 875 wat.push_str(&format!( 876 r#"(import "host" "{IMPORT_FUNCTION}" (func $host (type $import)))"# 877 )); 878 879 // Do the same as above for the exported function's type which is lifted 880 // with `lift_abi`. 881 // 882 // Note that `export_results_loc` being `None` means that `task.return` is 883 // used to communicate results. 884 wat.push_str(&format!("(type $export (func")); 885 let (export_params_loc, _nparams) = push_params(&mut wat, ¶ms_lowered, MAX_FLAT_PARAMS); 886 let export_results_loc = match lift_abi { 887 LiftAbi::Sync => Some(push_group(&mut wat, "result", &result_lowered, MAX_FLAT_RESULTS).0), 888 LiftAbi::AsyncCallback => { 889 wat.push_str(" (result i32)"); // status code 890 None 891 } 892 LiftAbi::AsyncStackful => None, 893 }; 894 wat.push_str("))\n"); 895 896 // If the export is async, generate `task.return` as an import as well 897 // which is necesary to communicate the results. 898 if export_results_loc.is_none() { 899 wat.push_str(&format!("(type $task.return (func")); 900 push_params(&mut wat, &result_lowered, MAX_FLAT_PARAMS); 901 wat.push_str("))\n"); 902 wat.push_str(&format!( 903 r#"(import "" "task.return" (func $task.return (type $task.return)))"# 904 )); 905 } 906 907 wat.push_str(&format!( 908 r#" 909 (func (export "{EXPORT_FUNCTION}") (type $export) 910 (local $retptr i32) 911 (local $argptr i32) 912 "# 913 )); 914 let mut store_helpers = IndexSet::new(); 915 let mut load_helpers = IndexSet::new(); 916 917 match (export_params_loc, import_params_loc) { 918 // flat => flat is just moving locals around 919 (Location::Flat, Location::Flat) => { 920 for (index, _) in params_lowered.iter().enumerate() { 921 uwrite!(wat, "local.get {index}\n"); 922 } 923 } 924 925 // indirect => indirect is just moving locals around 926 (Location::Indirect(i), Location::Indirect(j)) => { 927 assert_eq!(j, 0); 928 uwrite!(wat, "local.get {i}\n"); 929 } 930 931 // flat => indirect means that all parameters are stored in memory as 932 // if it was a record of all the parameters. 933 (Location::Flat, Location::Indirect(_)) => { 934 let SizeAndAlignment { size, alignment } = 935 record_size_and_alignment(params.iter().cloned()); 936 wat.push_str(&format!( 937 r#" 938 (local.set $argptr 939 (call $realloc 940 (i32.const 0) 941 (i32.const 0) 942 (i32.const {alignment}) 943 (i32.const {size}))) 944 local.get $argptr 945 "# 946 )); 947 let mut locals = (0..params_lowered.len() as u32).map(FlatSource::Local); 948 for (offset, ty) in record_field_offsets(params.iter().cloned()) { 949 ty.store_flat(&mut wat, "$argptr", offset, &mut locals, &mut store_helpers); 950 } 951 assert!(locals.next().is_none()); 952 } 953 954 (Location::Indirect(_), Location::Flat) => unreachable!(), 955 } 956 957 // Pass a return-pointer if necessary. 958 match import_results_loc { 959 Location::Flat => {} 960 Location::Indirect(_) => { 961 let SizeAndAlignment { size, alignment } = result.unwrap().size_and_alignment(); 962 963 wat.push_str(&format!( 964 r#" 965 (local.set $retptr 966 (call $realloc 967 (i32.const 0) 968 (i32.const 0) 969 (i32.const {alignment}) 970 (i32.const {size}))) 971 local.get $retptr 972 "# 973 )); 974 } 975 } 976 977 wat.push_str("call $host\n"); 978 979 // Assert the lowered call is ready if an async code was returned. 980 // 981 // TODO: handle when the import isn't ready yet 982 if let LowerAbi::Async = lower_abi { 983 wat.push_str("i32.const 2\n"); 984 wat.push_str("i32.ne\n"); 985 wat.push_str("if unreachable end\n"); 986 } 987 988 // TODO: conditionally inject a yield here 989 990 match (import_results_loc, export_results_loc) { 991 // flat => flat results involves nothing, the results are already on 992 // the stack. 993 (Location::Flat, Some(Location::Flat)) => {} 994 995 // indirect => indirect results requires returning the `$retptr` the 996 // host call filled in. 997 (Location::Indirect(_), Some(Location::Indirect(_))) => { 998 wat.push_str("local.get $retptr\n"); 999 } 1000 1001 // indirect => flat requires loading the result from the return pointer 1002 (Location::Indirect(_), Some(Location::Flat)) => { 1003 result 1004 .unwrap() 1005 .load_flat(&mut wat, "$retptr", 0, &mut load_helpers); 1006 } 1007 1008 // flat => task.return is easy, the results are already there so just 1009 // call the function. 1010 (Location::Flat, None) => { 1011 wat.push_str("call $task.return\n"); 1012 } 1013 1014 // indirect => task.return needs to forward `$retptr` if the results 1015 // are indirect, or otherwise it must be loaded from memory to a flat 1016 // representation. 1017 (Location::Indirect(_), None) => { 1018 if result_lowered.len() <= MAX_FLAT_PARAMS { 1019 result 1020 .unwrap() 1021 .load_flat(&mut wat, "$retptr", 0, &mut load_helpers); 1022 } else { 1023 wat.push_str("local.get $retptr\n"); 1024 } 1025 wat.push_str("call $task.return\n"); 1026 } 1027 1028 (Location::Flat, Some(Location::Indirect(_))) => unreachable!(), 1029 } 1030 1031 if let LiftAbi::AsyncCallback = lift_abi { 1032 wat.push_str("i32.const 0\n"); // completed status code 1033 } 1034 1035 wat.push_str(")\n"); 1036 1037 // Generate a `callback` function for the callback ABI. 1038 // 1039 // TODO: fill this in 1040 if let LiftAbi::AsyncCallback = lift_abi { 1041 wat.push_str( 1042 r#" 1043 (func (export "callback") (param i32 i32 i32) (result i32) unreachable) 1044 "#, 1045 ); 1046 } 1047 1048 // Fill out all store/load helpers that were needed during generation 1049 // above. This is a fix-point-loop since each helper may end up requiring 1050 // more helpers. 1051 let mut i = 0; 1052 while i < store_helpers.len() { 1053 let ty = store_helpers[i].0; 1054 ty.store_flat_helper(&mut wat, i, &mut store_helpers); 1055 i += 1; 1056 } 1057 i = 0; 1058 while i < load_helpers.len() { 1059 let ty = load_helpers[i].0; 1060 ty.load_flat_helper(&mut wat, i, &mut load_helpers); 1061 i += 1; 1062 } 1063 1064 return wat; 1065 1066 fn push_params(wat: &mut String, params: &[CoreType], max_flat: usize) -> (Location, u32) { 1067 push_group(wat, "param", params, max_flat) 1068 } 1069 1070 fn push_group( 1071 wat: &mut String, 1072 name: &str, 1073 params: &[CoreType], 1074 max_flat: usize, 1075 ) -> (Location, u32) { 1076 let mut nparams = 0; 1077 let loc = if params.is_empty() { 1078 // nothing to emit... 1079 Location::Flat 1080 } else if params.len() <= max_flat { 1081 wat.push_str(&format!(" ({name}")); 1082 for ty in params { 1083 wat.push_str(&format!(" {ty}")); 1084 nparams += 1; 1085 } 1086 wat.push_str(")"); 1087 Location::Flat 1088 } else { 1089 wat.push_str(&format!(" ({name} i32)")); 1090 nparams += 1; 1091 Location::Indirect(0) 1092 }; 1093 (loc, nparams) 1094 } 1095 1096 fn push_result_or_retptr( 1097 wat: &mut String, 1098 results: &[CoreType], 1099 nparams: u32, 1100 max_flat: usize, 1101 ) -> Location { 1102 if results.is_empty() { 1103 // nothing to emit... 1104 Location::Flat 1105 } else if results.len() <= max_flat { 1106 wat.push_str(" (result"); 1107 for ty in results { 1108 wat.push_str(&format!(" {ty}")); 1109 } 1110 wat.push_str(")"); 1111 Location::Flat 1112 } else { 1113 wat.push_str(" (param i32)"); 1114 Location::Indirect(nparams) 1115 } 1116 } 1117 } 1118 1119 struct Helper<'a>(&'a Type); 1120 1121 impl Hash for Helper<'_> { 1122 fn hash<H: Hasher>(&self, h: &mut H) { 1123 std::ptr::hash(self.0, h); 1124 } 1125 } 1126 1127 impl PartialEq for Helper<'_> { 1128 fn eq(&self, other: &Self) -> bool { 1129 std::ptr::eq(self.0, other.0) 1130 } 1131 } 1132 1133 impl Eq for Helper<'_> {} 1134 1135 fn make_rust_name(name_counter: &mut u32) -> Ident { 1136 let name = format_ident!("Foo{name_counter}"); 1137 *name_counter += 1; 1138 name 1139 } 1140 1141 /// Generate a [`TokenStream`] containing the rust type name for a type. 1142 /// 1143 /// The `name_counter` parameter is used to generate names for each recursively visited type. The `declarations` 1144 /// parameter is used to accumulate declarations for each recursively visited type. 1145 pub fn rust_type(ty: &Type, name_counter: &mut u32, declarations: &mut TokenStream) -> TokenStream { 1146 match ty { 1147 Type::Bool => quote!(bool), 1148 Type::S8 => quote!(i8), 1149 Type::U8 => quote!(u8), 1150 Type::S16 => quote!(i16), 1151 Type::U16 => quote!(u16), 1152 Type::S32 => quote!(i32), 1153 Type::U32 => quote!(u32), 1154 Type::S64 => quote!(i64), 1155 Type::U64 => quote!(u64), 1156 Type::Float32 => quote!(Float32), 1157 Type::Float64 => quote!(Float64), 1158 Type::Char => quote!(char), 1159 Type::String => quote!(Box<str>), 1160 Type::List(ty) => { 1161 let ty = rust_type(ty, name_counter, declarations); 1162 quote!(Vec<#ty>) 1163 } 1164 Type::Record(types) => { 1165 let fields = types 1166 .iter() 1167 .enumerate() 1168 .map(|(index, ty)| { 1169 let name = format_ident!("f{index}"); 1170 let ty = rust_type(ty, name_counter, declarations); 1171 quote!(#name: #ty,) 1172 }) 1173 .collect::<TokenStream>(); 1174 1175 let name = make_rust_name(name_counter); 1176 1177 declarations.extend(quote! { 1178 #[derive(ComponentType, Lift, Lower, PartialEq, Debug, Clone, Arbitrary)] 1179 #[component(record)] 1180 struct #name { 1181 #fields 1182 } 1183 }); 1184 1185 quote!(#name) 1186 } 1187 Type::Tuple(types) => { 1188 let fields = types 1189 .0 1190 .iter() 1191 .map(|ty| { 1192 let ty = rust_type(ty, name_counter, declarations); 1193 quote!(#ty,) 1194 }) 1195 .collect::<TokenStream>(); 1196 1197 quote!((#fields)) 1198 } 1199 Type::Variant(types) => { 1200 let cases = types 1201 .0 1202 .iter() 1203 .enumerate() 1204 .map(|(index, ty)| { 1205 let name = format_ident!("C{index}"); 1206 let ty = match ty { 1207 Some(ty) => { 1208 let ty = rust_type(ty, name_counter, declarations); 1209 quote!((#ty)) 1210 } 1211 None => quote!(), 1212 }; 1213 quote!(#name #ty,) 1214 }) 1215 .collect::<TokenStream>(); 1216 1217 let name = make_rust_name(name_counter); 1218 declarations.extend(quote! { 1219 #[derive(ComponentType, Lift, Lower, PartialEq, Debug, Clone, Arbitrary)] 1220 #[component(variant)] 1221 enum #name { 1222 #cases 1223 } 1224 }); 1225 1226 quote!(#name) 1227 } 1228 Type::Enum(count) => { 1229 let cases = (0..*count) 1230 .map(|index| { 1231 let name = format_ident!("E{index}"); 1232 quote!(#name,) 1233 }) 1234 .collect::<TokenStream>(); 1235 1236 let name = make_rust_name(name_counter); 1237 let repr = match DiscriminantSize::from_count(*count as usize).unwrap() { 1238 DiscriminantSize::Size1 => quote!(u8), 1239 DiscriminantSize::Size2 => quote!(u16), 1240 DiscriminantSize::Size4 => quote!(u32), 1241 }; 1242 1243 declarations.extend(quote! { 1244 #[derive(ComponentType, Lift, Lower, PartialEq, Debug, Copy, Clone, Arbitrary)] 1245 #[component(enum)] 1246 #[repr(#repr)] 1247 enum #name { 1248 #cases 1249 } 1250 }); 1251 1252 quote!(#name) 1253 } 1254 Type::Option(ty) => { 1255 let ty = rust_type(ty, name_counter, declarations); 1256 quote!(Option<#ty>) 1257 } 1258 Type::Result { ok, err } => { 1259 let ok = match ok { 1260 Some(ok) => rust_type(ok, name_counter, declarations), 1261 None => quote!(()), 1262 }; 1263 let err = match err { 1264 Some(err) => rust_type(err, name_counter, declarations), 1265 None => quote!(()), 1266 }; 1267 quote!(Result<#ok, #err>) 1268 } 1269 Type::Flags(count) => { 1270 let type_name = make_rust_name(name_counter); 1271 1272 let mut flags = TokenStream::new(); 1273 let mut names = TokenStream::new(); 1274 1275 for index in 0..*count { 1276 let name = format_ident!("F{index}"); 1277 flags.extend(quote!(const #name;)); 1278 names.extend(quote!(#type_name::#name,)) 1279 } 1280 1281 declarations.extend(quote! { 1282 wasmtime::component::flags! { 1283 #type_name { 1284 #flags 1285 } 1286 } 1287 1288 impl<'a> arbitrary::Arbitrary<'a> for #type_name { 1289 fn arbitrary(input: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> { 1290 let mut flags = #type_name::default(); 1291 for flag in [#names] { 1292 if input.arbitrary()? { 1293 flags |= flag; 1294 } 1295 } 1296 Ok(flags) 1297 } 1298 } 1299 }); 1300 1301 quote!(#type_name) 1302 } 1303 } 1304 } 1305 1306 #[derive(Default)] 1307 struct TypesBuilder<'a> { 1308 next: u32, 1309 worklist: Vec<(u32, &'a Type)>, 1310 } 1311 1312 impl<'a> TypesBuilder<'a> { 1313 fn write_ref(&mut self, ty: &'a Type, dst: &mut String) { 1314 match ty { 1315 // Primitive types can be referenced directly 1316 Type::Bool => dst.push_str("bool"), 1317 Type::S8 => dst.push_str("s8"), 1318 Type::U8 => dst.push_str("u8"), 1319 Type::S16 => dst.push_str("s16"), 1320 Type::U16 => dst.push_str("u16"), 1321 Type::S32 => dst.push_str("s32"), 1322 Type::U32 => dst.push_str("u32"), 1323 Type::S64 => dst.push_str("s64"), 1324 Type::U64 => dst.push_str("u64"), 1325 Type::Float32 => dst.push_str("float32"), 1326 Type::Float64 => dst.push_str("float64"), 1327 Type::Char => dst.push_str("char"), 1328 Type::String => dst.push_str("string"), 1329 1330 // Otherwise emit a reference to the type and remember to generate 1331 // the corresponding type alias later. 1332 Type::List(_) 1333 | Type::Record(_) 1334 | Type::Tuple(_) 1335 | Type::Variant(_) 1336 | Type::Enum(_) 1337 | Type::Option(_) 1338 | Type::Result { .. } 1339 | Type::Flags(_) => { 1340 let idx = self.next; 1341 self.next += 1; 1342 uwrite!(dst, "$t{idx}"); 1343 self.worklist.push((idx, ty)); 1344 } 1345 } 1346 } 1347 1348 fn write_decl(&mut self, idx: u32, ty: &'a Type) -> String { 1349 let mut decl = format!("(type $t{idx}' "); 1350 match ty { 1351 Type::Bool 1352 | Type::S8 1353 | Type::U8 1354 | Type::S16 1355 | Type::U16 1356 | Type::S32 1357 | Type::U32 1358 | Type::S64 1359 | Type::U64 1360 | Type::Float32 1361 | Type::Float64 1362 | Type::Char 1363 | Type::String => unreachable!(), 1364 1365 Type::List(ty) => { 1366 decl.push_str("(list "); 1367 self.write_ref(ty, &mut decl); 1368 decl.push_str(")"); 1369 } 1370 Type::Record(types) => { 1371 decl.push_str("(record"); 1372 for (index, ty) in types.iter().enumerate() { 1373 uwrite!(decl, r#" (field "f{index}" "#); 1374 self.write_ref(ty, &mut decl); 1375 decl.push_str(")"); 1376 } 1377 decl.push_str(")"); 1378 } 1379 Type::Tuple(types) => { 1380 decl.push_str("(tuple"); 1381 for ty in types.iter() { 1382 decl.push_str(" "); 1383 self.write_ref(ty, &mut decl); 1384 } 1385 decl.push_str(")"); 1386 } 1387 Type::Variant(types) => { 1388 decl.push_str("(variant"); 1389 for (index, ty) in types.iter().enumerate() { 1390 uwrite!(decl, r#" (case "C{index}""#); 1391 if let Some(ty) = ty { 1392 decl.push_str(" "); 1393 self.write_ref(ty, &mut decl); 1394 } 1395 decl.push_str(")"); 1396 } 1397 decl.push_str(")"); 1398 } 1399 Type::Enum(count) => { 1400 decl.push_str("(enum"); 1401 for index in 0..*count { 1402 uwrite!(decl, r#" "E{index}""#); 1403 } 1404 decl.push_str(")"); 1405 } 1406 Type::Option(ty) => { 1407 decl.push_str("(option "); 1408 self.write_ref(ty, &mut decl); 1409 decl.push_str(")"); 1410 } 1411 Type::Result { ok, err } => { 1412 decl.push_str("(result"); 1413 if let Some(ok) = ok { 1414 decl.push_str(" "); 1415 self.write_ref(ok, &mut decl); 1416 } 1417 if let Some(err) = err { 1418 decl.push_str(" (error "); 1419 self.write_ref(err, &mut decl); 1420 decl.push_str(")"); 1421 } 1422 decl.push_str(")"); 1423 } 1424 Type::Flags(count) => { 1425 decl.push_str("(flags"); 1426 for index in 0..*count { 1427 uwrite!(decl, r#" "F{index}""#); 1428 } 1429 decl.push_str(")"); 1430 } 1431 } 1432 decl.push_str(")\n"); 1433 uwriteln!(decl, "(import \"t{idx}\" (type $t{idx} (eq $t{idx}')))"); 1434 decl 1435 } 1436 } 1437 1438 /// Represents custom fragments of a WAT file which may be used to create a component for exercising [`TestCase`]s 1439 #[derive(Debug)] 1440 pub struct Declarations { 1441 /// Type declarations (if any) referenced by `params` and/or `result` 1442 pub types: Cow<'static, str>, 1443 /// Types to thread through when instantiating sub-components. 1444 pub type_instantiation_args: Cow<'static, str>, 1445 /// Parameter declarations used for the imported and exported functions 1446 pub params: Cow<'static, str>, 1447 /// Result declaration used for the imported and exported functions 1448 pub results: Cow<'static, str>, 1449 /// Implementation of the "caller" component, which invokes the `callee` 1450 /// composed component. 1451 pub caller_module: Cow<'static, str>, 1452 /// Implementation of the "callee" component, which invokes the host. 1453 pub callee_module: Cow<'static, str>, 1454 /// Options used for caller/calle ABI/etc. 1455 pub options: TestCaseOptions, 1456 } 1457 1458 impl Declarations { 1459 /// Generate a complete WAT file based on the specified fragments. 1460 pub fn make_component(&self) -> Box<str> { 1461 let Self { 1462 types, 1463 type_instantiation_args, 1464 params, 1465 results, 1466 caller_module, 1467 callee_module, 1468 options, 1469 } = self; 1470 let mk_component = |name: &str, 1471 module: &str, 1472 import_async: bool, 1473 export_async: bool, 1474 encoding: StringEncoding, 1475 lift_abi: LiftAbi, 1476 lower_abi: LowerAbi| { 1477 let import_async = if import_async { "async" } else { "" }; 1478 let export_async = if export_async { "async" } else { "" }; 1479 let lower_async_option = match lower_abi { 1480 LowerAbi::Sync => "", 1481 LowerAbi::Async => "async", 1482 }; 1483 let lift_async_option = match lift_abi { 1484 LiftAbi::Sync => "", 1485 LiftAbi::AsyncStackful => "async", 1486 LiftAbi::AsyncCallback => "async (callback (func $i \"callback\"))", 1487 }; 1488 1489 let mut intrinsic_defs = String::new(); 1490 let mut intrinsic_imports = String::new(); 1491 1492 match lift_abi { 1493 LiftAbi::Sync => {} 1494 LiftAbi::AsyncCallback | LiftAbi::AsyncStackful => { 1495 intrinsic_defs.push_str(&format!( 1496 r#" 1497 (core func $task.return (canon task.return {results} 1498 (memory $libc "memory") string-encoding={encoding})) 1499 "#, 1500 )); 1501 intrinsic_imports.push_str( 1502 r#" 1503 (with "" (instance (export "task.return" (func $task.return)))) 1504 "#, 1505 ); 1506 } 1507 } 1508 1509 format!( 1510 r#" 1511 (component ${name} 1512 {types} 1513 (type $import_sig (func {import_async} {params} {results})) 1514 (type $export_sig (func {export_async} {params} {results})) 1515 (import "{IMPORT_FUNCTION}" (func $f (type $import_sig))) 1516 1517 (core instance $libc (instantiate $libc)) 1518 1519 (core func $f_lower (canon lower 1520 (func $f) 1521 (memory $libc "memory") 1522 (realloc (func $libc "realloc")) 1523 string-encoding={encoding} 1524 {lower_async_option} 1525 )) 1526 1527 {intrinsic_defs} 1528 1529 (core module $m 1530 (memory (import "libc" "memory") 1) 1531 (func $realloc (import "libc" "realloc") (param i32 i32 i32 i32) (result i32)) 1532 1533 {module} 1534 ) 1535 1536 (core instance $i (instantiate $m 1537 (with "libc" (instance $libc)) 1538 (with "host" (instance (export "{IMPORT_FUNCTION}" (func $f_lower)))) 1539 {intrinsic_imports} 1540 )) 1541 1542 (func (export "{EXPORT_FUNCTION}") (type $export_sig) 1543 (canon lift 1544 (core func $i "{EXPORT_FUNCTION}") 1545 (memory $libc "memory") 1546 (realloc (func $libc "realloc")) 1547 string-encoding={encoding} 1548 {lift_async_option} 1549 ) 1550 ) 1551 ) 1552 "# 1553 ) 1554 }; 1555 1556 let c1 = mk_component( 1557 "callee", 1558 &callee_module, 1559 options.host_async, 1560 options.guest_callee_async, 1561 options.callee_encoding, 1562 options.callee_lift_abi, 1563 options.callee_lower_abi, 1564 ); 1565 let c2 = mk_component( 1566 "caller", 1567 &caller_module, 1568 options.guest_callee_async, 1569 options.guest_caller_async, 1570 options.caller_encoding, 1571 options.caller_lift_abi, 1572 options.caller_lower_abi, 1573 ); 1574 let host_async = if options.host_async { "async" } else { "" }; 1575 1576 format!( 1577 r#" 1578 (component 1579 (core module $libc 1580 (memory (export "memory") 1) 1581 {REALLOC_AND_FREE} 1582 ) 1583 1584 1585 {types} 1586 1587 (type $host_sig (func {host_async} {params} {results})) 1588 (import "{IMPORT_FUNCTION}" (func $f (type $host_sig))) 1589 1590 {c1} 1591 {c2} 1592 (instance $c1 (instantiate $callee 1593 {type_instantiation_args} 1594 (with "{IMPORT_FUNCTION}" (func $f)) 1595 )) 1596 (instance $c2 (instantiate $caller 1597 {type_instantiation_args} 1598 (with "{IMPORT_FUNCTION}" (func $c1 "{EXPORT_FUNCTION}")) 1599 )) 1600 (export "{EXPORT_FUNCTION}" (func $c2 "{EXPORT_FUNCTION}")) 1601 )"#, 1602 ) 1603 .into() 1604 } 1605 } 1606 1607 /// Represents a test case for calling a component function 1608 #[derive(Debug)] 1609 pub struct TestCase<'a> { 1610 /// The types of parameters to pass to the function 1611 pub params: Vec<&'a Type>, 1612 /// The result types of the function 1613 pub result: Option<&'a Type>, 1614 /// ABI options to use for this test case. 1615 pub options: TestCaseOptions, 1616 } 1617 1618 /// Collection of options which configure how the caller/callee/etc ABIs are 1619 /// all configured. 1620 #[derive(Debug, Arbitrary, Copy, Clone)] 1621 pub struct TestCaseOptions { 1622 /// Whether or not the guest caller component (the entrypoint) is using an 1623 /// `async` function type. 1624 pub guest_caller_async: bool, 1625 /// Whether or not the guest callee component (what the entrypoint calls) 1626 /// is using an `async` function type. 1627 pub guest_callee_async: bool, 1628 /// Whether or not the host is using an async function type (what the 1629 /// guest callee calls). 1630 pub host_async: bool, 1631 /// The string encoding of the caller component. 1632 pub caller_encoding: StringEncoding, 1633 /// The string encoding of the callee component. 1634 pub callee_encoding: StringEncoding, 1635 /// The ABI that the caller component is using to lift its export (the main 1636 /// entrypoint). 1637 pub caller_lift_abi: LiftAbi, 1638 /// The ABI that the callee component is using to lift its export (called 1639 /// by the caller). 1640 pub callee_lift_abi: LiftAbi, 1641 /// The ABI that the caller component is using to lower its import (the 1642 /// callee's export). 1643 pub caller_lower_abi: LowerAbi, 1644 /// The ABI that the callee component is using to lower its import (the 1645 /// host function). 1646 pub callee_lower_abi: LowerAbi, 1647 } 1648 1649 #[derive(Debug, Arbitrary, Copy, Clone)] 1650 pub enum LiftAbi { 1651 Sync, 1652 AsyncStackful, 1653 AsyncCallback, 1654 } 1655 1656 #[derive(Debug, Arbitrary, Copy, Clone)] 1657 pub enum LowerAbi { 1658 Sync, 1659 Async, 1660 } 1661 1662 impl<'a> TestCase<'a> { 1663 pub fn generate(types: &'a [Type], u: &mut Unstructured<'_>) -> arbitrary::Result<Self> { 1664 let max_params = if types.len() > 0 { 5 } else { 0 }; 1665 let params = (0..u.int_in_range(0..=max_params)?) 1666 .map(|_| u.choose(&types)) 1667 .collect::<arbitrary::Result<Vec<_>>>()?; 1668 let result = if types.len() > 0 && u.arbitrary()? { 1669 Some(u.choose(&types)?) 1670 } else { 1671 None 1672 }; 1673 1674 let mut options = u.arbitrary::<TestCaseOptions>()?; 1675 1676 // Sync tasks cannot call async functions via a sync lower, nor can they 1677 // block in other ways (e.g. by calling `waitable-set.wait`, returning 1678 // `CALLBACK_CODE_WAIT`, etc.) prior to returning. Therefore, 1679 // async-ness cascades to the callers: 1680 if options.host_async { 1681 options.guest_callee_async = true; 1682 } 1683 if options.guest_callee_async { 1684 options.guest_caller_async = true; 1685 } 1686 1687 Ok(Self { 1688 params, 1689 result, 1690 options, 1691 }) 1692 } 1693 1694 /// Generate a `Declarations` for this `TestCase` which may be used to build a component to execute the case. 1695 pub fn declarations(&self) -> Declarations { 1696 let mut builder = TypesBuilder::default(); 1697 1698 let mut params = String::new(); 1699 for (i, ty) in self.params.iter().enumerate() { 1700 params.push_str(&format!(" (param \"p{i}\" ")); 1701 builder.write_ref(ty, &mut params); 1702 params.push_str(")"); 1703 } 1704 1705 let mut results = String::new(); 1706 if let Some(ty) = self.result { 1707 results.push_str(&format!(" (result ")); 1708 builder.write_ref(ty, &mut results); 1709 results.push_str(")"); 1710 } 1711 1712 let caller_module = make_import_and_export( 1713 &self.params, 1714 self.result, 1715 self.options.caller_lift_abi, 1716 self.options.caller_lower_abi, 1717 ); 1718 let callee_module = make_import_and_export( 1719 &self.params, 1720 self.result, 1721 self.options.callee_lift_abi, 1722 self.options.callee_lower_abi, 1723 ); 1724 1725 let mut type_decls = Vec::new(); 1726 let mut type_instantiation_args = String::new(); 1727 while let Some((idx, ty)) = builder.worklist.pop() { 1728 type_decls.push(builder.write_decl(idx, ty)); 1729 uwriteln!(type_instantiation_args, "(with \"t{idx}\" (type $t{idx}))"); 1730 } 1731 1732 // Note that types are printed here in reverse order since they were 1733 // pushed onto `type_decls` as they were referenced meaning the last one 1734 // is the "base" one. 1735 let mut types = String::new(); 1736 for decl in type_decls.into_iter().rev() { 1737 types.push_str(&decl); 1738 types.push_str("\n"); 1739 } 1740 1741 Declarations { 1742 types: types.into(), 1743 type_instantiation_args: type_instantiation_args.into(), 1744 params: params.into(), 1745 results: results.into(), 1746 caller_module: caller_module.into(), 1747 callee_module: callee_module.into(), 1748 options: self.options, 1749 } 1750 } 1751 } 1752 1753 #[derive(Copy, Clone, Debug, Arbitrary)] 1754 pub enum StringEncoding { 1755 Utf8, 1756 Utf16, 1757 Latin1OrUtf16, 1758 } 1759 1760 impl fmt::Display for StringEncoding { 1761 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 1762 match self { 1763 StringEncoding::Utf8 => fmt::Display::fmt(&"utf8", f), 1764 StringEncoding::Utf16 => fmt::Display::fmt(&"utf16", f), 1765 StringEncoding::Latin1OrUtf16 => fmt::Display::fmt(&"latin1+utf16", f), 1766 } 1767 } 1768 } 1769 1770 impl ToTokens for TestCaseOptions { 1771 fn to_tokens(&self, tokens: &mut TokenStream) { 1772 let TestCaseOptions { 1773 guest_caller_async, 1774 guest_callee_async, 1775 host_async, 1776 caller_encoding, 1777 callee_encoding, 1778 caller_lift_abi, 1779 callee_lift_abi, 1780 caller_lower_abi, 1781 callee_lower_abi, 1782 } = self; 1783 tokens.extend(quote!(wasmtime_test_util::component_fuzz::TestCaseOptions { 1784 guest_caller_async: #guest_caller_async, 1785 guest_callee_async: #guest_callee_async, 1786 host_async: #host_async, 1787 caller_encoding: #caller_encoding, 1788 callee_encoding: #callee_encoding, 1789 caller_lift_abi: #caller_lift_abi, 1790 callee_lift_abi: #callee_lift_abi, 1791 caller_lower_abi: #caller_lower_abi, 1792 callee_lower_abi: #callee_lower_abi, 1793 })); 1794 } 1795 } 1796 1797 impl ToTokens for LowerAbi { 1798 fn to_tokens(&self, tokens: &mut TokenStream) { 1799 let me = match self { 1800 LowerAbi::Sync => quote!(Sync), 1801 LowerAbi::Async => quote!(Async), 1802 }; 1803 tokens.extend(quote!(wasmtime_test_util::component_fuzz::LowerAbi::#me)); 1804 } 1805 } 1806 1807 impl ToTokens for LiftAbi { 1808 fn to_tokens(&self, tokens: &mut TokenStream) { 1809 let me = match self { 1810 LiftAbi::Sync => quote!(Sync), 1811 LiftAbi::AsyncCallback => quote!(AsyncCallback), 1812 LiftAbi::AsyncStackful => quote!(AsyncStackful), 1813 }; 1814 tokens.extend(quote!(wasmtime_test_util::component_fuzz::LiftAbi::#me)); 1815 } 1816 } 1817 1818 impl ToTokens for StringEncoding { 1819 fn to_tokens(&self, tokens: &mut TokenStream) { 1820 let me = match self { 1821 StringEncoding::Utf8 => quote!(Utf8), 1822 StringEncoding::Utf16 => quote!(Utf16), 1823 StringEncoding::Latin1OrUtf16 => quote!(Latin1OrUtf16), 1824 }; 1825 tokens.extend(quote!(wasmtime_test_util::component_fuzz::StringEncoding::#me)); 1826 } 1827 } 1828 1829 #[cfg(test)] 1830 mod tests { 1831 use super::*; 1832 1833 #[test] 1834 fn arbtest() { 1835 arbtest::arbtest(|u| { 1836 let mut fuel = 100; 1837 let types = (0..5) 1838 .map(|_| Type::generate(u, 3, &mut fuel)) 1839 .collect::<arbitrary::Result<Vec<_>>>()?; 1840 let case = TestCase::generate(&types, u)?; 1841 let decls = case.declarations(); 1842 let component = decls.make_component(); 1843 let wasm = wat::parse_str(&component).unwrap_or_else(|e| { 1844 panic!("failed to parse generated component as wat: {e}\n\n{component}"); 1845 }); 1846 wasmparser::Validator::new_with_features(wasmparser::WasmFeatures::all()) 1847 .validate_all(&wasm) 1848 .unwrap_or_else(|e| { 1849 let mut wat = String::new(); 1850 let mut dst = wasmprinter::PrintFmtWrite(&mut wat); 1851 let to_print = if wasmprinter::Config::new() 1852 .print_offsets(true) 1853 .print_operand_stack(true) 1854 .print(&wasm, &mut dst) 1855 .is_ok() 1856 { 1857 &wat[..] 1858 } else { 1859 &component[..] 1860 }; 1861 panic!("generated component is not valid wasm: {e}\n\n{to_print}"); 1862 }); 1863 Ok(()) 1864 }) 1865 .budget_ms(1_000) 1866 // .seed(0x3c9050d4000000e9) 1867 ; 1868 } 1869 } 1870