Skip to content

Commit a71975d

Browse files
committed
update
1 parent 92d1880 commit a71975d

File tree

1 file changed

+26
-33
lines changed

1 file changed

+26
-33
lines changed

datafusion/src/optimizer/simplify_expressions.rs

Lines changed: 26 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -284,10 +284,10 @@ impl OptimizerRule for SimplifyExpressions {
284284
// expressions that references non-projected columns within the same project plan or its
285285
// children plans.
286286
let mut simplifier =
287-
super::simplify_expressions::Simplifier::new(plan.all_schemas());
287+
Simplifier::new(plan.all_schemas());
288288

289289
let mut const_evaluator =
290-
super::simplify_expressions::ConstEvaluator::new(execution_props);
290+
ConstEvaluator::new(execution_props);
291291

292292
let new_inputs = plan
293293
.inputs()
@@ -649,6 +649,8 @@ impl<'a> ExprRewriter for Simplifier<'a> {
649649

650650
let new_expr = match expr {
651651
BinaryExpr { left, op, right } => match op {
652+
// true = A --> A
653+
// false = A --> !A
652654
Operator::Eq => match (left.as_ref(), right.as_ref()) {
653655
(Literal(Boolean(l)), Literal(Boolean(r))) => match (l, r) {
654656
(Some(l), Some(r)) => Literal(Boolean(Some(l == r))),
@@ -670,6 +672,8 @@ impl<'a> ExprRewriter for Simplifier<'a> {
670672
right,
671673
},
672674
},
675+
// true != A --> !A
676+
// false != A --> A
673677
Operator::NotEq => match (left.as_ref(), right.as_ref()) {
674678
(Literal(Boolean(l)), Literal(Boolean(r))) => match (l, r) {
675679
(Some(l), Some(r)) => Literal(Boolean(Some(l != r))),
@@ -745,6 +749,8 @@ impl<'a> ExprRewriter for Simplifier<'a> {
745749
}
746750
}
747751

752+
753+
748754
#[cfg(test)]
749755
mod tests {
750756
use std::sync::Arc;
@@ -910,12 +916,12 @@ mod tests {
910916
}
911917

912918
#[test]
913-
fn test_simplify_do_not_simplify_arithmetic_expr() {
919+
fn test_simplify_simplify_arithmetic_expr() {
914920
let expr_plus = binary_expr(lit(1), Operator::Plus, lit(1));
915921
let expr_eq = binary_expr(lit(1), Operator::Eq, lit(1));
916922

917-
assert_eq!(simplify(expr_plus.clone()), expr_plus);
918-
assert_eq!(simplify(expr_eq.clone()), expr_eq);
923+
assert_eq!(simplify(expr_plus), lit(2));
924+
assert_eq!(simplify(expr_eq), lit(true));
919925
}
920926

921927
// ------------------------------
@@ -1162,7 +1168,15 @@ mod tests {
11621168
fn simplify(expr: Expr) -> Expr {
11631169
let schema = expr_test_schema();
11641170
let mut rewriter = Simplifier::new(vec![&schema]);
1165-
expr.rewrite(&mut rewriter).expect("expected to simplify")
1171+
1172+
let execution_props = ExecutionProps::new();
1173+
let mut const_evaluator =
1174+
ConstEvaluator::new(&execution_props);
1175+
1176+
expr.rewrite(&mut rewriter)
1177+
.expect("expected to simplify")
1178+
.rewrite(&mut const_evaluator)
1179+
.expect("expected to const evaluate")
11661180
}
11671181

11681182
fn expr_test_schema() -> DFSchemaRef {
@@ -1238,16 +1252,8 @@ mod tests {
12381252
// Make sure c1 column to be used in tests is not boolean type
12391253
assert_eq!(col("c1").get_type(&schema).unwrap(), DataType::Utf8);
12401254

1241-
// don't fold c1 = true
1242-
assert_eq!(simplify(col("c1").eq(lit(true))), col("c1").eq(lit(true)),);
1243-
1244-
// don't fold c1 = false
1245-
assert_eq!(simplify(col("c1").eq(lit(false))), col("c1").eq(lit(false)),);
1246-
1247-
// test constant operands
1248-
assert_eq!(simplify(lit(1).eq(lit(true))), lit(1).eq(lit(true)),);
1249-
1250-
assert_eq!(simplify(lit("a").eq(lit(false))), lit("a").eq(lit(false)),);
1255+
// don't fold c1 = foo
1256+
assert_eq!(simplify(col("c1").eq(lit("foo"))), col("c1").eq(lit("foo")),);
12511257
}
12521258

12531259
#[test]
@@ -1278,21 +1284,8 @@ mod tests {
12781284
assert_eq!(col("c1").get_type(&schema).unwrap(), DataType::Utf8);
12791285

12801286
assert_eq!(
1281-
simplify(col("c1").not_eq(lit(true))),
1282-
col("c1").not_eq(lit(true)),
1283-
);
1284-
1285-
assert_eq!(
1286-
simplify(col("c1").not_eq(lit(false))),
1287-
col("c1").not_eq(lit(false)),
1288-
);
1289-
1290-
// test constants
1291-
assert_eq!(simplify(lit(1).not_eq(lit(true))), lit(1).not_eq(lit(true)),);
1292-
1293-
assert_eq!(
1294-
simplify(lit("a").not_eq(lit(false))),
1295-
lit("a").not_eq(lit(false)),
1287+
simplify(col("c1").not_eq(lit("foo"))),
1288+
col("c1").not_eq(lit("foo")),
12961289
);
12971290
}
12981291

@@ -1303,15 +1296,15 @@ mod tests {
13031296
expr: None,
13041297
when_then_expr: vec![(
13051298
Box::new(col("c2").not_eq(lit(false))),
1306-
Box::new(lit("ok").eq(lit(true))),
1299+
Box::new(lit("ok").eq(lit("not_ok"))),
13071300
)],
13081301
else_expr: Some(Box::new(col("c2").eq(lit(true)))),
13091302
}),
13101303
Expr::Case {
13111304
expr: None,
13121305
when_then_expr: vec![(
13131306
Box::new(col("c2")),
1314-
Box::new(lit("ok").eq(lit(true)))
1307+
Box::new(lit(false))
13151308
)],
13161309
else_expr: Some(Box::new(col("c2"))),
13171310
}

0 commit comments

Comments
 (0)