From 669f972df9398f2355675daab276c412546e0b31 Mon Sep 17 00:00:00 2001 From: Remo Pas Date: Tue, 10 Jan 2023 20:43:15 +0100 Subject: [PATCH] add `from` and `try_from` attributes --- postgres-from-row-derive/src/lib.rs | 88 +++++++++++++++++++++-------- 1 file changed, 63 insertions(+), 25 deletions(-) diff --git a/postgres-from-row-derive/src/lib.rs b/postgres-from-row-derive/src/lib.rs index 4cd9d4b..db037fd 100644 --- a/postgres-from-row-derive/src/lib.rs +++ b/postgres-from-row-derive/src/lib.rs @@ -1,5 +1,5 @@ use darling::ast::{self, Style}; -use darling::{FromDeriveInput, FromField}; +use darling::{FromDeriveInput, FromField, ToTokens}; use proc_macro::TokenStream; use quote::quote; use syn::{parse_macro_input, DeriveInput, Ident}; @@ -60,18 +60,33 @@ impl DeriveFromRow { Style::Struct => fields.fields, }; - let from_row_fields = fields.iter().map(|f| f.generate_from_row(&module)); - let try_from_row_fields = fields.iter().map(|f| f.generate_try_from_row(&module)); + let from_row_fields = fields + .iter() + .map(|f| f.generate_from_row(&module)) + .collect::>>()?; + + let try_from_row_fields = fields + .iter() + .map(|f| f.generate_try_from_row(&module)) + .collect::>>()?; + let original_predicates = where_clause.clone().map(|w| &w.predicates).into_iter(); let mut predicates = Vec::new(); for field in fields.iter() { + let target_ty = &field.target_ty()?; let ty = &field.ty; predicates.push(if field.flatten { - quote! (#ty: postgres_from_row::FromRow) + quote! (#target_ty: postgres_from_row::FromRow) } else { - quote! (#ty: for<'a> #module::types::FromSql<'a>) + quote! (#target_ty: for<'a> #module::types::FromSql<'a>) }); + + if field.from.is_some() { + predicates.push(quote!(#ty: std::convert::From<#target_ty>)) + } else if field.try_from.is_some() { + predicates.push(quote!(#ty: std::convert::From<#target_ty>)) + } } Ok(quote! { @@ -101,38 +116,61 @@ struct FromRowField { ty: syn::Type, #[darling(default)] flatten: bool, + try_from: Option, + from: Option, } impl FromRowField { - fn generate_from_row(&self, module: &Ident) -> proc_macro2::TokenStream { + fn target_ty(&self) -> syn::Result { + if let Some(from) = &self.from { + Ok(from.parse()?) + } else if let Some(try_from) = &self.try_from { + Ok(try_from.parse()?) + } else { + Ok(self.ty.to_token_stream()) + } + } + + fn generate_from_row(&self, module: &Ident) -> syn::Result { let ident = self.ident.as_ref().unwrap(); let str_ident = ident.to_string(); - let ty = &self.ty; + let field_ty = &self.ty; - if self.flatten { - quote! { - #ident: <#ty as postgres_from_row::FromRow>::from_row(row) - } + let target_ty = self.target_ty()?; + + let mut base = if self.flatten { + quote!(<#target_ty as postgres_from_row::FromRow>::from_row(row)) } else { - quote! { - #ident: #module::Row::get::<&str, #ty>(row, #str_ident) - } - } + quote!(#module::Row::get::<&str, #target_ty>(row, #str_ident)) + }; + + if self.from.is_some() { + base = quote!(<#field_ty as std::convert::From<#target_ty>>::from(#base)); + } else if self.try_from.is_some() { + base = quote!(<#field_ty as std::convert::TryFrom<#target_ty>>::try_from(#base).expect("could not convert column")); + }; + + Ok(quote!(#ident: #base)) } - fn generate_try_from_row(&self, module: &Ident) -> proc_macro2::TokenStream { + fn generate_try_from_row(&self, module: &Ident) -> syn::Result { let ident = self.ident.as_ref().unwrap(); let str_ident = ident.to_string(); - let ty = &self.ty; + let field_ty = &self.ty; + let target_ty = self.target_ty()?; - if self.flatten { - quote! { - #ident: <#ty as postgres_from_row::FromRow>::try_from_row(row)? - } + let mut base = if self.flatten { + quote!(<#target_ty as postgres_from_row::FromRow>::try_from_row(row)?) } else { - quote! { - #ident: #module::Row::try_get::<&str, #ty>(row, #str_ident)? - } - } + quote!(#module::Row::try_get::<&str, #target_ty>(row, #str_ident)?) + }; + + if self.from.is_some() { + base = quote!(<#field_ty as std::convert::From<#target_ty>>::from(#base)); + } else if self.try_from.is_some() { + base = quote!(<#field_ty as std::convert::TryFrom<#target_ty>>::try_from(#base)?); + }; + + Ok(quote!(#ident: #base)) } }