Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 139 additions & 17 deletions datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1747,23 +1747,38 @@ impl Expr {
pub fn infer_placeholder_types(self, schema: &DFSchema) -> Result<(Expr, bool)> {
let mut has_placeholder = false;
self.transform(|mut expr| {
// Default to assuming the arguments are the same type
if let Expr::BinaryExpr(BinaryExpr { left, op: _, right }) = &mut expr {
rewrite_placeholder(left.as_mut(), right.as_ref(), schema)?;
rewrite_placeholder(right.as_mut(), left.as_ref(), schema)?;
};
if let Expr::Between(Between {
expr,
negated: _,
low,
high,
}) = &mut expr
{
rewrite_placeholder(low.as_mut(), expr.as_ref(), schema)?;
rewrite_placeholder(high.as_mut(), expr.as_ref(), schema)?;
}
if let Expr::Placeholder(_) = &expr {
has_placeholder = true;
match &mut expr {
// Default to assuming the arguments are the same type
Expr::BinaryExpr(BinaryExpr { left, op: _, right }) => {
rewrite_placeholder(left.as_mut(), right.as_ref(), schema)?;
rewrite_placeholder(right.as_mut(), left.as_ref(), schema)?;
}
Expr::Between(Between {
expr,
negated: _,
low,
high,
}) => {
rewrite_placeholder(low.as_mut(), expr.as_ref(), schema)?;
rewrite_placeholder(high.as_mut(), expr.as_ref(), schema)?;
}
Expr::InList(InList {
expr,
list,
negated: _,
}) => {
for item in list.iter_mut() {
rewrite_placeholder(item, expr.as_ref(), schema)?;
}
}
Expr::Like(Like { expr, pattern, .. })
| Expr::SimilarTo(Like { expr, pattern, .. }) => {
rewrite_placeholder(pattern.as_mut(), expr.as_ref(), schema)?;
}
Expr::Placeholder(_) => {
has_placeholder = true;
}
_ => {}
}
Ok(Transformed::yes(expr))
})
Expand Down Expand Up @@ -3185,10 +3200,117 @@ mod test {
case, lit, qualified_wildcard, wildcard, wildcard_with_options, ColumnarValue,
ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Volatility,
};
use arrow::datatypes::{Field, Schema};
use sqlparser::ast;
use sqlparser::ast::{Ident, IdentWithAlias};
use std::any::Any;

#[test]
fn infer_placeholder_in_clause() {
// SELECT * FROM employees WHERE department_id IN ($1, $2, $3);
let column = col("department_id");
let param_placeholders = vec![
Expr::Placeholder(Placeholder {
id: "$1".to_string(),
data_type: None,
}),
Expr::Placeholder(Placeholder {
id: "$2".to_string(),
data_type: None,
}),
Expr::Placeholder(Placeholder {
id: "$3".to_string(),
data_type: None,
}),
];
let in_list = Expr::InList(InList {
expr: Box::new(column),
list: param_placeholders,
negated: false,
});

let schema = Arc::new(Schema::new(vec![
Field::new("name", DataType::Utf8, true),
Field::new("department_id", DataType::Int32, true),
]));
let df_schema = DFSchema::try_from(schema).unwrap();

let (inferred_expr, contains_placeholder) =
in_list.infer_placeholder_types(&df_schema).unwrap();

assert!(contains_placeholder);

match inferred_expr {
Expr::InList(in_list) => {
for expr in in_list.list {
match expr {
Expr::Placeholder(placeholder) => {
assert_eq!(
placeholder.data_type,
Some(DataType::Int32),
"Placeholder {} should infer Int32",
placeholder.id
);
}
_ => panic!("Expected Placeholder expression"),
}
}
}
_ => panic!("Expected InList expression"),
}
}

#[test]
fn infer_placeholder_like_and_similar_to() {
// name LIKE $1
let schema =
Arc::new(Schema::new(vec![Field::new("name", DataType::Utf8, true)]));
let df_schema = DFSchema::try_from(schema).unwrap();

let like = Like {
expr: Box::new(col("name")),
pattern: Box::new(Expr::Placeholder(Placeholder {
id: "$1".to_string(),
data_type: None,
})),
negated: false,
case_insensitive: false,
escape_char: None,
};

let expr = Expr::Like(like.clone());

let (inferred_expr, _) = expr.infer_placeholder_types(&df_schema).unwrap();
match inferred_expr {
Expr::Like(like) => match *like.pattern {
Expr::Placeholder(placeholder) => {
assert_eq!(placeholder.data_type, Some(DataType::Utf8));
}
_ => panic!("Expected Placeholder"),
},
_ => panic!("Expected Like"),
}

// name SIMILAR TO $1
let expr = Expr::SimilarTo(like);

let (inferred_expr, _) = expr.infer_placeholder_types(&df_schema).unwrap();
match inferred_expr {
Expr::SimilarTo(like) => match *like.pattern {
Expr::Placeholder(placeholder) => {
assert_eq!(
placeholder.data_type,
Some(DataType::Utf8),
"Placeholder {} should infer Utf8",
placeholder.id
);
}
_ => panic!("Expected Placeholder expression"),
},
_ => panic!("Expected SimilarTo expression"),
}
}

#[test]
#[allow(deprecated)]
fn format_case_when() -> Result<()> {
Expand Down