1 use proc_macro2::{Span, TokenStream};
2 use quote::ToTokens;
3 use std::collections::HashMap;
4 use std::env;
5 use std::path::{Path, PathBuf};
6 use std::sync::atomic::{AtomicUsize, Ordering::Relaxed};
7 use syn::parse::{Error, Parse, ParseStream, Result};
8 use syn::punctuated::Punctuated;
9 use syn::{Token, braced, token};
10 use wasmtime_wit_bindgen::{
11     FunctionConfig, FunctionFilter, FunctionFlags, Opts, Ownership, TrappableError,
12 };
13 use wit_parser::{PackageId, Resolve, UnresolvedPackageGroup, WorldId};
14 
15 pub struct Config {
16     opts: Opts,
17     resolve: Resolve,
18     world: WorldId,
19     files: Vec<PathBuf>,
20     include_generated_code_from_file: bool,
21 }
22 
expand(input: &Config) -> Result<TokenStream>23 pub fn expand(input: &Config) -> Result<TokenStream> {
24     let mut src = match input.opts.generate(&input.resolve, input.world) {
25         Ok(s) => s,
26         Err(e) => return Err(Error::new(Span::call_site(), e.to_string())),
27     };
28 
29     if input.opts.stringify {
30         return Ok(quote::quote!(#src));
31     }
32 
33     // If a magical `WASMTIME_DEBUG_BINDGEN` environment variable is set then
34     // place a formatted version of the expanded code into a file. This file
35     // will then show up in rustc error messages for any codegen issues and can
36     // be inspected manually.
37     if input.include_generated_code_from_file
38         || input.opts.debug
39         || std::env::var("WASMTIME_DEBUG_BINDGEN").is_ok()
40     {
41         static INVOCATION: AtomicUsize = AtomicUsize::new(0);
42         let root = Path::new(env!("DEBUG_OUTPUT_DIR"));
43         let world_name = &input.resolve.worlds[input.world].name;
44         let n = INVOCATION.fetch_add(1, Relaxed);
45         let path = root.join(format!("{world_name}{n}.rs"));
46 
47         std::fs::write(&path, &src).unwrap();
48 
49         // optimistically format the code but don't require success
50         drop(
51             std::process::Command::new("rustfmt")
52                 .arg(&path)
53                 .arg("--edition=2021")
54                 .output(),
55         );
56 
57         src = format!("include!({path:?});");
58     }
59     let mut contents = src.parse::<TokenStream>().unwrap();
60 
61     // Include a dummy `include_bytes!` for any files we read so rustc knows that
62     // we depend on the contents of those files.
63     for file in input.files.iter() {
64         contents.extend(
65             format!(
66                 "const _: &[u8] = include_bytes!(r#\"{}\"#);\n",
67                 file.display()
68             )
69             .parse::<TokenStream>()
70             .unwrap(),
71         );
72     }
73 
74     Ok(contents)
75 }
76 
77 impl Parse for Config {
parse(input: ParseStream<'_>) -> Result<Self>78     fn parse(input: ParseStream<'_>) -> Result<Self> {
79         let call_site = Span::call_site();
80         let mut opts = Opts::default();
81         let mut world = None;
82         let mut inline = None;
83         let mut paths = Vec::new();
84         let mut imports_configured = false;
85         let mut exports_configured = false;
86         let mut include_generated_code_from_file = false;
87 
88         if input.peek(token::Brace) {
89             let content;
90             syn::braced!(content in input);
91             let fields = Punctuated::<Opt, Token![,]>::parse_terminated(&content)?;
92             for field in fields.into_pairs() {
93                 match field.into_value() {
94                     Opt::Path(p) => {
95                         paths.extend(p.into_iter().map(|p| p.value()));
96                     }
97                     Opt::World(s) => {
98                         if world.is_some() {
99                             return Err(Error::new(s.span(), "cannot specify second world"));
100                         }
101                         world = Some(s.value());
102                     }
103                     Opt::Inline(s) => {
104                         if inline.is_some() {
105                             return Err(Error::new(s.span(), "cannot specify second source"));
106                         }
107                         inline = Some(s.value());
108                     }
109                     Opt::Debug(val) => opts.debug = val,
110                     Opt::TrappableErrorType(val) => opts.trappable_error_type = val,
111                     Opt::Ownership(val) => opts.ownership = val,
112                     Opt::Interfaces(s) => {
113                         if inline.is_some() {
114                             return Err(Error::new(s.span(), "cannot specify a second source"));
115                         }
116                         inline = Some(format!(
117                             "
118                                 package wasmtime:component-macro-synthesized;
119 
120                                 world interfaces {{
121                                     {}
122                                 }}
123                             ",
124                             s.value()
125                         ));
126 
127                         if world.is_some() {
128                             return Err(Error::new(
129                                 s.span(),
130                                 "cannot specify a world with `interfaces`",
131                             ));
132                         }
133                         world = Some("wasmtime:component-macro-synthesized/interfaces".to_string());
134 
135                         opts.only_interfaces = true;
136                     }
137                     Opt::With(val) => opts.with.extend(val),
138                     Opt::AdditionalDerives(paths) => {
139                         opts.additional_derive_attributes = paths
140                             .into_iter()
141                             .map(|p| p.into_token_stream().to_string())
142                             .collect()
143                     }
144                     Opt::Stringify(val) => opts.stringify = val,
145                     Opt::SkipMutForwardingImpls(val) => opts.skip_mut_forwarding_impls = val,
146                     Opt::RequireStoreDataSend(val) => opts.require_store_data_send = val,
147                     Opt::WasmtimeCrate(f) => {
148                         opts.wasmtime_crate = Some(f.into_token_stream().to_string())
149                     }
150                     Opt::Anyhow(val) => {
151                         opts.anyhow = val;
152                     }
153                     Opt::IncludeGeneratedCodeFromFile(i) => include_generated_code_from_file = i,
154                     Opt::Imports(config, span) => {
155                         if imports_configured {
156                             return Err(Error::new(span, "cannot specify imports configuration"));
157                         }
158                         opts.imports = config;
159                         imports_configured = true;
160                     }
161                     Opt::Exports(config, span) => {
162                         if exports_configured {
163                             return Err(Error::new(span, "cannot specify exports configuration"));
164                         }
165                         opts.exports = config;
166                         exports_configured = true;
167                     }
168                 }
169             }
170         } else {
171             world = input.parse::<Option<syn::LitStr>>()?.map(|s| s.value());
172             if input.parse::<Option<syn::token::In>>()?.is_some() {
173                 paths.push(input.parse::<syn::LitStr>()?.value());
174             }
175         }
176         let (resolve, pkgs, files) = parse_source(&paths, &inline)
177             .map_err(|err| Error::new(call_site, format!("{err:?}")))?;
178 
179         let world = resolve
180             .select_world(&pkgs, world.as_deref())
181             .map_err(|e| Error::new(call_site, format!("{e:?}")))?;
182         Ok(Config {
183             opts,
184             resolve,
185             world,
186             files,
187             include_generated_code_from_file,
188         })
189     }
190 }
191 
parse_source( paths: &Vec<String>, inline: &Option<String>, ) -> anyhow::Result<(Resolve, Vec<PackageId>, Vec<PathBuf>)>192 fn parse_source(
193     paths: &Vec<String>,
194     inline: &Option<String>,
195 ) -> anyhow::Result<(Resolve, Vec<PackageId>, Vec<PathBuf>)> {
196     let mut resolve = Resolve::default();
197     resolve.all_features = true;
198     let mut files = Vec::new();
199     let mut pkgs = Vec::new();
200     let root = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap());
201     let default = root.join("wit");
202 
203     let parse = |resolve: &mut Resolve,
204                  files: &mut Vec<PathBuf>,
205                  pkgs: &mut Vec<PackageId>,
206                  paths: &[PathBuf]|
207      -> anyhow::Result<_> {
208         for path in paths {
209             let p = root.join(path);
210             // Try to normalize the path to make the error message more understandable when
211             // the path is not correct. Fallback to the original path if normalization fails
212             // (probably return an error somewhere else).
213             let normalized_path = match std::fs::canonicalize(&p) {
214                 Ok(p) => p,
215                 Err(_) => p.to_path_buf(),
216             };
217             let (pkg, sources) = resolve.push_path(normalized_path)?;
218             pkgs.push(pkg);
219             files.extend(sources.paths().map(|p| p.to_owned()));
220         }
221         Ok(())
222     };
223 
224     if paths.is_empty() {
225         if default.exists() {
226             parse(&mut resolve, &mut files, &mut pkgs, &[default])?;
227         }
228     } else {
229         parse(
230             &mut resolve,
231             &mut files,
232             &mut pkgs,
233             &paths.iter().map(|s| s.into()).collect::<Vec<_>>(),
234         )?;
235     }
236 
237     if let Some(inline) = inline {
238         pkgs.truncate(0);
239         pkgs.push(resolve.push_group(UnresolvedPackageGroup::parse("macro-input", inline)?)?);
240     }
241 
242     Ok((resolve, pkgs, files))
243 }
244 
245 mod kw {
246     syn::custom_keyword!(inline);
247     syn::custom_keyword!(path);
248     syn::custom_keyword!(tracing);
249     syn::custom_keyword!(verbose_tracing);
250     syn::custom_keyword!(trappable_error_type);
251     syn::custom_keyword!(world);
252     syn::custom_keyword!(ownership);
253     syn::custom_keyword!(interfaces);
254     syn::custom_keyword!(with);
255     syn::custom_keyword!(except_imports);
256     syn::custom_keyword!(only_imports);
257     syn::custom_keyword!(additional_derives);
258     syn::custom_keyword!(stringify);
259     syn::custom_keyword!(skip_mut_forwarding_impls);
260     syn::custom_keyword!(require_store_data_send);
261     syn::custom_keyword!(wasmtime_crate);
262     syn::custom_keyword!(anyhow);
263     syn::custom_keyword!(include_generated_code_from_file);
264     syn::custom_keyword!(debug);
265     syn::custom_keyword!(imports);
266     syn::custom_keyword!(exports);
267     syn::custom_keyword!(store);
268     syn::custom_keyword!(trappable);
269     syn::custom_keyword!(ignore_wit);
270     syn::custom_keyword!(exact);
271 }
272 
273 enum Opt {
274     World(syn::LitStr),
275     Path(Vec<syn::LitStr>),
276     Inline(syn::LitStr),
277     TrappableErrorType(Vec<TrappableError>),
278     Ownership(Ownership),
279     Interfaces(syn::LitStr),
280     With(HashMap<String, String>),
281     AdditionalDerives(Vec<syn::Path>),
282     Stringify(bool),
283     SkipMutForwardingImpls(bool),
284     RequireStoreDataSend(bool),
285     WasmtimeCrate(syn::Path),
286     Anyhow(bool),
287     IncludeGeneratedCodeFromFile(bool),
288     Debug(bool),
289     Imports(FunctionConfig, Span),
290     Exports(FunctionConfig, Span),
291 }
292 
293 impl Parse for Opt {
parse(input: ParseStream<'_>) -> Result<Self>294     fn parse(input: ParseStream<'_>) -> Result<Self> {
295         let l = input.lookahead1();
296         if l.peek(kw::debug) {
297             input.parse::<kw::debug>()?;
298             input.parse::<Token![:]>()?;
299             Ok(Opt::Debug(input.parse::<syn::LitBool>()?.value))
300         } else if l.peek(kw::path) {
301             input.parse::<kw::path>()?;
302             input.parse::<Token![:]>()?;
303 
304             let mut paths: Vec<syn::LitStr> = vec![];
305 
306             let l = input.lookahead1();
307             if l.peek(syn::LitStr) {
308                 paths.push(input.parse()?);
309             } else if l.peek(syn::token::Bracket) {
310                 let contents;
311                 syn::bracketed!(contents in input);
312                 let list = Punctuated::<_, Token![,]>::parse_terminated(&contents)?;
313 
314                 paths.extend(list);
315             } else {
316                 return Err(l.error());
317             };
318 
319             Ok(Opt::Path(paths))
320         } else if l.peek(kw::inline) {
321             input.parse::<kw::inline>()?;
322             input.parse::<Token![:]>()?;
323             Ok(Opt::Inline(input.parse()?))
324         } else if l.peek(kw::world) {
325             input.parse::<kw::world>()?;
326             input.parse::<Token![:]>()?;
327             Ok(Opt::World(input.parse()?))
328         } else if l.peek(kw::ownership) {
329             input.parse::<kw::ownership>()?;
330             input.parse::<Token![:]>()?;
331             let ownership = input.parse::<syn::Ident>()?;
332             Ok(Opt::Ownership(match ownership.to_string().as_str() {
333                 "Owning" => Ownership::Owning,
334                 "Borrowing" => Ownership::Borrowing {
335                     duplicate_if_necessary: {
336                         let contents;
337                         braced!(contents in input);
338                         let field = contents.parse::<syn::Ident>()?;
339                         match field.to_string().as_str() {
340                             "duplicate_if_necessary" => {
341                                 contents.parse::<Token![:]>()?;
342                                 contents.parse::<syn::LitBool>()?.value
343                             }
344                             name => {
345                                 return Err(Error::new(
346                                     field.span(),
347                                     format!(
348                                         "unrecognized `Ownership::Borrowing` field: `{name}`; \
349                                          expected `duplicate_if_necessary`"
350                                     ),
351                                 ));
352                             }
353                         }
354                     },
355                 },
356                 name => {
357                     return Err(Error::new(
358                         ownership.span(),
359                         format!(
360                             "unrecognized ownership: `{name}`; \
361                              expected `Owning` or `Borrowing`"
362                         ),
363                     ));
364                 }
365             }))
366         } else if l.peek(kw::trappable_error_type) {
367             input.parse::<kw::trappable_error_type>()?;
368             input.parse::<Token![:]>()?;
369             let contents;
370             let _lbrace = braced!(contents in input);
371             let fields: Punctuated<_, Token![,]> =
372                 contents.parse_terminated(trappable_error_field_parse, Token![,])?;
373             Ok(Opt::TrappableErrorType(Vec::from_iter(fields)))
374         } else if l.peek(kw::interfaces) {
375             input.parse::<kw::interfaces>()?;
376             input.parse::<Token![:]>()?;
377             Ok(Opt::Interfaces(input.parse::<syn::LitStr>()?))
378         } else if l.peek(kw::with) {
379             input.parse::<kw::with>()?;
380             input.parse::<Token![:]>()?;
381             let contents;
382             let _lbrace = braced!(contents in input);
383             let fields: Punctuated<(String, String), Token![,]> =
384                 contents.parse_terminated(with_field_parse, Token![,])?;
385             Ok(Opt::With(HashMap::from_iter(fields)))
386         } else if l.peek(kw::additional_derives) {
387             input.parse::<kw::additional_derives>()?;
388             input.parse::<Token![:]>()?;
389             let contents;
390             syn::bracketed!(contents in input);
391             let list = Punctuated::<_, Token![,]>::parse_terminated(&contents)?;
392             Ok(Opt::AdditionalDerives(list.iter().cloned().collect()))
393         } else if l.peek(kw::stringify) {
394             input.parse::<kw::stringify>()?;
395             input.parse::<Token![:]>()?;
396             Ok(Opt::Stringify(input.parse::<syn::LitBool>()?.value))
397         } else if l.peek(kw::skip_mut_forwarding_impls) {
398             input.parse::<kw::skip_mut_forwarding_impls>()?;
399             input.parse::<Token![:]>()?;
400             Ok(Opt::SkipMutForwardingImpls(
401                 input.parse::<syn::LitBool>()?.value,
402             ))
403         } else if l.peek(kw::require_store_data_send) {
404             input.parse::<kw::require_store_data_send>()?;
405             input.parse::<Token![:]>()?;
406             Ok(Opt::RequireStoreDataSend(
407                 input.parse::<syn::LitBool>()?.value,
408             ))
409         } else if l.peek(kw::wasmtime_crate) {
410             input.parse::<kw::wasmtime_crate>()?;
411             input.parse::<Token![:]>()?;
412             Ok(Opt::WasmtimeCrate(input.parse()?))
413         } else if l.peek(kw::anyhow) {
414             input.parse::<kw::anyhow>()?;
415             input.parse::<Token![:]>()?;
416             Ok(Opt::Anyhow(input.parse::<syn::LitBool>()?.value))
417         } else if l.peek(kw::include_generated_code_from_file) {
418             input.parse::<kw::include_generated_code_from_file>()?;
419             input.parse::<Token![:]>()?;
420             Ok(Opt::IncludeGeneratedCodeFromFile(
421                 input.parse::<syn::LitBool>()?.value,
422             ))
423         } else if l.peek(kw::imports) {
424             let span = input.parse::<kw::imports>()?.span;
425             input.parse::<Token![:]>()?;
426             Ok(Opt::Imports(parse_function_config(input)?, span))
427         } else if l.peek(kw::exports) {
428             let span = input.parse::<kw::exports>()?.span;
429             input.parse::<Token![:]>()?;
430             Ok(Opt::Exports(parse_function_config(input)?, span))
431         } else {
432             Err(l.error())
433         }
434     }
435 }
436 
trappable_error_field_parse(input: ParseStream<'_>) -> Result<TrappableError>437 fn trappable_error_field_parse(input: ParseStream<'_>) -> Result<TrappableError> {
438     let wit_path = input.parse::<syn::LitStr>()?.value();
439     input.parse::<Token![=>]>()?;
440     let rust_type_name = input.parse::<syn::Path>()?.to_token_stream().to_string();
441     Ok(TrappableError {
442         wit_path,
443         rust_type_name,
444     })
445 }
446 
with_field_parse(input: ParseStream<'_>) -> Result<(String, String)>447 fn with_field_parse(input: ParseStream<'_>) -> Result<(String, String)> {
448     let interface = input.parse::<syn::LitStr>()?.value();
449     input.parse::<Token![:]>()?;
450     let start = input.span();
451     let path = input.parse::<syn::Path>()?;
452 
453     // It's not possible for the segments of a path to be empty
454     let span = start
455         .join(path.segments.last().unwrap().ident.span())
456         .unwrap_or(start);
457 
458     let mut buf = String::new();
459     let append = |buf: &mut String, segment: syn::PathSegment| -> Result<()> {
460         if segment.arguments != syn::PathArguments::None {
461             return Err(Error::new(
462                 span,
463                 "Module path must not contain angles or parens",
464             ));
465         }
466 
467         buf.push_str(&segment.ident.to_string());
468 
469         Ok(())
470     };
471 
472     if path.leading_colon.is_some() {
473         buf.push_str("::");
474     }
475 
476     let mut segments = path.segments.into_iter();
477 
478     if let Some(segment) = segments.next() {
479         append(&mut buf, segment)?;
480     }
481 
482     for segment in segments {
483         buf.push_str("::");
484         append(&mut buf, segment)?;
485     }
486 
487     Ok((interface, buf))
488 }
489 
parse_function_config(input: ParseStream<'_>) -> Result<FunctionConfig>490 fn parse_function_config(input: ParseStream<'_>) -> Result<FunctionConfig> {
491     let content;
492     syn::braced!(content in input);
493     let mut ret = FunctionConfig::new();
494 
495     let list = Punctuated::<FunctionConfigSyntax, Token![,]>::parse_terminated(&content)?;
496     for item in list.into_iter() {
497         ret.push(item.filter, item.flags);
498     }
499 
500     return Ok(ret);
501 
502     struct FunctionConfigSyntax {
503         filter: FunctionFilter,
504         flags: FunctionFlags,
505     }
506 
507     impl Parse for FunctionConfigSyntax {
508         fn parse(input: ParseStream<'_>) -> Result<Self> {
509             let l = input.lookahead1();
510             let filter = if l.peek(syn::LitStr) {
511                 FunctionFilter::Name(input.parse::<syn::LitStr>()?.value())
512             } else if l.peek(Token![default]) {
513                 input.parse::<Token![default]>()?;
514                 FunctionFilter::Default
515             } else {
516                 return Err(l.error());
517             };
518 
519             input.parse::<Token![:]>()?;
520 
521             let mut flags = FunctionFlags::empty();
522             while !input.is_empty() {
523                 let l = input.lookahead1();
524                 if l.peek(Token![async]) {
525                     input.parse::<Token![async]>()?;
526                     flags |= FunctionFlags::ASYNC;
527                 } else if l.peek(kw::tracing) {
528                     input.parse::<kw::tracing>()?;
529                     flags |= FunctionFlags::TRACING;
530                 } else if l.peek(kw::verbose_tracing) {
531                     input.parse::<kw::verbose_tracing>()?;
532                     flags |= FunctionFlags::VERBOSE_TRACING;
533                 } else if l.peek(kw::store) {
534                     input.parse::<kw::store>()?;
535                     flags |= FunctionFlags::STORE;
536                 } else if l.peek(kw::trappable) {
537                     input.parse::<kw::trappable>()?;
538                     flags |= FunctionFlags::TRAPPABLE;
539                 } else if l.peek(kw::ignore_wit) {
540                     input.parse::<kw::ignore_wit>()?;
541                     flags |= FunctionFlags::IGNORE_WIT;
542                 } else if l.peek(kw::exact) {
543                     input.parse::<kw::exact>()?;
544                     flags |= FunctionFlags::EXACT;
545                 } else {
546                     return Err(l.error());
547                 }
548 
549                 if input.peek(Token![|]) {
550                     input.parse::<Token![|]>()?;
551                 } else {
552                     break;
553                 }
554             }
555 
556             Ok(FunctionConfigSyntax { filter, flags })
557         }
558     }
559 }
560