Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 28 additions & 4 deletions core/src/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@ pub(crate) fn codegen(expanded: ExpandedConditionalQueryAs) -> proc_macro2::Toke
let variant = format_ident!("Variant{}", idx);
let output_type = &expanded.output_type;
let query_fragments = &arm.query_fragments;
let run_time_bindings = arm
.run_time_bindings
.iter()
.map(|(name, type_override)| quote!(#name #type_override));
let run_time_bindings =
arm.run_time_bindings
.iter()
.map(|(name, type_override)| match type_override {
Some(ty) => quote!(#name as #ty),
None => quote!(#name),
});

match_arms.push(quote! {
(#(#patterns,)*) => {
Expand Down Expand Up @@ -204,4 +207,25 @@ mod tests {
let expanded = crate::expand::expand(lowered).unwrap();
let _codegened = codegen(expanded);
}

#[test]
fn type_override() {
let parsed = syn::parse_str::<crate::parse::ParsedConditionalQueryAs>(
r#"
SomeType,
"{some_binding:ty}",
"#,
)
.unwrap();
let analyzed = crate::analyze::analyze(parsed.clone()).unwrap();
let lowered = crate::lower::lower(analyzed);
let expanded = crate::expand::expand(lowered).unwrap();
let codegened = codegen(expanded);

let stringified = codegened.to_string();
assert!(
stringified.contains(" some_binding as ty"),
"binding type override was not correctly generated: {stringified}"
);
}
}
38 changes: 28 additions & 10 deletions core/src/expand.rs
Original file line number Diff line number Diff line change
Expand Up @@ -254,15 +254,11 @@ fn expand_run_time_bindings(
let binding_name = &fragment_str[..end_of_binding];
let (binding_name, type_override) = if let Some(offset) = binding_name.find(':') {
let (binding_name, type_override) = binding_name.split_at(offset);
let type_override =
type_override
.parse::<proc_macro2::TokenStream>()
.map_err(|err| {
ExpandError::BindingReferenceTypeOverrideParseError(
err,
fragment.span(),
)
})?;
let type_override = type_override[1..]
.parse::<proc_macro2::TokenStream>()
.map_err(|err| {
ExpandError::BindingReferenceTypeOverrideParseError(err, fragment.span())
})?;
(binding_name.trim(), Some(type_override))
} else {
(binding_name, None)
Expand Down Expand Up @@ -337,14 +333,15 @@ mod tests {
let parsed = syn::parse_str::<crate::parse::ParsedConditionalQueryAs>(
r#"
SomeType,
"some {foo} {bar} {foo} query",
"some {foo:ty} {bar} {foo} query",
"#,
)
.unwrap();
let analyzed = crate::analyze::analyze(parsed.clone()).unwrap();
let lowered = crate::lower::lower(analyzed);
let expanded = expand(lowered).unwrap();

// Check that run-time binding references are generated properly.
assert_eq!(
expanded.match_arms[0]
.query_fragments
Expand Down Expand Up @@ -372,5 +369,26 @@ mod tests {
],
}
);

// Check that type overrides are parsed properly.
let run_time_bindings: Vec<_> = expanded.match_arms[0]
.run_time_bindings
.iter()
.map(|(ident, ts)| (ident.to_string(), ts.as_ref().map(|ts| ts.to_string())))
.collect();
assert_eq!(
run_time_bindings,
match DATABASE_TYPE {
DatabaseType::PostgreSql => vec![
("foo".to_string(), Some("ty".to_string())),
("bar".to_string(), None),
],
DatabaseType::MySql | DatabaseType::Sqlite => vec![
("foo".to_string(), Some("ty".to_string())),
("bar".to_string(), None),
("foo".to_string(), Some("ty".to_string())),
],
}
);
}
}