diff --git a/core/src/codegen.rs b/core/src/codegen.rs index 2ae7672..77d1d0d 100644 --- a/core/src/codegen.rs +++ b/core/src/codegen.rs @@ -12,10 +12,13 @@ pub(crate) fn codegen(expanded: ExpandedConditionalQueryAs) -> proc_macro2::Toke let variant = format_ident!("Variant{}", idx); let output_type = &expanded.output_type; let query_fragments = &arm.query_fragments; - let run_time_bindings = arm - .run_time_bindings - .iter() - .map(|(name, type_override)| quote!(#name #type_override)); + let run_time_bindings = + arm.run_time_bindings + .iter() + .map(|(name, type_override)| match type_override { + Some(ty) => quote!(#name as #ty), + None => quote!(#name), + }); match_arms.push(quote! { (#(#patterns,)*) => { @@ -204,4 +207,25 @@ mod tests { let expanded = crate::expand::expand(lowered).unwrap(); let _codegened = codegen(expanded); } + + #[test] + fn type_override() { + let parsed = syn::parse_str::( + r#" + SomeType, + "{some_binding:ty}", + "#, + ) + .unwrap(); + let analyzed = crate::analyze::analyze(parsed.clone()).unwrap(); + let lowered = crate::lower::lower(analyzed); + let expanded = crate::expand::expand(lowered).unwrap(); + let codegened = codegen(expanded); + + let stringified = codegened.to_string(); + assert!( + stringified.contains(" some_binding as ty"), + "binding type override was not correctly generated: {stringified}" + ); + } } diff --git a/core/src/expand.rs b/core/src/expand.rs index 8015a88..91da91b 100644 --- a/core/src/expand.rs +++ b/core/src/expand.rs @@ -254,15 +254,11 @@ fn expand_run_time_bindings( let binding_name = &fragment_str[..end_of_binding]; let (binding_name, type_override) = if let Some(offset) = binding_name.find(':') { let (binding_name, type_override) = binding_name.split_at(offset); - let type_override = - type_override - .parse::() - .map_err(|err| { - ExpandError::BindingReferenceTypeOverrideParseError( - err, - fragment.span(), - ) - })?; + let type_override = type_override[1..] + .parse::() + .map_err(|err| { + ExpandError::BindingReferenceTypeOverrideParseError(err, fragment.span()) + })?; (binding_name.trim(), Some(type_override)) } else { (binding_name, None) @@ -337,7 +333,7 @@ mod tests { let parsed = syn::parse_str::( r#" SomeType, - "some {foo} {bar} {foo} query", + "some {foo:ty} {bar} {foo} query", "#, ) .unwrap(); @@ -345,6 +341,7 @@ mod tests { let lowered = crate::lower::lower(analyzed); let expanded = expand(lowered).unwrap(); + // Check that run-time binding references are generated properly. assert_eq!( expanded.match_arms[0] .query_fragments @@ -372,5 +369,26 @@ mod tests { ], } ); + + // Check that type overrides are parsed properly. + let run_time_bindings: Vec<_> = expanded.match_arms[0] + .run_time_bindings + .iter() + .map(|(ident, ts)| (ident.to_string(), ts.as_ref().map(|ts| ts.to_string()))) + .collect(); + assert_eq!( + run_time_bindings, + match DATABASE_TYPE { + DatabaseType::PostgreSql => vec![ + ("foo".to_string(), Some("ty".to_string())), + ("bar".to_string(), None), + ], + DatabaseType::MySql | DatabaseType::Sqlite => vec![ + ("foo".to_string(), Some("ty".to_string())), + ("bar".to_string(), None), + ("foo".to_string(), Some("ty".to_string())), + ], + } + ); } }