summaryrefslogtreecommitdiff
path: root/derive/src
diff options
context:
space:
mode:
Diffstat (limited to 'derive/src')
-rw-r--r--derive/src/lib.rs170
1 files changed, 170 insertions, 0 deletions
diff --git a/derive/src/lib.rs b/derive/src/lib.rs
new file mode 100644
index 0000000..8daabf0
--- /dev/null
+++ b/derive/src/lib.rs
@@ -0,0 +1,170 @@
+use darling::FromMeta;
+use proc_macro::{self, TokenStream};
+use quote::{quote, ToTokens};
+use syn::{parse_macro_input, parse_quote};
+
+#[derive(Debug, FromMeta, Copy, Clone, Eq, PartialEq)]
+#[darling(rename_all = "snake_case")]
+enum To {
+ Clt,
+ Srv,
+}
+
+#[derive(Debug, FromMeta)]
+struct MacroArgs {
+ to: To,
+ repr: Option<syn::Type>,
+ tag: Option<String>,
+ content: Option<String>,
+ #[darling(default)]
+ enumset: bool,
+}
+
+fn wrap_attr(attr: &mut syn::Attribute) {
+ match attr.path.get_ident().map(|i| i.to_string()).as_deref() {
+ Some("mt") => {
+ let path = attr.path.clone();
+ let tokens = attr.tokens.clone();
+
+ *attr = parse_quote! {
+ #[cfg_attr(any(feature = "client", feature = "server"), #path #tokens)]
+ };
+ }
+ Some("serde") => {
+ let path = attr.path.clone();
+ let tokens = attr.tokens.clone();
+
+ *attr = parse_quote! {
+ #[cfg_attr(feature = "serde", #path #tokens)]
+ };
+ }
+ _ => {}
+ }
+}
+
+#[proc_macro_attribute]
+pub fn mt_derive(attr: TokenStream, item: TokenStream) -> TokenStream {
+ let item2 = item.clone();
+
+ let attr_args = parse_macro_input!(attr as syn::AttributeArgs);
+ let mut input = parse_macro_input!(item2 as syn::Item);
+
+ let args = match MacroArgs::from_list(&attr_args) {
+ Ok(v) => v,
+ Err(e) => {
+ return TokenStream::from(e.write_errors());
+ }
+ };
+
+ let (serializer, deserializer) = match args.to {
+ To::Clt => ("server", "client"),
+ To::Srv => ("client", "server"),
+ };
+
+ let mut out = quote! {
+ #[derive(Debug)]
+ #[cfg_attr(feature = "random", derive(GenerateRandom))]
+ #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
+ };
+
+ macro_rules! iter {
+ ($t:expr, $f:expr) => {
+ $t.iter_mut().for_each($f)
+ };
+ }
+
+ match &mut input {
+ syn::Item::Enum(e) => {
+ iter!(e.attrs, wrap_attr);
+ iter!(e.variants, |v| {
+ iter!(v.attrs, wrap_attr);
+ iter!(v.fields, |f| iter!(f.attrs, wrap_attr));
+ });
+
+ let repr = args.repr.expect("missing repr for enum");
+
+ if args.enumset {
+ let repr_str = repr.to_token_stream().to_string();
+
+ out.extend(quote! {
+ #[derive(EnumSetType)]
+ #[enumset(repr = #repr_str, serialize_as_map)]
+ })
+ } else {
+ let has_payload = e
+ .variants
+ .iter()
+ .find_map(|v| if v.fields.is_empty() { None } else { Some(()) })
+ .is_some();
+
+ if has_payload {
+ let tag = args.tag.expect("missing tag for enum with payload");
+
+ out.extend(quote! {
+ #[derive(Clone)]
+ #[cfg_attr(feature = "serde", serde(tag = #tag))]
+ });
+
+ if let Some(content) = args.content {
+ out.extend(quote! {
+ #[cfg_attr(feature = "serde", serde(content = #content))]
+ });
+ }
+ } else {
+ out.extend(quote! {
+ #[derive(Copy, Clone, PartialEq, Eq)]
+ });
+ }
+
+ out.extend(quote! {
+ #[repr(#repr)]
+ #[cfg_attr(feature = #serializer, derive(MtSerialize))]
+ #[cfg_attr(feature = #deserializer, derive(MtDeserialize))]
+ });
+ }
+
+ out.extend(quote! {
+ #[cfg_attr(feature = "serde", serde(rename_all = "snake_case"))]
+ });
+ }
+ syn::Item::Struct(s) => {
+ iter!(s.attrs, wrap_attr);
+ iter!(s.fields, |f| iter!(f.attrs, wrap_attr));
+
+ out.extend(quote! {
+ #[derive(Clone)]
+ #[cfg_attr(feature = #serializer, derive(MtSerialize))]
+ #[cfg_attr(feature = #deserializer, derive(MtDeserialize))]
+ });
+ }
+ _ => panic!("only enum and struct supported"),
+ }
+
+ out.extend(input.to_token_stream());
+ out.into()
+}
+
+#[proc_macro_derive(MtSerialize, attributes(mt))]
+pub fn derive_serialize(input: TokenStream) -> TokenStream {
+ let syn::DeriveInput { ident, .. } = parse_macro_input!(input);
+ let output = quote! {
+ impl MtSerialize for #ident {
+ fn mt_serialize<W: std::io::Write>(&self, writer: &mut W) -> Result<(), mt_data::SerializeError> {
+ Err(mt_data::SerializeError::Unimplemented)
+ }
+ }
+ };
+ output.into()
+}
+
+#[proc_macro_derive(MtDeserialize, attributes(mt))]
+pub fn derive_deserialize(input: TokenStream) -> TokenStream {
+ let syn::DeriveInput { ident, .. } = parse_macro_input!(input);
+ quote! {
+ impl MtDeserialize for #ident {
+ fn mt_deserialize<R: std::io::Read>(reader: &mut R) -> Result<Self, mt_data::DeserializeError> {
+ Err(mt_data::DeserializeError::Unimplemented)
+ }
+ }
+ }.into()
+}