diff --git a/macros/Cargo.toml b/macros/Cargo.toml index 7c5b0ca..e800b74 100644 --- a/macros/Cargo.toml +++ b/macros/Cargo.toml @@ -15,7 +15,7 @@ proc-macro = true proc-macro2 = "1.0" syn = { version = "1.0", features = ["full"] } quote = "1.0" -uuid = { version = "0.8.1", features = ["v4"] } +uuid = { version = "0.8", features = ["v4"] } [dev-dependencies] intertrait = { version = "=0.2.1", path = ".." } diff --git a/macros/src/gen_caster.rs b/macros/src/gen_caster.rs index f3fdc87..7ef02ff 100644 --- a/macros/src/gen_caster.rs +++ b/macros/src/gen_caster.rs @@ -1,7 +1,6 @@ use std::str::from_utf8_unchecked; use proc_macro2::TokenStream; -use syn::Path; use uuid::adapter::Simple; use uuid::Uuid; @@ -9,7 +8,7 @@ use quote::format_ident; use quote::quote; use quote::ToTokens; -pub fn generate_caster(ty: &impl ToTokens, trait_: &Path, sync: bool) -> TokenStream { +pub fn generate_caster(ty: &impl ToTokens, trait_: &impl ToTokens, sync: bool) -> TokenStream { let mut fn_buf = [0u8; FN_BUF_LEN]; let fn_ident = format_ident!("{}", new_fn_name(&mut fn_buf)); let new_caster = if sync { diff --git a/macros/src/item_impl.rs b/macros/src/item_impl.rs index f9fa52d..44c9166 100644 --- a/macros/src/item_impl.rs +++ b/macros/src/item_impl.rs @@ -1,20 +1,25 @@ +use std::collections::HashSet; + +use PathArguments::AngleBracketed; use proc_macro2::TokenStream; +use quote::{quote, quote_spanned, ToTokens}; +use syn::{AngleBracketedGenericArguments, Binding, GenericArgument, ImplItem, ItemImpl, Path, PathArguments}; +use syn::punctuated::Punctuated; use syn::spanned::Spanned; -use syn::ItemImpl; - -use quote::{quote, quote_spanned}; +use syn::Token; use crate::args::Flag; use crate::gen_caster::generate_caster; -use std::collections::HashSet; pub fn process(flags: &HashSet, input: ItemImpl) -> TokenStream { let ItemImpl { ref self_ty, ref trait_, + ref items, .. } = input; + let generated = match trait_ { None => quote_spanned! { self_ty.span() => compile_error!("#[cast_to] should only be on an impl of a trait"); @@ -23,7 +28,10 @@ pub fn process(flags: &HashSet, input: ItemImpl) -> TokenStream { (Some(bang), _, _) => quote_spanned! { bang.span() => compile_error!("#[cast_to] is not for !Trait impl"); }, - (None, path, _) => generate_caster(self_ty, path, flags.contains(&Flag::Sync)), + (None, path, _) => { + let path = fully_bound_trait(path, items); + generate_caster(self_ty, &path, flags.contains(&Flag::Sync)) + } }, }; @@ -32,3 +40,37 @@ pub fn process(flags: &HashSet, input: ItemImpl) -> TokenStream { #generated } } + +fn fully_bound_trait(path: &Path, items: &Vec) -> impl ToTokens { + let bindings = items.iter() + .filter_map(|item| if let ImplItem::Type(assoc_ty) = item { + Some(GenericArgument::Binding(Binding { + ident: assoc_ty.ident.to_owned(), + eq_token: Default::default(), + ty: assoc_ty.ty.to_owned(), + })) + } else { + None + }).collect::>(); + + let mut path = path.clone(); + + if bindings.is_empty() { + return path; + } + + if let Some(last) = path.segments.last_mut() { + match &mut last.arguments { + PathArguments::None => last.arguments = AngleBracketed(AngleBracketedGenericArguments { + colon2_token: None, + lt_token: Default::default(), + args: bindings, + gt_token: Default::default(), + }), + AngleBracketed(args) => args.args.extend(bindings), + _ => {} + } + } + path +} + diff --git a/tests/on-trait-impl-assoc-type1.rs b/tests/on-trait-impl-assoc-type1.rs new file mode 100644 index 0000000..e63b546 --- /dev/null +++ b/tests/on-trait-impl-assoc-type1.rs @@ -0,0 +1,33 @@ +use std::fmt::Debug; + +use intertrait::*; +use intertrait::cast::*; + +struct I32Data(i32); + +trait Source: CastFrom {} + +trait Producer { + type Output: Debug; + + fn produce(&self) -> Self::Output; +} + +#[cast_to] +impl Producer for I32Data { + type Output = i32; + + fn produce(&self) -> Self::Output { + self.0 + } +} + +impl Source for I32Data {} + +#[test] +fn test_cast_to_on_trait_impl_with_assoc_type1() { + let data = I32Data(100); + let source: &dyn Source = &data; + let producer = source.cast::>(); + assert_eq!(producer.unwrap().produce(), data.0); +} diff --git a/tests/on-trait-impl-assoc-type2.rs b/tests/on-trait-impl-assoc-type2.rs new file mode 100644 index 0000000..198766b --- /dev/null +++ b/tests/on-trait-impl-assoc-type2.rs @@ -0,0 +1,35 @@ +use std::fmt::Debug; + +use intertrait::*; +use intertrait::cast::*; + +struct Data; + +trait Source: CastFrom {} + +trait Concat { + type I1: Debug; + type I2: Debug; + + fn concat(&self, a: Self::I1, b: Self::I2) -> String; +} + +#[cast_to] +impl Concat for Data { + type I1 = i32; + type I2 = &'static str; + + fn concat(&self, a: Self::I1, b: Self::I2) -> String { + format!("Data: {} - {}", a, b) + } +} + +impl Source for Data {} + +#[test] +fn test_cast_to_on_trait_impl_with_assoc_type2() { + let data = Data; + let source: &dyn Source = &data; + let concat = source.cast::>(); + assert_eq!(concat.unwrap().concat(101, "hello"), "Data: 101 - hello"); +} diff --git a/tests/on-trait-impl-assoc-type3.rs b/tests/on-trait-impl-assoc-type3.rs new file mode 100644 index 0000000..079e2bd --- /dev/null +++ b/tests/on-trait-impl-assoc-type3.rs @@ -0,0 +1,35 @@ +use std::fmt::Debug; + +use intertrait::*; +use intertrait::cast::*; + +struct Data; + +trait Source: CastFrom {} + +trait Concat { + type I1: Debug; + type I2: Debug; + + fn concat(&self, prefix: T, a: Self::I1, b: Self::I2) -> String; +} + +#[cast_to] +impl Concat for Data { + type I1 = i32; + type I2 = &'static str; + + fn concat(&self, prefix: String, a: Self::I1, b: Self::I2) -> String { + format!("{}: {} - {}", prefix, a, b) + } +} + +impl Source for Data {} + +#[test] +fn test_cast_to_on_trait_impl_with_assoc_type3() { + let data = Data; + let source: &dyn Source = &data; + let concat = source.cast::>(); + assert_eq!(concat.unwrap().concat("Data".to_owned(), 101, "hello"), "Data: 101 - hello"); +}