aboutsummaryrefslogtreecommitdiff
path: root/derive-macro/src/handle_enum.rs
blob: 8165ab5cdc8ae694d1afb8cf2792971adb05eb85 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
use super::generate_fields;
use proc_macro2::{Ident, Literal, TokenStream};
use quote::quote;
use syn::{DataEnum, Variant};

fn variant_weight(variant: &Variant) -> Literal {
    for attr in variant.attrs.iter() {
        if attr.path.is_ident("weight") {
            return attr
                .parse_args::<Literal>()
                .expect("expected literal for `#[weight(...)]`");
        }
    }
    Literal::u64_suffixed(1)
}

pub fn generate(name: &Ident, ty: DataEnum) -> TokenStream {
    let mut variant_weights = ty
        .variants
        .into_iter()
        .map(|variant| (variant_weight(&variant), variant));

    let mut arms = TokenStream::new();
    let mut total_weight = quote! { 0 };
    if let Some((weight, variant)) = variant_weights.next() {
        let variant_name = variant.ident;
        let fields = generate_fields(variant.fields);
        arms.extend(quote! {
            let end = #weight;
            if 0 <= value && value < end {
                return Self::#variant_name #fields
            }
        });
        total_weight = quote! { #weight };
        for (weight, variant) in variant_weights {
            let variant_name = variant.ident;
            let fields = generate_fields(variant.fields);
            arms.extend(quote! {
                let start = end;
                let end = start + #weight;
                if start <= value && value < end {
                    return Self::#variant_name #fields
                }
            });
            total_weight = quote! { #total_weight + #weight };
        }
    }

    quote! {
        impl generate_random::GenerateRandom for #name {
            fn generate_random<R: rand::Rng + ?Sized>(rng: &mut R) -> Self {
                let total_weight = #total_weight;
                let value = rng.gen_range(0..total_weight);
                #arms
                unreachable!()
            }
        }
    }
}