From 7eb996d58467143bec5dc15564c377b8ce16e494 Mon Sep 17 00:00:00 2001 From: Oleks V Date: Fri, 18 Jul 2025 15:45:11 -0700 Subject: [PATCH] chore: use `equals_datatype` for `BinaryExpr` (#16813) * chore: use `equals_datatype` instead of direct type comparison for `BinaryExpr` * chore: use `equals_datatype` instead of direct type comparison for `BinaryExpr` (cherry picked from commit acff1b6bdd288a15755fda36d939b0fbdae144d2) --- .../physical-expr/src/expressions/binary.rs | 65 ++++++++++++++++++- 1 file changed, 63 insertions(+), 2 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index 798e68a459ce..eff948c6a0f4 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -387,8 +387,8 @@ impl PhysicalExpr for BinaryExpr { let input_schema = schema.as_ref(); if left_data_type.is_nested() { - if right_data_type != left_data_type { - return internal_err!("type mismatch"); + if !left_data_type.equals_datatype(&right_data_type) { + return internal_err!("Cannot evaluate binary expression because of type mismatch: left {}, right {} ", left_data_type, right_data_type); } return apply_cmp_for_nested(self.op, &lhs, &rhs); } @@ -5399,4 +5399,65 @@ mod tests { Interval::make(Some(false), Some(false)).unwrap() ); } + + #[test] + fn test_evaluate_nested_type() { + let batch_schema = Arc::new(Schema::new(vec![ + Field::new( + "a", + DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true))), + true, + ), + Field::new( + "b", + DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true))), + true, + ), + ])); + + let mut list_builder_a = ListBuilder::new(Int32Builder::new()); + + list_builder_a.append_value([Some(1)]); + list_builder_a.append_value([Some(2)]); + list_builder_a.append_value([]); + list_builder_a.append_value([None]); + + let list_array_a: ArrayRef = Arc::new(list_builder_a.finish()); + + let mut list_builder_b = ListBuilder::new(Int32Builder::new()); + + list_builder_b.append_value([Some(1)]); + list_builder_b.append_value([Some(2)]); + list_builder_b.append_value([]); + list_builder_b.append_value([None]); + + let list_array_b: ArrayRef = Arc::new(list_builder_b.finish()); + + let batch = + RecordBatch::try_new(batch_schema, vec![list_array_a, list_array_b]).unwrap(); + + let schema = Arc::new(Schema::new(vec![ + Field::new( + "a", + DataType::List(Arc::new(Field::new("foo", DataType::Int32, true))), + true, + ), + Field::new( + "b", + DataType::List(Arc::new(Field::new("bar", DataType::Int32, true))), + true, + ), + ])); + + let a = Arc::new(Column::new("a", 0)) as _; + let b = Arc::new(Column::new("b", 1)) as _; + + let eq_expr = + binary_expr(Arc::clone(&a), Operator::Eq, Arc::clone(&b), &schema).unwrap(); + + let eq_result = eq_expr.evaluate(&batch).unwrap(); + let expected = + BooleanArray::from_iter(vec![Some(true), Some(true), Some(true), Some(true)]); + assert_eq!(eq_result.into_array(4).unwrap().as_boolean(), &expected); + } }