strum_macros/macros/strings/
from_string.rs

1use proc_macro2::TokenStream;
2use quote::quote;
3use syn::{Data, DeriveInput, Fields};
4
5use crate::helpers::{
6    non_enum_error, occurrence_error, HasStrumVariantProperties, HasTypeProperties,
7};
8
9pub fn from_string_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {
10    let name = &ast.ident;
11    let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
12    let variants = match &ast.data {
13        Data::Enum(v) => &v.variants,
14        _ => return Err(non_enum_error()),
15    };
16
17    let type_properties = ast.get_type_properties()?;
18    let strum_module_path = type_properties.crate_module_path();
19
20    let mut default_kw = None;
21    let mut default =
22        quote! { ::core::result::Result::Err(#strum_module_path::ParseError::VariantNotFound) };
23
24    let mut phf_exact_match_arms = Vec::new();
25    let mut standard_match_arms = Vec::new();
26    for variant in variants {
27        let ident = &variant.ident;
28        let variant_properties = variant.get_variant_properties()?;
29
30        if variant_properties.disabled.is_some() {
31            continue;
32        }
33
34        if let Some(kw) = variant_properties.default {
35            if let Some(fst_kw) = default_kw {
36                return Err(occurrence_error(fst_kw, kw, "default"));
37            }
38
39            match &variant.fields {
40                Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {}
41                _ => {
42                    return Err(syn::Error::new_spanned(
43                        variant,
44                        "Default only works on newtype structs with a single String field",
45                    ))
46                }
47            }
48
49            default_kw = Some(kw);
50            default = quote! {
51                ::core::result::Result::Ok(#name::#ident(s.into()))
52            };
53            continue;
54        }
55
56        let params = match &variant.fields {
57            Fields::Unit => quote! {},
58            Fields::Unnamed(fields) => {
59                let defaults =
60                    ::core::iter::repeat(quote!(Default::default())).take(fields.unnamed.len());
61                quote! { (#(#defaults),*) }
62            }
63            Fields::Named(fields) => {
64                let fields = fields
65                    .named
66                    .iter()
67                    .map(|field| field.ident.as_ref().unwrap());
68                quote! { {#(#fields: Default::default()),*} }
69            }
70        };
71
72        let is_ascii_case_insensitive = variant_properties
73            .ascii_case_insensitive
74            .unwrap_or(type_properties.ascii_case_insensitive);
75
76        // If we don't have any custom variants, add the default serialized name.
77        for serialization in variant_properties.get_serializations(type_properties.case_style) {
78            if type_properties.use_phf {
79                phf_exact_match_arms.push(quote! { #serialization => #name::#ident #params, });
80
81                if is_ascii_case_insensitive {
82                    // Store the lowercase and UPPERCASE variants in the phf map to capture 
83                    let ser_string = serialization.value();
84
85                    let lower =
86                        syn::LitStr::new(&ser_string.to_ascii_lowercase(), serialization.span());
87                    let upper =
88                        syn::LitStr::new(&ser_string.to_ascii_uppercase(), serialization.span());
89                    phf_exact_match_arms.push(quote! { #lower => #name::#ident #params, });
90                    phf_exact_match_arms.push(quote! { #upper => #name::#ident #params, });
91                    standard_match_arms.push(quote! { s if s.eq_ignore_ascii_case(#serialization) => #name::#ident #params, });
92                }
93            } else {
94                standard_match_arms.push(if !is_ascii_case_insensitive {
95                    quote! { #serialization => #name::#ident #params, }
96                } else {
97                    quote! { s if s.eq_ignore_ascii_case(#serialization) => #name::#ident #params, }
98                });
99            }
100        }
101    }
102
103    let phf_body = if phf_exact_match_arms.is_empty() {
104        quote!()
105    } else {
106        quote! {
107            use #strum_module_path::_private_phf_reexport_for_macro_if_phf_feature as phf;
108            static PHF: phf::Map<&'static str, #name> = phf::phf_map! {
109                #(#phf_exact_match_arms)*
110            };
111            if let Some(value) = PHF.get(s).cloned() {
112                return ::core::result::Result::Ok(value);
113            }
114        }
115    };
116    let standard_match_body = if standard_match_arms.is_empty() {
117        default
118    } else {
119        quote! {
120            ::core::result::Result::Ok(match s {
121                #(#standard_match_arms)*
122                _ => return #default,
123            })
124        }
125    };
126
127    let from_str = quote! {
128        #[allow(clippy::use_self)]
129        impl #impl_generics ::core::str::FromStr for #name #ty_generics #where_clause {
130            type Err = #strum_module_path::ParseError;
131            fn from_str(s: &str) -> ::core::result::Result< #name #ty_generics , <Self as ::core::str::FromStr>::Err> {
132                #phf_body
133                #standard_match_body
134            }
135        }
136    };
137
138    let try_from_str = try_from_str(
139        name,
140        &impl_generics,
141        &ty_generics,
142        where_clause,
143        &strum_module_path,
144    );
145
146    Ok(quote! {
147        #from_str
148        #try_from_str
149    })
150}
151
152#[rustversion::before(1.34)]
153fn try_from_str(
154    _name: &proc_macro2::Ident,
155    _impl_generics: &syn::ImplGenerics,
156    _ty_generics: &syn::TypeGenerics,
157    _where_clause: Option<&syn::WhereClause>,
158    _strum_module_path: &syn::Path,
159) -> TokenStream {
160    Default::default()
161}
162
163#[rustversion::since(1.34)]
164fn try_from_str(
165    name: &proc_macro2::Ident,
166    impl_generics: &syn::ImplGenerics,
167    ty_generics: &syn::TypeGenerics,
168    where_clause: Option<&syn::WhereClause>,
169    strum_module_path: &syn::Path,
170) -> TokenStream {
171    quote! {
172        #[allow(clippy::use_self)]
173        impl #impl_generics ::core::convert::TryFrom<&str> for #name #ty_generics #where_clause {
174            type Error = #strum_module_path::ParseError;
175            fn try_from(s: &str) -> ::core::result::Result< #name #ty_generics , <Self as ::core::convert::TryFrom<&str>>::Error> {
176                ::core::str::FromStr::from_str(s)
177            }
178        }
179    }
180}