diff options
Diffstat (limited to 'derive-macro/src')
-rw-r--r-- | derive-macro/src/handle_enum.rs | 66 |
1 files changed, 46 insertions, 20 deletions
diff --git a/derive-macro/src/handle_enum.rs b/derive-macro/src/handle_enum.rs index 8165ab5..ecda02b 100644 --- a/derive-macro/src/handle_enum.rs +++ b/derive-macro/src/handle_enum.rs @@ -15,35 +15,40 @@ fn variant_weight(variant: &Variant) -> Literal { } pub fn generate(name: &Ident, ty: DataEnum) -> TokenStream { - let mut variant_weights = ty + let variant_weights = ty .variants .into_iter() - .map(|variant| (variant_weight(&variant), variant)); + .enumerate() + .map(|(i, variant)| (i, variant_weight(&variant), variant)); let mut arms = TokenStream::new(); + let mut arms_variant = TokenStream::new(); + let mut arms_variant_name = TokenStream::new(); + let mut num_variants: usize = 0; + let mut total_weight = quote! { 0 }; - if let Some((weight, variant)) = variant_weights.next() { + for (index, weight, variant) in variant_weights { let variant_name = variant.ident; - let fields = generate_fields(variant.fields); arms.extend(quote! { - let end = #weight; - if 0 <= value && value < end { - return Self::#variant_name #fields + let start = end; + let end = start + #weight; + if start <= value && value < end { + return generate_random::GenerateRandomVariant::generate_random_variant(rng, #index); } }); - total_weight = quote! { #weight }; - for (weight, variant) in variant_weights { - let variant_name = variant.ident; - let fields = generate_fields(variant.fields); - arms.extend(quote! { - let start = end; - let end = start + #weight; - if start <= value && value < end { - return Self::#variant_name #fields - } - }); - total_weight = quote! { #total_weight + #weight }; - } + + let fields = generate_fields(variant.fields); + arms_variant.extend(quote! { + #index => Self::#variant_name #fields, + }); + + let variant_str = variant_name.to_string(); + arms_variant_name.extend(quote! { + #index => #variant_str, + }); + + total_weight = quote! { #total_weight + #weight }; + num_variants += 1; } quote! { @@ -51,9 +56,30 @@ pub fn generate(name: &Ident, ty: DataEnum) -> TokenStream { fn generate_random<R: rand::Rng + ?Sized>(rng: &mut R) -> Self { let total_weight = #total_weight; let value = rng.gen_range(0..total_weight); + let end = 0; #arms unreachable!() } } + + impl generate_random::GenerateRandomVariant for #name { + fn num_variants() -> usize { + #num_variants + } + + fn variant_name(variant: usize) -> &'static str { + match variant { + #arms_variant_name + _ => "", + } + } + + fn generate_random_variant<R: rand::Rng + ?Sized>(rng: &mut R, variant: usize) -> Self { + match variant { + #arms_variant + _ => generate_random::GenerateRandom::generate_random(rng), + } + } + } } } |