1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
use core::convert::TryFrom; use proc_macro2::TokenStream as TokenStream2; use quote::quote; use syn::{Fields, Ident, ItemEnum, WhereClause}; use crate::attribute_helpers::{contains_initialize_with, contains_skip}; pub fn enum_de(input: &ItemEnum, cratename: Ident) -> syn::Result<TokenStream2> { let name = &input.ident; let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); let mut where_clause = where_clause.map_or_else( || WhereClause { where_token: Default::default(), predicates: Default::default(), }, Clone::clone, ); let init_method = contains_initialize_with(&input.attrs)?; let mut variant_arms = TokenStream2::new(); for (variant_idx, variant) in input.variants.iter().enumerate() { let variant_idx = u8::try_from(variant_idx).expect("up to 256 enum variants are supported"); let variant_ident = &variant.ident; let mut variant_header = TokenStream2::new(); match &variant.fields { Fields::Named(fields) => { for field in &fields.named { let field_name = field.ident.as_ref().unwrap(); if contains_skip(&field.attrs) { variant_header.extend(quote! { #field_name: Default::default(), }); } else { let field_type = &field.ty; where_clause.predicates.push( syn::parse2(quote! { #field_type: #cratename::BorshDeserialize }) .unwrap(), ); variant_header.extend(quote! { #field_name: #cratename::BorshDeserialize::deserialize(buf)?, }); } } variant_header = quote! { { #variant_header }}; } Fields::Unnamed(fields) => { for field in fields.unnamed.iter() { if contains_skip(&field.attrs) { variant_header.extend(quote! { Default::default(), }); } else { let field_type = &field.ty; where_clause.predicates.push( syn::parse2(quote! { #field_type: #cratename::BorshDeserialize }) .unwrap(), ); variant_header .extend(quote! { #cratename::BorshDeserialize::deserialize(buf)?, }); } } variant_header = quote! { ( #variant_header )}; } Fields::Unit => {} } variant_arms.extend(quote! { #variant_idx => #name::#variant_ident #variant_header , }); } let variant_idx = quote! { let variant_idx: u8 = #cratename::BorshDeserialize::deserialize(buf)?; }; if let Some(method_ident) = init_method { Ok(quote! { impl #impl_generics #cratename::de::BorshDeserialize for #name #ty_generics #where_clause { fn deserialize(buf: &mut &[u8]) -> core::result::Result<Self, #cratename::maybestd::io::Error> { #variant_idx let mut return_value = match variant_idx { #variant_arms _ => { let msg = #cratename::maybestd::format!("Unexpected variant index: {:?}", variant_idx); return Err(#cratename::maybestd::io::Error::new( #cratename::maybestd::io::ErrorKind::InvalidInput, msg, )); } }; return_value.#method_ident(); Ok(return_value) } } }) } else { Ok(quote! { impl #impl_generics #cratename::de::BorshDeserialize for #name #ty_generics #where_clause { fn deserialize(buf: &mut &[u8]) -> core::result::Result<Self, #cratename::maybestd::io::Error> { #variant_idx let return_value = match variant_idx { #variant_arms _ => { let msg = #cratename::maybestd::format!("Unexpected variant index: {:?}", variant_idx); return Err(#cratename::maybestd::io::Error::new( #cratename::maybestd::io::ErrorKind::InvalidInput, msg, )); } }; Ok(return_value) } } }) } }