1 use crate::{Ownership, types::TypeInfo};
2 use heck::*;
3 use wit_parser::*;
4 
5 #[derive(Debug, Copy, Clone, PartialEq)]
6 pub enum TypeMode {
7     Owned,
8     AllBorrowed(&'static str),
9 }
10 
11 pub trait RustGenerator<'a> {
12     fn resolve(&self) -> &'a Resolve;
13 
14     fn push_str(&mut self, s: &str);
15     fn info(&self, ty: TypeId) -> TypeInfo;
16     fn path_to_interface(&self, interface: InterfaceId) -> Option<String>;
17     fn is_imported_interface(&self, interface: InterfaceId) -> bool;
18     fn wasmtime_path(&self) -> String;
19 
20     /// This determines whether we generate owning types or (where appropriate)
21     /// borrowing types.
22     ///
23     /// For example, when generating a type which is only used as a parameter to
24     /// a guest-exported function, there is no need for it to own its fields.
25     /// However, constructing deeply-nested borrows (e.g. `&[&[&[&str]]]]` for
26     /// `list<list<list<string>>>`) can be very awkward, so by default we
27     /// generate owning types and use only shallow borrowing at the top level
28     /// inside function signatures.
29     fn ownership(&self) -> Ownership;
30 
31     fn print_ty(&mut self, ty: &Type, mode: TypeMode) {
32         self.push_str(&self.ty(ty, mode))
33     }
34     fn ty(&self, ty: &Type, mode: TypeMode) -> String {
35         match ty {
36             Type::Id(t) => self.tyid(*t, mode),
37             Type::Bool => "bool".to_string(),
38             Type::U8 => "u8".to_string(),
39             Type::U16 => "u16".to_string(),
40             Type::U32 => "u32".to_string(),
41             Type::U64 => "u64".to_string(),
42             Type::S8 => "i8".to_string(),
43             Type::S16 => "i16".to_string(),
44             Type::S32 => "i32".to_string(),
45             Type::S64 => "i64".to_string(),
46             Type::F32 => "f32".to_string(),
47             Type::F64 => "f64".to_string(),
48             Type::Char => "char".to_string(),
49             Type::String => match mode {
50                 TypeMode::AllBorrowed(lt) => {
51                     if lt != "'_" {
52                         format!("&{lt} str")
53                     } else {
54                         format!("&str")
55                     }
56                 }
57                 TypeMode::Owned => {
58                     let wt = self.wasmtime_path();
59                     format!("{wt}::component::__internal::String")
60                 }
61             },
62             Type::ErrorContext => {
63                 let wt = self.wasmtime_path();
64                 format!("{wt}::component::ErrorContext")
65             }
66         }
67     }
68 
69     fn print_optional_ty(&mut self, ty: Option<&Type>, mode: TypeMode) {
70         self.push_str(&self.optional_ty(ty, mode))
71     }
72     fn optional_ty(&self, ty: Option<&Type>, mode: TypeMode) -> String {
73         match ty {
74             Some(ty) => self.ty(ty, mode),
75             None => "()".to_string(),
76         }
77     }
78 
79     fn tyid(&self, id: TypeId, mode: TypeMode) -> String {
80         let info = self.info(id);
81         let lt = self.lifetime_for(&info, mode);
82         let ty = &self.resolve().types[id];
83         if ty.name.is_some() {
84             // If this type has a list internally, no lifetime is being printed,
85             // but we're in a borrowed mode, then that means we're in a borrowed
86             // context and don't want ownership of the type but we're using an
87             // owned type definition. Inject a `&` in front to indicate that, at
88             // the API level, ownership isn't required.
89             let mut out = String::new();
90             if info.has_list && lt.is_none() {
91                 if let TypeMode::AllBorrowed(lt) = mode {
92                     if lt != "'_" {
93                         out.push_str(&format!("&{lt} "))
94                     } else {
95                         out.push_str("&")
96                     }
97                 }
98             }
99             let name = if lt.is_some() {
100                 self.param_name(id)
101             } else {
102                 self.result_name(id)
103             };
104             out.push_str(&self.type_name_in_interface(ty.owner, &name));
105 
106             // If the type recursively owns data and it's a
107             // variant/record/list, then we need to place the
108             // lifetime parameter on the type as well.
109             if info.has_list && needs_generics(self.resolve(), &ty.kind) {
110                 out.push_str(&self.generics(lt));
111             }
112 
113             return out;
114 
115             fn needs_generics(resolve: &Resolve, ty: &TypeDefKind) -> bool {
116                 match ty {
117                     TypeDefKind::Variant(_)
118                     | TypeDefKind::Record(_)
119                     | TypeDefKind::Option(_)
120                     | TypeDefKind::Result(_)
121                     | TypeDefKind::Future(_)
122                     | TypeDefKind::Stream(_)
123                     | TypeDefKind::List(_)
124                     | TypeDefKind::Flags(_)
125                     | TypeDefKind::Enum(_)
126                     | TypeDefKind::Tuple(_)
127                     | TypeDefKind::Handle(_)
128                     | TypeDefKind::Resource => true,
129                     TypeDefKind::Type(Type::Id(t)) => {
130                         needs_generics(resolve, &resolve.types[*t].kind)
131                     }
132                     TypeDefKind::Type(Type::String) => true,
133                     TypeDefKind::Type(_) => false,
134                     TypeDefKind::Unknown => unreachable!(),
135                     TypeDefKind::FixedLengthList(..) => todo!(),
136                     TypeDefKind::Map(..) => todo!(),
137                 }
138             }
139         }
140 
141         match &ty.kind {
142             TypeDefKind::List(t) => self.list(t, mode),
143 
144             TypeDefKind::Option(t) => {
145                 format!("Option<{}>", self.ty(t, mode))
146             }
147 
148             TypeDefKind::Result(r) => {
149                 let ok = self.optional_ty(r.ok.as_ref(), mode);
150                 let err = self.optional_ty(r.err.as_ref(), mode);
151                 format!("Result<{ok},{err}>")
152             }
153 
154             TypeDefKind::Variant(_) => panic!("unsupported anonymous variant"),
155 
156             // Tuple-like records are mapped directly to Rust tuples of
157             // types. Note the trailing comma after each member to
158             // appropriately handle 1-tuples.
159             TypeDefKind::Tuple(t) => {
160                 let mut out = "(".to_string();
161                 for ty in t.types.iter() {
162                     out.push_str(&self.ty(ty, mode));
163                     out.push_str(",");
164                 }
165                 out.push_str(")");
166                 out
167             }
168             TypeDefKind::Record(_) => {
169                 panic!("unsupported anonymous type reference: record")
170             }
171             TypeDefKind::Flags(_) => {
172                 panic!("unsupported anonymous type reference: flags")
173             }
174             TypeDefKind::Enum(_) => {
175                 panic!("unsupported anonymous type reference: enum")
176             }
177             TypeDefKind::Future(ty) => {
178                 let wt = self.wasmtime_path();
179                 let t = self.optional_ty(ty.as_ref(), TypeMode::Owned);
180                 format!("{wt}::component::FutureReader<{t}>")
181             }
182             TypeDefKind::Stream(ty) => {
183                 let wt = self.wasmtime_path();
184                 let t = self.optional_ty(ty.as_ref(), TypeMode::Owned);
185                 format!("{wt}::component::StreamReader<{t}>")
186             }
187             TypeDefKind::Handle(handle) => self.handle(handle),
188             TypeDefKind::Resource => unreachable!(),
189 
190             TypeDefKind::Type(t) => self.ty(t, mode),
191             TypeDefKind::Unknown => unreachable!(),
192             TypeDefKind::FixedLengthList(..) => todo!(),
193             TypeDefKind::Map(..) => todo!(),
194         }
195     }
196 
197     fn type_name_in_interface(&self, owner: TypeOwner, name: &str) -> String {
198         let mut out = String::new();
199         if let TypeOwner::Interface(id) = owner {
200             if let Some(path) = self.path_to_interface(id) {
201                 out.push_str(&path);
202                 out.push_str("::");
203             }
204         }
205         out.push_str(name);
206         out
207     }
208 
209     fn print_list(&mut self, ty: &Type, mode: TypeMode) {
210         self.push_str(&self.list(ty, mode))
211     }
212     fn list(&self, ty: &Type, mode: TypeMode) -> String {
213         let next_mode = if matches!(self.ownership(), Ownership::Owning) {
214             TypeMode::Owned
215         } else {
216             mode
217         };
218         let ty = self.ty(ty, next_mode);
219         match mode {
220             TypeMode::AllBorrowed(lt) => {
221                 if lt != "'_" {
222                     format!("&{lt} [{ty}]")
223                 } else {
224                     format!("&[{ty}]")
225                 }
226             }
227             TypeMode::Owned => {
228                 let wt = self.wasmtime_path();
229                 format!("{wt}::component::__internal::Vec<{ty}>")
230             }
231         }
232     }
233 
234     fn print_stream(&mut self, ty: Option<&Type>) {
235         self.push_str(&self.stream(ty))
236     }
237     fn stream(&self, ty: Option<&Type>) -> String {
238         let wt = self.wasmtime_path();
239         let mut out = format!("{wt}::component::HostStream<");
240         out.push_str(&self.optional_ty(ty, TypeMode::Owned));
241         out.push_str(">");
242         out
243     }
244 
245     fn print_future(&mut self, ty: Option<&Type>) {
246         self.push_str(&self.future(ty))
247     }
248     fn future(&self, ty: Option<&Type>) -> String {
249         let wt = self.wasmtime_path();
250         let mut out = format!("{wt}::component::HostFuture<");
251         out.push_str(&self.optional_ty(ty, TypeMode::Owned));
252         out.push_str(">");
253         out
254     }
255 
256     fn print_handle(&mut self, handle: &Handle) {
257         self.push_str(&self.handle(handle))
258     }
259     fn handle(&self, handle: &Handle) -> String {
260         // Handles are either printed as `ResourceAny` for any guest-defined
261         // resource or `Resource<T>` for all host-defined resources. This means
262         // that this function needs to determine if `handle` points to a host
263         // or a guest resource which is determined by:
264         //
265         // * For world-owned resources, they're always imported.
266         // * For interface-owned resources, it depends on the how bindings were
267         //   last generated for this interface.
268         //
269         // Additionally type aliases via `use` are "peeled" here to find the
270         // original definition of the resource since that's the one that we
271         // care about for determining whether it's imported or not.
272         let resource = match handle {
273             Handle::Own(t) | Handle::Borrow(t) => *t,
274         };
275         let ty = &self.resolve().types[resource];
276         let def_id = super::resolve_type_definition_id(self.resolve(), resource);
277         let ty_def = &self.resolve().types[def_id];
278         let is_host_defined = match ty_def.owner {
279             TypeOwner::Interface(i) => self.is_imported_interface(i),
280             _ => true,
281         };
282         let wt = self.wasmtime_path();
283         if is_host_defined {
284             let mut out = format!("{wt}::component::Resource<");
285             out.push_str(&self.type_name_in_interface(
286                 ty.owner,
287                 &ty.name.as_ref().unwrap().to_upper_camel_case(),
288             ));
289             out.push_str(">");
290             out
291         } else {
292             format!("{wt}::component::ResourceAny")
293         }
294     }
295 
296     fn print_generics(&mut self, lifetime: Option<&str>) {
297         self.push_str(&self.generics(lifetime))
298     }
299     fn generics(&self, lifetime: Option<&str>) -> String {
300         if let Some(lt) = lifetime {
301             format!("<{lt},>")
302         } else {
303             String::new()
304         }
305     }
306 
307     fn modes_of(&self, ty: TypeId) -> Vec<(String, TypeMode)> {
308         let info = self.info(ty);
309         // Info only populated for types that are passed to and from functions. For
310         // types which are not, default to the ownership setting.
311         if !info.owned && !info.borrowed {
312             return vec![(
313                 self.param_name(ty),
314                 match self.ownership() {
315                     Ownership::Owning => TypeMode::Owned,
316                     Ownership::Borrowing { .. } => TypeMode::AllBorrowed("'a"),
317                 },
318             )];
319         }
320         let mut result = Vec::new();
321         let first_mode =
322             if info.owned || !info.borrowed || matches!(self.ownership(), Ownership::Owning) {
323                 TypeMode::Owned
324             } else {
325                 assert!(!self.uses_two_names(&info));
326                 TypeMode::AllBorrowed("'a")
327             };
328         result.push((self.result_name(ty), first_mode));
329         if self.uses_two_names(&info) {
330             result.push((self.param_name(ty), TypeMode::AllBorrowed("'a")));
331         }
332         result
333     }
334 
335     fn param_name(&self, ty: TypeId) -> String {
336         let info = self.info(ty);
337         let name = self.resolve().types[ty]
338             .name
339             .as_ref()
340             .unwrap()
341             .to_upper_camel_case();
342         if self.uses_two_names(&info) {
343             format!("{name}Param")
344         } else {
345             name
346         }
347     }
348 
349     fn result_name(&self, ty: TypeId) -> String {
350         let info = self.info(ty);
351         let name = self.resolve().types[ty]
352             .name
353             .as_ref()
354             .unwrap()
355             .to_upper_camel_case();
356         if self.uses_two_names(&info) {
357             format!("{name}Result")
358         } else {
359             name
360         }
361     }
362 
363     fn uses_two_names(&self, info: &TypeInfo) -> bool {
364         info.has_list
365             && info.borrowed
366             && info.owned
367             && matches!(
368                 self.ownership(),
369                 Ownership::Borrowing {
370                     duplicate_if_necessary: true
371                 }
372             )
373     }
374 
375     fn lifetime_for(&self, info: &TypeInfo, mode: TypeMode) -> Option<&'static str> {
376         if matches!(self.ownership(), Ownership::Owning) {
377             return None;
378         }
379         let lt = match mode {
380             TypeMode::AllBorrowed(s) => s,
381             _ => return None,
382         };
383         // No lifetimes needed unless this has a list.
384         if !info.has_list {
385             return None;
386         }
387         // If two names are used then this type will have an owned and a
388         // borrowed copy and the borrowed copy is being used, so it needs a
389         // lifetime. Otherwise if it's only borrowed and not owned then this can
390         // also use a lifetime since it's not needed in two contexts and only
391         // the borrowed version of the structure was generated.
392         if self.uses_two_names(info) || (info.borrowed && !info.owned) {
393             Some(lt)
394         } else {
395             None
396         }
397     }
398 
399     fn typedfunc_sig(&self, func: &Function, param_mode: TypeMode) -> String {
400         let mut out = "(".to_string();
401         for param in func.params.iter() {
402             out.push_str(&self.ty(&param.ty, param_mode));
403             out.push_str(", ");
404         }
405         out.push_str("), (");
406         if let Some(ty) = func.result {
407             out.push_str(&self.ty(&ty, TypeMode::Owned));
408             out.push_str(", ");
409         }
410         out.push_str(")");
411         out
412     }
413 }
414 
415 /// Translate `name` to a Rust `snake_case` identifier.
416 pub fn to_rust_ident(name: &str) -> String {
417     match name {
418         // Escape Rust keywords.
419         // Source: https://doc.rust-lang.org/reference/keywords.html
420         "as" => "as_".into(),
421         "break" => "break_".into(),
422         "const" => "const_".into(),
423         "continue" => "continue_".into(),
424         "crate" => "crate_".into(),
425         "else" => "else_".into(),
426         "enum" => "enum_".into(),
427         "extern" => "extern_".into(),
428         "false" => "false_".into(),
429         "fn" => "fn_".into(),
430         "for" => "for_".into(),
431         "if" => "if_".into(),
432         "impl" => "impl_".into(),
433         "in" => "in_".into(),
434         "let" => "let_".into(),
435         "loop" => "loop_".into(),
436         "match" => "match_".into(),
437         "mod" => "mod_".into(),
438         "move" => "move_".into(),
439         "mut" => "mut_".into(),
440         "pub" => "pub_".into(),
441         "ref" => "ref_".into(),
442         "return" => "return_".into(),
443         "self" => "self_".into(),
444         "static" => "static_".into(),
445         "struct" => "struct_".into(),
446         "super" => "super_".into(),
447         "trait" => "trait_".into(),
448         "true" => "true_".into(),
449         "type" => "type_".into(),
450         "unsafe" => "unsafe_".into(),
451         "use" => "use_".into(),
452         "where" => "where_".into(),
453         "while" => "while_".into(),
454         "async" => "async_".into(),
455         "await" => "await_".into(),
456         "dyn" => "dyn_".into(),
457         "abstract" => "abstract_".into(),
458         "become" => "become_".into(),
459         "box" => "box_".into(),
460         "do" => "do_".into(),
461         "final" => "final_".into(),
462         "macro" => "macro_".into(),
463         "override" => "override_".into(),
464         "priv" => "priv_".into(),
465         "typeof" => "typeof_".into(),
466         "unsized" => "unsized_".into(),
467         "virtual" => "virtual_".into(),
468         "yield" => "yield_".into(),
469         "try" => "try_".into(),
470         "gen" => "gen_".into(),
471         s => s.to_snake_case(),
472     }
473 }
474 
475 /// Translate `name` to a Rust `UpperCamelCase` identifier.
476 pub fn to_rust_upper_camel_case(name: &str) -> String {
477     match name {
478         // We use `Host` as the name of the trait for host implementations
479         // to fill in, so rename it if "Host" is used as a regular identifier.
480         "host" => "Host_".into(),
481         s => s.to_upper_camel_case(),
482     }
483 }
484