1fbf8fb32SBenno Lossin // SPDX-License-Identifier: GPL-2.0
2fbf8fb32SBenno Lossin
3*7cb5dee4SBenno Lossin #[cfg(not(kernel))]
4*7cb5dee4SBenno Lossin use proc_macro2 as proc_macro;
5*7cb5dee4SBenno Lossin
6fbf8fb32SBenno Lossin use crate::helpers::{parse_generics, Generics};
7fbf8fb32SBenno Lossin use proc_macro::{TokenStream, TokenTree};
8fbf8fb32SBenno Lossin
derive(input: TokenStream) -> TokenStream9fbf8fb32SBenno Lossin pub(crate) fn derive(input: TokenStream) -> TokenStream {
10fbf8fb32SBenno Lossin let (
11fbf8fb32SBenno Lossin Generics {
12fbf8fb32SBenno Lossin impl_generics,
13fbf8fb32SBenno Lossin decl_generics: _,
14fbf8fb32SBenno Lossin ty_generics,
15fbf8fb32SBenno Lossin },
16fbf8fb32SBenno Lossin mut rest,
17fbf8fb32SBenno Lossin ) = parse_generics(input);
18fbf8fb32SBenno Lossin // This should be the body of the struct `{...}`.
19fbf8fb32SBenno Lossin let last = rest.pop();
20fbf8fb32SBenno Lossin // Now we insert `Zeroable` as a bound for every generic parameter in `impl_generics`.
21fbf8fb32SBenno Lossin let mut new_impl_generics = Vec::with_capacity(impl_generics.len());
22fbf8fb32SBenno Lossin // Are we inside of a generic where we want to add `Zeroable`?
23fbf8fb32SBenno Lossin let mut in_generic = !impl_generics.is_empty();
24fbf8fb32SBenno Lossin // Have we already inserted `Zeroable`?
25fbf8fb32SBenno Lossin let mut inserted = false;
26fbf8fb32SBenno Lossin // Level of `<>` nestings.
27fbf8fb32SBenno Lossin let mut nested = 0;
28fbf8fb32SBenno Lossin for tt in impl_generics {
29fbf8fb32SBenno Lossin match &tt {
30fbf8fb32SBenno Lossin // If we find a `,`, then we have finished a generic/constant/lifetime parameter.
31fbf8fb32SBenno Lossin TokenTree::Punct(p) if nested == 0 && p.as_char() == ',' => {
32fbf8fb32SBenno Lossin if in_generic && !inserted {
33dbd5058bSBenno Lossin new_impl_generics.extend(quote! { : ::pin_init::Zeroable });
34fbf8fb32SBenno Lossin }
35fbf8fb32SBenno Lossin in_generic = true;
36fbf8fb32SBenno Lossin inserted = false;
37fbf8fb32SBenno Lossin new_impl_generics.push(tt);
38fbf8fb32SBenno Lossin }
39fbf8fb32SBenno Lossin // If we find `'`, then we are entering a lifetime.
40fbf8fb32SBenno Lossin TokenTree::Punct(p) if nested == 0 && p.as_char() == '\'' => {
41fbf8fb32SBenno Lossin in_generic = false;
42fbf8fb32SBenno Lossin new_impl_generics.push(tt);
43fbf8fb32SBenno Lossin }
44fbf8fb32SBenno Lossin TokenTree::Punct(p) if nested == 0 && p.as_char() == ':' => {
45fbf8fb32SBenno Lossin new_impl_generics.push(tt);
46fbf8fb32SBenno Lossin if in_generic {
47dbd5058bSBenno Lossin new_impl_generics.extend(quote! { ::pin_init::Zeroable + });
48fbf8fb32SBenno Lossin inserted = true;
49fbf8fb32SBenno Lossin }
50fbf8fb32SBenno Lossin }
51fbf8fb32SBenno Lossin TokenTree::Punct(p) if p.as_char() == '<' => {
52fbf8fb32SBenno Lossin nested += 1;
53fbf8fb32SBenno Lossin new_impl_generics.push(tt);
54fbf8fb32SBenno Lossin }
55fbf8fb32SBenno Lossin TokenTree::Punct(p) if p.as_char() == '>' => {
56fbf8fb32SBenno Lossin assert!(nested > 0);
57fbf8fb32SBenno Lossin nested -= 1;
58fbf8fb32SBenno Lossin new_impl_generics.push(tt);
59fbf8fb32SBenno Lossin }
60fbf8fb32SBenno Lossin _ => new_impl_generics.push(tt),
61fbf8fb32SBenno Lossin }
62fbf8fb32SBenno Lossin }
63fbf8fb32SBenno Lossin assert_eq!(nested, 0);
64fbf8fb32SBenno Lossin if in_generic && !inserted {
65dbd5058bSBenno Lossin new_impl_generics.extend(quote! { : ::pin_init::Zeroable });
66fbf8fb32SBenno Lossin }
67fbf8fb32SBenno Lossin quote! {
68dbd5058bSBenno Lossin ::pin_init::__derive_zeroable!(
69fbf8fb32SBenno Lossin parse_input:
70fbf8fb32SBenno Lossin @sig(#(#rest)*),
71fbf8fb32SBenno Lossin @impl_generics(#(#new_impl_generics)*),
72fbf8fb32SBenno Lossin @ty_generics(#(#ty_generics)*),
73fbf8fb32SBenno Lossin @body(#last),
74fbf8fb32SBenno Lossin );
75fbf8fb32SBenno Lossin }
76fbf8fb32SBenno Lossin }
77