1 use crate::codegen_settings::{CodegenSettings, ErrorType};
2 use crate::lifetimes::anon_lifetime;
3 use crate::module_trait::passed_by_reference;
4 use crate::names;
5 use crate::types::WiggleType;
6 use proc_macro2::{Ident, Span, TokenStream};
7 use quote::quote;
8 use std::mem;
9 use witx::Instruction;
10 
define_func( module: &witx::Module, func: &witx::InterfaceFunc, settings: &CodegenSettings, ) -> TokenStream11 pub fn define_func(
12     module: &witx::Module,
13     func: &witx::InterfaceFunc,
14     settings: &CodegenSettings,
15 ) -> TokenStream {
16     let (ts, _bounds) = _define_func(module, func, settings);
17     ts
18 }
19 
func_bounds( module: &witx::Module, func: &witx::InterfaceFunc, settings: &CodegenSettings, ) -> Vec<Ident>20 pub fn func_bounds(
21     module: &witx::Module,
22     func: &witx::InterfaceFunc,
23     settings: &CodegenSettings,
24 ) -> Vec<Ident> {
25     let (_ts, bounds) = _define_func(module, func, settings);
26     bounds
27 }
28 
_define_func( module: &witx::Module, func: &witx::InterfaceFunc, settings: &CodegenSettings, ) -> (TokenStream, Vec<Ident>)29 fn _define_func(
30     module: &witx::Module,
31     func: &witx::InterfaceFunc,
32     settings: &CodegenSettings,
33 ) -> (TokenStream, Vec<Ident>) {
34     let ident = names::func(&func.name);
35 
36     let (wasm_params, wasm_results) = func.wasm_signature();
37     let param_names = (0..wasm_params.len())
38         .map(|i| Ident::new(&format!("arg{i}"), Span::call_site()))
39         .collect::<Vec<_>>();
40     let abi_params = wasm_params.iter().zip(&param_names).map(|(arg, name)| {
41         let wasm = names::wasm_type(*arg);
42         quote!(#name : #wasm)
43     });
44 
45     let abi_ret = match wasm_results.len() {
46         0 => quote!(()),
47         1 => {
48             let ty = names::wasm_type(wasm_results[0]);
49             quote!(#ty)
50         }
51         _ => unimplemented!(),
52     };
53 
54     let mut body = TokenStream::new();
55     let mut bounds = vec![names::trait_name(&module.name)];
56     func.call_interface(
57         &module.name,
58         &mut Rust {
59             src: &mut body,
60             params: &param_names,
61             block_storage: Vec::new(),
62             blocks: Vec::new(),
63             module,
64             funcname: func.name.as_str(),
65             settings,
66             bounds: &mut bounds,
67         },
68     );
69 
70     let mod_name = &module.name.as_str();
71     let func_name = &func.name.as_str();
72     let mk_span = quote!(
73         let _span = wiggle::tracing::span!(
74             wiggle::tracing::Level::TRACE,
75             "wiggle abi",
76             module = #mod_name,
77             function = #func_name
78         );
79     );
80     let ctx_type = if settings.mutable {
81         quote!(&mut)
82     } else {
83         quote!(&)
84     };
85     if settings.get_async(&module, &func).is_sync() {
86         let traced_body = if settings.tracing.enabled_for(&mod_name, &func_name) {
87             quote!(
88                 #mk_span
89                 _span.in_scope(|| {
90                   #body
91                 })
92             )
93         } else {
94             quote!(#body)
95         };
96         (
97             quote!(
98                 #[allow(unreachable_code)] // deals with warnings in noreturn functions
99                 pub fn #ident(
100                     ctx: #ctx_type (impl #(#bounds)+*),
101                     memory: &mut wiggle::GuestMemory<'_>,
102                     #(#abi_params),*
103                 ) -> wiggle::error::Result<#abi_ret> {
104                     use std::convert::TryFrom as _;
105                     #traced_body
106                 }
107             ),
108             bounds,
109         )
110     } else {
111         let traced_body = if settings.tracing.enabled_for(&mod_name, &func_name) {
112             quote!(
113                 use wiggle::tracing::Instrument as _;
114                 #mk_span
115                 async move {
116                     #body
117                 }.instrument(_span).await
118             )
119         } else {
120             quote!(
121                 #body
122             )
123         };
124         (
125             quote!(
126                 #[allow(unreachable_code)] // deals with warnings in noreturn functions
127                 pub async fn #ident(
128                     ctx: #ctx_type (impl #(#bounds)+*),
129                     memory: &mut wiggle::GuestMemory<'_>,
130                     #(#abi_params),*
131                 ) -> wiggle::error::Result<#abi_ret> {
132                     use std::convert::TryFrom as _;
133                     #traced_body
134                 }
135             ),
136             bounds,
137         )
138     }
139 }
140 
141 struct Rust<'a> {
142     src: &'a mut TokenStream,
143     params: &'a [Ident],
144     block_storage: Vec<TokenStream>,
145     blocks: Vec<TokenStream>,
146     module: &'a witx::Module,
147     funcname: &'a str,
148     settings: &'a CodegenSettings,
149     bounds: &'a mut Vec<Ident>,
150 }
151 
152 impl Rust<'_> {
bound(&mut self, i: Ident)153     fn bound(&mut self, i: Ident) {
154         if !self.bounds.contains(&i) {
155             self.bounds.push(i);
156         }
157     }
158 }
159 
160 impl witx::Bindgen for Rust<'_> {
161     type Operand = TokenStream;
162 
push_block(&mut self)163     fn push_block(&mut self) {
164         let prev = mem::replace(self.src, TokenStream::new());
165         self.block_storage.push(prev);
166     }
167 
finish_block(&mut self, operand: Option<TokenStream>)168     fn finish_block(&mut self, operand: Option<TokenStream>) {
169         let to_restore = self.block_storage.pop().unwrap();
170         let src = mem::replace(self.src, to_restore);
171         match operand {
172             None => self.blocks.push(src),
173             Some(s) => {
174                 if src.is_empty() {
175                     self.blocks.push(s);
176                 } else {
177                     self.blocks.push(quote!({ #src; #s }));
178                 }
179             }
180         }
181     }
182 
183     // This is only used for `call_wasm` at this time.
allocate_space(&mut self, _: usize, _: &witx::NamedType)184     fn allocate_space(&mut self, _: usize, _: &witx::NamedType) {
185         unimplemented!()
186     }
187 
emit( &mut self, inst: &Instruction<'_>, operands: &mut Vec<TokenStream>, results: &mut Vec<TokenStream>, )188     fn emit(
189         &mut self,
190         inst: &Instruction<'_>,
191         operands: &mut Vec<TokenStream>,
192         results: &mut Vec<TokenStream>,
193     ) {
194         let wrap_err = |location: &str| {
195             let modulename = self.module.name.as_str();
196             let funcname = self.funcname;
197             quote! {
198                 |e| {
199                     wiggle::GuestError::InFunc {
200                         modulename: #modulename,
201                         funcname: #funcname,
202                         location: #location,
203                         err: Box::new(wiggle::GuestError::from(e)),
204                     }
205                 }
206             }
207         };
208 
209         let mut try_from = |ty: TokenStream| {
210             let val = operands.pop().unwrap();
211             let wrap_err = wrap_err(&format!("convert {ty}"));
212             results.push(quote!(#ty::try_from(#val).map_err(#wrap_err)?));
213         };
214 
215         match inst {
216             Instruction::GetArg { nth } => {
217                 let param = &self.params[*nth];
218                 results.push(quote!(#param));
219             }
220 
221             Instruction::PointerFromI32 { ty } | Instruction::ConstPointerFromI32 { ty } => {
222                 let val = operands.pop().unwrap();
223                 let pointee_type = names::type_ref(ty, anon_lifetime());
224                 results.push(quote! {
225                     wiggle::GuestPtr::<#pointee_type>::new(#val as u32)
226                 });
227             }
228 
229             Instruction::ListFromPointerLength { ty } => {
230                 let ptr = &operands[0];
231                 let len = &operands[1];
232                 let ty = match &**ty.type_() {
233                     witx::Type::Builtin(witx::BuiltinType::Char) => quote!(str),
234                     _ => {
235                         let ty = names::type_ref(ty, anon_lifetime());
236                         quote!([#ty])
237                     }
238                 };
239                 results.push(quote! {
240                     wiggle::GuestPtr::<#ty>::new((#ptr as u32, #len as u32));
241                 })
242             }
243 
244             Instruction::CallInterface { func, .. } => {
245                 // Use the `tracing` crate to log all arguments that are going
246                 // out, and afterwards we call the function with those bindings.
247                 let mut args = Vec::new();
248                 for (i, param) in func.params.iter().enumerate() {
249                     let name = names::func_param(&param.name);
250                     let val = &operands[i];
251                     self.src.extend(quote!(let #name = #val;));
252                     if passed_by_reference(param.tref.type_()) {
253                         args.push(quote!(&#name));
254                     } else {
255                         args.push(quote!(#name));
256                     }
257                 }
258                 if self
259                     .settings
260                     .tracing
261                     .enabled_for(self.module.name.as_str(), self.funcname)
262                     && func.params.len() > 0
263                 {
264                     let args = func
265                         .params
266                         .iter()
267                         .map(|param| {
268                             let name = names::func_param(&param.name);
269                             if param.impls_display() {
270                                 quote!( #name = wiggle::tracing::field::display(&#name) )
271                             } else {
272                                 quote!( #name = wiggle::tracing::field::debug(&#name) )
273                             }
274                         })
275                         .collect::<Vec<_>>();
276                     self.src.extend(quote! {
277                         wiggle::tracing::event!(wiggle::tracing::Level::TRACE, #(#args),*);
278                     });
279                 }
280 
281                 let trait_name = names::trait_name(&self.module.name);
282                 let ident = names::func(&func.name);
283                 if self.settings.get_async(&self.module, &func).is_sync() {
284                     self.src.extend(quote! {
285                         let ret = #trait_name::#ident(ctx, memory, #(#args),*);
286                     })
287                 } else {
288                     self.src.extend(quote! {
289                         let ret = #trait_name::#ident(ctx, memory, #(#args),*).await;
290                     })
291                 };
292                 if self
293                     .settings
294                     .tracing
295                     .enabled_for(self.module.name.as_str(), self.funcname)
296                 {
297                     self.src.extend(quote! {
298                         wiggle::tracing::event!(
299                             wiggle::tracing::Level::TRACE,
300                             result = wiggle::tracing::field::debug(&ret),
301                         );
302                     });
303                 }
304 
305                 if func.results.len() > 0 {
306                     results.push(quote!(ret));
307                 } else if func.noreturn {
308                     self.src.extend(quote!(return Err(ret);));
309                 }
310             }
311 
312             // Lowering an enum is typically simple but if we have an error
313             // transformation registered for this enum then what we're actually
314             // doing is lowering from a user-defined error type to the error
315             // enum, and *then* we lower to an i32.
316             Instruction::EnumLower { ty } => {
317                 let val = operands.pop().unwrap();
318                 let val = match self.settings.errors.for_name(ty) {
319                     Some(ErrorType::User(custom)) => {
320                         let method = names::user_error_conversion_method(&custom);
321                         self.bound(quote::format_ident!("UserErrorConversion"));
322                         quote!(UserErrorConversion::#method(ctx, #val)?)
323                     }
324                     Some(ErrorType::Generated(_)) => quote!(#val.downcast()?),
325                     None => val,
326                 };
327                 results.push(quote!(#val as i32));
328             }
329 
330             Instruction::ResultLower { err: err_ty, .. } => {
331                 let err = self.blocks.pop().unwrap();
332                 let ok = self.blocks.pop().unwrap();
333                 let val = operands.pop().unwrap();
334                 let err_typename = names::type_ref(err_ty.unwrap(), anon_lifetime());
335                 results.push(quote! {
336                     match #val {
337                         Ok(e) => { #ok; <#err_typename as wiggle::GuestErrorType>::success() as i32 }
338                         Err(e) => { #err }
339                     }
340                 });
341             }
342 
343             Instruction::VariantPayload => results.push(quote!(e)),
344 
345             Instruction::Return { amt: 0 } => {
346                 self.src.extend(quote!(return Ok(())));
347             }
348             Instruction::Return { amt: 1 } => {
349                 let val = operands.pop().unwrap();
350                 self.src.extend(quote!(return Ok(#val)));
351             }
352             Instruction::Return { .. } => unimplemented!(),
353 
354             Instruction::TupleLower { amt } => {
355                 let names = (0..*amt)
356                     .map(|i| Ident::new(&format!("t{i}"), Span::call_site()))
357                     .collect::<Vec<_>>();
358                 let val = operands.pop().unwrap();
359                 self.src.extend(quote!( let (#(#names,)*) = #val;));
360                 results.extend(names.iter().map(|i| quote!(#i)));
361             }
362 
363             Instruction::Store { ty } => {
364                 let ptr = operands.pop().unwrap();
365                 let val = operands.pop().unwrap();
366                 let wrap_err = wrap_err(&format!("write {}", ty.name.as_str()));
367                 let pointee_type = names::type_(&ty.name);
368                 self.src.extend(quote! {
369                     memory.write(
370                         wiggle::GuestPtr::<#pointee_type>::new(#ptr as u32),
371                         #val,
372                     )
373                     .map_err(#wrap_err)?;
374                 });
375             }
376 
377             Instruction::Load { ty } => {
378                 let ptr = operands.pop().unwrap();
379                 let wrap_err = wrap_err(&format!("read {}", ty.name.as_str()));
380                 let pointee_type = names::type_(&ty.name);
381                 results.push(quote! {
382                     memory.read(wiggle::GuestPtr::<#pointee_type>::new(#ptr as u32))
383                         .map_err(#wrap_err)?
384                 });
385             }
386 
387             Instruction::HandleFromI32 { ty } => {
388                 let val = operands.pop().unwrap();
389                 let ty = names::type_(&ty.name);
390                 results.push(quote!(#ty::from(#val)));
391             }
392 
393             // Smaller-than-32 numerical conversions are done with `TryFrom` to
394             // ensure we're not losing bits.
395             Instruction::U8FromI32 => try_from(quote!(u8)),
396             Instruction::S8FromI32 => try_from(quote!(i8)),
397             Instruction::Char8FromI32 => try_from(quote!(u8)),
398             Instruction::U16FromI32 => try_from(quote!(u16)),
399             Instruction::S16FromI32 => try_from(quote!(i16)),
400 
401             // Conversions with matching bit-widths but different signededness
402             // use `as` since we're basically just reinterpreting the bits.
403             Instruction::U32FromI32 | Instruction::UsizeFromI32 => {
404                 let val = operands.pop().unwrap();
405                 results.push(quote!(#val as u32));
406             }
407             Instruction::U64FromI64 => {
408                 let val = operands.pop().unwrap();
409                 results.push(quote!(#val as u64));
410             }
411 
412             // Conversions to enums/bitflags use `TryFrom` to ensure that the
413             // values are valid coming in.
414             Instruction::EnumLift { ty }
415             | Instruction::BitflagsFromI64 { ty }
416             | Instruction::BitflagsFromI32 { ty } => {
417                 let ty = names::type_(&ty.name);
418                 try_from(quote!(#ty))
419             }
420 
421             // No conversions necessary for these, the native wasm type matches
422             // our own representation.
423             Instruction::If32FromF32
424             | Instruction::If64FromF64
425             | Instruction::S32FromI32
426             | Instruction::S64FromI64 => results.push(operands.pop().unwrap()),
427 
428             // There's a number of other instructions we could implement but
429             // they're not exercised by WASI at this time. As necessary we can
430             // add code to implement them.
431             other => panic!("no implementation for {other:?}"),
432         }
433     }
434 }
435