diff options
author | Alexander van Ratingen <470642+alvra@users.noreply.github.com> | 2022-04-21 22:16:02 +0200 |
---|---|---|
committer | Alexander van Ratingen <470642+alvra@users.noreply.github.com> | 2022-04-21 22:16:02 +0200 |
commit | 645d3f8d22136d58d65f034fc15ea748d751eb96 (patch) | |
tree | bf55549077f17e9167194f6cb21fed41ec30ba03 /derive-macro/src | |
download | generate-random-645d3f8d22136d58d65f034fc15ea748d751eb96.tar.xz |
Initial commit
Diffstat (limited to 'derive-macro/src')
-rw-r--r-- | derive-macro/src/handle_enum.rs | 55 | ||||
-rw-r--r-- | derive-macro/src/handle_struct.rs | 15 | ||||
-rw-r--r-- | derive-macro/src/lib.rs | 46 |
3 files changed, 116 insertions, 0 deletions
diff --git a/derive-macro/src/handle_enum.rs b/derive-macro/src/handle_enum.rs new file mode 100644 index 0000000..b2b14c0 --- /dev/null +++ b/derive-macro/src/handle_enum.rs @@ -0,0 +1,55 @@ +use proc_macro2::{Ident, Literal, TokenStream}; +use quote::quote; +use syn::{DataEnum, Variant}; +use super::generate_fields; + +fn variant_weight(variant: &Variant) -> Literal { + for attr in variant.attrs.iter() { + if attr.path.is_ident("weight") { + return attr.parse_args::<Literal>().expect("expected literal for `#[weight(...)]`") + } + } + Literal::u64_suffixed(1) +} + +pub fn generate(name: &Ident, ty: DataEnum) -> TokenStream { + let mut variant_weights = ty.variants.into_iter() + .map(|variant| (variant_weight(&variant), variant)); + + let mut arms = TokenStream::new(); + let mut total_weight = quote! { 0 }; + if let Some((weight, variant)) = variant_weights.next() { + let variant_name = variant.ident; + let fields = generate_fields(variant.fields); + arms.extend(quote! { + let end = #weight; + if 0 <= value && value < end { + return Self::#variant_name #fields + } + }); + total_weight = quote! { #weight }; + for (weight, variant) in variant_weights { + let variant_name = variant.ident; + let fields = generate_fields(variant.fields); + arms.extend(quote! { + let start = end; + let end = start + #weight; + if start <= value && value < end { + return Self::#variant_name #fields + } + }); + total_weight = quote! { #total_weight + #weight }; + } + } + + quote! { + impl generate_random::GenerateRandom for #name { + fn generate_random<R: rand::Rng + ?Sized>(rng: &mut R) -> Self { + let total_weight = #total_weight; + let value = rng.gen_range(0..total_weight); + #arms + unreachable!() + } + } + } +} diff --git a/derive-macro/src/handle_struct.rs b/derive-macro/src/handle_struct.rs new file mode 100644 index 0000000..988bd06 --- /dev/null +++ b/derive-macro/src/handle_struct.rs @@ -0,0 +1,15 @@ +use proc_macro2::{Ident, TokenStream}; +use quote::quote; +use syn::DataStruct; +use super::generate_fields; + +pub fn generate(name: &Ident, ty: DataStruct) -> TokenStream { + let fields = generate_fields(ty.fields); + quote! { + impl generate_random::GenerateRandom for #name { + fn generate_random<R: rand::Rng + ?Sized>(rng: &mut R) -> Self { + Self #fields + } + } + } +} diff --git a/derive-macro/src/lib.rs b/derive-macro/src/lib.rs new file mode 100644 index 0000000..9aa93f5 --- /dev/null +++ b/derive-macro/src/lib.rs @@ -0,0 +1,46 @@ +//! This crate provide the [`GenerateRandom`] derive macro +//! that implements the trait of the same name from the `generate-random` crate. +//! Refer to the documentation of that crate for more information. + +use syn::{DeriveInput, Data, Fields}; + +mod handle_struct; +mod handle_enum; + +#[proc_macro_derive(GenerateRandom, attributes(weight))] +pub fn derive_generate_random(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + let input: DeriveInput = syn::parse(input).unwrap(); + match input.data { + Data::Struct(ty) => handle_struct::generate(&input.ident, ty), + Data::Enum(ty) => handle_enum::generate(&input.ident, ty), + Data::Union(_) => panic!("Unions are not supported"), + }.into() +} + +fn generate_fields(fields: Fields) -> proc_macro2::TokenStream { + use quote::quote; + match fields { + Fields::Named(fields) => { + let fields = fields.named.into_iter() + .map(|field| { + let field = field.ident.unwrap(); + quote! { + #field: generate_random::GenerateRandom::generate_random(rng), + } + }) + .collect::<proc_macro2::TokenStream>(); + quote! { { #fields } } + } + Fields::Unnamed(fields) => { + let fields = fields.unnamed.into_iter() + .map(|_field| { + quote! { + generate_random::GenerateRandom::generate_random(rng), + } + }) + .collect::<proc_macro2::TokenStream>(); + quote! { ( #fields ) } + } + Fields::Unit => quote! {}, + } +} |