From 645d3f8d22136d58d65f034fc15ea748d751eb96 Mon Sep 17 00:00:00 2001 From: Alexander van Ratingen <470642+alvra@users.noreply.github.com> Date: Thu, 21 Apr 2022 22:16:02 +0200 Subject: Initial commit --- derive-macro/src/handle_enum.rs | 55 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 derive-macro/src/handle_enum.rs (limited to 'derive-macro/src/handle_enum.rs') diff --git a/derive-macro/src/handle_enum.rs b/derive-macro/src/handle_enum.rs new file mode 100644 index 0000000..b2b14c0 --- /dev/null +++ b/derive-macro/src/handle_enum.rs @@ -0,0 +1,55 @@ +use proc_macro2::{Ident, Literal, TokenStream}; +use quote::quote; +use syn::{DataEnum, Variant}; +use super::generate_fields; + +fn variant_weight(variant: &Variant) -> Literal { + for attr in variant.attrs.iter() { + if attr.path.is_ident("weight") { + return attr.parse_args::().expect("expected literal for `#[weight(...)]`") + } + } + Literal::u64_suffixed(1) +} + +pub fn generate(name: &Ident, ty: DataEnum) -> TokenStream { + let mut variant_weights = ty.variants.into_iter() + .map(|variant| (variant_weight(&variant), variant)); + + let mut arms = TokenStream::new(); + let mut total_weight = quote! { 0 }; + if let Some((weight, variant)) = variant_weights.next() { + let variant_name = variant.ident; + let fields = generate_fields(variant.fields); + arms.extend(quote! { + let end = #weight; + if 0 <= value && value < end { + return Self::#variant_name #fields + } + }); + total_weight = quote! { #weight }; + for (weight, variant) in variant_weights { + let variant_name = variant.ident; + let fields = generate_fields(variant.fields); + arms.extend(quote! { + let start = end; + let end = start + #weight; + if start <= value && value < end { + return Self::#variant_name #fields + } + }); + total_weight = quote! { #total_weight + #weight }; + } + } + + quote! { + impl generate_random::GenerateRandom for #name { + fn generate_random(rng: &mut R) -> Self { + let total_weight = #total_weight; + let value = rng.gen_range(0..total_weight); + #arms + unreachable!() + } + } + } +} -- cgit v1.2.3