diff options
Diffstat (limited to 'derive')
-rw-r--r-- | derive/src/lib.rs | 96 |
1 files changed, 61 insertions, 35 deletions
diff --git a/derive/src/lib.rs b/derive/src/lib.rs index 27af31b..0ccc8ca 100644 --- a/derive/src/lib.rs +++ b/derive/src/lib.rs @@ -187,6 +187,7 @@ struct MtArgs { zlib: bool, zstd: bool, typename: Option<syn::Ident>, // remote derive + bounds: Option<syn::WhereClause>, } type Fields<'a> = Vec<(TokStr, &'a syn::Field)>; @@ -438,19 +439,46 @@ fn iter_variants(e: &syn::DataEnum, args: &MtArgs, mut f: impl FnMut(&syn::Varia } } +fn make_impl( + traitname: TokStr, + input: &syn::DeriveInput, + typename: &syn::Ident, + args: &MtArgs, + code: TokStr, +) -> TokenStream { + let generics = &input.generics; + let bounds = args.bounds.clone().or_else(|| { + if generics.params.is_empty() { + None + } else { + Some( + syn::parse( + generics + .params + .iter() + .rfold(quote! { where }, |before, t| match t { + syn::GenericParam::Type(x) => quote! { #before #x: #traitname, }, + _ => before, + }) + .into(), + ) + .expect("invalid where clause"), + ) + } + }); + + quote! { + #[automatically_derived] + impl #generics #traitname for #typename #generics #bounds { #code } + } + .into() +} + #[proc_macro_derive(MtSerialize, attributes(mt))] pub fn derive_serialize(input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as syn::DeriveInput); let args = MtArgs::from_derive_input(&input).unwrap(); let typename = args.typename.as_ref().unwrap_or(&input.ident); - let generics = &input.generics; - let mut generics_bounded = generics.clone(); - for t in generics_bounded.params.iter_mut() { - match t { - syn::GenericParam::Type(x) => *t = parse_quote! { #x: mt_ser::MtSerialize }, - _ => {} - } - } let mut code = match &input.data { syn::Data::Enum(e) => { @@ -489,16 +517,19 @@ pub fn derive_serialize(input: TokenStream) -> TokenStream { serialize_args(&args, &mut code); - quote! { - #[automatically_derived] - impl #generics_bounded mt_ser::MtSerialize for #typename #generics { - fn mt_serialize<C: mt_ser::MtCfg>(&self, __writer: &mut impl std::io::Write) -> Result<(), mt_ser::SerializeError> { - #code - - Ok(()) - } - } - }.into() + make_impl( + quote! { mt_ser::MtSerialize }, + &input, + typename, + &args, + quote! { + fn mt_serialize<C: mt_ser::MtCfg>(&self, __writer: &mut impl std::io::Write) -> Result<(), mt_ser::SerializeError> { + #code + + Ok(()) + } + }, + ) } #[proc_macro_derive(MtDeserialize, attributes(mt))] @@ -506,14 +537,6 @@ pub fn derive_deserialize(input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as syn::DeriveInput); let args = MtArgs::from_derive_input(&input).unwrap(); let typename = args.typename.as_ref().unwrap_or(&input.ident); - let generics = &input.generics; - let mut generics_bounded = generics.clone(); - for t in generics_bounded.params.iter_mut() { - match t { - syn::GenericParam::Type(x) => *t = parse_quote! { #x: mt_ser::MtDeserialize }, - _ => {} - } - } let mut code = match &input.data { syn::Data::Enum(e) => { @@ -582,13 +605,16 @@ pub fn derive_deserialize(input: TokenStream) -> TokenStream { deserialize_args(&args, &mut code); - quote! { - #[automatically_derived] - impl #generics_bounded mt_ser::MtDeserialize for #typename #generics { - #[allow(non_upper_case_globals)] - fn mt_deserialize<C: mt_ser::MtCfg>(__reader: &mut impl std::io::Read) -> Result<Self, mt_ser::DeserializeError> { - #code - } - } - }.into() + make_impl( + quote! { mt_ser::MtDeserialize }, + &input, + typename, + &args, + quote! { + #[allow(non_upper_case_globals)] + fn mt_deserialize<C: mt_ser::MtCfg>(__reader: &mut impl std::io::Read) -> Result<Self, mt_ser::DeserializeError> { + #code + } + }, + ) } |