diff --git a/Cargo.toml b/Cargo.toml index ad358d7..9331cba 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,3 +35,5 @@ postgres-from-row-derive = { path = "postgres-from-row-derive", version = "=0.5. tokio-postgres = { version = "0.7.8", default_features = false } postgres-from-row-derive.workspace = true +[dev-dependencies] +tokio-postgres = { version = "0.7.8", default_features = false, features = ["with-serde_json-1"] } diff --git a/postgres-from-row-derive/src/lib.rs b/postgres-from-row-derive/src/lib.rs index 7db7f17..5c7660b 100644 --- a/postgres-from-row-derive/src/lib.rs +++ b/postgres-from-row-derive/src/lib.rs @@ -1,3 +1,5 @@ +use std::str::FromStr; + use darling::{ast::Data, Error, FromDeriveInput, FromField, ToTokens}; use proc_macro::TokenStream; use proc_macro2::TokenStream as TokenStream2; @@ -125,14 +127,37 @@ struct FromRowField { /// Override the name of the actual sql column instead of using `self.ident`. /// Is not compatible with `flatten` since no column is needed there. rename: Option, + /// Optionally use this function to convert the value from the database into a struct field. + from_fn: Option, + /// Optionally use this function to convert the value from the database into a struct field. + try_from_fn: Option, } impl FromRowField { /// Checks wether this field has a valid combination of attributes fn validate(&self) -> Result<()> { - if self.from.is_some() && self.try_from.is_some() { + match ( + &self.from, + &self.from_fn, + &self.try_from, + &self.try_from_fn, + ) { + (Some(_), None, None, None) => {} + (None, Some(_), None, None) => {} + (None, None, Some(_), None) => {} + (None, None, None, Some(_)) => {} + (None, None, None, None) => {} + _ => { + return Err(Error::custom( + r#"can't use the `#[from_row(*from*)]` attributes together"#, + ) + .into()); + } + } + + if self.flatten && (self.from.is_some() || self.try_from.is_some() || self.from_fn.is_some() || self.try_from_fn.is_some()) { return Err(Error::custom( - r#"can't combine `#[from_row(from = "..")]` with `#[from_row(try_from = "..")]`"#, + r#"can't combine `#[from_row(flatten)]` with one of the `#[from_row(*from*)]` attributes`"#, ) .into()); } @@ -181,11 +206,13 @@ impl FromRowField { let target_ty = &self.target_ty()?; let ty = &self.ty; - predicates.push(if self.flatten { - quote! (#target_ty: postgres_from_row::FromRow) - } else { - quote! (#target_ty: for<'__from_row_lifetime> postgres_from_row::tokio_postgres::types::FromSql<'__from_row_lifetime>) - }); + if self.try_from_fn.is_none() && self.from_fn.is_none() { + predicates.push(if self.flatten { + quote! (#target_ty: postgres_from_row::FromRow) + } else { + quote! (#target_ty: for<'__from_row_lifetime> postgres_from_row::tokio_postgres::types::FromSql<'__from_row_lifetime>) + }); + } if self.from.is_some() { predicates.push(quote!(#ty: std::convert::From<#target_ty>)) @@ -205,7 +232,11 @@ impl FromRowField { let ident = self.ident.as_ref().unwrap(); let column_name = self.column_name(); let field_ty = &self.ty; - let target_ty = self.target_ty()?; + let target_ty = if self.from_fn.is_none() && self.try_from_fn.is_none() { + self.target_ty()? + } else { + quote!(_) + }; let mut base = if self.flatten { quote!(<#target_ty as postgres_from_row::FromRow>::from_row(row)) @@ -213,7 +244,13 @@ impl FromRowField { quote!(postgres_from_row::tokio_postgres::Row::get::<&str, #target_ty>(row, #column_name)) }; - if self.from.is_some() { + if let Some(from_fn) = &self.from_fn { + let from_fn = TokenStream2::from_str(&from_fn)?; + base = quote!(#from_fn(#base)); + } else if let Some(try_from_fn) = &self.try_from_fn { + let try_from_fn = TokenStream2::from_str(&try_from_fn)?; + base = quote!(#try_from_fn(#base).expect("could not convert column")); + } else if self.from.is_some() { base = quote!(<#field_ty as std::convert::From<#target_ty>>::from(#base)); } else if self.try_from.is_some() { base = quote!(<#field_ty as std::convert::TryFrom<#target_ty>>::try_from(#base).expect("could not convert column")); @@ -227,7 +264,11 @@ impl FromRowField { let ident = self.ident.as_ref().unwrap(); let column_name = self.column_name(); let field_ty = &self.ty; - let target_ty = self.target_ty()?; + let target_ty = if self.from_fn.is_none() && self.try_from_fn.is_none() { + self.target_ty()? + } else { + quote!(_) + }; let mut base = if self.flatten { quote!(<#target_ty as postgres_from_row::FromRow>::try_from_row(row)?) @@ -235,7 +276,13 @@ impl FromRowField { quote!(postgres_from_row::tokio_postgres::Row::try_get::<&str, #target_ty>(row, #column_name)?) }; - if self.from.is_some() { + if let Some(from_fn) = &self.from_fn { + let from_fn = TokenStream2::from_str(&from_fn)?; + base = quote!(#from_fn(#base)); + } else if let Some(try_from_fn) = &self.try_from_fn { + let try_from_fn = TokenStream2::from_str(&try_from_fn)?; + base = quote!(#try_from_fn(#base)?); + } else if self.from.is_some() { base = quote!(<#field_ty as std::convert::From<#target_ty>>::from(#base)); } else if self.try_from.is_some() { base = quote!(<#field_ty as std::convert::TryFrom<#target_ty>>::try_from(#base)?); diff --git a/tests/integration.rs b/tests/integration.rs index 7fad406..3ce2335 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -1,5 +1,7 @@ +use std::collections::HashMap; + use postgres_from_row::FromRow; -use tokio_postgres::Row; +use tokio_postgres::{types::Json, Row}; #[derive(FromRow)] #[allow(dead_code)] @@ -8,6 +10,12 @@ pub struct Todo { text: String, #[from_row(flatten)] user: User, + #[from_row(from_fn = "json")] + json: HashMap, +} + +fn json(wrapper: Json) -> T { + wrapper.0 } #[derive(FromRow)]