1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
use crate::add_like; use crate::mul_helpers::generics_and_exprs; use crate::utils::{AttrParams, MultiFieldData, RefType, State}; use proc_macro2::{Span, TokenStream}; use quote::quote; use std::collections::HashSet; use std::iter; use syn::{DeriveInput, Ident, Result}; pub fn expand(input: &DeriveInput, trait_name: &'static str) -> Result<TokenStream> { let mut state = State::with_attr_params( input, trait_name, quote!(::core::ops), trait_name.to_lowercase(), AttrParams::struct_(vec!["forward"]), )?; if state.default_info.forward { return Ok(add_like::expand(input, trait_name)); } let scalar_ident = &Ident::new("__RhsT", Span::call_site()); state.add_trait_path_type_param(quote!(#scalar_ident)); let multi_field_data = state.enabled_fields_data(); let MultiFieldData { input_type, field_types, ty_generics, trait_path, trait_path_with_params, method_ident, .. } = multi_field_data.clone(); let tys = field_types.iter().collect::<HashSet<_>>(); let tys = tys.iter(); let scalar_iter = iter::repeat(scalar_ident); let trait_path_iter = iter::repeat(trait_path); let type_where_clauses = quote! { where #(#tys: #trait_path_iter<#scalar_iter, Output=#tys>),* }; let (generics, initializers) = generics_and_exprs( multi_field_data.clone(), scalar_ident, type_where_clauses, RefType::No, ); let body = multi_field_data.initializer(&initializers); let (impl_generics, _, where_clause) = generics.split_for_impl(); Ok(quote!( impl#impl_generics #trait_path_with_params for #input_type#ty_generics #where_clause { type Output = #input_type#ty_generics; #[inline] fn #method_ident(self, rhs: #scalar_ident) -> #input_type#ty_generics { #body } } )) }