From c44e3d5764284b2cdef0e5634f23c8944b7e1fa2 Mon Sep 17 00:00:00 2001 From: Kevin <4733573+kczimm@users.noreply.github.com.> Date: Fri, 25 Apr 2025 22:08:40 -0500 Subject: [PATCH 1/3] infer placeholder datatype for IN lists --- datafusion/expr/src/expr.rs | 101 ++++++++++++++++++++++++++++++------ 1 file changed, 84 insertions(+), 17 deletions(-) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 9f6855b69824..d6fd45419361 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -1747,23 +1747,34 @@ impl Expr { pub fn infer_placeholder_types(self, schema: &DFSchema) -> Result<(Expr, bool)> { let mut has_placeholder = false; self.transform(|mut expr| { - // Default to assuming the arguments are the same type - if let Expr::BinaryExpr(BinaryExpr { left, op: _, right }) = &mut expr { - rewrite_placeholder(left.as_mut(), right.as_ref(), schema)?; - rewrite_placeholder(right.as_mut(), left.as_ref(), schema)?; - }; - if let Expr::Between(Between { - expr, - negated: _, - low, - high, - }) = &mut expr - { - rewrite_placeholder(low.as_mut(), expr.as_ref(), schema)?; - rewrite_placeholder(high.as_mut(), expr.as_ref(), schema)?; - } - if let Expr::Placeholder(_) = &expr { - has_placeholder = true; + match &mut expr { + // Default to assuming the arguments are the same type + Expr::BinaryExpr(BinaryExpr { left, op: _, right }) => { + rewrite_placeholder(left.as_mut(), right.as_ref(), schema)?; + rewrite_placeholder(right.as_mut(), left.as_ref(), schema)?; + } + Expr::Between(Between { + expr, + negated: _, + low, + high, + }) => { + rewrite_placeholder(low.as_mut(), expr.as_ref(), schema)?; + rewrite_placeholder(high.as_mut(), expr.as_ref(), schema)?; + } + Expr::InList(InList { + expr, + list, + negated: _, + }) => { + for item in list.iter_mut() { + rewrite_placeholder(item, expr.as_ref(), schema)?; + } + } + Expr::Placeholder(_) => { + has_placeholder = true; + } + _ => {} } Ok(Transformed::yes(expr)) }) @@ -3185,10 +3196,66 @@ mod test { case, lit, qualified_wildcard, wildcard, wildcard_with_options, ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Volatility, }; + use arrow::datatypes::{Field, Schema}; use sqlparser::ast; use sqlparser::ast::{Ident, IdentWithAlias}; use std::any::Any; + #[test] + fn infer_placeholder_in_clause() { + // SELECT * FROM employees WHERE department_id IN ($1, $2, $3); + let column = col("department_id"); + let param_placeholders = vec![ + Expr::Placeholder(Placeholder { + id: "$1".to_string(), + data_type: None, + }), + Expr::Placeholder(Placeholder { + id: "$2".to_string(), + data_type: None, + }), + Expr::Placeholder(Placeholder { + id: "$3".to_string(), + data_type: None, + }), + ]; + let in_list = Expr::InList(InList { + expr: Box::new(column), + list: param_placeholders, + negated: false, + }); + + let schema = Arc::new(Schema::new(vec![ + Field::new("name", DataType::Utf8, true), + Field::new("department_id", DataType::Int32, true), + ])); + let df_schema = DFSchema::try_from(schema).unwrap(); + + let (inferred_expr, contains_placeholder) = + in_list.infer_placeholder_types(&df_schema).unwrap(); + + assert!(contains_placeholder); + + match inferred_expr { + Expr::InList(in_list) => { + for expr in in_list.list { + match expr { + Expr::Placeholder(placeholder) => { + assert_eq!( + placeholder.data_type, + Some(DataType::Int32), + "Placeholder {} should infer Int32", + placeholder.id + ); + } + _ => panic!("Expected Placeholder expression"), + } + } + } + _ => panic!("Expected InList expression"), + } + } + #[test] #[allow(deprecated)] fn format_case_when() -> Result<()> { From 2b5e760b6696b89d13f8c09840d934b82a759fff Mon Sep 17 00:00:00 2001 From: Kevin <4733573+kczimm@users.noreply.github.com.> Date: Fri, 25 Apr 2025 22:38:06 -0500 Subject: [PATCH 2/3] infer placeholder datatype for Expr::Like --- datafusion/expr/src/expr.rs | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index d6fd45419361..2a9d2eace04b 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -1771,6 +1771,10 @@ impl Expr { rewrite_placeholder(item, expr.as_ref(), schema)?; } } + Expr::Like(Like { expr, pattern, .. }) + | Expr::SimilarTo(Like { expr, pattern, .. }) => { + rewrite_placeholder(pattern.as_mut(), expr.as_ref(), schema)?; + } Expr::Placeholder(_) => { has_placeholder = true; } @@ -3256,6 +3260,34 @@ mod test { } } + #[test] + fn infer_placeholder_like() { + // name LIKE $1 + let schema = + Arc::new(Schema::new(vec![Field::new("name", DataType::Utf8, true)])); + let df_schema = DFSchema::try_from(schema).unwrap(); + let expr = Expr::Like(Like { + expr: Box::new(col("name")), + pattern: Box::new(Expr::Placeholder(Placeholder { + id: "$1".to_string(), + data_type: None, + })), + negated: false, + case_insensitive: false, + escape_char: None, + }); + let (inferred_expr, _) = expr.infer_placeholder_types(&df_schema).unwrap(); + match inferred_expr { + Expr::Like(like) => match *like.pattern { + Expr::Placeholder(placeholder) => { + assert_eq!(placeholder.data_type, Some(DataType::Utf8)); + } + _ => panic!("Expected Placeholder"), + }, + _ => panic!("Expected Like"), + } + } + #[test] #[allow(deprecated)] fn format_case_when() -> Result<()> { From 3484bd5839b1c53e06bbbe98ff5f5a1a2c444dc1 Mon Sep 17 00:00:00 2001 From: Kevin <4733573+kczimm@users.noreply.github.com.> Date: Fri, 25 Apr 2025 22:45:38 -0500 Subject: [PATCH 3/3] add tests for Expr::SimilarTo --- datafusion/expr/src/expr.rs | 29 ++++++++++++++++++++++++++--- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 2a9d2eace04b..b8e4204a9c9e 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -3261,12 +3261,13 @@ mod test { } #[test] - fn infer_placeholder_like() { + fn infer_placeholder_like_and_similar_to() { // name LIKE $1 let schema = Arc::new(Schema::new(vec![Field::new("name", DataType::Utf8, true)])); let df_schema = DFSchema::try_from(schema).unwrap(); - let expr = Expr::Like(Like { + + let like = Like { expr: Box::new(col("name")), pattern: Box::new(Expr::Placeholder(Placeholder { id: "$1".to_string(), @@ -3275,7 +3276,10 @@ mod test { negated: false, case_insensitive: false, escape_char: None, - }); + }; + + let expr = Expr::Like(like.clone()); + let (inferred_expr, _) = expr.infer_placeholder_types(&df_schema).unwrap(); match inferred_expr { Expr::Like(like) => match *like.pattern { @@ -3286,6 +3290,25 @@ mod test { }, _ => panic!("Expected Like"), } + + // name SIMILAR TO $1 + let expr = Expr::SimilarTo(like); + + let (inferred_expr, _) = expr.infer_placeholder_types(&df_schema).unwrap(); + match inferred_expr { + Expr::SimilarTo(like) => match *like.pattern { + Expr::Placeholder(placeholder) => { + assert_eq!( + placeholder.data_type, + Some(DataType::Utf8), + "Placeholder {} should infer Utf8", + placeholder.id + ); + } + _ => panic!("Expected Placeholder expression"), + }, + _ => panic!("Expected SimilarTo expression"), + } } #[test]