diff options
Diffstat (limited to 'derive/src')
-rw-r--r-- | derive/src/lib.rs | 213 |
1 files changed, 125 insertions, 88 deletions
diff --git a/derive/src/lib.rs b/derive/src/lib.rs index 82d9644..ae9842c 100644 --- a/derive/src/lib.rs +++ b/derive/src/lib.rs @@ -1,3 +1,4 @@ +use convert_case::{Case, Casing}; use darling::{FromDeriveInput, FromField, FromMeta, FromVariant}; use proc_macro::TokenStream; use proc_macro2::TokenStream as TokStr; @@ -117,14 +118,6 @@ pub fn mt_derive(attr: TokenStream, item: TokenStream) -> TokenStream { }); } - if let Some(repr) = args.repr { - out.extend(quote! { - #[repr(#repr)] - }); - } else if !args.custom { - panic!("missing repr for enum"); - } - out.extend(quote! { #[derive(Clone, PartialEq)] }); @@ -135,6 +128,20 @@ pub fn mt_derive(attr: TokenStream, item: TokenStream) -> TokenStream { #[cfg_attr(feature = #deserializer, derive(MtDeserialize))] }); } + + if let Some(repr) = args.repr { + if repr == parse_quote! { str } { + out.extend(quote! { + #[mt(string_repr)] + }); + } else { + out.extend(quote! { + #[repr(#repr)] + }); + } + } else if !args.custom { + panic!("missing repr for enum"); + } } out.extend(quote! { @@ -180,6 +187,8 @@ struct MtArgs { zlib: bool, zstd: bool, // TODO default: bool, // type must implement Default + + string_repr: bool, // for enums } type Fields<'a> = Vec<(TokStr, &'a syn::Field)>; @@ -392,14 +401,35 @@ fn get_fields_struct(input: &syn::Fields) -> (Fields, TokStr) { (fields, fields_struct) } -fn get_repr(input: &syn::DeriveInput) -> syn::Type { - input - .attrs - .iter() - .find(|a| a.path.is_ident("repr")) - .expect("missing repr") - .parse_args() - .expect("invalid repr") +fn get_repr(input: &syn::DeriveInput, args: &MtArgs) -> syn::Type { + if args.string_repr { + parse_quote! { &str } + } else { + input + .attrs + .iter() + .find(|a| a.path.is_ident("repr")) + .expect("missing repr") + .parse_args() + .expect("invalid repr") + } +} + +fn iter_variants(e: &syn::DataEnum, args: &MtArgs, mut f: impl FnMut(&syn::Variant, &syn::Expr)) { + let mut discr = parse_quote! { 0 }; + + for v in e.variants.iter() { + discr = if args.string_repr { + let lit = v.ident.to_string().to_case(Case::Snake); + parse_quote! { #lit } + } else { + v.discriminant.clone().map(|x| x.1).unwrap_or(discr) + }; + + f(&v, &discr); + + discr = parse_quote! { 1 + #discr }; + } } #[proc_macro_derive(MtSerialize, attributes(mt))] @@ -407,40 +437,38 @@ pub fn derive_serialize(input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as syn::DeriveInput); let typename = &input.ident; - let code = serialize_args(MtArgs::from_derive_input(&input), |_| match &input.data { - syn::Data::Enum(e) => { - let repr = get_repr(&input); - let variants: TokStr = e.variants - .iter() - .fold((parse_quote! { 0 }, TokStr::new()), |(discr, before), v| { - let discr = v.discriminant.clone().map(|x| x.1).unwrap_or(discr); - let (fields, fields_struct) = get_fields_struct(&v.fields); - - let code = serialize_args(MtArgs::from_variant(v), |_| - serialize_fields(&fields)); - let variant = &v.ident; - - ( - parse_quote! { 1 + #discr }, - quote! { - #before - #typename::#variant #fields_struct => { - mt_ser::MtSerialize::mt_serialize::<mt_ser::DefCfg>(&((#discr) as #repr), __writer)?; - #code - } - } - ) - }).1; + let code = serialize_args(MtArgs::from_derive_input(&input), |args| { + match &input.data { + syn::Data::Enum(e) => { + let repr = get_repr(&input, &args); + let mut variants = TokStr::new(); + + iter_variants(&e, &args, |v, discr| { + let (fields, fields_struct) = get_fields_struct(&v.fields); + let code = + serialize_args(MtArgs::from_variant(v), |_| serialize_fields(&fields)); + let ident = &v.ident; + + variants.extend(quote! { + #typename::#ident #fields_struct => { + mt_ser::MtSerialize::mt_serialize::<mt_ser::DefCfg>(&((#discr) as #repr), __writer)?; + #code + } + }); + }); - quote! { - match self { - #variants + quote! { + match self { + #variants + } } } - } - syn::Data::Struct(s) => serialize_fields(&get_fields(&s.fields, |f| quote! { &self.#f })), - _ => { - panic!("only enum and struct supported"); + syn::Data::Struct(s) => { + serialize_fields(&get_fields(&s.fields, |f| quote! { &self.#f })) + } + _ => { + panic!("only enum and struct supported"); + } } }); @@ -461,60 +489,69 @@ pub fn derive_deserialize(input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as syn::DeriveInput); let typename = &input.ident; - let code = deserialize_args(MtArgs::from_derive_input(&input), |_| match &input.data { - syn::Data::Enum(e) => { - let repr = get_repr(&input); - let type_str = typename.to_string(); + let code = deserialize_args(MtArgs::from_derive_input(&input), |args| { + match &input.data { + syn::Data::Enum(e) => { + let repr = get_repr(&input, &args); - let mut consts = TokStr::new(); - let mut arms = TokStr::new(); - let mut discr = parse_quote! { 0 }; + let mut consts = TokStr::new(); + let mut arms = TokStr::new(); - for v in e.variants.iter() { - discr = v.discriminant.clone().map(|x| x.1).unwrap_or(discr); + iter_variants(&e, &args, |v, discr| { + let ident = &v.ident; + let (fields, fields_struct) = get_fields_struct(&v.fields); + let code = deserialize_args(MtArgs::from_variant(v), |_| { + let fields_code = deserialize_fields(&fields); - let ident = &v.ident; - let (fields, fields_struct) = get_fields_struct(&v.fields); - let code = deserialize_args(MtArgs::from_variant(v), |_| { - let fields_code = deserialize_fields(&fields); - - quote! { - #fields_code - Ok(Self::#ident #fields_struct) - } - }); + quote! { + #fields_code + Ok(Self::#ident #fields_struct) + } + }); - consts.extend(quote! { - const #ident: #repr = #discr; - }); + consts.extend(quote! { + const #ident: #repr = #discr; + }); - arms.extend(quote! { - #ident => { #code } + arms.extend(quote! { + #ident => { #code } + }); }); - discr = parse_quote! { 1 + #discr }; - } + let type_str = typename.to_string(); + let discr_match = if args.string_repr { + quote! { + let __discr = String::mt_deserialize::<DefCfg>(__reader)?; + match __discr.as_str() + } + } else { + quote! { + let __discr = mt_ser::MtDeserialize::mt_deserialize::<DefCfg>(__reader)?; + match __discr + } + }; - quote! { - #consts + quote! { + #consts - match mt_ser::MtDeserialize::mt_deserialize::<DefCfg>(__reader)? { - #arms - x => Err(mt_ser::DeserializeError::InvalidEnumVariant(#type_str, x as u64)) + #discr_match { + #arms + _ => Err(mt_ser::DeserializeError::InvalidEnum(#type_str, Box::new(__discr))) + } } } - } - syn::Data::Struct(s) => { - let (fields, fields_struct) = get_fields_struct(&s.fields); - let code = deserialize_fields(&fields); + syn::Data::Struct(s) => { + let (fields, fields_struct) = get_fields_struct(&s.fields); + let code = deserialize_fields(&fields); - quote! { - #code - Ok(Self #fields_struct) + quote! { + #code + Ok(Self #fields_struct) + } + } + _ => { + panic!("only enum and struct supported"); } - } - _ => { - panic!("only enum and struct supported"); } }); |