From 8ecc215bb7fe44d8cf9dcb4b90df753f0c50afb7 Mon Sep 17 00:00:00 2001 From: Qingping Hou Date: Sun, 18 Apr 2021 17:31:07 -0700 Subject: [PATCH 01/25] support qualified columns in queries --- datafusion/src/dataframe.rs | 13 +- datafusion/src/datasource/csv.rs | 12 +- datafusion/src/execution/context.rs | 341 +++++----- datafusion/src/execution/dataframe_impl.rs | 15 +- datafusion/src/lib.rs | 10 +- datafusion/src/logical_plan/builder.rs | 335 +++++++--- datafusion/src/logical_plan/dfschema.rs | 139 +++- datafusion/src/logical_plan/expr.rs | 176 +++++- datafusion/src/logical_plan/mod.rs | 11 +- datafusion/src/logical_plan/plan.rs | 81 +-- datafusion/src/optimizer/constant_folding.rs | 32 +- datafusion/src/optimizer/filter_push_down.rs | 220 ++++--- .../src/optimizer/hash_build_probe_order.rs | 9 +- datafusion/src/optimizer/limit_push_down.rs | 6 +- .../src/optimizer/projection_push_down.rs | 141 +++-- datafusion/src/optimizer/utils.rs | 32 +- .../physical_optimizer/coalesce_batches.rs | 15 +- .../src/physical_plan/expressions/binary.rs | 41 +- .../src/physical_plan/expressions/case.rs | 20 +- .../src/physical_plan/expressions/cast.rs | 12 +- .../src/physical_plan/expressions/column.rs | 28 +- .../src/physical_plan/expressions/in_list.rs | 112 +++- .../physical_plan/expressions/is_not_null.rs | 2 +- .../src/physical_plan/expressions/is_null.rs | 5 +- .../src/physical_plan/expressions/min_max.rs | 2 +- .../src/physical_plan/expressions/mod.rs | 8 +- .../src/physical_plan/expressions/not.rs | 4 +- .../src/physical_plan/expressions/try_cast.rs | 9 +- datafusion/src/physical_plan/filter.rs | 4 +- datafusion/src/physical_plan/functions.rs | 4 +- .../src/physical_plan/hash_aggregate.rs | 55 +- datafusion/src/physical_plan/hash_join.rs | 384 +++++++++--- datafusion/src/physical_plan/hash_utils.rs | 90 +-- datafusion/src/physical_plan/parquet.rs | 95 +-- datafusion/src/physical_plan/planner.rs | 319 +++++++--- datafusion/src/physical_plan/projection.rs | 6 +- datafusion/src/physical_plan/repartition.rs | 8 +- datafusion/src/physical_plan/sort.rs | 10 +- datafusion/src/physical_plan/type_coercion.rs | 4 +- datafusion/src/prelude.rs | 2 +- datafusion/src/sql/planner.rs | 592 +++++++++--------- datafusion/src/sql/utils.rs | 12 +- datafusion/src/test/mod.rs | 2 +- datafusion/tests/dataframe.rs | 8 +- datafusion/tests/sql.rs | 18 +- datafusion/tests/user_defined_plan.rs | 21 +- 46 files changed, 2235 insertions(+), 1230 deletions(-) diff --git a/datafusion/src/dataframe.rs b/datafusion/src/dataframe.rs index 9c7c2ef96d6b..59f9fa23cbf8 100644 --- a/datafusion/src/dataframe.rs +++ b/datafusion/src/dataframe.rs @@ -20,7 +20,7 @@ use crate::arrow::record_batch::RecordBatch; use crate::error::Result; use crate::logical_plan::{ - DFSchema, Expr, FunctionRegistry, JoinType, LogicalPlan, Partitioning, + Column, DFSchema, Expr, FunctionRegistry, JoinType, LogicalPlan, Partitioning, }; use std::sync::Arc; @@ -175,7 +175,12 @@ pub trait DataFrame: Send + Sync { /// col("a").alias("a2"), /// col("b").alias("b2"), /// col("c").alias("c2")])?; - /// let join = left.join(right, JoinType::Inner, &["a", "b"], &["a2", "b2"])?; + /// let join = left.join( + /// right, + /// JoinType::Inner, + /// vec![Column::from_name("a".to_string()), Column::from_name("b".to_string())], + /// vec![Column::from_name("a2".to_string()), Column::from_name("b2".to_string())], + /// )?; /// let batches = join.collect().await?; /// # Ok(()) /// # } @@ -184,8 +189,8 @@ pub trait DataFrame: Send + Sync { &self, right: Arc, join_type: JoinType, - left_cols: &[&str], - right_cols: &[&str], + left_cols: Vec, + right_cols: Vec, ) -> Result>; /// Repartition a DataFrame based on a logical partitioning scheme. diff --git a/datafusion/src/datasource/csv.rs b/datafusion/src/datasource/csv.rs index 6f6c9abe0774..2a564e0d4ba7 100644 --- a/datafusion/src/datasource/csv.rs +++ b/datafusion/src/datasource/csv.rs @@ -33,7 +33,7 @@ //! let schema = csvdata.schema(); //! ``` -use arrow::datatypes::SchemaRef; +use arrow::datatypes::{Schema, SchemaRef}; use std::any::Any; use std::string::String; use std::sync::Arc; @@ -123,10 +123,18 @@ impl TableProvider for CsvFile { _filters: &[Expr], limit: Option, ) -> Result> { + let fields = self + .schema + .fields() + .iter() + .map(|f| f.clone()) + .collect::>(); + let schema = Schema::new(fields); + Ok(Arc::new(CsvExec::try_new( &self.path, CsvReadOptions::new() - .schema(&self.schema) + .schema(&schema) .has_header(self.has_header) .delimiter(self.delimiter) .file_extension(self.file_extension.as_str()), diff --git a/datafusion/src/execution/context.rs b/datafusion/src/execution/context.rs index c83ca4d8de5e..cfafa116f70e 100644 --- a/datafusion/src/execution/context.rs +++ b/datafusion/src/execution/context.rs @@ -292,9 +292,10 @@ impl ExecutionContext { &mut self, provider: Arc, ) -> Result> { + // FIMXE: add table name method to table provider? let schema = provider.schema(); let table_scan = LogicalPlan::TableScan { - table_name: "".to_string(), + table_name: None, source: provider, projected_schema: schema.to_dfschema_ref()?, projection: None, @@ -405,22 +406,15 @@ impl ExecutionContext { ) -> Result> { let table_ref = table_ref.into(); let schema = self.state.lock().unwrap().schema_for_ref(table_ref)?; - match schema.table(table_ref.table()) { Some(ref provider) => { - let schema = provider.schema(); - let table_scan = LogicalPlan::TableScan { - table_name: table_ref.table().to_owned(), - source: Arc::clone(provider), - projected_schema: schema.to_dfschema_ref()?, - projection: None, - filters: vec![], - limit: None, - }; - Ok(Arc::new(DataFrameImpl::new( - self.state.clone(), - &LogicalPlanBuilder::from(&table_scan).build()?, - ))) + let plan = LogicalPlanBuilder::scan( + Some(table_ref.table()), + Arc::clone(provider), + None, + )? + .build()?; + Ok(Arc::new(DataFrameImpl::new(self.state.clone(), &plan))) } _ => Err(DataFusionError::Plan(format!( "No table named '{}'", @@ -971,7 +965,6 @@ mod tests { let logical_plan = ctx.optimize(&logical_plan)?; let physical_plan = ctx.create_physical_plan(&logical_plan)?; - println!("{:?}", physical_plan); let results = collect_partitioned(physical_plan).await?; @@ -1043,7 +1036,7 @@ mod tests { _ => panic!("expect optimized_plan to be projection"), } - let expected = "Projection: #c2\ + let expected = "Projection: #test.c2\ \n TableScan: test projection=Some([1])"; assert_eq!(format!("{:?}", optimized_plan), expected); @@ -1066,7 +1059,7 @@ mod tests { let schema: Schema = ctx.table("test").unwrap().schema().clone().into(); assert_eq!(schema.field_with_name("c1")?.is_nullable(), false); - let plan = LogicalPlanBuilder::scan_empty("", &schema, None)? + let plan = LogicalPlanBuilder::scan_empty(None, &schema, None)? .project(vec![col("c1")])? .build()?; @@ -1215,11 +1208,11 @@ mod tests { assert_eq!(results.len(), 1); let expected = vec![ - "+---------+---------+", - "| SUM(c1) | SUM(c2) |", - "+---------+---------+", - "| 60 | 220 |", - "+---------+---------+", + "+--------------+--------------+", + "| SUM(test.c1) | SUM(test.c2) |", + "+--------------+--------------+", + "| 60 | 220 |", + "+--------------+--------------+", ]; assert_batches_sorted_eq!(expected, &results); @@ -1236,11 +1229,11 @@ mod tests { assert_eq!(results.len(), 1); let expected = vec![ - "+---------+---------+", - "| SUM(c1) | SUM(c2) |", - "+---------+---------+", - "| | |", - "+---------+---------+", + "+--------------+--------------+", + "| SUM(test.c1) | SUM(test.c2) |", + "+--------------+--------------+", + "| | |", + "+--------------+--------------+", ]; assert_batches_sorted_eq!(expected, &results); @@ -1253,11 +1246,11 @@ mod tests { assert_eq!(results.len(), 1); let expected = vec![ - "+---------+---------+", - "| AVG(c1) | AVG(c2) |", - "+---------+---------+", - "| 1.5 | 5.5 |", - "+---------+---------+", + "+--------------+--------------+", + "| AVG(test.c1) | AVG(test.c2) |", + "+--------------+--------------+", + "| 1.5 | 5.5 |", + "+--------------+--------------+", ]; assert_batches_sorted_eq!(expected, &results); @@ -1270,11 +1263,11 @@ mod tests { assert_eq!(results.len(), 1); let expected = vec![ - "+---------+---------+", - "| MAX(c1) | MAX(c2) |", - "+---------+---------+", - "| 3 | 10 |", - "+---------+---------+", + "+--------------+--------------+", + "| MAX(test.c1) | MAX(test.c2) |", + "+--------------+--------------+", + "| 3 | 10 |", + "+--------------+--------------+", ]; assert_batches_sorted_eq!(expected, &results); @@ -1287,11 +1280,11 @@ mod tests { assert_eq!(results.len(), 1); let expected = vec![ - "+---------+---------+", - "| MIN(c1) | MIN(c2) |", - "+---------+---------+", - "| 0 | 1 |", - "+---------+---------+", + "+--------------+--------------+", + "| MIN(test.c1) | MIN(test.c2) |", + "+--------------+--------------+", + "| 0 | 1 |", + "+--------------+--------------+", ]; assert_batches_sorted_eq!(expected, &results); @@ -1304,14 +1297,14 @@ mod tests { assert_eq!(results.len(), 1); let expected = vec![ - "+----+---------+", - "| c1 | SUM(c2) |", - "+----+---------+", - "| 0 | 55 |", - "| 1 | 55 |", - "| 2 | 55 |", - "| 3 | 55 |", - "+----+---------+", + "+----+--------------+", + "| c1 | SUM(test.c2) |", + "+----+--------------+", + "| 0 | 55 |", + "| 1 | 55 |", + "| 2 | 55 |", + "| 3 | 55 |", + "+----+--------------+", ]; assert_batches_sorted_eq!(expected, &results); @@ -1324,14 +1317,14 @@ mod tests { assert_eq!(results.len(), 1); let expected = vec![ - "+----+---------+", - "| c1 | AVG(c2) |", - "+----+---------+", - "| 0 | 5.5 |", - "| 1 | 5.5 |", - "| 2 | 5.5 |", - "| 3 | 5.5 |", - "+----+---------+", + "+----+--------------+", + "| c1 | AVG(test.c2) |", + "+----+--------------+", + "| 0 | 5.5 |", + "| 1 | 5.5 |", + "| 2 | 5.5 |", + "| 3 | 5.5 |", + "+----+--------------+", ]; assert_batches_sorted_eq!(expected, &results); @@ -1378,14 +1371,14 @@ mod tests { assert_eq!(results.len(), 1); let expected = vec![ - "+----+---------+", - "| c1 | MAX(c2) |", - "+----+---------+", - "| 0 | 10 |", - "| 1 | 10 |", - "| 2 | 10 |", - "| 3 | 10 |", - "+----+---------+", + "+----+--------------+", + "| c1 | MAX(test.c2) |", + "+----+--------------+", + "| 0 | 10 |", + "| 1 | 10 |", + "| 2 | 10 |", + "| 3 | 10 |", + "+----+--------------+", ]; assert_batches_sorted_eq!(expected, &results); @@ -1398,14 +1391,14 @@ mod tests { assert_eq!(results.len(), 1); let expected = vec![ - "+----+---------+", - "| c1 | MIN(c2) |", - "+----+---------+", - "| 0 | 1 |", - "| 1 | 1 |", - "| 2 | 1 |", - "| 3 | 1 |", - "+----+---------+", + "+----+--------------+", + "| c1 | MIN(test.c2) |", + "+----+--------------+", + "| 0 | 1 |", + "| 1 | 1 |", + "| 2 | 1 |", + "| 3 | 1 |", + "+----+--------------+", ]; assert_batches_sorted_eq!(expected, &results); @@ -1446,11 +1439,11 @@ mod tests { .unwrap(); let expected = vec![ - "+--------------+---------------+---------------+-------------+", - "| COUNT(nanos) | COUNT(micros) | COUNT(millis) | COUNT(secs) |", - "+--------------+---------------+---------------+-------------+", - "| 3 | 3 | 3 | 3 |", - "+--------------+---------------+---------------+-------------+", + "+----------------+-----------------+-----------------+---------------+", + "| COUNT(t.nanos) | COUNT(t.micros) | COUNT(t.millis) | COUNT(t.secs) |", + "+----------------+-----------------+-----------------+---------------+", + "| 3 | 3 | 3 | 3 |", + "+----------------+-----------------+-----------------+---------------+", ]; assert_batches_sorted_eq!(expected, &results); @@ -1473,7 +1466,7 @@ mod tests { let expected = vec![ "+----------------------------+----------------------------+-------------------------+---------------------+", - "| MIN(nanos) | MIN(micros) | MIN(millis) | MIN(secs) |", + "| MIN(t.nanos) | MIN(t.micros) | MIN(t.millis) | MIN(t.secs) |", "+----------------------------+----------------------------+-------------------------+---------------------+", "| 2011-12-13 11:13:10.123450 | 2011-12-13 11:13:10.123450 | 2011-12-13 11:13:10.123 | 2011-12-13 11:13:10 |", "+----------------------------+----------------------------+-------------------------+---------------------+", @@ -1499,7 +1492,7 @@ mod tests { let expected = vec![ "+-------------------------+-------------------------+-------------------------+---------------------+", - "| MAX(nanos) | MAX(micros) | MAX(millis) | MAX(secs) |", + "| MAX(t.nanos) | MAX(t.micros) | MAX(t.millis) | MAX(t.secs) |", "+-------------------------+-------------------------+-------------------------+---------------------+", "| 2021-01-01 05:11:10.432 | 2021-01-01 05:11:10.432 | 2021-01-01 05:11:10.432 | 2021-01-01 05:11:10 |", "+-------------------------+-------------------------+-------------------------+---------------------+", @@ -1550,11 +1543,11 @@ mod tests { assert_eq!(results.len(), 1); let expected = vec![ - "+-----------+-----------+", - "| COUNT(c1) | COUNT(c2) |", - "+-----------+-----------+", - "| 10 | 10 |", - "+-----------+-----------+", + "+----------------+----------------+", + "| COUNT(test.c1) | COUNT(test.c2) |", + "+----------------+----------------+", + "| 10 | 10 |", + "+----------------+----------------+", ]; assert_batches_sorted_eq!(expected, &results); Ok(()) @@ -1566,11 +1559,11 @@ mod tests { assert_eq!(results.len(), 1); let expected = vec![ - "+-----------+-----------+", - "| COUNT(c1) | COUNT(c2) |", - "+-----------+-----------+", - "| 40 | 40 |", - "+-----------+-----------+", + "+----------------+----------------+", + "| COUNT(test.c1) | COUNT(test.c2) |", + "+----------------+----------------+", + "| 40 | 40 |", + "+----------------+----------------+", ]; assert_batches_sorted_eq!(expected, &results); Ok(()) @@ -1582,14 +1575,14 @@ mod tests { assert_eq!(results.len(), 1); let expected = vec![ - "+----+-----------+", - "| c1 | COUNT(c2) |", - "+----+-----------+", - "| 0 | 10 |", - "| 1 | 10 |", - "| 2 | 10 |", - "| 3 | 10 |", - "+----+-----------+", + "+----+----------------+", + "| c1 | COUNT(test.c2) |", + "+----+----------------+", + "| 0 | 10 |", + "| 1 | 10 |", + "| 2 | 10 |", + "| 3 | 10 |", + "+----+----------------+", ]; assert_batches_sorted_eq!(expected, &results); Ok(()) @@ -1634,12 +1627,12 @@ mod tests { assert_eq!(results.len(), 1); let expected = vec![ - "+---------------------+---------+", - "| week | SUM(c2) |", - "+---------------------+---------+", - "| 2020-12-07 00:00:00 | 24 |", - "| 2020-12-14 00:00:00 | 156 |", - "+---------------------+---------+", + "+---------------------+--------------+", + "| week | SUM(test.c2) |", + "+---------------------+--------------+", + "| 2020-12-07 00:00:00 | 24 |", + "| 2020-12-14 00:00:00 | 156 |", + "+---------------------+--------------+", ]; assert_batches_sorted_eq!(expected, &results); @@ -1685,13 +1678,13 @@ mod tests { .expect("ran plan correctly"); let expected = vec![ - "+------+------------+", - "| dict | COUNT(val) |", - "+------+------------+", - "| A | 4 |", - "| B | 1 |", - "| C | 1 |", - "+------+------------+", + "+------+--------------+", + "| dict | COUNT(t.val) |", + "+------+--------------+", + "| A | 4 |", + "| B | 1 |", + "| C | 1 |", + "+------+--------------+", ]; assert_batches_sorted_eq!(expected, &results); @@ -1702,13 +1695,13 @@ mod tests { .expect("ran plan correctly"); let expected = vec![ - "+-----+-------------+", - "| val | COUNT(dict) |", - "+-----+-------------+", - "| 1 | 3 |", - "| 2 | 2 |", - "| 4 | 1 |", - "+-----+-------------+", + "+-----+---------------+", + "| val | COUNT(t.dict) |", + "+-----+---------------+", + "| 1 | 3 |", + "| 2 | 2 |", + "| 4 | 1 |", + "+-----+---------------+", ]; assert_batches_sorted_eq!(expected, &results); } @@ -1809,13 +1802,13 @@ mod tests { let expected = vec! [ - "+---------+-----------------+------------------------+-------------------------+-------------------------+-------------------------+-------------------------+--------------------------+--------------------------+--------------------------+", - "| c_group | COUNT(c_uint64) | COUNT(DISTINCT c_int8) | COUNT(DISTINCT c_int16) | COUNT(DISTINCT c_int32) | COUNT(DISTINCT c_int64) | COUNT(DISTINCT c_uint8) | COUNT(DISTINCT c_uint16) | COUNT(DISTINCT c_uint32) | COUNT(DISTINCT c_uint64) |", - "+---------+-----------------+------------------------+-------------------------+-------------------------+-------------------------+-------------------------+--------------------------+--------------------------+--------------------------+", - "| a | 3 | 2 | 2 | 2 | 2 | 2 | 2 | 2 | 2 |", - "| b | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 |", - "| c | 3 | 2 | 2 | 2 | 2 | 2 | 2 | 2 | 2 |", - "+---------+-----------------+------------------------+-------------------------+-------------------------+-------------------------+-------------------------+--------------------------+--------------------------+--------------------------+", + "+---------+----------------------+-----------------------------+------------------------------+------------------------------+------------------------------+------------------------------+-------------------------------+-------------------------------+-------------------------------+", + "| c_group | COUNT(test.c_uint64) | COUNT(DISTINCT test.c_int8) | COUNT(DISTINCT test.c_int16) | COUNT(DISTINCT test.c_int32) | COUNT(DISTINCT test.c_int64) | COUNT(DISTINCT test.c_uint8) | COUNT(DISTINCT test.c_uint16) | COUNT(DISTINCT test.c_uint32) | COUNT(DISTINCT test.c_uint64) |", + "+---------+----------------------+-----------------------------+------------------------------+------------------------------+------------------------------+------------------------------+-------------------------------+-------------------------------+-------------------------------+", + "| a | 3 | 2 | 2 | 2 | 2 | 2 | 2 | 2 | 2 |", + "| b | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 |", + "| c | 3 | 2 | 2 | 2 | 2 | 2 | 2 | 2 | 2 |", + "+---------+----------------------+-----------------------------+------------------------------+------------------------------+------------------------------+------------------------------+-------------------------------+-------------------------------+-------------------------------+", ]; assert_batches_sorted_eq!(expected, &results); @@ -1836,13 +1829,13 @@ mod tests { assert_eq!(results.len(), 1); let expected = vec![ - "+---------+-----------------+------------------------+-------------------------+-------------------------+-------------------------+-------------------------+--------------------------+--------------------------+--------------------------+", - "| c_group | COUNT(c_uint64) | COUNT(DISTINCT c_int8) | COUNT(DISTINCT c_int16) | COUNT(DISTINCT c_int32) | COUNT(DISTINCT c_int64) | COUNT(DISTINCT c_uint8) | COUNT(DISTINCT c_uint16) | COUNT(DISTINCT c_uint32) | COUNT(DISTINCT c_uint64) |", - "+---------+-----------------+------------------------+-------------------------+-------------------------+-------------------------+-------------------------+--------------------------+--------------------------+--------------------------+", - "| a | 5 | 3 | 3 | 3 | 3 | 3 | 3 | 3 | 3 |", - "| b | 5 | 4 | 4 | 4 | 4 | 4 | 4 | 4 | 4 |", - "| c | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 |", - "+---------+-----------------+------------------------+-------------------------+-------------------------+-------------------------+-------------------------+--------------------------+--------------------------+--------------------------+", + "+---------+----------------------+-----------------------------+------------------------------+------------------------------+------------------------------+------------------------------+-------------------------------+-------------------------------+-------------------------------+", + "| c_group | COUNT(test.c_uint64) | COUNT(DISTINCT test.c_int8) | COUNT(DISTINCT test.c_int16) | COUNT(DISTINCT test.c_int32) | COUNT(DISTINCT test.c_int64) | COUNT(DISTINCT test.c_uint8) | COUNT(DISTINCT test.c_uint16) | COUNT(DISTINCT test.c_uint32) | COUNT(DISTINCT test.c_uint64) |", + "+---------+----------------------+-----------------------------+------------------------------+------------------------------+------------------------------+------------------------------+-------------------------------+-------------------------------+-------------------------------+", + "| a | 5 | 3 | 3 | 3 | 3 | 3 | 3 | 3 | 3 |", + "| b | 5 | 4 | 4 | 4 | 4 | 4 | 4 | 4 | 4 |", + "| c | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 |", + "+---------+----------------------+-----------------------------+------------------------------+------------------------------+------------------------------+------------------------------+-------------------------------+-------------------------------+-------------------------------+", ]; assert_batches_sorted_eq!(expected, &results); @@ -1859,7 +1852,7 @@ mod tests { Field::new("c2", DataType::UInt32, false), ])); - let plan = LogicalPlanBuilder::scan_empty("", schema.as_ref(), None)? + let plan = LogicalPlanBuilder::scan_empty(None, schema.as_ref(), None)? .aggregate(vec![col("c1")], vec![sum(col("c2"))])? .project(vec![col("c1"), col("SUM(c2)").alias("total_salary")])? .build()?; @@ -1957,11 +1950,11 @@ mod tests { .unwrap(); let expected = vec![ - "+---------+", - "| sqrt(i) |", - "+---------+", - "| 1 |", - "+---------+", + "+-----------+", + "| sqrt(t.i) |", + "+-----------+", + "| 1 |", + "+-----------+", ]; let results = plan_and_collect(&mut ctx, "SELECT sqrt(i) FROM t") @@ -2019,11 +2012,11 @@ mod tests { let result = plan_and_collect(&mut ctx, "SELECT \"MY_FUNC\"(i) FROM t").await?; let expected = vec![ - "+------------+", - "| MY_FUNC(i) |", - "+------------+", - "| 1 |", - "+------------+", + "+--------------+", + "| MY_FUNC(t.i) |", + "+--------------+", + "| 1 |", + "+--------------+", ]; assert_batches_eq!(expected, &result); @@ -2037,11 +2030,11 @@ mod tests { .unwrap(); let expected = vec![ - "+--------+", - "| MAX(i) |", - "+--------+", - "| 1 |", - "+--------+", + "+----------+", + "| MAX(t.i) |", + "+----------+", + "| 1 |", + "+----------+", ]; let results = plan_and_collect(&mut ctx, "SELECT max(i) FROM t") @@ -2100,11 +2093,11 @@ mod tests { let result = plan_and_collect(&mut ctx, "SELECT \"MY_AVG\"(i) FROM t").await?; let expected = vec![ - "+-----------+", - "| MY_AVG(i) |", - "+-----------+", - "| 1 |", - "+-----------+", + "+-------------+", + "| MY_AVG(t.i) |", + "+-------------+", + "| 1 |", + "+-------------+", ]; assert_batches_eq!(expected, &result); @@ -2200,11 +2193,11 @@ mod tests { assert_eq!(results.len(), 1); let expected = vec![ - "+---------+---------+-----------------+", - "| SUM(c1) | SUM(c2) | COUNT(UInt8(1)) |", - "+---------+---------+-----------------+", - "| 10 | 110 | 20 |", - "+---------+---------+-----------------+", + "+--------------+--------------+-----------------+", + "| SUM(test.c1) | SUM(test.c2) | COUNT(UInt8(1)) |", + "+--------------+--------------+-----------------+", + "| 10 | 110 | 20 |", + "+--------------+--------------+-----------------+", ]; assert_batches_eq!(expected, &results); @@ -2310,7 +2303,7 @@ mod tests { assert_eq!( format!("{:?}", plan), - "Projection: #a, #b, my_add(#a, #b)\n TableScan: t projection=None" + "Projection: #t.a, #t.b, my_add(#t.a, #t.b)\n TableScan: t projection=None" ); let plan = ctx.optimize(&plan)?; @@ -2318,14 +2311,14 @@ mod tests { let result = collect(plan).await?; let expected = vec![ - "+-----+-----+-------------+", - "| a | b | my_add(a,b) |", - "+-----+-----+-------------+", - "| 1 | 2 | 3 |", - "| 10 | 12 | 22 |", - "| 10 | 12 | 22 |", - "| 100 | 120 | 220 |", - "+-----+-----+-------------+", + "+-----+-----+-----------------+", + "| a | b | my_add(t.a,t.b) |", + "+-----+-----+-----------------+", + "| 1 | 2 | 3 |", + "| 10 | 12 | 22 |", + "| 10 | 12 | 22 |", + "| 100 | 120 | 220 |", + "+-----+-----+-----------------+", ]; assert_batches_eq!(expected, &result); @@ -2428,11 +2421,11 @@ mod tests { let result = plan_and_collect(&mut ctx, "SELECT MY_AVG(a) FROM t").await?; let expected = vec![ - "+-----------+", - "| my_avg(a) |", - "+-----------+", - "| 3 |", - "+-----------+", + "+-------------+", + "| my_avg(t.a) |", + "+-------------+", + "| 3 |", + "+-------------+", ]; assert_batches_eq!(expected, &result); diff --git a/datafusion/src/execution/dataframe_impl.rs b/datafusion/src/execution/dataframe_impl.rs index 2a0c39aa48eb..f178fc4776e7 100644 --- a/datafusion/src/execution/dataframe_impl.rs +++ b/datafusion/src/execution/dataframe_impl.rs @@ -23,8 +23,8 @@ use crate::arrow::record_batch::RecordBatch; use crate::error::Result; use crate::execution::context::{ExecutionContext, ExecutionContextState}; use crate::logical_plan::{ - col, DFSchema, Expr, FunctionRegistry, JoinType, LogicalPlan, LogicalPlanBuilder, - Partitioning, + col, Column, DFSchema, Expr, FunctionRegistry, JoinType, LogicalPlan, + LogicalPlanBuilder, Partitioning, }; use crate::{ dataframe::*, @@ -106,8 +106,8 @@ impl DataFrame for DataFrameImpl { &self, right: Arc, join_type: JoinType, - left_cols: &[&str], - right_cols: &[&str], + left_cols: Vec, + right_cols: Vec, ) -> Result> { let plan = LogicalPlanBuilder::from(&self.plan) .join(&right.to_logical_plan(), join_type, left_cols, right_cols)? @@ -252,7 +252,12 @@ mod tests { let right = test_table()?.select_columns(&["c1", "c3"])?; let left_rows = left.collect().await?; let right_rows = right.collect().await?; - let join = left.join(right, JoinType::Inner, &["c1"], &["c1"])?; + let join = left.join( + right, + JoinType::Inner, + vec![Column::from_name("c1".to_string())], + vec![Column::from_name("c1".to_string())], + )?; let join_rows = join.collect().await?; assert_eq!(100, left_rows.iter().map(|x| x.num_rows()).sum::()); assert_eq!(100, right_rows.iter().map(|x| x.num_rows()).sum::()); diff --git a/datafusion/src/lib.rs b/datafusion/src/lib.rs index 44a8a686a496..5f42fc3d0e23 100644 --- a/datafusion/src/lib.rs +++ b/datafusion/src/lib.rs @@ -95,11 +95,11 @@ //! let pretty_results = arrow::util::pretty::pretty_format_batches(&results)?; //! //! let expected = vec![ -//! "+---+--------+", -//! "| a | MIN(b) |", -//! "+---+--------+", -//! "| 1 | 2 |", -//! "+---+--------+" +//! "+---+----------------+", +//! "| a | MIN(example.b) |", +//! "+---+----------------+", +//! "| 1 | 2 |", +//! "+---+----------------+" //! ]; //! //! assert_eq!(pretty_results.trim().lines().collect::>(), expected); diff --git a/datafusion/src/logical_plan/builder.rs b/datafusion/src/logical_plan/builder.rs index fed82fd23b81..8b90dae52791 100644 --- a/datafusion/src/logical_plan/builder.rs +++ b/datafusion/src/logical_plan/builder.rs @@ -32,12 +32,17 @@ use crate::{ }; use super::dfschema::ToDFSchema; -use super::{ - col, exprlist_to_fields, Expr, JoinType, LogicalPlan, PlanType, StringifiedPlan, +use super::{exprlist_to_fields, Expr, JoinType, LogicalPlan, PlanType, StringifiedPlan}; +use crate::logical_plan::{ + normalize_col, normalize_cols, Column, DFField, DFSchema, DFSchemaRef, Partitioning, }; -use crate::logical_plan::{DFField, DFSchema, DFSchemaRef, Partitioning}; use std::collections::HashSet; +pub enum JoinConstraint { + On, + Using, +} + /// Builder for logical plans /// /// ``` @@ -63,7 +68,7 @@ use std::collections::HashSet; /// // FROM employees /// // WHERE salary < 1000 /// let plan = LogicalPlanBuilder::scan_empty( -/// "employee.csv", +/// Some("employee"), /// &employee_schema(), /// None, /// )? @@ -103,7 +108,7 @@ impl LogicalPlanBuilder { projection: Option>, ) -> Result { let provider = Arc::new(MemTable::try_new(schema, partitions)?); - Self::scan("", provider, projection) + Self::scan(None, provider, projection) } /// Scan a CSV data source @@ -113,7 +118,7 @@ impl LogicalPlanBuilder { projection: Option>, ) -> Result { let provider = Arc::new(CsvFile::try_new(path, options)?); - Self::scan("", provider, projection) + Self::scan(None, provider, projection) } /// Scan a Parquet data source @@ -123,12 +128,12 @@ impl LogicalPlanBuilder { max_concurrency: usize, ) -> Result { let provider = Arc::new(ParquetTable::try_new(path, max_concurrency)?); - Self::scan("", provider, projection) + Self::scan(None, provider, projection) } /// Scan an empty data source, mainly used in tests pub fn scan_empty( - name: &str, + name: Option<&str>, table_schema: &Schema, projection: Option>, ) -> Result { @@ -139,22 +144,54 @@ impl LogicalPlanBuilder { /// Convert a table provider into a builder with a TableScan pub fn scan( - name: &str, + table_name: Option<&str>, provider: Arc, projection: Option>, ) -> Result { + if let Some(name) = table_name { + if name.is_empty() { + return Err(DataFusionError::Plan( + "table_name cannot be empty".to_string(), + )); + } + } + let schema = provider.schema(); let projected_schema = projection .as_ref() - .map(|p| Schema::new(p.iter().map(|i| schema.field(*i).clone()).collect())) - .map_or(schema, SchemaRef::new) - .to_dfschema_ref()?; + .map(|p| DFSchema { + fields: p + .iter() + .map(|i| match table_name { + // FIXME: move if check outside + Some(name) => { + DFField::from_qualified(name, schema.field(*i).clone()) + } + None => DFField::from(schema.field(*i).clone()), + }) + .collect(), + }) + .unwrap_or_else(|| { + // FIXME: remove unwrap + match table_name { + Some(name) => DFSchema::try_from_qualified(name, &schema).unwrap(), + None => DFSchema::new( + schema + .fields() + .iter() + .map(|f| DFField::from(f.clone())) + .collect(), + ) + .unwrap(), + } + }); + // FIXME: check for empty table name let table_scan = LogicalPlan::TableScan { - table_name: name.to_string(), + table_name: table_name.clone().map(|s| s.to_string()), source: provider, - projected_schema, + projected_schema: Arc::new(projected_schema), projection, filters: vec![], limit: None, @@ -171,16 +208,19 @@ impl LogicalPlanBuilder { /// * An invalid expression is used (e.g. a `sort` expression) pub fn project(&self, expr: impl IntoIterator) -> Result { let input_schema = self.plan.schema(); + let all_schemas = self.plan.all_schemas(); let mut projected_expr = vec![]; for e in expr { - match e { + let normalized_e = normalize_col(e, &all_schemas)?; + match normalized_e { Expr::Wildcard => { (0..input_schema.fields().len()).for_each(|i| { - projected_expr.push(col(input_schema.field(i).name())) + projected_expr + .push(Expr::Column(input_schema.field(i).qualified_column())) }); } - _ => projected_expr.push(e), - }; + _ => projected_expr.push(normalized_e), + } } validate_unique_names("Projections", projected_expr.iter(), input_schema)?; @@ -196,6 +236,7 @@ impl LogicalPlanBuilder { /// Apply a filter pub fn filter(&self, expr: Expr) -> Result { + let expr = normalize_col(expr, &self.plan.all_schemas())?; Ok(Self::from(&LogicalPlan::Filter { predicate: expr, input: Arc::new(self.plan.clone()), @@ -211,64 +252,97 @@ impl LogicalPlanBuilder { } /// Apply a sort - pub fn sort(&self, expr: impl IntoIterator) -> Result { + pub fn sort(&self, exprs: impl IntoIterator) -> Result { + let schemas = self.plan.all_schemas(); Ok(Self::from(&LogicalPlan::Sort { - expr: expr.into_iter().collect(), + expr: normalize_cols(exprs, &schemas)?, input: Arc::new(self.plan.clone()), })) } /// Apply a union pub fn union(&self, plan: LogicalPlan) -> Result { - let schema = self.plan.schema(); + Ok(Self::from(&union_with_alias( + self.plan.clone(), + plan, + None, + )?)) + } - if plan.schema() != schema { + /// Apply a join with on constraint + pub fn join( + &self, + right: &LogicalPlan, + join_type: JoinType, + left_keys: Vec, + right_keys: Vec, + ) -> Result { + if left_keys.len() != right_keys.len() { return Err(DataFusionError::Plan( - "Schema's for union should be the same ".to_string(), + "left_keys and right_keys were not the same length".to_string(), )); } - // Add plan to existing union if possible - let mut inputs = match &self.plan { - LogicalPlan::Union { inputs, .. } => inputs.clone(), - _ => vec![self.plan.clone()], - }; - inputs.push(plan); - Ok(Self::from(&LogicalPlan::Union { - inputs, - schema: schema.clone(), - alias: None, + let left_keys: Vec = left_keys + .into_iter() + .map(|c| c.normalize(&self.plan.all_schemas())) + .collect::>()?; + let right_keys: Vec = right_keys + .into_iter() + // FIXME: write a test for this + .map(|c| c.normalize(&right.all_schemas())) + .collect::>()?; + let on: Vec<(_, _)> = left_keys.into_iter().zip(right_keys.into_iter()).collect(); + let join_schema = build_join_schema( + self.plan.schema(), + right.schema(), + &on, + &join_type, + JoinConstraint::On, + )?; + + Ok(Self::from(&LogicalPlan::Join { + left: Arc::new(self.plan.clone()), + right: Arc::new(right.clone()), + on, + join_type, + schema: DFSchemaRef::new(join_schema), })) } - /// Apply a join - pub fn join( + /// Apply a join with using constraint, which duplicates all join columns in output schema. + pub fn join_using( &self, right: &LogicalPlan, join_type: JoinType, - left_keys: &[&str], - right_keys: &[&str], + using_keys: Vec, ) -> Result { - if left_keys.len() != right_keys.len() { - Err(DataFusionError::Plan( - "left_keys and right_keys were not the same length".to_string(), - )) - } else { - let on: Vec<_> = left_keys - .iter() - .zip(right_keys.iter()) - .map(|(x, y)| (x.to_string(), y.to_string())) - .collect::>(); - let join_schema = - build_join_schema(self.plan.schema(), right.schema(), &on, &join_type)?; - Ok(Self::from(&LogicalPlan::Join { - left: Arc::new(self.plan.clone()), - right: Arc::new(right.clone()), - on, - join_type, - schema: DFSchemaRef::new(join_schema), - })) - } + let left_keys: Vec = using_keys + .clone() + .into_iter() + .map(|c| c.normalize(&self.plan.all_schemas())) + .collect::>()?; + let right_keys: Vec = using_keys + .into_iter() + .map(|c| c.normalize(&right.all_schemas())) + .collect::>()?; + + let on: Vec<(_, _)> = left_keys.into_iter().zip(right_keys.into_iter()).collect(); + let join_schema = build_join_schema( + self.plan.schema(), + right.schema(), + &on, + &join_type, + JoinConstraint::Using, + )?; + + Ok(Self::from(&LogicalPlan::Join { + left: Arc::new(self.plan.clone()), + right: Arc::new(right.clone()), + on, + join_type, + schema: DFSchemaRef::new(join_schema), + })) } /// Repartition @@ -287,9 +361,9 @@ impl LogicalPlanBuilder { group_expr: impl IntoIterator, aggr_expr: impl IntoIterator, ) -> Result { - let group_expr = group_expr.into_iter().collect::>(); - let aggr_expr = aggr_expr.into_iter().collect::>(); - + let schemas = self.plan.all_schemas(); + let group_expr = normalize_cols(group_expr, &schemas)?; + let aggr_expr = normalize_cols(aggr_expr, &schemas)?; let all_expr = group_expr.iter().chain(aggr_expr.iter()); validate_unique_names("Aggregations", all_expr.clone(), self.plan.schema())?; @@ -333,40 +407,55 @@ impl LogicalPlanBuilder { fn build_join_schema( left: &DFSchema, right: &DFSchema, - on: &[(String, String)], + on: &[(Column, Column)], join_type: &JoinType, + join_constraint: JoinConstraint, ) -> Result { let fields: Vec = match join_type { JoinType::Inner | JoinType::Left => { - // remove right-side join keys if they have the same names as the left-side - let duplicate_keys = &on - .iter() - .filter(|(l, r)| l == r) - .map(|on| on.1.to_string()) - .collect::>(); + let duplicate_keys = match join_constraint { + JoinConstraint::On => on + .iter() + .filter(|(l, r)| l == r) + .map(|on| on.1.clone()) + .collect::>(), + // using join requires unique join columns in the output schema, so we mark all + // right join keys as duplicate + JoinConstraint::Using => { + on.iter().map(|on| on.1.clone()).collect::>() + } + }; let left_fields = left.fields().iter(); + // remove right-side join keys if they have the same names as the left-side let right_fields = right .fields() .iter() - .filter(|f| !duplicate_keys.contains(f.name())); + .filter(|f| !duplicate_keys.contains(&f.qualified_column())); // left then right left_fields.chain(right_fields).cloned().collect() } JoinType::Right => { - // remove left-side join keys if they have the same names as the right-side - let duplicate_keys = &on - .iter() - .filter(|(l, r)| l == r) - .map(|on| on.1.to_string()) - .collect::>(); + let duplicate_keys = match join_constraint { + JoinConstraint::On => on + .iter() + .filter(|(l, r)| l == r) + .map(|on| on.1.clone()) + .collect::>(), + // using join requires unique join columns in the output schema, so we mark all + // left join keys as duplicate + JoinConstraint::Using => { + on.iter().map(|on| on.0.clone()).collect::>() + } + }; + // remove left-side join keys if they have the same names as the right-side let left_fields = left .fields() .iter() - .filter(|f| !duplicate_keys.contains(f.name())); + .filter(|f| !duplicate_keys.contains(&f.qualified_column())); let right_fields = right.fields().iter(); @@ -374,6 +463,7 @@ fn build_join_schema( left_fields.chain(right_fields).cloned().collect() } }; + DFSchema::new(fields) } @@ -404,17 +494,56 @@ fn validate_unique_names<'a>( }) } +/// Union two logical plans with an optional alias. +pub fn union_with_alias( + left_plan: LogicalPlan, + right_plan: LogicalPlan, + alias: Option, +) -> Result { + let inputs = vec![left_plan, right_plan] + .into_iter() + .flat_map(|p| match p { + LogicalPlan::Union { inputs, .. } => inputs, + x => vec![x], + }) + .collect::>(); + if inputs.is_empty() { + return Err(DataFusionError::Plan("Empty UNION".to_string())); + } + + let union_schema = (**inputs[0].schema()).clone(); + let union_schema = Arc::new(match alias { + Some(ref alias) => union_schema.replace_qualifier(alias.as_str()), + None => union_schema.strip_qualifiers(), + }); + if !inputs.iter().skip(1).all(|input_plan| { + // union changes all qualifers in resulting schema, so we only need to + // match against arrow schema here, which doesn't include qualifiers + union_schema.matches_arrow_schema(&((**input_plan.schema()).clone().into())) + }) { + return Err(DataFusionError::Plan( + "UNION ALL schemas are expected to be the same".to_string(), + )); + } + + Ok(LogicalPlan::Union { + schema: union_schema, + inputs, + alias, + }) +} + #[cfg(test)] mod tests { use arrow::datatypes::{DataType, Field}; - use super::super::{lit, sum}; + use super::super::{col, lit, sum}; use super::*; #[test] fn plan_builder_simple() -> Result<()> { let plan = LogicalPlanBuilder::scan_empty( - "employee.csv", + Some("employee_csv"), &employee_schema(), Some(vec![0, 3]), )? @@ -422,9 +551,9 @@ mod tests { .project(vec![col("id")])? .build()?; - let expected = "Projection: #id\ - \n Filter: #state Eq Utf8(\"CO\")\ - \n TableScan: employee.csv projection=Some([0, 3])"; + let expected = "Projection: #employee_csv.id\ + \n Filter: #employee_csv.state Eq Utf8(\"CO\")\ + \n TableScan: employee_csv projection=Some([0, 3])"; assert_eq!(expected, format!("{:?}", plan)); @@ -434,7 +563,7 @@ mod tests { #[test] fn plan_builder_aggregate() -> Result<()> { let plan = LogicalPlanBuilder::scan_empty( - "employee.csv", + Some("employee_csv"), &employee_schema(), Some(vec![3, 4]), )? @@ -445,9 +574,9 @@ mod tests { .project(vec![col("state"), col("total_salary")])? .build()?; - let expected = "Projection: #state, #total_salary\ - \n Aggregate: groupBy=[[#state]], aggr=[[SUM(#salary) AS total_salary]]\ - \n TableScan: employee.csv projection=Some([3, 4])"; + let expected = "Projection: #employee_csv.state, #total_salary\ + \n Aggregate: groupBy=[[#employee_csv.state]], aggr=[[SUM(#employee_csv.salary) AS total_salary]]\ + \n TableScan: employee_csv projection=Some([3, 4])"; assert_eq!(expected, format!("{:?}", plan)); @@ -457,7 +586,7 @@ mod tests { #[test] fn plan_builder_sort() -> Result<()> { let plan = LogicalPlanBuilder::scan_empty( - "employee.csv", + Some("employee_csv"), &employee_schema(), Some(vec![3, 4]), )? @@ -468,15 +597,15 @@ mod tests { nulls_first: true, }, Expr::Sort { - expr: Box::new(col("total_salary")), + expr: Box::new(col("salary")), asc: false, nulls_first: false, }, ])? .build()?; - let expected = "Sort: #state ASC NULLS FIRST, #total_salary DESC NULLS LAST\ - \n TableScan: employee.csv projection=Some([3, 4])"; + let expected = "Sort: #employee_csv.state ASC NULLS FIRST, #employee_csv.salary DESC NULLS LAST\ + \n TableScan: employee_csv projection=Some([3, 4])"; assert_eq!(expected, format!("{:?}", plan)); @@ -486,7 +615,7 @@ mod tests { #[test] fn plan_builder_union_combined_single_union() -> Result<()> { let plan = LogicalPlanBuilder::scan_empty( - "employee.csv", + Some("employee_csv"), &employee_schema(), Some(vec![3, 4]), )?; @@ -499,10 +628,10 @@ mod tests { // output has only one union let expected = "Union\ - \n TableScan: employee.csv projection=Some([3, 4])\ - \n TableScan: employee.csv projection=Some([3, 4])\ - \n TableScan: employee.csv projection=Some([3, 4])\ - \n TableScan: employee.csv projection=Some([3, 4])"; + \n TableScan: employee_csv projection=Some([3, 4])\ + \n TableScan: employee_csv projection=Some([3, 4])\ + \n TableScan: employee_csv projection=Some([3, 4])\ + \n TableScan: employee_csv projection=Some([3, 4])"; assert_eq!(expected, format!("{:?}", plan)); @@ -512,9 +641,10 @@ mod tests { #[test] fn projection_non_unique_names() -> Result<()> { let plan = LogicalPlanBuilder::scan_empty( - "employee.csv", + Some("employee_csv"), &employee_schema(), - Some(vec![0, 3]), + // project id and first_name by column index + Some(vec![0, 1]), )? // two columns with the same name => error .project(vec![col("id"), col("first_name").alias("id")]); @@ -523,9 +653,8 @@ mod tests { Err(DataFusionError::Plan(e)) => { assert_eq!( e, - "Projections require unique expression names \ - but the expression \"#id\" at position 0 and \"#first_name AS id\" at \ - position 1 have the same name. Consider aliasing (\"AS\") one of them." + "Schema contains qualified field name 'employee_csv.id' \ + and unqualified field name 'id' which would be ambiguous" ); Ok(()) } @@ -538,9 +667,10 @@ mod tests { #[test] fn aggregate_non_unique_names() -> Result<()> { let plan = LogicalPlanBuilder::scan_empty( - "employee.csv", + Some("employee_csv"), &employee_schema(), - Some(vec![0, 3]), + // project state and salary by column index + Some(vec![3, 4]), )? // two columns with the same name => error .aggregate(vec![col("state")], vec![sum(col("salary")).alias("state")]); @@ -549,9 +679,8 @@ mod tests { Err(DataFusionError::Plan(e)) => { assert_eq!( e, - "Aggregations require unique expression names \ - but the expression \"#state\" at position 0 and \"SUM(#salary) AS state\" at \ - position 1 have the same name. Consider aliasing (\"AS\") one of them." + "Schema contains qualified field name 'employee_csv.state' and \ + unqualified field name 'state' which would be ambiguous" ); Ok(()) } diff --git a/datafusion/src/logical_plan/dfschema.rs b/datafusion/src/logical_plan/dfschema.rs index 9adb22b43d07..d39697595559 100644 --- a/datafusion/src/logical_plan/dfschema.rs +++ b/datafusion/src/logical_plan/dfschema.rs @@ -23,6 +23,7 @@ use std::convert::TryFrom; use std::sync::Arc; use crate::error::{DataFusionError, Result}; +use crate::logical_plan::Column; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use std::fmt::{Display, Formatter}; @@ -34,7 +35,7 @@ pub type DFSchemaRef = Arc; #[derive(Debug, Clone, PartialEq, Eq)] pub struct DFSchema { /// Fields - fields: Vec, + pub(crate) fields: Vec, } impl DFSchema { @@ -88,6 +89,7 @@ impl DFSchema { } /// Create a `DFSchema` from an Arrow schema + // FIXME: change to a better name? pub fn try_from_qualified(qualifier: &str, schema: &Schema) -> Result { Self::new( schema @@ -108,6 +110,21 @@ impl DFSchema { Self::new(fields) } + /// Merge a schema into self + pub fn merge(&mut self, other_schema: &DFSchema) { + for field in other_schema.fields() { + // skip duplicate columns + let duplicated_field = match field.qualifier() { + Some(q) => self.field_with_name(Some(q.as_str()), field.name()).is_ok(), + // for unqualifed columns, check as unqualified name + None => self.field_with_unqualified_name(field.name()).is_ok(), + }; + if !duplicated_field { + self.fields.push(field.clone()); + } + } + } + /// Get a list of fields pub fn fields(&self) -> &Vec { &self.fields @@ -119,7 +136,7 @@ impl DFSchema { &self.fields[i] } - /// Find the index of the column with the given name + /// Find the index of the column with the given unqualifed name pub fn index_of(&self, name: &str) -> Result { for i in 0..self.fields.len() { if self.fields[i].name() == name { @@ -129,6 +146,20 @@ impl DFSchema { Err(DataFusionError::Plan(format!("No field named '{}'", name))) } + /// Find the index of the column with the given qualifer and name + pub fn index_of_column(&self, col: &Column) -> Result { + for i in 0..self.fields.len() { + let field = &self.fields[i]; + if field.qualifier() == col.relation.as_ref() && field.name() == &col.name { + return Ok(i); + } + } + Err(DataFusionError::Plan(format!( + "No field matches column '{}'", + col, + ))) + } + /// Find the field with the given name pub fn field_with_name( &self, @@ -150,7 +181,10 @@ impl DFSchema { .filter(|field| field.name() == name) .collect(); match matches.len() { - 0 => Err(DataFusionError::Plan(format!("No field named '{}'", name))), + 0 => Err(DataFusionError::Plan(format!( + "No field with unqualified name '{}'", + name + ))), 1 => Ok(matches[0].to_owned()), _ => Err(DataFusionError::Plan(format!( "Ambiguous reference to field named '{}'", @@ -184,6 +218,61 @@ impl DFSchema { ))), } } + + /// Find the field with the given qualified column + pub fn field_from_qualified_column(&self, column: &Column) -> Result { + match &column.relation { + Some(r) => self.field_with_qualified_name(r, &column.name), + None => self.field_with_unqualified_name(&column.name), + } + } + + pub fn matches_arrow_schema(&self, arrow_schema: &Schema) -> bool { + self.fields + .iter() + .zip(arrow_schema.fields().iter()) + .all(|(dffield, arrowfield)| dffield.name() == arrowfield.name()) + } + + /// Strip all field qualifier in schema + pub fn strip_qualifiers(self) -> Self { + DFSchema { + fields: self + .fields + .into_iter() + .map(|f| { + if f.qualifier().is_some() { + DFField::new( + None, + f.name(), + f.data_type().to_owned(), + f.is_nullable(), + ) + } else { + f + } + }) + .collect(), + } + } + + /// Replace all field qualifier with new value in schema + pub fn replace_qualifier(self, qualifer: &str) -> Self { + DFSchema { + fields: self + .fields + .into_iter() + .map(|f| { + DFField::new( + Some(qualifer.clone()), + f.name(), + f.data_type().to_owned(), + f.is_nullable(), + ) + }) + .collect(), + } + } } impl Into for DFSchema { @@ -195,7 +284,7 @@ impl Into for DFSchema { .map(|f| { if f.qualifier().is_some() { Field::new( - f.qualified_name().as_str(), + f.name().as_str(), f.data_type().to_owned(), f.is_nullable(), ) @@ -208,9 +297,32 @@ impl Into for DFSchema { } } +impl Into for &DFSchema { + /// Convert a schema into a DFSchema + fn into(self) -> Schema { + Schema::new( + self.fields + .iter() + .map(|f| { + if f.qualifier().is_some() { + Field::new( + f.name().as_str(), + f.data_type().to_owned(), + f.is_nullable(), + ) + } else { + f.field.clone() + } + }) + .collect(), + ) + } +} + /// Create a `DFSchema` from an Arrow schema impl TryFrom for DFSchema { type Error = DataFusionError; + // FIXME: change this to reference of schema fn try_from(schema: Schema) -> std::result::Result { Self::new( schema @@ -338,7 +450,7 @@ impl DFField { self.field.is_nullable() } - /// Returns a reference to the `DFField`'s qualified name + /// Returns a string to the `DFField`'s qualified name pub fn qualified_name(&self) -> String { if let Some(relation_name) = &self.qualifier { format!("{}.{}", relation_name, self.field.name()) @@ -347,10 +459,23 @@ impl DFField { } } + /// Builds a qualified column based on self + pub fn qualified_column(&self) -> Column { + Column { + relation: self.qualifier.clone(), + name: self.field.name().to_string(), + } + } + /// Get the optional qualifier pub fn qualifier(&self) -> Option<&String> { self.qualifier.as_ref() } + + /// Get the arrow field + pub fn field(&self) -> &Field { + &self.field + } } #[cfg(test)] @@ -392,8 +517,8 @@ mod tests { fn from_qualified_schema_into_arrow_schema() -> Result<()> { let schema = DFSchema::try_from_qualified("t1", &test_schema_1())?; let arrow_schema: Schema = schema.into(); - let expected = "Field { name: \"t1.c0\", data_type: Boolean, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: None }, \ - Field { name: \"t1.c1\", data_type: Boolean, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: None }"; + let expected = "Field { name: \"c0\", data_type: Boolean, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: None }, \ + Field { name: \"c1\", data_type: Boolean, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: None }"; assert_eq!(expected, arrow_schema.to_string()); Ok(()) } diff --git a/datafusion/src/logical_plan/expr.rs b/datafusion/src/logical_plan/expr.rs index fa9b9e0a2490..d69a5a123c92 100644 --- a/datafusion/src/logical_plan/expr.rs +++ b/datafusion/src/logical_plan/expr.rs @@ -27,7 +27,7 @@ use aggregates::{AccumulatorFunctionImplementation, StateTypeFunction}; use arrow::{compute::can_cast_types, datatypes::DataType}; use crate::error::{DataFusionError, Result}; -use crate::logical_plan::{DFField, DFSchema}; +use crate::logical_plan::{DFField, DFSchema, DFSchemaRef}; use crate::physical_plan::{ aggregates, expressions::binary_operator_data_type, functions, udf::ScalarUDF, }; @@ -35,6 +35,81 @@ use crate::{physical_plan::udaf::AggregateUDF, scalar::ScalarValue}; use functions::{ReturnTypeFunction, ScalarFunctionImplementation, Signature}; use std::collections::HashSet; +/// A named reference to a qualified filed in a schema. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct Column { + /// relation/table name. + pub relation: Option, + /// field/column name. + pub name: String, +} + +impl Column { + /// Create Column from unqualified name. + pub fn from_name(name: String) -> Self { + Self { + relation: None, + name, + } + } + + pub fn from_flat_name(flat_name: &str) -> Self { + use sqlparser::tokenizer::Token; + + let dialect = sqlparser::dialect::GenericDialect {}; + let mut tokenizer = sqlparser::tokenizer::Tokenizer::new(&dialect, flat_name); + // FIXME: remove unwrap + let tokens = tokenizer.tokenize().unwrap(); + + // any expression that's not in the form of foo.bar will be treated as unqualified + // column name + match tokens.as_slice() { + [Token::Word(relation), Token::Period, Token::Word(name)] => Column { + relation: Some(relation.value.clone()), + name: name.value.clone(), + }, + _ => Column { + relation: None, + name: String::from(flat_name), + }, + } + } + + pub fn flat_name(&self) -> String { + match &self.relation { + Some(r) => format!("{}.{}", r, self.name), + None => self.name.clone(), + } + } + + /// Normalize Column with qualifier based on provided dataframe schemas. + pub fn normalize(self, schemas: &[&DFSchemaRef]) -> Result { + if !self.relation.is_none() { + return Ok(self); + } + + for schema in schemas { + if let Ok(field) = schema.field_with_unqualified_name(&self.name) { + return Ok(field.qualified_column()); + } + } + + return Err(DataFusionError::Plan(format!( + "Column {} not found in provided schemas", + self + ))); + } +} + +impl fmt::Display for Column { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match &self.relation { + Some(r) => write!(f, "#{}.{}", r, self.name), + None => write!(f, "#{}", self.name), + } + } +} + /// `Expr` is a central struct of DataFusion's query API, and /// represent logical expressions such as `A + 1`, or `CAST(c1 AS /// int)`. @@ -49,7 +124,7 @@ use std::collections::HashSet; /// ``` /// # use datafusion::logical_plan::*; /// let expr = col("c1"); -/// assert_eq!(expr, Expr::Column("c1".to_string())); +/// assert_eq!(expr, Expr::Column(Column::from_name("c1".to_string()))); /// ``` /// /// ## Create the expression `c1 + c2` to add columns "c1" and "c2" together @@ -83,8 +158,8 @@ use std::collections::HashSet; pub enum Expr { /// An expression with a specific name. Alias(Box, String), - /// A named reference to a field in a schema. - Column(String), + /// A named reference to a qualified filed in a schema. + Column(Column), /// A named reference to a variable in a registry. ScalarVariable(Vec), /// A constant value. @@ -221,10 +296,9 @@ impl Expr { pub fn get_type(&self, schema: &DFSchema) -> Result { match self { Expr::Alias(expr, _) => expr.get_type(schema), - Expr::Column(name) => Ok(schema - .field_with_unqualified_name(name)? - .data_type() - .clone()), + Expr::Column(c) => { + Ok(schema.field_from_qualified_column(c)?.data_type().clone()) + } Expr::ScalarVariable(_) => Ok(DataType::Utf8), Expr::Literal(l) => Ok(l.get_datatype()), Expr::Case { when_then_expr, .. } => when_then_expr[0].1.get_type(schema), @@ -289,9 +363,9 @@ impl Expr { pub fn nullable(&self, input_schema: &DFSchema) -> Result { match self { Expr::Alias(expr, _) => expr.nullable(input_schema), - Expr::Column(name) => Ok(input_schema - .field_with_unqualified_name(name)? - .is_nullable()), + Expr::Column(c) => { + Ok(input_schema.field_from_qualified_column(c)?.is_nullable()) + } Expr::Literal(value) => Ok(value.is_null()), Expr::ScalarVariable(_) => Ok(true), Expr::Case { @@ -345,12 +419,20 @@ impl Expr { /// Returns a [arrow::datatypes::Field] compatible with this expression. pub fn to_field(&self, input_schema: &DFSchema) -> Result { - Ok(DFField::new( - None, //TODO qualifier - &self.name(input_schema)?, - self.get_type(input_schema)?, - self.nullable(input_schema)?, - )) + match self { + Expr::Column(c) => Ok(DFField::new( + c.relation.as_ref().map(|s| s.as_str()), + &c.name, + self.get_type(input_schema)?, + self.nullable(input_schema)?, + )), + _ => Ok(DFField::new( + None, + &self.name(input_schema)?, + self.get_type(input_schema)?, + self.nullable(input_schema)?, + )), + } } /// Wraps this expression in a cast to a target [arrow::datatypes::DataType]. @@ -521,7 +603,7 @@ impl Expr { // recurse (and cover all expression types) let visitor = match self { Expr::Alias(expr, _) => expr.accept(visitor), - Expr::Column(..) => Ok(visitor), + Expr::Column(_) => Ok(visitor), Expr::ScalarVariable(..) => Ok(visitor), Expr::Literal(..) => Ok(visitor), Expr::BinaryExpr { left, right, .. } => { @@ -632,7 +714,7 @@ impl Expr { // recurse into all sub expressions(and cover all expression types) let expr = match self { Expr::Alias(expr, name) => Expr::Alias(rewrite_boxed(expr, rewriter)?, name), - Expr::Column(name) => Expr::Column(name), + Expr::Column(_) => self.clone(), Expr::ScalarVariable(names) => Expr::ScalarVariable(names), Expr::Literal(value) => Expr::Literal(value), Expr::BinaryExpr { left, op, right } => Expr::BinaryExpr { @@ -936,9 +1018,51 @@ pub fn or(left: Expr, right: Expr) -> Expr { } } -/// Create a column expression based on a column name -pub fn col(name: &str) -> Expr { - Expr::Column(name.to_owned()) +/// Create a column expression based on a qualified or unqualified column name +pub fn col(ident: &str) -> Expr { + Expr::Column(Column::from_flat_name(ident)) +} + +/// Recursively normalize all Column expressions in a given expression tree +pub fn normalize_col(e: Expr, schemas: &[&DFSchemaRef]) -> Result { + struct ColumnNormalizer<'a, 'b> { + schemas: &'a [&'b DFSchemaRef], + } + + impl<'a, 'b> ExprRewriter for ColumnNormalizer<'a, 'b> { + fn mutate(&mut self, expr: Expr) -> Result { + if let Expr::Column(ref c) = expr { + // FIXME: reuse ColumnNormalizer::normalize? + if c.relation.is_none() { + for schema in self.schemas { + if let Ok(field) = schema.field_with_unqualified_name(&c.name) { + return Ok(Expr::Column(field.qualified_column())); + } + } + return Err(DataFusionError::Plan(format!( + "Column {} not found in provided schemas", + c.name, + ))); + } + } + + Ok(expr) + } + } + + e.rewrite(&mut ColumnNormalizer { schemas }) +} + +/// Recursively normalize all Column expressions in a list of expression trees +#[inline] +pub fn normalize_cols( + exprs: impl IntoIterator, + schemas: &[&DFSchemaRef], +) -> Result> { + exprs + .into_iter() + .map(|e| normalize_col(e.clone(), schemas)) + .collect() } /// Create an expression to represent the min() aggregate function @@ -1189,7 +1313,7 @@ impl fmt::Debug for Expr { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { Expr::Alias(expr, alias) => write!(f, "{:?} AS {}", expr, alias), - Expr::Column(name) => write!(f, "#{}", name), + Expr::Column(c) => write!(f, "{}", c), Expr::ScalarVariable(var_names) => write!(f, "{}", var_names.join(".")), Expr::Literal(v) => write!(f, "{:?}", v), Expr::Case { @@ -1304,7 +1428,7 @@ fn create_function_name( fn create_name(e: &Expr, input_schema: &DFSchema) -> Result { match e { Expr::Alias(_, name) => Ok(name.clone()), - Expr::Column(name) => Ok(name.clone()), + Expr::Column(c) => Ok(c.flat_name()), Expr::ScalarVariable(variable_names) => Ok(variable_names.join(".")), Expr::Literal(value) => Ok(format!("{:?}", value)), Expr::BinaryExpr { left, op, right } => { @@ -1442,8 +1566,8 @@ mod tests { #[test] fn filter_is_null_and_is_not_null() { - let col_null = Expr::Column("col1".to_string()); - let col_not_null = Expr::Column("col2".to_string()); + let col_null = col("col1"); + let col_not_null = col("col2"); assert_eq!(format!("{:?}", col_null.is_null()), "#col1 IS NULL"); assert_eq!( format!("{:?}", col_not_null.is_not_null()), diff --git a/datafusion/src/logical_plan/mod.rs b/datafusion/src/logical_plan/mod.rs index f9be1ff98300..ba4426d0351c 100644 --- a/datafusion/src/logical_plan/mod.rs +++ b/datafusion/src/logical_plan/mod.rs @@ -29,7 +29,7 @@ mod extension; mod operators; mod plan; mod registry; -pub use builder::LogicalPlanBuilder; +pub use builder::{union_with_alias, LogicalPlanBuilder}; pub use dfschema::{DFField, DFSchema, DFSchemaRef, ToDFSchema}; pub use display::display_schema; pub use expr::{ @@ -37,10 +37,11 @@ pub use expr::{ ceil, character_length, chr, col, combine_filters, concat, concat_ws, cos, count, count_distinct, create_udaf, create_udf, exp, exprlist_to_fields, floor, in_list, initcap, left, length, lit, ln, log10, log2, lower, lpad, ltrim, max, md5, min, - octet_length, or, regexp_match, regexp_replace, repeat, replace, reverse, right, - round, rpad, rtrim, sha224, sha256, sha384, sha512, signum, sin, split_part, sqrt, - starts_with, strpos, substr, sum, tan, to_hex, translate, trim, trunc, upper, when, - Expr, ExprRewriter, ExpressionVisitor, Literal, Recursion, + normalize_col, normalize_cols, octet_length, or, regexp_match, regexp_replace, + repeat, replace, reverse, right, round, rpad, rtrim, sha224, sha256, sha384, sha512, + signum, sin, split_part, sqrt, starts_with, strpos, substr, sum, tan, to_hex, + translate, trim, trunc, upper, when, Column, Expr, ExprRewriter, ExpressionVisitor, + Literal, Recursion, }; pub use extension::UserDefinedLogicalNode; pub use operators::Operator; diff --git a/datafusion/src/logical_plan/plan.rs b/datafusion/src/logical_plan/plan.rs index d1b9b827a5a3..54c61c216c1b 100644 --- a/datafusion/src/logical_plan/plan.rs +++ b/datafusion/src/logical_plan/plan.rs @@ -18,7 +18,6 @@ //! via a logical query plan. use std::{ - cmp::min, fmt::{self, Display}, sync::Arc, }; @@ -28,12 +27,9 @@ use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use crate::datasource::TableProvider; use crate::sql::parser::FileType; -use super::expr::Expr; +use super::display::{GraphvizVisitor, IndentVisitor}; +use super::expr::{Column, Expr}; use super::extension::UserDefinedLogicalNode; -use super::{ - col, - display::{GraphvizVisitor, IndentVisitor}, -}; use crate::logical_plan::dfschema::DFSchemaRef; /// Join type @@ -107,7 +103,7 @@ pub enum LogicalPlan { /// Right input right: Arc, /// Equijoin clause expressed as pairs of (left, right) join columns - on: Vec<(String, String)>, + on: Vec<(Column, Column)>, /// Join type join_type: JoinType, /// The output schema, containing fields from the left and right inputs @@ -132,7 +128,7 @@ pub enum LogicalPlan { /// Produces rows from a table provider by reference or from the context TableScan { /// The name of the table - table_name: String, + table_name: Option, /// The source of the table source: Arc, /// Optional column indices to use as a projection @@ -280,9 +276,10 @@ impl LogicalPlan { result.extend(aggr_expr.clone()); result } - LogicalPlan::Join { on, .. } => { - on.iter().flat_map(|(l, r)| vec![col(l), col(r)]).collect() - } + LogicalPlan::Join { on, .. } => on + .iter() + .flat_map(|(l, r)| vec![Expr::Column(l.clone()), Expr::Column(r.clone())]) + .collect(), LogicalPlan::Sort { expr, .. } => expr.clone(), LogicalPlan::Extension { node } => node.expressions(), // plans without expressions @@ -440,9 +437,9 @@ impl LogicalPlan { /// per node. For example: /// /// ```text - /// Projection: #id - /// Filter: #state Eq Utf8(\"CO\")\ - /// CsvScan: employee.csv projection=Some([0, 3]) + /// Projection: #employee.id + /// Filter: #employee.state Eq Utf8(\"CO\")\ + /// CsvScan: employee projection=Some([0, 3]) /// ``` /// /// ``` @@ -451,15 +448,15 @@ impl LogicalPlan { /// let schema = Schema::new(vec![ /// Field::new("id", DataType::Int32, false), /// ]); - /// let plan = LogicalPlanBuilder::scan_empty("foo.csv", &schema, None).unwrap() + /// let plan = LogicalPlanBuilder::scan_empty(Some("foo_csv"), &schema, None).unwrap() /// .filter(col("id").eq(lit(5))).unwrap() /// .build().unwrap(); /// /// // Format using display_indent /// let display_string = format!("{}", plan.display_indent()); /// - /// assert_eq!("Filter: #id Eq Int32(5)\ - /// \n TableScan: foo.csv projection=None", + /// assert_eq!("Filter: #foo_csv.id Eq Int32(5)\ + /// \n TableScan: foo_csv projection=None", /// display_string); /// ``` pub fn display_indent(&self) -> impl fmt::Display + '_ { @@ -481,9 +478,9 @@ impl LogicalPlan { /// per node that includes the output schema. For example: /// /// ```text - /// Projection: #id [id:Int32]\ - /// Filter: #state Eq Utf8(\"CO\") [id:Int32, state:Utf8]\ - /// TableScan: employee.csv projection=Some([0, 3]) [id:Int32, state:Utf8]"; + /// Projection: #employee.id [id:Int32]\ + /// Filter: #employee.state Eq Utf8(\"CO\") [id:Int32, state:Utf8]\ + /// TableScan: employee projection=Some([0, 3]) [id:Int32, state:Utf8]"; /// ``` /// /// ``` @@ -492,15 +489,15 @@ impl LogicalPlan { /// let schema = Schema::new(vec![ /// Field::new("id", DataType::Int32, false), /// ]); - /// let plan = LogicalPlanBuilder::scan_empty("foo.csv", &schema, None).unwrap() + /// let plan = LogicalPlanBuilder::scan_empty(Some("foo_csv"), &schema, None).unwrap() /// .filter(col("id").eq(lit(5))).unwrap() /// .build().unwrap(); /// /// // Format using display_indent_schema /// let display_string = format!("{}", plan.display_indent_schema()); /// - /// assert_eq!("Filter: #id Eq Int32(5) [id:Int32]\ - /// \n TableScan: foo.csv projection=None [id:Int32]", + /// assert_eq!("Filter: #foo_csv.id Eq Int32(5) [id:Int32]\ + /// \n TableScan: foo_csv projection=None [id:Int32]", /// display_string); /// ``` pub fn display_indent_schema(&self) -> impl fmt::Display + '_ { @@ -532,7 +529,7 @@ impl LogicalPlan { /// let schema = Schema::new(vec![ /// Field::new("id", DataType::Int32, false), /// ]); - /// let plan = LogicalPlanBuilder::scan_empty("foo.csv", &schema, None).unwrap() + /// let plan = LogicalPlanBuilder::scan_empty(Some("foo.csv"), &schema, None).unwrap() /// .filter(col("id").eq(lit(5))).unwrap() /// .build().unwrap(); /// @@ -591,7 +588,7 @@ impl LogicalPlan { /// let schema = Schema::new(vec![ /// Field::new("id", DataType::Int32, false), /// ]); - /// let plan = LogicalPlanBuilder::scan_empty("foo.csv", &schema, None).unwrap() + /// let plan = LogicalPlanBuilder::scan_empty(Some("foo.csv"), &schema, None).unwrap() /// .build().unwrap(); /// /// // Format using display @@ -614,11 +611,16 @@ impl LogicalPlan { ref limit, .. } => { - let sep = " ".repeat(min(1, table_name.len())); + let sep = match table_name { + Some(_) => " ", + None => "", + }; write!( f, "TableScan: {}{}projection={:?}", - table_name, sep, projection + table_name.as_ref().map(|s| s.as_str()).unwrap_or(""), + sep, + projection )?; if !filters.is_empty() { @@ -779,7 +781,7 @@ mod tests { fn display_plan() -> LogicalPlan { LogicalPlanBuilder::scan_empty( - "employee.csv", + Some("employee_csv"), &employee_schema(), Some(vec![0, 3]), ) @@ -796,9 +798,9 @@ mod tests { fn test_display_indent() { let plan = display_plan(); - let expected = "Projection: #id\ - \n Filter: #state Eq Utf8(\"CO\")\ - \n TableScan: employee.csv projection=Some([0, 3])"; + let expected = "Projection: #employee_csv.id\ + \n Filter: #employee_csv.state Eq Utf8(\"CO\")\ + \n TableScan: employee_csv projection=Some([0, 3])"; assert_eq!(expected, format!("{}", plan.display_indent())); } @@ -807,9 +809,9 @@ mod tests { fn test_display_indent_schema() { let plan = display_plan(); - let expected = "Projection: #id [id:Int32]\ - \n Filter: #state Eq Utf8(\"CO\") [id:Int32, state:Utf8]\ - \n TableScan: employee.csv projection=Some([0, 3]) [id:Int32, state:Utf8]"; + let expected = "Projection: #employee_csv.id [id:Int32]\ + \n Filter: #employee_csv.state Eq Utf8(\"CO\") [id:Int32, state:Utf8]\ + \n TableScan: employee_csv projection=Some([0, 3]) [id:Int32, state:Utf8]"; assert_eq!(expected, format!("{}", plan.display_indent_schema())); } @@ -831,12 +833,12 @@ mod tests { ); assert!( graphviz.contains( - r#"[shape=box label="TableScan: employee.csv projection=Some([0, 3])"]"# + r#"[shape=box label="TableScan: employee_csv projection=Some([0, 3])"]"# ), "\n{}", plan.display_graphviz() ); - assert!(graphviz.contains(r#"[shape=box label="TableScan: employee.csv projection=Some([0, 3])\nSchema: [id:Int32, state:Utf8]"]"#), + assert!(graphviz.contains(r#"[shape=box label="TableScan: employee_csv projection=Some([0, 3])\nSchema: [id:Int32, state:Utf8]"]"#), "\n{}", plan.display_graphviz()); assert!( graphviz.contains(r#"// End DataFusion GraphViz Plan"#), @@ -1081,9 +1083,12 @@ mod tests { } fn test_plan() -> LogicalPlan { - let schema = Schema::new(vec![Field::new("id", DataType::Int32, false)]); + let schema = Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("state", DataType::Utf8, false), + ]); - LogicalPlanBuilder::scan_empty("", &schema, Some(vec![0])) + LogicalPlanBuilder::scan_empty(None, &schema, Some(vec![0, 1])) .unwrap() .filter(col("state").eq(lit("CO"))) .unwrap() diff --git a/datafusion/src/optimizer/constant_folding.rs b/datafusion/src/optimizer/constant_folding.rs index 2fa03eb5c709..a40c3e97a16d 100644 --- a/datafusion/src/optimizer/constant_folding.rs +++ b/datafusion/src/optimizer/constant_folding.rs @@ -224,7 +224,7 @@ mod tests { Field::new("c", DataType::Boolean, false), Field::new("d", DataType::UInt32, false), ]); - LogicalPlanBuilder::scan_empty("test", &schema, None)?.build() + LogicalPlanBuilder::scan_empty(Some("test"), &schema, None)?.build() } fn expr_test_schema() -> DFSchemaRef { @@ -473,9 +473,9 @@ mod tests { .build()?; let expected = "\ - Projection: #a\ - \n Filter: NOT #c\ - \n Filter: #b\ + Projection: #test.a\ + \n Filter: NOT #test.c\ + \n Filter: #test.b\ \n TableScan: test projection=None"; assert_optimized_plan_eq(&plan, expected); @@ -493,10 +493,10 @@ mod tests { .build()?; let expected = "\ - Projection: #a\ + Projection: #test.a\ \n Limit: 1\ - \n Filter: #c\ - \n Filter: NOT #b\ + \n Filter: #test.c\ + \n Filter: NOT #test.b\ \n TableScan: test projection=None"; assert_optimized_plan_eq(&plan, expected); @@ -512,8 +512,8 @@ mod tests { .build()?; let expected = "\ - Projection: #a\ - \n Filter: NOT #b And #c\ + Projection: #test.a\ + \n Filter: NOT #test.b And #test.c\ \n TableScan: test projection=None"; assert_optimized_plan_eq(&plan, expected); @@ -529,8 +529,8 @@ mod tests { .build()?; let expected = "\ - Projection: #a\ - \n Filter: NOT #b Or NOT #c\ + Projection: #test.a\ + \n Filter: NOT #test.b Or NOT #test.c\ \n TableScan: test projection=None"; assert_optimized_plan_eq(&plan, expected); @@ -546,8 +546,8 @@ mod tests { .build()?; let expected = "\ - Projection: #a\ - \n Filter: #b\ + Projection: #test.a\ + \n Filter: #test.b\ \n TableScan: test projection=None"; assert_optimized_plan_eq(&plan, expected); @@ -562,7 +562,7 @@ mod tests { .build()?; let expected = "\ - Projection: #a, #d, NOT #b\ + Projection: #test.a, #test.d, NOT #test.b\ \n TableScan: test projection=None"; assert_optimized_plan_eq(&plan, expected); @@ -581,8 +581,8 @@ mod tests { .build()?; let expected = "\ - Aggregate: groupBy=[[#a, #c]], aggr=[[MAX(#b), MIN(#b)]]\ - \n Projection: #a, #c, #b\ + Aggregate: groupBy=[[#test.a, #test.c]], aggr=[[MAX(#test.b), MIN(#test.b)]]\ + \n Projection: #test.a, #test.c, #test.b\ \n TableScan: test projection=None"; assert_optimized_plan_eq(&plan, expected); diff --git a/datafusion/src/optimizer/filter_push_down.rs b/datafusion/src/optimizer/filter_push_down.rs index ec260a41dc57..56b7dab0ccd6 100644 --- a/datafusion/src/optimizer/filter_push_down.rs +++ b/datafusion/src/optimizer/filter_push_down.rs @@ -15,7 +15,7 @@ //! Filter Push Down optimizer rule ensures that filters are applied as early as possible in the plan use crate::datasource::datasource::TableProviderFilterPushDown; -use crate::logical_plan::{and, LogicalPlan}; +use crate::logical_plan::{and, Column, LogicalPlan}; use crate::logical_plan::{DFSchema, Expr}; use crate::optimizer::optimizer::OptimizerRule; use crate::optimizer::utils; @@ -55,15 +55,15 @@ pub struct FilterPushDown {} #[derive(Debug, Clone, Default)] struct State { // (predicate, columns on the predicate) - filters: Vec<(Expr, HashSet)>, + filters: Vec<(Expr, HashSet)>, } -type Predicates<'a> = (Vec<&'a Expr>, Vec<&'a HashSet>); +type Predicates<'a> = (Vec<&'a Expr>, Vec<&'a HashSet>); /// returns all predicates in `state` that depend on any of `used_columns` fn get_predicates<'a>( state: &'a State, - used_columns: &HashSet, + used_columns: &HashSet, ) -> Predicates<'a> { state .filters @@ -88,19 +88,19 @@ fn get_join_predicates<'a>( left: &DFSchema, right: &DFSchema, ) -> ( - Vec<&'a HashSet>, - Vec<&'a HashSet>, + Vec<&'a HashSet>, + Vec<&'a HashSet>, Predicates<'a>, ) { let left_columns = &left .fields() .iter() - .map(|f| f.name().clone()) + .map(|f| f.qualified_column()) .collect::>(); let right_columns = &right .fields() .iter() - .map(|f| f.name().clone()) + .map(|f| f.qualified_column()) .collect::>(); let filters = state @@ -172,9 +172,9 @@ fn add_filter(plan: LogicalPlan, predicates: &[&Expr]) -> LogicalPlan { // remove all filters from `filters` that are in `predicate_columns` fn remove_filters( - filters: &[(Expr, HashSet)], - predicate_columns: &[&HashSet], -) -> Vec<(Expr, HashSet)> { + filters: &[(Expr, HashSet)], + predicate_columns: &[&HashSet], +) -> Vec<(Expr, HashSet)> { filters .iter() .filter(|(_, columns)| !predicate_columns.contains(&columns)) @@ -184,9 +184,9 @@ fn remove_filters( // keeps all filters from `filters` that are in `predicate_columns` fn keep_filters( - filters: &[(Expr, HashSet)], - predicate_columns: &[&HashSet], -) -> Vec<(Expr, HashSet)> { + filters: &[(Expr, HashSet)], + predicate_columns: &[&HashSet], +) -> Vec<(Expr, HashSet)> { filters .iter() .filter(|(_, columns)| predicate_columns.contains(&columns)) @@ -198,7 +198,7 @@ fn keep_filters( /// in `state` depend on the columns `used_columns`. fn issue_filters( mut state: State, - used_columns: HashSet, + used_columns: HashSet, plan: &LogicalPlan, ) -> Result { let (predicates, predicate_columns) = get_predicates(&state, &used_columns); @@ -240,7 +240,7 @@ fn optimize(plan: &LogicalPlan, mut state: State) -> Result { predicates .into_iter() .try_for_each::<_, Result<()>>(|predicate| { - let mut columns: HashSet = HashSet::new(); + let mut columns: HashSet = HashSet::new(); utils::expr_to_column_names(predicate, &mut columns)?; // collect the predicate state.filters.push((predicate.clone(), columns)); @@ -264,7 +264,7 @@ fn optimize(plan: &LogicalPlan, mut state: State) -> Result { expr => expr.clone(), }; - projection.insert(field.name().clone(), expr); + projection.insert(field.qualified_name().clone(), expr); }); // re-write all filters based on this projection @@ -294,7 +294,7 @@ fn optimize(plan: &LogicalPlan, mut state: State) -> Result { let agg_columns = aggr_expr .iter() - .map(|x| x.name(input.schema())) + .map(|x| Ok(Column::from_name(x.name(input.schema())?))) .collect::>>()?; used_columns.extend(agg_columns); @@ -310,7 +310,7 @@ fn optimize(plan: &LogicalPlan, mut state: State) -> Result { .schema() .fields() .iter() - .map(|f| f.name().clone()) + .map(|f| f.qualified_column()) .collect::>(); issue_filters(state, used_columns, plan) } @@ -387,7 +387,7 @@ fn optimize(plan: &LogicalPlan, mut state: State) -> Result { .schema() .fields() .iter() - .map(|f| f.name().clone()) + .map(|f| f.qualified_column()) .collect::>(); issue_filters(state, used_columns, plan) } @@ -420,8 +420,8 @@ fn rewrite(expr: &Expr, projection: &HashMap) -> Result { .map(|e| rewrite(e, &projection)) .collect::>>()?; - if let Expr::Column(name) = expr { - if let Some(expr) = projection.get(name) { + if let Expr::Column(c) = expr { + if let Some(expr) = projection.get(&c.flat_name()) { return Ok(expr.clone()); } } @@ -456,8 +456,8 @@ mod tests { .build()?; // filter is before projection let expected = "\ - Projection: #a, #b\ - \n Filter: #a Eq Int64(1)\ + Projection: #test.a, #test.b\ + \n Filter: #test.a Eq Int64(1)\ \n TableScan: test projection=None"; assert_optimized_plan_eq(&plan, expected); Ok(()) @@ -473,9 +473,9 @@ mod tests { .build()?; // filter is before single projection let expected = "\ - Filter: #a Eq Int64(1)\ + Filter: #test.a Eq Int64(1)\ \n Limit: 10\ - \n Projection: #a, #b\ + \n Projection: #test.a, #test.b\ \n TableScan: test projection=None"; assert_optimized_plan_eq(&plan, expected); Ok(()) @@ -491,9 +491,9 @@ mod tests { .build()?; // filter is before double projection let expected = "\ - Projection: #c, #b\ - \n Projection: #a, #b, #c\ - \n Filter: #a Eq Int64(1)\ + Projection: #test.c, #test.b\ + \n Projection: #test.a, #test.b, #test.c\ + \n Filter: #test.a Eq Int64(1)\ \n TableScan: test projection=None"; assert_optimized_plan_eq(&plan, expected); Ok(()) @@ -508,8 +508,8 @@ mod tests { .build()?; // filter of key aggregation is commutative let expected = "\ - Aggregate: groupBy=[[#a]], aggr=[[SUM(#b) AS total_salary]]\ - \n Filter: #a Gt Int64(10)\ + Aggregate: groupBy=[[#test.a]], aggr=[[SUM(#test.b) AS total_salary]]\ + \n Filter: #test.a Gt Int64(10)\ \n TableScan: test projection=None"; assert_optimized_plan_eq(&plan, expected); Ok(()) @@ -525,7 +525,7 @@ mod tests { // filter of aggregate is after aggregation since they are non-commutative let expected = "\ Filter: #b Gt Int64(10)\ - \n Aggregate: groupBy=[[#a]], aggr=[[SUM(#b) AS b]]\ + \n Aggregate: groupBy=[[#test.a]], aggr=[[SUM(#test.b) AS b]]\ \n TableScan: test projection=None"; assert_optimized_plan_eq(&plan, expected); Ok(()) @@ -541,8 +541,8 @@ mod tests { .build()?; // filter is before projection let expected = "\ - Projection: #a AS b, #c\ - \n Filter: #a Eq Int64(1)\ + Projection: #test.a AS b, #test.c\ + \n Filter: #test.a Eq Int64(1)\ \n TableScan: test projection=None"; assert_optimized_plan_eq(&plan, expected); Ok(()) @@ -581,14 +581,14 @@ mod tests { format!("{:?}", plan), "\ Filter: #b Eq Int64(1)\ - \n Projection: #a Multiply Int32(2) Plus #c AS b, #c\ + \n Projection: #test.a Multiply Int32(2) Plus #test.c AS b, #test.c\ \n TableScan: test projection=None" ); // filter is before projection let expected = "\ - Projection: #a Multiply Int32(2) Plus #c AS b, #c\ - \n Filter: #a Multiply Int32(2) Plus #c Eq Int64(1)\ + Projection: #test.a Multiply Int32(2) Plus #test.c AS b, #test.c\ + \n Filter: #test.a Multiply Int32(2) Plus #test.c Eq Int64(1)\ \n TableScan: test projection=None"; assert_optimized_plan_eq(&plan, expected); Ok(()) @@ -613,16 +613,16 @@ mod tests { format!("{:?}", plan), "\ Filter: #a Eq Int64(1)\ - \n Projection: #b Multiply Int32(3) AS a, #c\ - \n Projection: #a Multiply Int32(2) Plus #c AS b, #c\ + \n Projection: #b Multiply Int32(3) AS a, #test.c\ + \n Projection: #test.a Multiply Int32(2) Plus #test.c AS b, #test.c\ \n TableScan: test projection=None" ); // filter is before the projections let expected = "\ - Projection: #b Multiply Int32(3) AS a, #c\ - \n Projection: #a Multiply Int32(2) Plus #c AS b, #c\ - \n Filter: #a Multiply Int32(2) Plus #c Multiply Int32(3) Eq Int64(1)\ + Projection: #b Multiply Int32(3) AS a, #test.c\ + \n Projection: #test.a Multiply Int32(2) Plus #test.c AS b, #test.c\ + \n Filter: #test.a Multiply Int32(2) Plus #test.c Multiply Int32(3) Eq Int64(1)\ \n TableScan: test projection=None"; assert_optimized_plan_eq(&plan, expected); Ok(()) @@ -638,26 +638,26 @@ mod tests { .project(vec![col("a").alias("b"), col("c")])? .aggregate(vec![col("b")], vec![sum(col("c"))])? .filter(col("b").gt(lit(10i64)))? - .filter(col("SUM(c)").gt(lit(10i64)))? + .filter(col("SUM(test.c)").gt(lit(10i64)))? .build()?; // not part of the test, just good to know: assert_eq!( format!("{:?}", plan), "\ - Filter: #SUM(c) Gt Int64(10)\ + Filter: #SUM(test.c) Gt Int64(10)\ \n Filter: #b Gt Int64(10)\ - \n Aggregate: groupBy=[[#b]], aggr=[[SUM(#c)]]\ - \n Projection: #a AS b, #c\ + \n Aggregate: groupBy=[[#b]], aggr=[[SUM(#test.c)]]\ + \n Projection: #test.a AS b, #test.c\ \n TableScan: test projection=None" ); // filter is before the projections let expected = "\ - Filter: #SUM(c) Gt Int64(10)\ - \n Aggregate: groupBy=[[#b]], aggr=[[SUM(#c)]]\ - \n Projection: #a AS b, #c\ - \n Filter: #a Gt Int64(10)\ + Filter: #SUM(test.c) Gt Int64(10)\ + \n Aggregate: groupBy=[[#b]], aggr=[[SUM(#test.c)]]\ + \n Projection: #test.a AS b, #test.c\ + \n Filter: #test.a Gt Int64(10)\ \n TableScan: test projection=None"; assert_optimized_plan_eq(&plan, expected); @@ -674,8 +674,8 @@ mod tests { .project(vec![col("a").alias("b"), col("c")])? .aggregate(vec![col("b")], vec![sum(col("c"))])? .filter(and( - col("SUM(c)").gt(lit(10i64)), - and(col("b").gt(lit(10i64)), col("SUM(c)").lt(lit(20i64))), + col("SUM(test.c)").gt(lit(10i64)), + and(col("b").gt(lit(10i64)), col("SUM(test.c)").lt(lit(20i64))), ))? .build()?; @@ -683,18 +683,18 @@ mod tests { assert_eq!( format!("{:?}", plan), "\ - Filter: #SUM(c) Gt Int64(10) And #b Gt Int64(10) And #SUM(c) Lt Int64(20)\ - \n Aggregate: groupBy=[[#b]], aggr=[[SUM(#c)]]\ - \n Projection: #a AS b, #c\ + Filter: #SUM(test.c) Gt Int64(10) And #b Gt Int64(10) And #SUM(test.c) Lt Int64(20)\ + \n Aggregate: groupBy=[[#b]], aggr=[[SUM(#test.c)]]\ + \n Projection: #test.a AS b, #test.c\ \n TableScan: test projection=None" ); // filter is before the projections let expected = "\ - Filter: #SUM(c) Gt Int64(10) And #SUM(c) Lt Int64(20)\ - \n Aggregate: groupBy=[[#b]], aggr=[[SUM(#c)]]\ - \n Projection: #a AS b, #c\ - \n Filter: #a Gt Int64(10)\ + Filter: #SUM(test.c) Gt Int64(10) And #SUM(test.c) Lt Int64(20)\ + \n Aggregate: groupBy=[[#b]], aggr=[[SUM(#test.c)]]\ + \n Projection: #test.a AS b, #test.c\ + \n Filter: #test.a Gt Int64(10)\ \n TableScan: test projection=None"; assert_optimized_plan_eq(&plan, expected); @@ -714,11 +714,11 @@ mod tests { .build()?; // filter does not just any of the limits let expected = "\ - Projection: #a, #b\ - \n Filter: #a Eq Int64(1)\ + Projection: #test.a, #test.b\ + \n Filter: #test.a Eq Int64(1)\ \n Limit: 10\ \n Limit: 20\ - \n Projection: #a, #b\ + \n Projection: #test.a, #test.b\ \n TableScan: test projection=None"; assert_optimized_plan_eq(&plan, expected); Ok(()) @@ -740,20 +740,20 @@ mod tests { // not part of the test assert_eq!( format!("{:?}", plan), - "Filter: #a GtEq Int64(1)\ - \n Projection: #a\ + "Filter: #test.a GtEq Int64(1)\ + \n Projection: #test.a\ \n Limit: 1\ - \n Filter: #a LtEq Int64(1)\ - \n Projection: #a\ + \n Filter: #test.a LtEq Int64(1)\ + \n Projection: #test.a\ \n TableScan: test projection=None" ); let expected = "\ - Projection: #a\ - \n Filter: #a GtEq Int64(1)\ + Projection: #test.a\ + \n Filter: #test.a GtEq Int64(1)\ \n Limit: 1\ - \n Projection: #a\ - \n Filter: #a LtEq Int64(1)\ + \n Projection: #test.a\ + \n Filter: #test.a LtEq Int64(1)\ \n TableScan: test projection=None"; assert_optimized_plan_eq(&plan, expected); @@ -774,16 +774,16 @@ mod tests { // not part of the test assert_eq!( format!("{:?}", plan), - "Projection: #a\ - \n Filter: #a GtEq Int64(1)\ - \n Filter: #a LtEq Int64(1)\ + "Projection: #test.a\ + \n Filter: #test.a GtEq Int64(1)\ + \n Filter: #test.a LtEq Int64(1)\ \n Limit: 1\ \n TableScan: test projection=None" ); let expected = "\ - Projection: #a\ - \n Filter: #a GtEq Int64(1) And #a LtEq Int64(1)\ + Projection: #test.a\ + \n Filter: #test.a GtEq Int64(1) And #test.a LtEq Int64(1)\ \n Limit: 1\ \n TableScan: test projection=None"; @@ -804,7 +804,7 @@ mod tests { let expected = "\ TestUserDefined\ - \n Filter: #a LtEq Int64(1)\ + \n Filter: #test.a LtEq Int64(1)\ \n TableScan: test projection=None"; // not part of the test @@ -823,7 +823,12 @@ mod tests { .project(vec![col("a")])? .build()?; let plan = LogicalPlanBuilder::from(&left) - .join(&right, JoinType::Inner, &["a"], &["a"])? + .join( + &right, + JoinType::Inner, + vec![Column::from_name("a".to_string())], + vec![Column::from_name("a".to_string())], + )? .filter(col("a").lt_eq(lit(1i64)))? .build()?; @@ -831,20 +836,20 @@ mod tests { assert_eq!( format!("{:?}", plan), "\ - Filter: #a LtEq Int64(1)\ - \n Join: a = a\ + Filter: #test.a LtEq Int64(1)\ + \n Join: #test.a = #test.a\ \n TableScan: test projection=None\ - \n Projection: #a\ + \n Projection: #test.a\ \n TableScan: test projection=None" ); // filter sent to side before the join let expected = "\ - Join: a = a\ - \n Filter: #a LtEq Int64(1)\ + Join: #test.a = #test.a\ + \n Filter: #test.a LtEq Int64(1)\ \n TableScan: test projection=None\ - \n Projection: #a\ - \n Filter: #a LtEq Int64(1)\ + \n Projection: #test.a\ + \n Filter: #test.a LtEq Int64(1)\ \n TableScan: test projection=None"; assert_optimized_plan_eq(&plan, expected); Ok(()) @@ -861,7 +866,12 @@ mod tests { .project(vec![col("a"), col("b")])? .build()?; let plan = LogicalPlanBuilder::from(&left) - .join(&right, JoinType::Inner, &["a"], &["a"])? + .join( + &right, + JoinType::Inner, + vec![Column::from_name("a".to_string())], + vec![Column::from_name("a".to_string())], + )? // "b" and "c" are not shared by either side: they are only available together after the join .filter(col("c").lt_eq(col("b")))? .build()?; @@ -870,11 +880,11 @@ mod tests { assert_eq!( format!("{:?}", plan), "\ - Filter: #c LtEq #b\ - \n Join: a = a\ - \n Projection: #a, #c\ + Filter: #test.c LtEq #test.b\ + \n Join: #test.a = #test.a\ + \n Projection: #test.a, #test.c\ \n TableScan: test projection=None\ - \n Projection: #a, #b\ + \n Projection: #test.a, #test.b\ \n TableScan: test projection=None" ); @@ -895,7 +905,12 @@ mod tests { .project(vec![col("a"), col("c")])? .build()?; let plan = LogicalPlanBuilder::from(&left) - .join(&right, JoinType::Inner, &["a"], &["a"])? + .join( + &right, + JoinType::Inner, + vec![Column::from_name("a".to_string())], + vec![Column::from_name("a".to_string())], + )? .filter(col("b").lt_eq(lit(1i64)))? .build()?; @@ -903,20 +918,20 @@ mod tests { assert_eq!( format!("{:?}", plan), "\ - Filter: #b LtEq Int64(1)\ - \n Join: a = a\ - \n Projection: #a, #b\ + Filter: #test.b LtEq Int64(1)\ + \n Join: #test.a = #test.a\ + \n Projection: #test.a, #test.b\ \n TableScan: test projection=None\ - \n Projection: #a, #c\ + \n Projection: #test.a, #test.c\ \n TableScan: test projection=None" ); let expected = "\ - Join: a = a\ - \n Projection: #a, #b\ - \n Filter: #b LtEq Int64(1)\ + Join: #test.a = #test.a\ + \n Projection: #test.a, #test.b\ + \n Filter: #test.b LtEq Int64(1)\ \n TableScan: test projection=None\ - \n Projection: #a, #c\ + \n Projection: #test.a, #test.c\ \n TableScan: test projection=None"; assert_optimized_plan_eq(&plan, expected); Ok(()) @@ -966,14 +981,15 @@ mod tests { fn table_scan_with_pushdown_provider( filter_support: TableProviderFilterPushDown, ) -> Result { + use std::convert::TryFrom; + let test_provider = PushDownProvider { filter_support }; let table_scan = LogicalPlan::TableScan { - table_name: "".into(), + table_name: None, filters: vec![], - projected_schema: Arc::new(DFSchema::try_from_qualified( - "", - &*test_provider.schema(), + projected_schema: Arc::new(DFSchema::try_from( + (*test_provider.schema()).clone(), )?), projection: None, source: Arc::new(test_provider), diff --git a/datafusion/src/optimizer/hash_build_probe_order.rs b/datafusion/src/optimizer/hash_build_probe_order.rs index f44050f0b72e..1bf25c6554b4 100644 --- a/datafusion/src/optimizer/hash_build_probe_order.rs +++ b/datafusion/src/optimizer/hash_build_probe_order.rs @@ -120,10 +120,7 @@ impl OptimizerRule for HashBuildProbeOrder { Ok(LogicalPlan::Join { left: Arc::new(right), right: Arc::new(left), - on: on - .iter() - .map(|(l, r)| (r.to_string(), l.to_string())) - .collect(), + on: on.iter().map(|(l, r)| (r.clone(), l.clone())).collect(), join_type: swap_join_type(*join_type), schema: schema.clone(), }) @@ -234,7 +231,7 @@ mod tests { #[test] fn test_swap_order() { let lp_left = LogicalPlan::TableScan { - table_name: "left".to_string(), + table_name: Some("left".to_string()), projection: None, source: Arc::new(TestTableProvider { num_rows: 1000 }), projected_schema: Arc::new(DFSchema::empty()), @@ -243,7 +240,7 @@ mod tests { }; let lp_right = LogicalPlan::TableScan { - table_name: "right".to_string(), + table_name: Some("right".to_string()), projection: None, source: Arc::new(TestTableProvider { num_rows: 100 }), projected_schema: Arc::new(DFSchema::empty()), diff --git a/datafusion/src/optimizer/limit_push_down.rs b/datafusion/src/optimizer/limit_push_down.rs index 73a231f2248f..da58336a2161 100644 --- a/datafusion/src/optimizer/limit_push_down.rs +++ b/datafusion/src/optimizer/limit_push_down.rs @@ -160,7 +160,7 @@ mod test { // Should push the limit down to table provider // When it has a select let expected = "Limit: 1000\ - \n Projection: #a\ + \n Projection: #test.a\ \n TableScan: test projection=None, limit=1000"; assert_optimized_plan_eq(&plan, expected); @@ -199,7 +199,7 @@ mod test { // Limit should *not* push down aggregate node let expected = "Limit: 1000\ - \n Aggregate: groupBy=[[#a]], aggr=[[MAX(#b)]]\ + \n Aggregate: groupBy=[[#test.a]], aggr=[[MAX(#test.b)]]\ \n TableScan: test projection=None"; assert_optimized_plan_eq(&plan, expected); @@ -241,7 +241,7 @@ mod test { // Limit should use deeper LIMIT 1000, but Limit 10 shouldn't push down aggregation let expected = "Limit: 10\ - \n Aggregate: groupBy=[[#a]], aggr=[[MAX(#b)]]\ + \n Aggregate: groupBy=[[#test.a]], aggr=[[MAX(#test.b)]]\ \n Limit: 1000\ \n TableScan: test projection=None, limit=1000"; diff --git a/datafusion/src/optimizer/projection_push_down.rs b/datafusion/src/optimizer/projection_push_down.rs index 6b1cdfe18ca7..100e4ffcbbd5 100644 --- a/datafusion/src/optimizer/projection_push_down.rs +++ b/datafusion/src/optimizer/projection_push_down.rs @@ -19,10 +19,12 @@ //! loaded into memory use crate::error::Result; -use crate::logical_plan::{DFField, DFSchema, DFSchemaRef, LogicalPlan, ToDFSchema}; +use crate::logical_plan::{ + Column, DFField, DFSchema, DFSchemaRef, LogicalPlan, ToDFSchema, +}; use crate::optimizer::optimizer::OptimizerRule; use crate::optimizer::utils; -use arrow::datatypes::Schema; +use arrow::datatypes::{Field, Schema}; use arrow::error::Result as ArrowResult; use std::{collections::HashSet, sync::Arc}; use utils::optimize_explain; @@ -38,8 +40,8 @@ impl OptimizerRule for ProjectionPushDown { .schema() .fields() .iter() - .map(|f| f.name().clone()) - .collect::>(); + .map(|f| f.qualified_column()) + .collect::>(); optimize_plan(self, plan, &required_columns, false) } @@ -56,8 +58,9 @@ impl ProjectionPushDown { } fn get_projected_schema( + table_name: Option<&String>, schema: &Schema, - required_columns: &HashSet, + required_columns: &HashSet, has_projection: bool, ) -> Result<(Vec, DFSchemaRef)> { // once we reach the table scan, we can use the accumulated set of column @@ -67,7 +70,8 @@ fn get_projected_schema( // e.g. when the column derives from an aggregation let mut projection: Vec = required_columns .iter() - .map(|name| schema.index_of(name)) + .filter(|c| c.relation.as_ref() == table_name) + .map(|c| schema.index_of(&c.name)) .filter_map(ArrowResult::ok) .collect(); @@ -92,8 +96,20 @@ fn get_projected_schema( // create the projected schema let mut projected_fields: Vec = Vec::with_capacity(projection.len()); - for i in &projection { - projected_fields.push(DFField::from(schema.fields()[*i].clone())); + match table_name { + Some(qualifer) => { + for i in &projection { + projected_fields.push(DFField::from_qualified( + qualifer, + schema.fields()[*i].clone(), + )); + } + } + None => { + for i in &projection { + projected_fields.push(DFField::from(schema.fields()[*i].clone())); + } + } } Ok((projection, projected_fields.to_dfschema_ref()?)) @@ -103,7 +119,7 @@ fn get_projected_schema( fn optimize_plan( optimizer: &ProjectionPushDown, plan: &LogicalPlan, - required_columns: &HashSet, // set of columns required up to this step + required_columns: &HashSet, // set of columns required up to this step has_projection: bool, ) -> Result { let mut new_required_columns = required_columns.clone(); @@ -126,7 +142,7 @@ fn optimize_plan( .iter() .enumerate() .try_for_each(|(i, field)| { - if required_columns.contains(field.name()) { + if required_columns.contains(&field.qualified_column()) { new_expr.push(expr[i].clone()); new_fields.push(field.clone()); @@ -158,8 +174,8 @@ fn optimize_plan( schema, } => { for (l, r) in on { - new_required_columns.insert(l.to_owned()); - new_required_columns.insert(r.to_owned()); + new_required_columns.insert(l.clone()); + new_required_columns.insert(r.clone()); } Ok(LogicalPlan::Join { left: Arc::new(optimize_plan( @@ -197,10 +213,11 @@ fn optimize_plan( let mut new_aggr_expr = Vec::new(); aggr_expr.iter().try_for_each(|expr| { let name = &expr.name(&schema)?; + let column = Column::from_name(name.to_string()); - if required_columns.contains(name) { + if required_columns.contains(&column) { new_aggr_expr.push(expr.clone()); - new_required_columns.insert(name.clone()); + new_required_columns.insert(column); // add to the new set of required columns utils::expr_to_column_names(expr, &mut new_required_columns) @@ -213,7 +230,7 @@ fn optimize_plan( schema .fields() .iter() - .filter(|x| new_required_columns.contains(x.name())) + .filter(|x| new_required_columns.contains(&x.qualified_column())) .cloned() .collect(), )?; @@ -239,12 +256,15 @@ fn optimize_plan( limit, .. } => { - let (projection, projected_schema) = - get_projected_schema(&source.schema(), required_columns, has_projection)?; - + let (projection, projected_schema) = get_projected_schema( + table_name.as_ref(), + &source.schema(), + required_columns, + has_projection, + )?; // return the table scan with projection Ok(LogicalPlan::TableScan { - table_name: table_name.to_string(), + table_name: table_name.clone(), source: source.clone(), projection: Some(projection), projected_schema, @@ -261,6 +281,47 @@ fn optimize_plan( let schema = schema.as_ref().to_owned().into(); optimize_explain(optimizer, *verbose, &*plan, stringified_plans, &schema) } + LogicalPlan::Union { + inputs, + schema, + alias, + } => { + // UNION inputs will reference the same column with different identifiers, so we need + // to populate new_required_columns by unqualified column name based on required fields + // from the resulting UNION output + let union_required_fields = schema + .fields() + .iter() + .filter(|f| new_required_columns.contains(&f.qualified_column())) + .map(|f| f.field()) + .collect::>(); + + let new_inputs = inputs + .iter() + .map(|input_plan| { + input_plan + .schema() + .fields() + .iter() + .filter(|f| union_required_fields.contains(f.field())) + .for_each(|f| { + new_required_columns.insert(f.qualified_column()); + }); + optimize_plan( + optimizer, + input_plan, + &new_required_columns, + has_projection, + ) + }) + .collect::>>()?; + + Ok(LogicalPlan::Union { + inputs: new_inputs, + schema: schema.clone(), + alias: alias.clone(), + }) + } // all other nodes: Add any additional columns used by // expressions in this node to the list of required columns LogicalPlan::Limit { .. } @@ -269,7 +330,6 @@ fn optimize_plan( | LogicalPlan::EmptyRelation { .. } | LogicalPlan::Sort { .. } | LogicalPlan::CreateExternalTable { .. } - | LogicalPlan::Union { .. } | LogicalPlan::Extension { .. } => { let expr = plan.expressions(); // collect all required columns by this plan @@ -279,8 +339,13 @@ fn optimize_plan( let inputs = plan.inputs(); let new_inputs = inputs .iter() - .map(|plan| { - optimize_plan(optimizer, plan, &new_required_columns, has_projection) + .map(|input_plan| { + optimize_plan( + optimizer, + input_plan, + &new_required_columns, + has_projection, + ) }) .collect::>>()?; @@ -306,7 +371,7 @@ mod tests { .aggregate(vec![], vec![max(col("b"))])? .build()?; - let expected = "Aggregate: groupBy=[[]], aggr=[[MAX(#b)]]\ + let expected = "Aggregate: groupBy=[[]], aggr=[[MAX(#test.b)]]\ \n TableScan: test projection=Some([1])"; assert_optimized_plan_eq(&plan, expected); @@ -322,7 +387,7 @@ mod tests { .aggregate(vec![col("c")], vec![max(col("b"))])? .build()?; - let expected = "Aggregate: groupBy=[[#c]], aggr=[[MAX(#b)]]\ + let expected = "Aggregate: groupBy=[[#test.c]], aggr=[[MAX(#test.b)]]\ \n TableScan: test projection=Some([1, 2])"; assert_optimized_plan_eq(&plan, expected); @@ -339,8 +404,8 @@ mod tests { .aggregate(vec![], vec![max(col("b"))])? .build()?; - let expected = "Aggregate: groupBy=[[]], aggr=[[MAX(#b)]]\ - \n Filter: #c\ + let expected = "Aggregate: groupBy=[[]], aggr=[[MAX(#test.b)]]\ + \n Filter: #test.c\ \n TableScan: test projection=Some([1, 2])"; assert_optimized_plan_eq(&plan, expected); @@ -359,7 +424,7 @@ mod tests { }])? .build()?; - let expected = "Projection: CAST(#c AS Float64)\ + let expected = "Projection: CAST(#test.c AS Float64)\ \n TableScan: test projection=Some([2])"; assert_optimized_plan_eq(&projection, expected); @@ -379,7 +444,7 @@ mod tests { assert_fields_eq(&plan, vec!["a", "b"]); - let expected = "Projection: #a, #b\ + let expected = "Projection: #test.a, #test.b\ \n TableScan: test projection=Some([0, 1])"; assert_optimized_plan_eq(&plan, expected); @@ -401,7 +466,7 @@ mod tests { assert_fields_eq(&plan, vec!["c", "a"]); let expected = "Limit: 5\ - \n Projection: #c, #a\ + \n Projection: #test.c, #test.a\ \n TableScan: test projection=Some([0, 2])"; assert_optimized_plan_eq(&plan, expected); @@ -445,12 +510,12 @@ mod tests { .aggregate(vec![col("c")], vec![max(col("a"))])? .build()?; - assert_fields_eq(&plan, vec!["c", "MAX(a)"]); + assert_fields_eq(&plan, vec!["c", "MAX(test.a)"]); let expected = "\ - Aggregate: groupBy=[[#c]], aggr=[[MAX(#a)]]\ - \n Filter: #c Gt Int32(1)\ - \n Projection: #c, #a\ + Aggregate: groupBy=[[#test.c]], aggr=[[MAX(#test.a)]]\ + \n Filter: #test.c Gt Int32(1)\ + \n Projection: #test.c, #test.a\ \n TableScan: test projection=Some([0, 2])"; assert_optimized_plan_eq(&plan, expected); @@ -513,15 +578,15 @@ mod tests { let plan = LogicalPlanBuilder::from(&table_scan) .aggregate(vec![col("a"), col("c")], vec![max(col("b")), min(col("b"))])? .filter(col("c").gt(lit(1)))? - .project(vec![col("c"), col("a"), col("MAX(b)")])? + .project(vec![col("c"), col("a"), col("MAX(test.b)")])? .build()?; - assert_fields_eq(&plan, vec!["c", "a", "MAX(b)"]); + assert_fields_eq(&plan, vec!["c", "a", "MAX(test.b)"]); let expected = "\ - Projection: #c, #a, #MAX(b)\ - \n Filter: #c Gt Int32(1)\ - \n Aggregate: groupBy=[[#a, #c]], aggr=[[MAX(#b)]]\ + Projection: #test.c, #test.a, #MAX(test.b)\ + \n Filter: #test.c Gt Int32(1)\ + \n Aggregate: groupBy=[[#test.a, #test.c]], aggr=[[MAX(#test.b)]]\ \n TableScan: test projection=Some([0, 1, 2])"; assert_optimized_plan_eq(&plan, expected); diff --git a/datafusion/src/optimizer/utils.rs b/datafusion/src/optimizer/utils.rs index fe1d02381917..7a3b68206224 100644 --- a/datafusion/src/optimizer/utils.rs +++ b/datafusion/src/optimizer/utils.rs @@ -23,8 +23,8 @@ use arrow::datatypes::Schema; use super::optimizer::OptimizerRule; use crate::logical_plan::{ - Expr, LogicalPlan, Operator, Partitioning, PlanType, Recursion, StringifiedPlan, - ToDFSchema, + Column, Expr, LogicalPlan, Operator, Partitioning, PlanType, Recursion, + StringifiedPlan, ToDFSchema, }; use crate::prelude::lit; use crate::scalar::ScalarValue; @@ -40,7 +40,7 @@ const CASE_ELSE_MARKER: &str = "__DATAFUSION_CASE_ELSE__"; /// names referenced in the expression pub fn exprlist_to_column_names( expr: &[Expr], - accum: &mut HashSet, + accum: &mut HashSet, ) -> Result<()> { for e in expr { expr_to_column_names(e, accum)?; @@ -51,17 +51,17 @@ pub fn exprlist_to_column_names( /// Recursively walk an expression tree, collecting the unique set of column names /// referenced in the expression struct ColumnNameVisitor<'a> { - accum: &'a mut HashSet, + accum: &'a mut HashSet, } impl ExpressionVisitor for ColumnNameVisitor<'_> { fn pre_visit(self, expr: &Expr) -> Result> { match expr { - Expr::Column(name) => { - self.accum.insert(name.clone()); + Expr::Column(qc) => { + self.accum.insert(qc.clone()); } Expr::ScalarVariable(var_names) => { - self.accum.insert(var_names.join(".")); + self.accum.insert(Column::from_name(var_names.join("."))); } Expr::Alias(_, _) => {} Expr::Literal(_) => {} @@ -88,7 +88,7 @@ impl ExpressionVisitor for ColumnNameVisitor<'_> { /// Recursively walk an expression tree, collecting the unique set of column names /// referenced in the expression -pub fn expr_to_column_names(expr: &Expr, accum: &mut HashSet) -> Result<()> { +pub fn expr_to_column_names(expr: &Expr, accum: &mut HashSet) -> Result<()> { expr.accept(ColumnNameVisitor { accum })?; Ok(()) } @@ -263,10 +263,8 @@ pub fn expr_sub_expressions(expr: &Expr) -> Result> { } Expr::Cast { expr, .. } => Ok(vec![expr.as_ref().to_owned()]), Expr::TryCast { expr, .. } => Ok(vec![expr.as_ref().to_owned()]), - Expr::Column(_) => Ok(vec![]), + Expr::Column(_) | Expr::Literal(_) | Expr::ScalarVariable(_) => Ok(vec![]), Expr::Alias(expr, ..) => Ok(vec![expr.as_ref().to_owned()]), - Expr::Literal(_) => Ok(vec![]), - Expr::ScalarVariable(_) => Ok(vec![]), Expr::Not(expr) => Ok(vec![expr.as_ref().to_owned()]), Expr::Negative(expr) => Ok(vec![expr.as_ref().to_owned()]), Expr::Sort { expr, .. } => Ok(vec![expr.as_ref().to_owned()]), @@ -368,9 +366,6 @@ pub fn rewrite_expression(expr: &Expr, expressions: &[Expr]) -> Result { } Expr::Not(_) => Ok(Expr::Not(Box::new(expressions[0].clone()))), Expr::Negative(_) => Ok(Expr::Negative(Box::new(expressions[0].clone()))), - Expr::Column(_) => Ok(expr.clone()), - Expr::Literal(_) => Ok(expr.clone()), - Expr::ScalarVariable(_) => Ok(expr.clone()), Expr::Sort { asc, nulls_first, .. } => Ok(Expr::Sort { @@ -399,10 +394,13 @@ pub fn rewrite_expression(expr: &Expr, expressions: &[Expr]) -> Result { Ok(expr) } } - Expr::InList { .. } => Ok(expr.clone()), Expr::Wildcard { .. } => Err(DataFusionError::Internal( "Wildcard expressions are not valid in a logical query plan".to_owned(), )), + Expr::InList { .. } + | Expr::Column(_) + | Expr::Literal(_) + | Expr::ScalarVariable(_) => Ok(expr.clone()), } } @@ -415,7 +413,7 @@ mod tests { #[test] fn test_collect_expr() -> Result<()> { - let mut accum: HashSet = HashSet::new(); + let mut accum: HashSet = HashSet::new(); expr_to_column_names( &Expr::Cast { expr: Box::new(col("a")), @@ -431,7 +429,7 @@ mod tests { &mut accum, )?; assert_eq!(1, accum.len()); - assert!(accum.contains("a")); + assert!(accum.contains(&Column::from_name("a".to_string()))); Ok(()) } diff --git a/datafusion/src/physical_optimizer/coalesce_batches.rs b/datafusion/src/physical_optimizer/coalesce_batches.rs index 9af8911062df..a25b6e2e0ba7 100644 --- a/datafusion/src/physical_optimizer/coalesce_batches.rs +++ b/datafusion/src/physical_optimizer/coalesce_batches.rs @@ -23,7 +23,7 @@ use crate::{ error::Result, physical_plan::{ coalesce_batches::CoalesceBatchesExec, filter::FilterExec, - hash_join::HashJoinExec, repartition::RepartitionExec, + hash_join::HashJoinExec, repartition::RepartitionExec, Partitioning, }, }; use std::sync::Arc; @@ -37,6 +37,7 @@ impl CoalesceBatches { Self {} } } + impl PhysicalOptimizerRule for CoalesceBatches { fn optimize( &self, @@ -58,7 +59,17 @@ impl PhysicalOptimizerRule for CoalesceBatches { // See https://issues.apache.org/jira/browse/ARROW-11068 let wrap_in_coalesce = plan_any.downcast_ref::().is_some() || plan_any.downcast_ref::().is_some() - || plan_any.downcast_ref::().is_some(); + || { + match plan_any.downcast_ref::() { + Some(p) => match p.partitioning() { + // do not coalesce hash partitions since other plans like partitioned hash + // join depends on it empty batches for outter joins + Partitioning::Hash(_, _) => false, + _ => true, + }, + None => false, + } + }; //TODO we should also do this for HashAggregateExec but we need to update tests // as part of this work - see https://issues.apache.org/jira/browse/ARROW-11068 diff --git a/datafusion/src/physical_plan/expressions/binary.rs b/datafusion/src/physical_plan/expressions/binary.rs index 5c2d9ce02f51..61d7849ad586 100644 --- a/datafusion/src/physical_plan/expressions/binary.rs +++ b/datafusion/src/physical_plan/expressions/binary.rs @@ -592,11 +592,12 @@ mod tests { ]); let a = Int32Array::from(vec![1, 2, 3, 4, 5]); let b = Int32Array::from(vec![1, 2, 4, 8, 16]); + + // expression: "a < b" + let lt = binary_simple(col("a", &schema)?, Operator::Lt, col("b", &schema)?); let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)])?; - // expression: "a < b" - let lt = binary_simple(col("a"), Operator::Lt, col("b")); let result = lt.evaluate(&batch)?.into_array(batch.num_rows()); assert_eq!(result.len(), 5); @@ -620,16 +621,17 @@ mod tests { ]); let a = Int32Array::from(vec![2, 4, 6, 8, 10]); let b = Int32Array::from(vec![2, 5, 4, 8, 8]); - let batch = - RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)])?; // expression: "a < b OR a == b" let expr = binary_simple( - binary_simple(col("a"), Operator::Lt, col("b")), + binary_simple(col("a", &schema)?, Operator::Lt, col("b", &schema)?), Operator::Or, - binary_simple(col("a"), Operator::Eq, col("b")), + binary_simple(col("a", &schema)?, Operator::Eq, col("b", &schema)?), ); - assert_eq!("a < b OR a = b", format!("{}", expr)); + let batch = + RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)])?; + + assert_eq!("a@0 < b@1 OR a@0 = b@1", format!("{}", expr)); let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); assert_eq!(result.len(), 5); @@ -661,14 +663,15 @@ mod tests { ]); let a = $A_ARRAY::from($A_VEC); let b = $B_ARRAY::from($B_VEC); + + // verify that we can construct the expression + let expression = + binary(col("a", &schema)?, $OP, col("b", &schema)?, &schema)?; let batch = RecordBatch::try_new( Arc::new(schema.clone()), vec![Arc::new(a), Arc::new(b)], )?; - // verify that we can construct the expression - let expression = binary(col("a"), $OP, col("b"), &schema)?; - // verify that the expression's type is correct assert_eq!(expression.data_type(&schema)?, $C_TYPE); @@ -844,7 +847,12 @@ mod tests { // Test 1: dict = str // verify that we can construct the expression - let expression = binary(col("dict"), Operator::Eq, col("str"), &schema)?; + let expression = binary( + col("dict", &schema)?, + Operator::Eq, + col("str", &schema)?, + &schema, + )?; assert_eq!(expression.data_type(&schema)?, DataType::Boolean); // evaluate and verify the result type matched @@ -858,7 +866,12 @@ mod tests { // str = dict // verify that we can construct the expression - let expression = binary(col("str"), Operator::Eq, col("dict"), &schema)?; + let expression = binary( + col("str", &schema)?, + Operator::Eq, + col("dict", &schema)?, + &schema, + )?; assert_eq!(expression.data_type(&schema)?, DataType::Boolean); // evaluate and verify the result type matched @@ -970,7 +983,7 @@ mod tests { op: Operator, expected: PrimitiveArray, ) -> Result<()> { - let arithmetic_op = binary_simple(col("a"), op, col("b")); + let arithmetic_op = binary_simple(col("a", &schema)?, op, col("b", &schema)?); let batch = RecordBatch::try_new(schema, data)?; let result = arithmetic_op.evaluate(&batch)?.into_array(batch.num_rows()); @@ -985,7 +998,7 @@ mod tests { op: Operator, expected: BooleanArray, ) -> Result<()> { - let arithmetic_op = binary_simple(col("a"), op, col("b")); + let arithmetic_op = binary_simple(col("a", &schema)?, op, col("b", &schema)?); let data: Vec = vec![Arc::new(left), Arc::new(right)]; let batch = RecordBatch::try_new(schema, data)?; let result = arithmetic_op.evaluate(&batch)?.into_array(batch.num_rows()); diff --git a/datafusion/src/physical_plan/expressions/case.rs b/datafusion/src/physical_plan/expressions/case.rs index e8c500e5ed62..fd603f4bb17f 100644 --- a/datafusion/src/physical_plan/expressions/case.rs +++ b/datafusion/src/physical_plan/expressions/case.rs @@ -467,6 +467,7 @@ mod tests { #[test] fn case_with_expr() -> Result<()> { let batch = case_test_batch()?; + let schema = batch.schema(); // CASE a WHEN 'foo' THEN 123 WHEN 'bar' THEN 456 END let when1 = lit(ScalarValue::Utf8(Some("foo".to_string()))); @@ -474,7 +475,11 @@ mod tests { let when2 = lit(ScalarValue::Utf8(Some("bar".to_string()))); let then2 = lit(ScalarValue::Int32(Some(456))); - let expr = case(Some(col("a")), &[(when1, then1), (when2, then2)], None)?; + let expr = case( + Some(col("a", &schema)?), + &[(when1, then1), (when2, then2)], + None, + )?; let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); let result = result .as_any() @@ -491,6 +496,7 @@ mod tests { #[test] fn case_with_expr_else() -> Result<()> { let batch = case_test_batch()?; + let schema = batch.schema(); // CASE a WHEN 'foo' THEN 123 WHEN 'bar' THEN 456 ELSE 999 END let when1 = lit(ScalarValue::Utf8(Some("foo".to_string()))); @@ -500,7 +506,7 @@ mod tests { let else_value = lit(ScalarValue::Int32(Some(999))); let expr = case( - Some(col("a")), + Some(col("a", &schema)?), &[(when1, then1), (when2, then2)], Some(else_value), )?; @@ -521,17 +527,18 @@ mod tests { #[test] fn case_without_expr() -> Result<()> { let batch = case_test_batch()?; + let schema = batch.schema(); // CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN 456 END let when1 = binary( - col("a"), + col("a", &schema)?, Operator::Eq, lit(ScalarValue::Utf8(Some("foo".to_string()))), &batch.schema(), )?; let then1 = lit(ScalarValue::Int32(Some(123))); let when2 = binary( - col("a"), + col("a", &schema)?, Operator::Eq, lit(ScalarValue::Utf8(Some("bar".to_string()))), &batch.schema(), @@ -555,17 +562,18 @@ mod tests { #[test] fn case_without_expr_else() -> Result<()> { let batch = case_test_batch()?; + let schema = batch.schema(); // CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN 456 ELSE 999 END let when1 = binary( - col("a"), + col("a", &schema)?, Operator::Eq, lit(ScalarValue::Utf8(Some("foo".to_string()))), &batch.schema(), )?; let then1 = lit(ScalarValue::Int32(Some(123))); let when2 = binary( - col("a"), + col("a", &schema)?, Operator::Eq, lit(ScalarValue::Utf8(Some("bar".to_string()))), &batch.schema(), diff --git a/datafusion/src/physical_plan/expressions/cast.rs b/datafusion/src/physical_plan/expressions/cast.rs index ba395f54d917..9b6b9af3e2c3 100644 --- a/datafusion/src/physical_plan/expressions/cast.rs +++ b/datafusion/src/physical_plan/expressions/cast.rs @@ -178,10 +178,14 @@ mod tests { RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; // verify that we can construct the expression - let expression = cast_with_options(col("a"), &schema, $TYPE, $CAST_OPTIONS)?; + let expression = + cast_with_options(col("a", &schema)?, &schema, $TYPE, $CAST_OPTIONS)?; // verify that its display is correct - assert_eq!(format!("CAST(a AS {:?})", $TYPE), format!("{}", expression)); + assert_eq!( + format!("CAST(a@0 AS {:?})", $TYPE), + format!("{}", expression) + ); // verify that the expression's type is correct assert_eq!(expression.data_type(&schema)?, $TYPE); @@ -270,7 +274,7 @@ mod tests { // Ensure a useful error happens at plan time if invalid casts are used let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); - let result = cast(col("a"), &schema, DataType::LargeBinary); + let result = cast(col("a", &schema).unwrap(), &schema, DataType::LargeBinary); result.expect_err("expected Invalid CAST"); } @@ -281,7 +285,7 @@ mod tests { let a = StringArray::from(vec!["9.1"]); let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; let expression = cast_with_options( - col("a"), + col("a", &schema)?, &schema, DataType::Int32, DEFAULT_DATAFUSION_CAST_OPTIONS, diff --git a/datafusion/src/physical_plan/expressions/column.rs b/datafusion/src/physical_plan/expressions/column.rs index 7e0304e51fe7..0fd2f3e903cf 100644 --- a/datafusion/src/physical_plan/expressions/column.rs +++ b/datafusion/src/physical_plan/expressions/column.rs @@ -28,16 +28,18 @@ use crate::error::Result; use crate::physical_plan::{ColumnarValue, PhysicalExpr}; /// Represents the column at a given index in a RecordBatch -#[derive(Debug)] +#[derive(Debug, Hash, PartialEq, Eq, Clone)] pub struct Column { name: String, + index: usize, } impl Column { /// Create a new column expression - pub fn new(name: &str) -> Self { + pub fn new(name: &str, index: usize) -> Self { Self { name: name.to_owned(), + index, } } @@ -45,11 +47,16 @@ impl Column { pub fn name(&self) -> &str { &self.name } + + /// Get the column index + pub fn index(&self) -> usize { + self.index + } } impl std::fmt::Display for Column { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "{}", self.name) + write!(f, "{}@{}", self.name, self.index) } } @@ -61,26 +68,21 @@ impl PhysicalExpr for Column { /// Get the data type of this expression, given the schema of the input fn data_type(&self, input_schema: &Schema) -> Result { - Ok(input_schema - .field_with_name(&self.name)? - .data_type() - .clone()) + Ok(input_schema.field(self.index).data_type().clone()) } /// Decide whehter this expression is nullable, given the schema of the input fn nullable(&self, input_schema: &Schema) -> Result { - Ok(input_schema.field_with_name(&self.name)?.is_nullable()) + Ok(input_schema.field(self.index).is_nullable()) } /// Evaluate the expression fn evaluate(&self, batch: &RecordBatch) -> Result { - Ok(ColumnarValue::Array( - batch.column(batch.schema().index_of(&self.name)?).clone(), - )) + Ok(ColumnarValue::Array(batch.column(self.index).clone())) } } /// Create a column expression -pub fn col(name: &str) -> Arc { - Arc::new(Column::new(name)) +pub fn col(name: &str, schema: &Schema) -> Result> { + Ok(Arc::new(Column::new(name, schema.index_of(name)?))) } diff --git a/datafusion/src/physical_plan/expressions/in_list.rs b/datafusion/src/physical_plan/expressions/in_list.rs index 41f111006ea2..38b2b9d45b9b 100644 --- a/datafusion/src/physical_plan/expressions/in_list.rs +++ b/datafusion/src/physical_plan/expressions/in_list.rs @@ -296,8 +296,8 @@ mod tests { // applies the in_list expr to an input batch and list macro_rules! in_list { - ($BATCH:expr, $LIST:expr, $NEGATED:expr, $EXPECTED:expr) => {{ - let expr = in_list(col("a"), $LIST, $NEGATED).unwrap(); + ($BATCH:expr, $LIST:expr, $NEGATED:expr, $EXPECTED:expr, $COL:expr) => {{ + let expr = in_list($COL, $LIST, $NEGATED).unwrap(); let result = expr.evaluate(&$BATCH)?.into_array($BATCH.num_rows()); let result = result .as_any() @@ -312,6 +312,7 @@ mod tests { fn in_list_utf8() -> Result<()> { let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]); let a = StringArray::from(vec![Some("a"), Some("d"), None]); + let col_a = col("a", &schema)?; let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?; // expression: "a in ("a", "b")" @@ -319,14 +320,26 @@ mod tests { lit(ScalarValue::Utf8(Some("a".to_string()))), lit(ScalarValue::Utf8(Some("b".to_string()))), ]; - in_list!(batch, list, &false, vec![Some(true), Some(false), None]); + in_list!( + batch, + list, + &false, + vec![Some(true), Some(false), None], + col_a.clone() + ); // expression: "a not in ("a", "b")" let list = vec![ lit(ScalarValue::Utf8(Some("a".to_string()))), lit(ScalarValue::Utf8(Some("b".to_string()))), ]; - in_list!(batch, list, &true, vec![Some(false), Some(true), None]); + in_list!( + batch, + list, + &true, + vec![Some(false), Some(true), None], + col_a.clone() + ); // expression: "a not in ("a", "b")" let list = vec![ @@ -334,7 +347,13 @@ mod tests { lit(ScalarValue::Utf8(Some("b".to_string()))), lit(ScalarValue::Utf8(None)), ]; - in_list!(batch, list, &false, vec![Some(true), None, None]); + in_list!( + batch, + list, + &false, + vec![Some(true), None, None], + col_a.clone() + ); // expression: "a not in ("a", "b")" let list = vec![ @@ -342,7 +361,13 @@ mod tests { lit(ScalarValue::Utf8(Some("b".to_string()))), lit(ScalarValue::Utf8(None)), ]; - in_list!(batch, list, &true, vec![Some(false), None, None]); + in_list!( + batch, + list, + &true, + vec![Some(false), None, None], + col_a.clone() + ); Ok(()) } @@ -351,6 +376,7 @@ mod tests { fn in_list_int64() -> Result<()> { let schema = Schema::new(vec![Field::new("a", DataType::Int64, true)]); let a = Int64Array::from(vec![Some(0), Some(2), None]); + let col_a = col("a", &schema)?; let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?; // expression: "a in (0, 1)" @@ -358,14 +384,26 @@ mod tests { lit(ScalarValue::Int64(Some(0))), lit(ScalarValue::Int64(Some(1))), ]; - in_list!(batch, list, &false, vec![Some(true), Some(false), None]); + in_list!( + batch, + list, + &false, + vec![Some(true), Some(false), None], + col_a.clone() + ); // expression: "a not in (0, 1)" let list = vec![ lit(ScalarValue::Int64(Some(0))), lit(ScalarValue::Int64(Some(1))), ]; - in_list!(batch, list, &true, vec![Some(false), Some(true), None]); + in_list!( + batch, + list, + &true, + vec![Some(false), Some(true), None], + col_a.clone() + ); // expression: "a in (0, 1, NULL)" let list = vec![ @@ -373,7 +411,13 @@ mod tests { lit(ScalarValue::Int64(Some(1))), lit(ScalarValue::Utf8(None)), ]; - in_list!(batch, list, &false, vec![Some(true), None, None]); + in_list!( + batch, + list, + &false, + vec![Some(true), None, None], + col_a.clone() + ); // expression: "a not in (0, 1, NULL)" let list = vec![ @@ -381,7 +425,13 @@ mod tests { lit(ScalarValue::Int64(Some(1))), lit(ScalarValue::Utf8(None)), ]; - in_list!(batch, list, &true, vec![Some(false), None, None]); + in_list!( + batch, + list, + &true, + vec![Some(false), None, None], + col_a.clone() + ); Ok(()) } @@ -390,6 +440,7 @@ mod tests { fn in_list_float64() -> Result<()> { let schema = Schema::new(vec![Field::new("a", DataType::Float64, true)]); let a = Float64Array::from(vec![Some(0.0), Some(0.2), None]); + let col_a = col("a", &schema)?; let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?; // expression: "a in (0.0, 0.2)" @@ -397,14 +448,26 @@ mod tests { lit(ScalarValue::Float64(Some(0.0))), lit(ScalarValue::Float64(Some(0.1))), ]; - in_list!(batch, list, &false, vec![Some(true), Some(false), None]); + in_list!( + batch, + list, + &false, + vec![Some(true), Some(false), None], + col_a.clone() + ); // expression: "a not in (0.0, 0.2)" let list = vec![ lit(ScalarValue::Float64(Some(0.0))), lit(ScalarValue::Float64(Some(0.1))), ]; - in_list!(batch, list, &true, vec![Some(false), Some(true), None]); + in_list!( + batch, + list, + &true, + vec![Some(false), Some(true), None], + col_a.clone() + ); // expression: "a in (0.0, 0.2, NULL)" let list = vec![ @@ -412,7 +475,13 @@ mod tests { lit(ScalarValue::Float64(Some(0.1))), lit(ScalarValue::Utf8(None)), ]; - in_list!(batch, list, &false, vec![Some(true), None, None]); + in_list!( + batch, + list, + &false, + vec![Some(true), None, None], + col_a.clone() + ); // expression: "a not in (0.0, 0.2, NULL)" let list = vec![ @@ -420,7 +489,13 @@ mod tests { lit(ScalarValue::Float64(Some(0.1))), lit(ScalarValue::Utf8(None)), ]; - in_list!(batch, list, &true, vec![Some(false), None, None]); + in_list!( + batch, + list, + &true, + vec![Some(false), None, None], + col_a.clone() + ); Ok(()) } @@ -429,29 +504,30 @@ mod tests { fn in_list_bool() -> Result<()> { let schema = Schema::new(vec![Field::new("a", DataType::Boolean, true)]); let a = BooleanArray::from(vec![Some(true), None]); + let col_a = col("a", &schema)?; let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?; // expression: "a in (true)" let list = vec![lit(ScalarValue::Boolean(Some(true)))]; - in_list!(batch, list, &false, vec![Some(true), None]); + in_list!(batch, list, &false, vec![Some(true), None], col_a.clone()); // expression: "a not in (true)" let list = vec![lit(ScalarValue::Boolean(Some(true)))]; - in_list!(batch, list, &true, vec![Some(false), None]); + in_list!(batch, list, &true, vec![Some(false), None], col_a.clone()); // expression: "a in (true, NULL)" let list = vec![ lit(ScalarValue::Boolean(Some(true))), lit(ScalarValue::Utf8(None)), ]; - in_list!(batch, list, &false, vec![Some(true), None]); + in_list!(batch, list, &false, vec![Some(true), None], col_a.clone()); // expression: "a not in (true, NULL)" let list = vec![ lit(ScalarValue::Boolean(Some(true))), lit(ScalarValue::Utf8(None)), ]; - in_list!(batch, list, &true, vec![Some(false), None]); + in_list!(batch, list, &true, vec![Some(false), None], col_a.clone()); Ok(()) } diff --git a/datafusion/src/physical_plan/expressions/is_not_null.rs b/datafusion/src/physical_plan/expressions/is_not_null.rs index 7ac2110b5022..cce27e36a68c 100644 --- a/datafusion/src/physical_plan/expressions/is_not_null.rs +++ b/datafusion/src/physical_plan/expressions/is_not_null.rs @@ -100,10 +100,10 @@ mod tests { fn is_not_null_op() -> Result<()> { let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]); let a = StringArray::from(vec![Some("foo"), None]); + let expr = is_not_null(col("a", &schema)?).unwrap(); let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?; // expression: "a is not null" - let expr = is_not_null(col("a")).unwrap(); let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); let result = result .as_any() diff --git a/datafusion/src/physical_plan/expressions/is_null.rs b/datafusion/src/physical_plan/expressions/is_null.rs index dfa53f3f7d26..dbb57dfa5f8b 100644 --- a/datafusion/src/physical_plan/expressions/is_null.rs +++ b/datafusion/src/physical_plan/expressions/is_null.rs @@ -100,10 +100,11 @@ mod tests { fn is_null_op() -> Result<()> { let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]); let a = StringArray::from(vec![Some("foo"), None]); - let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?; // expression: "a is null" - let expr = is_null(col("a")).unwrap(); + let expr = is_null(col("a", &schema)?).unwrap(); + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?; + let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); let result = result .as_any() diff --git a/datafusion/src/physical_plan/expressions/min_max.rs b/datafusion/src/physical_plan/expressions/min_max.rs index 5ed14610ada3..cfb3a865e653 100644 --- a/datafusion/src/physical_plan/expressions/min_max.rs +++ b/datafusion/src/physical_plan/expressions/min_max.rs @@ -274,7 +274,7 @@ macro_rules! min_max { } e => { return Err(DataFusionError::Internal(format!( - "MIN/MAX is not expected to receive a scalar {:?}", + "MIN/MAX is not expected to receive scalars of incompatible types {:?}", e ))) } diff --git a/datafusion/src/physical_plan/expressions/mod.rs b/datafusion/src/physical_plan/expressions/mod.rs index 6e252205955d..e93001911ef3 100644 --- a/datafusion/src/physical_plan/expressions/mod.rs +++ b/datafusion/src/physical_plan/expressions/mod.rs @@ -60,6 +60,7 @@ pub use not::{not, NotExpr}; pub use nullif::{nullif_func, SUPPORTED_NULLIF_TYPES}; pub use sum::{sum_return_type, Sum}; pub use try_cast::{try_cast, TryCastExpr}; + /// returns the name of the state pub fn format_state_name(name: &str, state_name: &str) -> String { format!("{}[{}]", name, state_name) @@ -107,8 +108,11 @@ mod tests { let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![$ARRAY])?; - let agg = - Arc::new(<$OP>::new(col("a"), "bla".to_string(), $EXPECTED_DATATYPE)); + let agg = Arc::new(<$OP>::new( + col("a", &schema)?, + "bla".to_string(), + $EXPECTED_DATATYPE, + )); let actual = aggregate(&batch, agg)?; let expected = ScalarValue::from($EXPECTED); diff --git a/datafusion/src/physical_plan/expressions/not.rs b/datafusion/src/physical_plan/expressions/not.rs index 23a1a46651de..4b2de2266ecb 100644 --- a/datafusion/src/physical_plan/expressions/not.rs +++ b/datafusion/src/physical_plan/expressions/not.rs @@ -127,7 +127,7 @@ mod tests { fn neg_op() -> Result<()> { let schema = Schema::new(vec![Field::new("a", DataType::Boolean, true)]); - let expr = not(col("a"), &schema)?; + let expr = not(col("a", &schema)?, &schema)?; assert_eq!(expr.data_type(&schema)?, DataType::Boolean); assert_eq!(expr.nullable(&schema)?, true); @@ -152,7 +152,7 @@ mod tests { fn neg_op_not_null() { let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]); - let expr = not(col("a"), &schema); + let expr = not(col("a", &schema).unwrap(), &schema); assert!(expr.is_err()); } } diff --git a/datafusion/src/physical_plan/expressions/try_cast.rs b/datafusion/src/physical_plan/expressions/try_cast.rs index 5e402fdea28a..1ba4a50260d4 100644 --- a/datafusion/src/physical_plan/expressions/try_cast.rs +++ b/datafusion/src/physical_plan/expressions/try_cast.rs @@ -139,10 +139,13 @@ mod tests { RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; // verify that we can construct the expression - let expression = try_cast(col("a"), &schema, $TYPE)?; + let expression = try_cast(col("a", &schema)?, &schema, $TYPE)?; // verify that its display is correct - assert_eq!(format!("CAST(a AS {:?})", $TYPE), format!("{}", expression)); + assert_eq!( + format!("CAST(a@0 AS {:?})", $TYPE), + format!("{}", expression) + ); // verify that the expression's type is correct assert_eq!(expression.data_type(&schema)?, $TYPE); @@ -241,7 +244,7 @@ mod tests { // Ensure a useful error happens at plan time if invalid casts are used let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); - let result = try_cast(col("a"), &schema, DataType::LargeBinary); + let result = try_cast(col("a", &schema).unwrap(), &schema, DataType::LargeBinary); result.expect_err("expected Invalid CAST"); } } diff --git a/datafusion/src/physical_plan/filter.rs b/datafusion/src/physical_plan/filter.rs index 61af78db8ed2..be06e61db968 100644 --- a/datafusion/src/physical_plan/filter.rs +++ b/datafusion/src/physical_plan/filter.rs @@ -209,14 +209,14 @@ mod tests { let predicate: Arc = binary( binary( - col("c2"), + col("c2", &schema)?, Operator::Gt, lit(ScalarValue::from(1u32)), &schema, )?, Operator::And, binary( - col("c2"), + col("c2", &schema)?, Operator::Lt, lit(ScalarValue::from(4u32)), &schema, diff --git a/datafusion/src/physical_plan/functions.rs b/datafusion/src/physical_plan/functions.rs index 56365fec1dc8..f9b5a5891bd7 100644 --- a/datafusion/src/physical_plan/functions.rs +++ b/datafusion/src/physical_plan/functions.rs @@ -3638,7 +3638,7 @@ mod tests { let expr = create_physical_expr( &BuiltinScalarFunction::Array, - &[col("a"), col("b")], + &[col("a", &schema)?, col("b", &schema)?], &schema, )?; @@ -3702,7 +3702,7 @@ mod tests { let columns: Vec = vec![col_value]; let expr = create_physical_expr( &BuiltinScalarFunction::RegexpMatch, - &[col("a"), pattern], + &[col("a", &schema)?, pattern], &schema, )?; diff --git a/datafusion/src/physical_plan/hash_aggregate.rs b/datafusion/src/physical_plan/hash_aggregate.rs index 234265022ef7..33869aef48fb 100644 --- a/datafusion/src/physical_plan/hash_aggregate.rs +++ b/datafusion/src/physical_plan/hash_aggregate.rs @@ -608,9 +608,12 @@ async fn compute_grouped_hash_aggregate( aggr_expr: Vec>, mut input: SendableRecordBatchStream, ) -> ArrowResult { - // the expressions to evaluate the batch, one vec of expressions per aggregation - let aggregate_expressions = aggregate_expressions(&aggr_expr, &mode) - .map_err(DataFusionError::into_arrow_external_error)?; + // The expressions to evaluate the batch, one vec of expressions per aggregation. + // Assume create_schema() always put group columns in front of aggr columns, we set + // col_idx_base to group expression count. + let aggregate_expressions = + aggregate_expressions(&aggr_expr, &mode, group_expr.len()) + .map_err(DataFusionError::into_arrow_external_error)?; // mapping key -> (set of accumulators, indices of the key in the batch) // * the indexes are updated at each row @@ -740,14 +743,21 @@ fn evaluate_many( .collect::>>() } -/// uses `state_fields` to build a vec of expressions required to merge the AggregateExpr' accumulator's state. +/// uses `state_fields` to build a vec of physical column expressions required to merge the +/// AggregateExpr' accumulator's state. +/// +/// `index_base` is the starting physical column index for the next expanded state field. fn merge_expressions( + index_base: usize, expr: &Arc, ) -> Result>> { Ok(expr .state_fields()? .iter() - .map(|f| Arc::new(Column::new(f.name())) as Arc) + .enumerate() + .map(|(idx, f)| { + Arc::new(Column::new(f.name(), index_base + idx)) as Arc + }) .collect::>()) } @@ -755,22 +765,27 @@ fn merge_expressions( /// The expressions are different depending on `mode`: /// * Partial: AggregateExpr::expressions /// * Final: columns of `AggregateExpr::state_fields()` -/// The return value is to be understood as: -/// * index 0 is the aggregation -/// * index 1 is the expression i of the aggregation fn aggregate_expressions( aggr_expr: &[Arc], mode: &AggregateMode, + col_idx_base: usize, ) -> Result>>> { match mode { AggregateMode::Partial => { Ok(aggr_expr.iter().map(|agg| agg.expressions()).collect()) } // in this mode, we build the merge expressions of the aggregation - AggregateMode::Final => Ok(aggr_expr - .iter() - .map(|agg| merge_expressions(agg)) - .collect::>>()?), + AggregateMode::Final => { + let mut col_idx_base = col_idx_base; + Ok(aggr_expr + .iter() + .map(|agg| { + let exprs = merge_expressions(col_idx_base, agg)?; + col_idx_base += exprs.len(); + Ok(exprs) + }) + .collect::>>()?) + } } } @@ -791,10 +806,8 @@ async fn compute_hash_aggregate( ) -> ArrowResult { let mut accumulators = create_accumulators(&aggr_expr) .map_err(DataFusionError::into_arrow_external_error)?; - - let expressions = aggregate_expressions(&aggr_expr, &mode) + let expressions = aggregate_expressions(&aggr_expr, &mode, 0) .map_err(DataFusionError::into_arrow_external_error)?; - let expressions = Arc::new(expressions); // 1 for each batch, update / merge accumulators with the expressions' values @@ -1215,16 +1228,17 @@ mod tests { /// build the aggregates on the data from some_data() and check the results async fn check_aggregates(input: Arc) -> Result<()> { + let input_schema = input.schema(); + let groups: Vec<(Arc, String)> = - vec![(col("a"), "a".to_string())]; + vec![(col("a", &input_schema)?, "a".to_string())]; let aggregates: Vec> = vec![Arc::new(Avg::new( - col("b"), + col("b", &input_schema)?, "AVG(b)".to_string(), DataType::Float64, ))]; - let input_schema = input.schema(); let partial_aggregate = Arc::new(HashAggregateExec::try_new( AggregateMode::Partial, groups.clone(), @@ -1248,8 +1262,9 @@ mod tests { let merge = Arc::new(MergeExec::new(partial_aggregate)); - let final_group: Vec> = - (0..groups.len()).map(|i| col(&groups[i].1)).collect(); + let final_group: Vec> = (0..groups.len()) + .map(|i| col(&groups[i].1, &input_schema)) + .collect::>()?; let merged_aggregate = Arc::new(HashAggregateExec::try_new( AggregateMode::Final, diff --git a/datafusion/src/physical_plan/hash_join.rs b/datafusion/src/physical_plan/hash_join.rs index 401fe6580a92..053a00f4e6f2 100644 --- a/datafusion/src/physical_plan/hash_join.rs +++ b/datafusion/src/physical_plan/hash_join.rs @@ -51,7 +51,7 @@ use arrow::array::{ UInt64Array, UInt8Array, }; -use super::expressions::col; +use super::expressions::Column; use super::{ hash_utils::{build_join_schema, check_join_is_valid, JoinOn, JoinType}, merge::MergeExec, @@ -60,6 +60,7 @@ use crate::error::{DataFusionError, Result}; use super::{ExecutionPlan, Partitioning, RecordBatchStream, SendableRecordBatchStream}; use crate::physical_plan::coalesce_batches::concat_batches; +use crate::physical_plan::PhysicalExpr; use log::debug; // Maps a `u64` hash value based on the left ["on" values] to a list of indices with this key's value. @@ -77,7 +78,7 @@ pub struct HashJoinExec { /// right (probe) side which are filtered by the hash table right: Arc, /// Set of common columns used to join on - on: Vec<(String, String)>, + on: Vec<(Column, Column)>, /// How the join is performed join_type: JoinType, /// The schema once the join is applied @@ -114,7 +115,7 @@ impl HashJoinExec { pub fn try_new( left: Arc, right: Arc, - on: &JoinOn, + on: JoinOn, join_type: &JoinType, partition_mode: PartitionMode, ) -> Result { @@ -125,15 +126,10 @@ impl HashJoinExec { let schema = Arc::new(build_join_schema( &left_schema, &right_schema, - on, + &on, &join_type, )); - let on = on - .iter() - .map(|(l, r)| (l.to_string(), r.to_string())) - .collect(); - let random_state = RandomState::with_seeds(0, 0, 0, 0); Ok(HashJoinExec { @@ -159,7 +155,7 @@ impl HashJoinExec { } /// Set of common columns used to join on - pub fn on(&self) -> &[(String, String)] { + pub fn on(&self) -> &[(Column, Column)] { &self.on } @@ -221,7 +217,7 @@ impl ExecutionPlan for HashJoinExec { 2 => Ok(Arc::new(HashJoinExec::try_new( children[0].clone(), children[1].clone(), - &self.on, + self.on.clone(), &self.join_type, self.mode, )?)), @@ -293,10 +289,10 @@ impl ExecutionPlan for HashJoinExec { *build_side = Some(left_side.clone()); debug!( - "Built build-side of hash join containing {} rows in {} ms", - num_rows, - start.elapsed().as_millis() - ); + "Built build-side of hash join containing {} rows in {} ms", + num_rows, + start.elapsed().as_millis() + ); left_side } @@ -313,9 +309,9 @@ impl ExecutionPlan for HashJoinExec { // 2. stores the batches in a vector. let initial = ( JoinHashMap::with_hasher(IdHashBuilder {}), - Vec::new(), - 0, - Vec::new(), + Vec::new(), // values + 0, // row count + Vec::new(), // batch hashes buffer ); let (hashmap, batches, num_rows, _) = stream .try_fold(initial, |mut acc, batch| async { @@ -361,32 +357,28 @@ impl ExecutionPlan for HashJoinExec { // we have the batches and the hash map with their keys. We can how create a stream // over the right that uses this information to issue new batches. - let stream = self.right.execute(partition).await?; + let right_stream = self.right.execute(partition).await?; let on_right = self.on.iter().map(|on| on.1.clone()).collect::>(); let column_indices = self.column_indices_from_schema()?; - Ok(Box::pin(HashJoinStream { - schema: self.schema.clone(), + + Ok(Box::pin(HashJoinStream::new( + self.schema.clone(), on_left, on_right, - join_type: self.join_type, + self.join_type, left_data, - right: stream, + right_stream, column_indices, - num_input_batches: 0, - num_input_rows: 0, - num_output_batches: 0, - num_output_rows: 0, - join_time: 0, - random_state: self.random_state.clone(), - })) + self.random_state.clone(), + ))) } } /// Updates `hash` with new entries from [RecordBatch] evaluated against the expressions `on`, /// assuming that the [RecordBatch] corresponds to the `index`th fn update_hash( - on: &[String], + on: &[Column], batch: &RecordBatch, hash: &mut JoinHashMap, offset: usize, @@ -396,7 +388,7 @@ fn update_hash( // evaluate the keys let keys_values = on .iter() - .map(|name| Ok(col(name).evaluate(batch)?.into_array(batch.num_rows()))) + .map(|c| Ok(c.evaluate(batch)?.into_array(batch.num_rows()))) .collect::>>()?; // update the hash map @@ -417,9 +409,9 @@ struct HashJoinStream { /// Input schema schema: Arc, /// columns from the left - on_left: Vec, + on_left: Vec, /// columns from the right used to compute the hash - on_right: Vec, + on_right: Vec, /// type of the join join_type: JoinType, /// information from the left @@ -442,6 +434,35 @@ struct HashJoinStream { random_state: RandomState, } +impl HashJoinStream { + fn new( + schema: Arc, + on_left: Vec, + on_right: Vec, + join_type: JoinType, + left_data: JoinLeftData, + right: SendableRecordBatchStream, + column_indices: Vec, + random_state: RandomState, + ) -> Self { + HashJoinStream { + schema, + on_left, + on_right, + join_type, + left_data, + right, + column_indices, + num_input_batches: 0, + num_input_rows: 0, + num_output_batches: 0, + num_output_rows: 0, + join_time: 0, + random_state, + } + } +} + impl RecordBatchStream for HashJoinStream { fn schema(&self) -> SchemaRef { self.schema.clone() @@ -483,8 +504,8 @@ fn build_batch_from_indices( fn build_batch( batch: &RecordBatch, left_data: &JoinLeftData, - on_left: &[String], - on_right: &[String], + on_left: &[Column], + on_right: &[Column], join_type: JoinType, schema: &Schema, column_indices: &[ColumnIndex], @@ -541,21 +562,17 @@ fn build_join_indexes( left_data: &JoinLeftData, right: &RecordBatch, join_type: JoinType, - left_on: &[String], - right_on: &[String], + left_on: &[Column], + right_on: &[Column], random_state: &RandomState, ) -> Result<(UInt64Array, UInt32Array)> { let keys_values = right_on .iter() - .map(|name| Ok(col(name).evaluate(right)?.into_array(right.num_rows()))) + .map(|c| Ok(c.evaluate(right)?.into_array(right.num_rows()))) .collect::>>()?; let left_join_values = left_on .iter() - .map(|name| { - Ok(col(name) - .evaluate(&left_data.1)? - .into_array(left_data.1.num_rows())) - }) + .map(|c| Ok(c.evaluate(&left_data.1)?.into_array(left_data.1.num_rows()))) .collect::>>()?; let hashes_buffer = &mut vec![0; keys_values[0].len()]; let hash_values = create_hashes(&keys_values, &random_state, hashes_buffer)?; @@ -854,6 +871,7 @@ impl Stream for HashJoinStream { .map(|maybe_batch| match maybe_batch { Some(Ok(batch)) => { let start = Instant::now(); + let result = build_batch( &batch, &self.left_data, @@ -874,6 +892,7 @@ impl Stream for HashJoinStream { Some(result) } other => { + // End of right batch, print stats in debug mode debug!( "Processed {} probe-side input batches containing {} rows and \ produced {} output batches containing {} rows in {} ms", @@ -893,7 +912,9 @@ impl Stream for HashJoinStream { mod tests { use crate::{ assert_batches_sorted_eq, - physical_plan::{common, memory::MemoryExec}, + physical_plan::{ + common, expressions::Column, memory::MemoryExec, repartition::RepartitionExec, + }, test::{build_table_i32, columns}, }; @@ -913,14 +934,74 @@ mod tests { fn join( left: Arc, right: Arc, - on: &[(&str, &str)], + on: JoinOn, join_type: &JoinType, ) -> Result { - let on: Vec<_> = on + HashJoinExec::try_new(left, right, on, join_type, PartitionMode::CollectLeft) + } + + async fn join_collect( + left: Arc, + right: Arc, + on: JoinOn, + join_type: &JoinType, + ) -> Result<(Vec, Vec)> { + let join = join(left, right, on, join_type)?; + let columns = columns(&join.schema()); + + let stream = join.execute(0).await?; + let batches = common::collect(stream).await?; + + Ok((columns, batches)) + } + + async fn partitioned_join_collect( + left: Arc, + right: Arc, + on: JoinOn, + join_type: &JoinType, + ) -> Result<(Vec, Vec)> { + let partition_count = 4; + + let (left_expr, right_expr) = on .iter() - .map(|(l, r)| (l.to_string(), r.to_string())) - .collect(); - HashJoinExec::try_new(left, right, &on, join_type, PartitionMode::CollectLeft) + .map(|(l, r)| { + ( + Arc::new(l.clone()) as Arc, + Arc::new(r.clone()) as Arc, + ) + }) + .unzip(); + + let join = HashJoinExec::try_new( + Arc::new(RepartitionExec::try_new( + left, + Partitioning::Hash(left_expr, partition_count), + )?), + Arc::new(RepartitionExec::try_new( + right, + Partitioning::Hash(right_expr, partition_count), + )?), + on, + join_type, + PartitionMode::Partitioned, + )?; + + let columns = columns(&join.schema()); + + let mut batches = vec![]; + for i in 0..partition_count { + let stream = join.execute(i).await?; + let more_batches = common::collect(stream).await?; + batches.extend( + more_batches + .into_iter() + .filter(|b| b.num_rows() > 0) + .collect::>(), + ); + } + + Ok((columns, batches)) } #[tokio::test] @@ -935,15 +1016,57 @@ mod tests { ("b1", &vec![4, 5, 6]), ("c2", &vec![70, 80, 90]), ); - let on = &[("b1", "b1")]; + let on = vec![( + Column::new("b1", left.schema().index_of("b1")?), + Column::new("b1", right.schema().index_of("b1")?), + )]; - let join = join(left, right, on, &JoinType::Inner)?; + let (columns, batches) = + join_collect(left.clone(), right.clone(), on.clone(), &JoinType::Inner) + .await?; - let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "c2"]); - let stream = join.execute(0).await?; - let batches = common::collect(stream).await?; + let expected = vec![ + "+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | c2 |", + "+----+----+----+----+----+", + "| 1 | 4 | 7 | 10 | 70 |", + "| 2 | 5 | 8 | 20 | 80 |", + "| 3 | 5 | 9 | 20 | 80 |", + "+----+----+----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &batches); + + Ok(()) + } + + #[tokio::test] + async fn partitioned_join_inner_one() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 5]), // this has a repetition + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), + ("c2", &vec![70, 80, 90]), + ); + let on = vec![( + Column::new("b1", left.schema().index_of("b1")?), + Column::new("b1", right.schema().index_of("b1")?), + )]; + + let (columns, batches) = partitioned_join_collect( + left.clone(), + right.clone(), + on.clone(), + &JoinType::Inner, + ) + .await?; + + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "c2"]); let expected = vec![ "+----+----+----+----+----+", @@ -971,16 +1094,15 @@ mod tests { ("b2", &vec![4, 5, 6]), ("c2", &vec![70, 80, 90]), ); - let on = &[("b1", "b2")]; + let on = vec![( + Column::new("b1", left.schema().index_of("b1")?), + Column::new("b2", right.schema().index_of("b2")?), + )]; - let join = join(left, right, on, &JoinType::Inner)?; + let (columns, batches) = join_collect(left, right, on, &JoinType::Inner).await?; - let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); - let stream = join.execute(0).await?; - let batches = common::collect(stream).await?; - let expected = vec![ "+----+----+----+----+----+----+", "| a1 | b1 | c1 | a2 | b2 | c2 |", @@ -1008,15 +1130,21 @@ mod tests { ("b2", &vec![1, 2, 2]), ("c2", &vec![70, 80, 90]), ); - let on = &[("a1", "a1"), ("b2", "b2")]; + let on = vec![ + ( + Column::new("a1", left.schema().index_of("a1")?), + Column::new("a1", right.schema().index_of("a1")?), + ), + ( + Column::new("b2", left.schema().index_of("b2")?), + Column::new("b2", right.schema().index_of("b2")?), + ), + ]; - let join = join(left, right, on, &JoinType::Inner)?; + let (columns, batches) = join_collect(left, right, on, &JoinType::Inner).await?; - let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "b2", "c1", "c2"]); - let stream = join.execute(0).await?; - let batches = common::collect(stream).await?; assert_eq!(batches.len(), 1); let expected = vec![ @@ -1054,15 +1182,21 @@ mod tests { ("b2", &vec![1, 2, 2]), ("c2", &vec![70, 80, 90]), ); - let on = &[("a1", "a1"), ("b2", "b2")]; + let on = vec![ + ( + Column::new("a1", left.schema().index_of("a1")?), + Column::new("a1", right.schema().index_of("a1")?), + ), + ( + Column::new("b2", left.schema().index_of("b2")?), + Column::new("b2", right.schema().index_of("b2")?), + ), + ]; - let join = join(left, right, on, &JoinType::Inner)?; + let (columns, batches) = join_collect(left, right, on, &JoinType::Inner).await?; - let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "b2", "c1", "c2"]); - let stream = join.execute(0).await?; - let batches = common::collect(stream).await?; assert_eq!(batches.len(), 1); let expected = vec![ @@ -1101,7 +1235,10 @@ mod tests { MemoryExec::try_new(&[vec![batch1], vec![batch2]], schema, None).unwrap(), ); - let on = &[("b1", "b1")]; + let on = vec![( + Column::new("b1", left.schema().index_of("b1")?), + Column::new("b1", right.schema().index_of("b1")?), + )]; let join = join(left, right, on, &JoinType::Inner)?; @@ -1152,15 +1289,55 @@ mod tests { ("b1", &vec![4, 5, 6]), ("c2", &vec![70, 80, 90]), ); - let on = &[("b1", "b1")]; + let on = vec![( + Column::new("b1", left.schema().index_of("b1")?), + Column::new("b1", right.schema().index_of("b1")?), + )]; + + let (columns, batches) = + join_collect(left.clone(), right.clone(), on.clone(), &JoinType::Left) + .await?; + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "c2"]); - let join = join(left, right, on, &JoinType::Left)?; + let expected = vec![ + "+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | c2 |", + "+----+----+----+----+----+", + "| 1 | 4 | 7 | 10 | 70 |", + "| 2 | 5 | 8 | 20 | 80 |", + "| 3 | 7 | 9 | | |", + "+----+----+----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &batches); - let columns = columns(&join.schema()); - assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "c2"]); + Ok(()) + } - let stream = join.execute(0).await?; - let batches = common::collect(stream).await?; + #[tokio::test] + async fn partitioned_join_left_one() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 7]), // 7 does not exist on the right + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), + ("c2", &vec![70, 80, 90]), + ); + let on = vec![( + Column::new("b1", left.schema().index_of("b1")?), + Column::new("b1", right.schema().index_of("b1")?), + )]; + + let (columns, batches) = partitioned_join_collect( + left.clone(), + right.clone(), + on.clone(), + &JoinType::Left, + ) + .await?; + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "c2"]); let expected = vec![ "+----+----+----+----+----+", @@ -1188,15 +1365,51 @@ mod tests { ("b1", &vec![4, 5, 6]), // 6 does not exist on the left ("c2", &vec![70, 80, 90]), ); - let on = &[("b1", "b1")]; + let on = vec![( + Column::new("b1", left.schema().index_of("b1")?), + Column::new("b1", right.schema().index_of("b1")?), + )]; - let join = join(left, right, on, &JoinType::Right)?; + let (columns, batches) = join_collect(left, right, on, &JoinType::Right).await?; - let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "c1", "a2", "b1", "c2"]); - let stream = join.execute(0).await?; - let batches = common::collect(stream).await?; + let expected = vec![ + "+----+----+----+----+----+", + "| a1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+", + "| | | 30 | 6 | 90 |", + "| 1 | 7 | 10 | 4 | 70 |", + "| 2 | 8 | 20 | 5 | 80 |", + "+----+----+----+----+----+", + ]; + + assert_batches_sorted_eq!(expected, &batches); + + Ok(()) + } + + #[tokio::test] + async fn partitioned_join_right_one() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 7]), + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), // 6 does not exist on the left + ("c2", &vec![70, 80, 90]), + ); + let on = vec![( + Column::new("b1", left.schema().index_of("b1")?), + Column::new("b1", right.schema().index_of("b1")?), + )]; + + let (columns, batches) = + partitioned_join_collect(left, right, on, &JoinType::Right).await?; + + assert_eq!(columns, vec!["a1", "c1", "a2", "b1", "c2"]); let expected = vec![ "+----+----+----+----+----+", @@ -1242,8 +1455,8 @@ mod tests { &left_data, &right, JoinType::Inner, - &["a".to_string()], - &["a".to_string()], + &[Column::new("a", 0)], + &[Column::new("a", 0)], &random_state, )?; @@ -1252,7 +1465,6 @@ mod tests { left_ids.append_value(1)?; let mut right_ids = UInt32Builder::new(0); - right_ids.append_value(0)?; right_ids.append_value(1)?; diff --git a/datafusion/src/physical_plan/hash_utils.rs b/datafusion/src/physical_plan/hash_utils.rs index b26ff9bb5fc2..6060bd0b5bb5 100644 --- a/datafusion/src/physical_plan/hash_utils.rs +++ b/datafusion/src/physical_plan/hash_utils.rs @@ -21,6 +21,8 @@ use crate::error::{DataFusionError, Result}; use arrow::datatypes::{Field, Schema}; use std::collections::HashSet; +use crate::physical_plan::expressions::Column; + /// All valid types of joins. #[derive(Clone, Copy, Debug)] pub enum JoinType { @@ -33,14 +35,23 @@ pub enum JoinType { } /// The on clause of the join, as vector of (left, right) columns. -pub type JoinOn = [(String, String)]; +pub type JoinOn = Vec<(Column, Column)>; /// Checks whether the schemas "left" and "right" and columns "on" represent a valid join. /// They are valid whenever their columns' intersection equals the set `on` pub fn check_join_is_valid(left: &Schema, right: &Schema, on: &JoinOn) -> Result<()> { - let left: HashSet = left.fields().iter().map(|f| f.name().clone()).collect(); - let right: HashSet = - right.fields().iter().map(|f| f.name().clone()).collect(); + let left: HashSet = left + .fields() + .iter() + .enumerate() + .map(|(idx, f)| Column::new(f.name(), idx)) + .collect(); + let right: HashSet = right + .fields() + .iter() + .enumerate() + .map(|(idx, f)| Column::new(f.name(), idx)) + .collect(); check_join_set_is_valid(&left, &right, on) } @@ -48,19 +59,19 @@ pub fn check_join_is_valid(left: &Schema, right: &Schema, on: &JoinOn) -> Result /// Checks whether the sets left, right and on compose a valid join. /// They are valid whenever their intersection equals the set `on` fn check_join_set_is_valid( - left: &HashSet, - right: &HashSet, - on: &JoinOn, + left: &HashSet, + right: &HashSet, + on: &[(Column, Column)], ) -> Result<()> { if on.is_empty() { return Err(DataFusionError::Plan( "The 'on' clause of a join cannot be empty".to_string(), )); } - let on_left = &on.iter().map(|on| on.0.to_string()).collect::>(); + let on_left = &on.iter().map(|on| on.0.clone()).collect::>(); let left_missing = on_left.difference(left).collect::>(); - let on_right = &on.iter().map(|on| on.1.to_string()).collect::>(); + let on_right = &on.iter().map(|on| on.1.clone()).collect::>(); let right_missing = on_right.difference(right).collect::>(); if !left_missing.is_empty() | !right_missing.is_empty() { @@ -74,7 +85,7 @@ fn check_join_set_is_valid( let remaining = right .difference(on_right) .cloned() - .collect::>(); + .collect::>(); let collisions = left.intersection(&remaining).collect::>(); @@ -101,8 +112,8 @@ pub fn build_join_schema( // remove right-side join keys if they have the same names as the left-side let duplicate_keys = &on .iter() - .filter(|(l, r)| l == r) - .map(|on| on.1.to_string()) + .filter(|(l, r)| l.name() == r.name()) + .map(|on| on.1.name()) .collect::>(); let left_fields = left.fields().iter(); @@ -110,7 +121,7 @@ pub fn build_join_schema( let right_fields = right .fields() .iter() - .filter(|f| !duplicate_keys.contains(f.name())); + .filter(|f| !duplicate_keys.contains(f.name().as_str())); // left then right left_fields.chain(right_fields).cloned().collect() @@ -119,14 +130,14 @@ pub fn build_join_schema( // remove left-side join keys if they have the same names as the right-side let duplicate_keys = &on .iter() - .filter(|(l, r)| l == r) - .map(|on| on.1.to_string()) + .filter(|(l, r)| l.name() == r.name()) + .map(|on| on.1.name()) .collect::>(); let left_fields = left .fields() .iter() - .filter(|f| !duplicate_keys.contains(f.name())); + .filter(|f| !duplicate_keys.contains(f.name().as_str())); let right_fields = right.fields().iter(); @@ -139,24 +150,25 @@ pub fn build_join_schema( #[cfg(test)] mod tests { - use super::*; - fn check(left: &[&str], right: &[&str], on: &[(&str, &str)]) -> Result<()> { - let left = left.iter().map(|x| x.to_string()).collect::>(); - let right = right.iter().map(|x| x.to_string()).collect::>(); - let on: Vec<_> = on + fn check(left: &[Column], right: &[Column], on: &[(Column, Column)]) -> Result<()> { + let left = left + .iter() + .map(|x| x.to_owned()) + .collect::>(); + let right = right .iter() - .map(|(l, r)| (l.to_string(), r.to_string())) - .collect(); - check_join_set_is_valid(&left, &right, &on) + .map(|x| x.to_owned()) + .collect::>(); + check_join_set_is_valid(&left, &right, on) } #[test] fn check_valid() -> Result<()> { - let left = vec!["a", "b1"]; - let right = vec!["a", "b2"]; - let on = &[("a", "a")]; + let left = vec![Column::new("a", 0), Column::new("b1", 1)]; + let right = vec![Column::new("a", 0), Column::new("b2", 1)]; + let on = &[(Column::new("a", 0), Column::new("a", 0))]; check(&left, &right, on)?; Ok(()) @@ -164,18 +176,18 @@ mod tests { #[test] fn check_not_in_right() { - let left = vec!["a", "b"]; - let right = vec!["b"]; - let on = &[("a", "a")]; + let left = vec![Column::new("a", 0), Column::new("b", 1)]; + let right = vec![Column::new("b", 0)]; + let on = &[(Column::new("a", 0), Column::new("a", 0))]; assert!(check(&left, &right, on).is_err()); } #[test] fn check_not_in_left() { - let left = vec!["b"]; - let right = vec!["a"]; - let on = &[("a", "a")]; + let left = vec![Column::new("b", 0)]; + let right = vec![Column::new("a", 0)]; + let on = &[(Column::new("a", 0), Column::new("a", 0))]; assert!(check(&left, &right, on).is_err()); } @@ -183,18 +195,18 @@ mod tests { #[test] fn check_collision() { // column "a" would appear both in left and right - let left = vec!["a", "c"]; - let right = vec!["a", "b"]; - let on = &[("a", "b")]; + let left = vec![Column::new("a", 0), Column::new("c", 1)]; + let right = vec![Column::new("a", 0), Column::new("b", 1)]; + let on = &[(Column::new("a", 0), Column::new("b", 1))]; assert!(check(&left, &right, on).is_err()); } #[test] fn check_in_right() { - let left = vec!["a", "c"]; - let right = vec!["b"]; - let on = &[("a", "b")]; + let left = vec![Column::new("a", 0), Column::new("c", 1)]; + let right = vec![Column::new("b", 0)]; + let on = &[(Column::new("a", 0), Column::new("b", 0))]; assert!(check(&left, &right, on).is_ok()); } diff --git a/datafusion/src/physical_plan/parquet.rs b/datafusion/src/physical_plan/parquet.rs index d41d6968fee0..fc94ad51f22c 100644 --- a/datafusion/src/physical_plan/parquet.rs +++ b/datafusion/src/physical_plan/parquet.rs @@ -17,6 +17,7 @@ //! Execution plan for reading Parquet files +use std::convert::TryFrom; use std::fmt; use std::fs::File; use std::sync::Arc; @@ -37,7 +38,7 @@ use crate::{ use crate::{ error::{DataFusionError, Result}, execution::context::ExecutionContextState, - logical_plan::{Expr, Operator}, + logical_plan::{Column, DFSchema, Expr, Operator}, optimizer::utils, prelude::ExecutionConfig, }; @@ -235,6 +236,7 @@ impl ParquetExec { schemas.len() ))); } + // FIXME: what if schemas has size 0? let schema = schemas[0].clone(); let predicate_builder = predicate.and_then(|predicate_expr| { RowGroupPredicateBuilder::try_new(&predicate_expr, schema.clone()).ok() @@ -371,7 +373,7 @@ impl ParquetPartition { pub struct RowGroupPredicateBuilder { parquet_schema: Schema, predicate_expr: Arc, - stat_column_req: Vec<(String, StatisticsType, Field)>, + stat_column_req: Vec<(Column, StatisticsType, Field)>, } impl RowGroupPredicateBuilder { @@ -381,19 +383,16 @@ impl RowGroupPredicateBuilder { /// then convert it to a DataFusion PhysicalExpression and cache it for later use by build_row_group_predicate. pub fn try_new(expr: &Expr, parquet_schema: Schema) -> Result { // build predicate expression once - let mut stat_column_req = Vec::<(String, StatisticsType, Field)>::new(); + let mut stat_column_req = Vec::<(Column, StatisticsType, Field)>::new(); let logical_predicate_expr = build_predicate_expression(expr, &parquet_schema, &mut stat_column_req)?; - // println!( - // "RowGroupPredicateBuilder::try_new, logical_predicate_expr: {:?}", - // logical_predicate_expr - // ); // build physical predicate expression let stat_fields = stat_column_req .iter() .map(|(_, _, f)| f.clone()) .collect::>(); let stat_schema = Schema::new(stat_fields); + let stat_dfschema = DFSchema::try_from(stat_schema.clone())?; let execution_context_state = ExecutionContextState { catalog_list: Arc::new(MemoryCatalogList::new()), scalar_functions: HashMap::new(), @@ -404,12 +403,9 @@ impl RowGroupPredicateBuilder { let predicate_expr = DefaultPhysicalPlanner::default().create_physical_expr( &logical_predicate_expr, &stat_schema, + &stat_dfschema, &execution_context_state, )?; - // println!( - // "RowGroupPredicateBuilder::try_new, predicate_expr: {:?}", - // predicate_expr - // ); Ok(Self { parquet_schema, predicate_expr, @@ -475,12 +471,12 @@ impl RowGroupPredicateBuilder { fn build_statistics_record_batch( row_groups: &[RowGroupMetaData], parquet_schema: &Schema, - stat_column_req: &[(String, StatisticsType, Field)], + stat_column_req: &[(Column, StatisticsType, Field)], ) -> Result { let mut fields = Vec::::new(); let mut arrays = Vec::::new(); - for (column_name, statistics_type, stat_field) in stat_column_req { - if let Some((column_index, _)) = parquet_schema.column_with_name(column_name) { + for (column, statistics_type, stat_field) in stat_column_req { + if let Some((column_index, _)) = parquet_schema.column_with_name(&column.name) { let statistics = row_groups .iter() .map(|g| g.column(column_index).statistics()) @@ -500,11 +496,11 @@ fn build_statistics_record_batch( } struct StatisticsExpressionBuilder<'a> { - column_name: String, + column: Column, column_expr: &'a Expr, scalar_expr: &'a Expr, parquet_field: &'a Field, - stat_column_req: &'a mut Vec<(String, StatisticsType, Field)>, + stat_column_req: &'a mut Vec<(Column, StatisticsType, Field)>, reverse_operator: bool, } @@ -513,14 +509,14 @@ impl<'a> StatisticsExpressionBuilder<'a> { left: &'a Expr, right: &'a Expr, parquet_schema: &'a Schema, - stat_column_req: &'a mut Vec<(String, StatisticsType, Field)>, + stat_column_req: &'a mut Vec<(Column, StatisticsType, Field)>, ) -> Result { // find column name; input could be a more complicated expression - let mut left_columns = HashSet::::new(); + let mut left_columns = HashSet::::new(); utils::expr_to_column_names(left, &mut left_columns)?; - let mut right_columns = HashSet::::new(); + let mut right_columns = HashSet::::new(); utils::expr_to_column_names(right, &mut right_columns)?; - let (column_expr, scalar_expr, column_names, reverse_operator) = + let (column_expr, scalar_expr, columns, reverse_operator) = match (left_columns.len(), right_columns.len()) { (1, 0) => (left, right, left_columns, false), (0, 1) => (right, left, right_columns, true), @@ -532,8 +528,8 @@ impl<'a> StatisticsExpressionBuilder<'a> { )); } }; - let column_name = column_names.iter().next().unwrap().clone(); - let field = match parquet_schema.column_with_name(&column_name) { + let column = columns.iter().next().unwrap().clone(); + let field = match parquet_schema.column_with_name(&column.flat_name()) { Some((_, f)) => f, _ => { // field not found in parquet schema @@ -544,7 +540,7 @@ impl<'a> StatisticsExpressionBuilder<'a> { }; Ok(Self { - column_name, + column, column_expr, scalar_expr, parquet_field: field, @@ -582,7 +578,7 @@ impl<'a> StatisticsExpressionBuilder<'a> { fn is_stat_column_missing(&self, statistics_type: StatisticsType) -> bool { self.stat_column_req .iter() - .filter(|(c, t, _f)| c == &self.column_name && t == &statistics_type) + .filter(|(c, t, _f)| c == &self.column && t == &statistics_type) .count() == 0 } @@ -592,22 +588,21 @@ impl<'a> StatisticsExpressionBuilder<'a> { stat_type: StatisticsType, suffix: &str, ) -> Result { - let stat_column_name = format!("{}_{}", self.column_name, suffix); + let stat_column = Column { + relation: self.column.relation.clone(), + name: format!("{}_{}", self.column.flat_name(), suffix), + }; let stat_field = Field::new( - stat_column_name.as_str(), + stat_column.flat_name().as_str(), self.parquet_field.data_type().clone(), self.parquet_field.is_nullable(), ); if self.is_stat_column_missing(stat_type) { // only add statistics column if not previously added self.stat_column_req - .push((self.column_name.clone(), stat_type, stat_field)); + .push((self.column.clone(), stat_type, stat_field)); } - rewrite_column_expr( - self.column_expr, - self.column_name.as_str(), - stat_column_name.as_str(), - ) + rewrite_column_expr(self.column_expr, &self.column, &stat_column) } fn min_column_expr(&mut self) -> Result { @@ -622,18 +617,18 @@ impl<'a> StatisticsExpressionBuilder<'a> { /// replaces a column with an old name with a new name in an expression fn rewrite_column_expr( expr: &Expr, - column_old_name: &str, - column_new_name: &str, + column_old: &Column, + column_new: &Column, ) -> Result { let expressions = utils::expr_sub_expressions(&expr)?; let expressions = expressions .iter() - .map(|e| rewrite_column_expr(e, column_old_name, column_new_name)) + .map(|e| rewrite_column_expr(e, column_old, column_new)) .collect::>>()?; - if let Expr::Column(name) = expr { - if name == column_old_name { - return Ok(Expr::Column(column_new_name.to_string())); + if let Expr::Column(c) = expr { + if c == column_old { + return Ok(Expr::Column(column_new.clone())); } } utils::rewrite_expression(&expr, &expressions) @@ -643,7 +638,7 @@ fn rewrite_column_expr( fn build_predicate_expression( expr: &Expr, parquet_schema: &Schema, - stat_column_req: &mut Vec<(String, StatisticsType, Field)>, + stat_column_req: &mut Vec<(Column, StatisticsType, Field)>, ) -> Result { use crate::logical_plan; // predicate expression can only be a binary expression @@ -914,7 +909,6 @@ fn read_files( loop { match batch_reader.next() { Some(Ok(batch)) => { - //println!("ParquetExec got new batch from {}", filename); total_rows += batch.num_rows(); send_result(&response_tx, Ok(batch))?; if limit.map(|l| total_rows >= l).unwrap_or(false) { @@ -1267,8 +1261,9 @@ mod tests { ]); // test AND operator joining supported c1 < 1 expression and unsupported c2 > c3 expression let expr = col("c1").lt(lit(1)).and(col("c2").lt(col("c3"))); - let expected_expr = "#c1_min Lt Int32(1) And Boolean(true)"; let predicate_expr = build_predicate_expression(&expr, &schema, &mut vec![])?; + dbg!(&predicate_expr); + let expected_expr = "#c1_min Lt Int32(1) And Boolean(true)"; assert_eq!(format!("{:?}", predicate_expr), expected_expr); Ok(()) @@ -1310,18 +1305,30 @@ mod tests { let c1_min_field = Field::new("c1_min", DataType::Int32, false); assert_eq!( stat_column_req[0], - ("c1".to_owned(), StatisticsType::Min, c1_min_field) + ( + Column::from_name("c1".to_string()), + StatisticsType::Min, + c1_min_field + ) ); // c2 = 2 should add c2_min and c2_max let c2_min_field = Field::new("c2_min", DataType::Int32, false); assert_eq!( stat_column_req[1], - ("c2".to_owned(), StatisticsType::Min, c2_min_field) + ( + Column::from_name("c2".to_string()), + StatisticsType::Min, + c2_min_field + ) ); let c2_max_field = Field::new("c2_max", DataType::Int32, false); assert_eq!( stat_column_req[2], - ("c2".to_owned(), StatisticsType::Max, c2_max_field) + ( + Column::from_name("c2".to_string()), + StatisticsType::Max, + c2_max_field + ) ); // c2 = 3 shouldn't add any new statistics fields assert_eq!(stat_column_req.len(), 3); diff --git a/datafusion/src/physical_plan/planner.rs b/datafusion/src/physical_plan/planner.rs index f9279ae48f0c..54962de96060 100644 --- a/datafusion/src/physical_plan/planner.rs +++ b/datafusion/src/physical_plan/planner.rs @@ -52,6 +52,14 @@ use arrow::datatypes::{Schema, SchemaRef}; use expressions::col; use log::debug; +fn physical_name(e: &Expr, input_schema: &DFSchema) -> Result { + // FIXME: finish this + match e { + Expr::Column(c) => Ok(c.name.clone()), + _ => e.name(&input_schema), + } +} + /// This trait exposes the ability to plan an [`ExecutionPlan`] out of a [`LogicalPlan`]. pub trait ExtensionPlanner { /// Create a physical plan for a [`UserDefinedLogicalNode`]. @@ -144,8 +152,7 @@ impl DefaultPhysicalPlanner { } => { // Initially need to perform the aggregate and then merge the partitions let input_exec = self.create_initial_plan(input, ctx_state)?; - let input_schema = input_exec.schema(); - let physical_input_schema = input_exec.as_ref().schema(); + let physical_input_schema = input_exec.schema(); let logical_input_schema = input.as_ref().schema(); let groups = group_expr @@ -155,9 +162,10 @@ impl DefaultPhysicalPlanner { self.create_physical_expr( e, &physical_input_schema, + &logical_input_schema, ctx_state, ), - e.name(&logical_input_schema), + physical_name(e, &logical_input_schema), )) }) .collect::>>()?; @@ -178,11 +186,13 @@ impl DefaultPhysicalPlanner { groups.clone(), aggregates.clone(), input_exec, - input_schema.clone(), + physical_input_schema.clone(), )?); - let final_group: Vec> = - (0..groups.len()).map(|i| col(&groups[i].1)).collect(); + // update group column indices based on partial aggregate plan evaluation + let final_group: Vec> = (0..groups.len()) + .map(|i| col(&groups[i].1, &initial_aggr.schema())) + .collect::>()?; // construct a second aggregation, keeping the final column name equal to the first aggregation // and the expressions corresponding to the respective aggregate @@ -195,7 +205,7 @@ impl DefaultPhysicalPlanner { .collect(), aggregates, initial_aggr, - input_schema, + physical_input_schema, )?)) } LogicalPlan::Projection { input, expr, .. } => { @@ -208,9 +218,10 @@ impl DefaultPhysicalPlanner { self.create_physical_expr( e, &input_exec.schema(), + input_schema, &ctx_state, ), - e.name(&input_schema), + physical_name(e, &input_schema), )) }) .collect::>>()?; @@ -219,11 +230,16 @@ impl DefaultPhysicalPlanner { LogicalPlan::Filter { input, predicate, .. } => { - let input = self.create_initial_plan(input, ctx_state)?; - let input_schema = input.as_ref().schema(); - let runtime_expr = - self.create_physical_expr(predicate, &input_schema, ctx_state)?; - Ok(Arc::new(FilterExec::try_new(runtime_expr, input)?)) + let physical_input = self.create_initial_plan(input, ctx_state)?; + let input_schema = physical_input.as_ref().schema(); + let input_dfschema = input.as_ref().schema(); + let runtime_expr = self.create_physical_expr( + predicate, + &input_schema, + input_dfschema, + ctx_state, + )?; + Ok(Arc::new(FilterExec::try_new(runtime_expr, physical_input)?)) } LogicalPlan::Union { inputs, .. } => { let physical_plans = inputs @@ -236,8 +252,9 @@ impl DefaultPhysicalPlanner { input, partitioning_scheme, } => { - let input = self.create_initial_plan(input, ctx_state)?; - let input_schema = input.schema(); + let physical_input = self.create_initial_plan(input, ctx_state)?; + let input_schema = physical_input.schema(); + let input_dfschema = input.as_ref().schema(); let physical_partitioning = match partitioning_scheme { LogicalPartitioning::RoundRobinBatch(n) => { Partitioning::RoundRobinBatch(*n) @@ -246,20 +263,26 @@ impl DefaultPhysicalPlanner { let runtime_expr = expr .iter() .map(|e| { - self.create_physical_expr(e, &input_schema, &ctx_state) + self.create_physical_expr( + e, + &input_schema, + &input_dfschema, + &ctx_state, + ) }) .collect::>>()?; Partitioning::Hash(runtime_expr, *n) } }; Ok(Arc::new(RepartitionExec::try_new( - input, + physical_input, physical_partitioning, )?)) } LogicalPlan::Sort { expr, input, .. } => { - let input = self.create_initial_plan(input, ctx_state)?; - let input_schema = input.as_ref().schema(); + let physical_input = self.create_initial_plan(input, ctx_state)?; + let input_schema = physical_input.as_ref().schema(); + let input_dfschema = input.as_ref().schema(); let sort_expr = expr .iter() @@ -271,6 +294,7 @@ impl DefaultPhysicalPlanner { } => self.create_physical_sort_expr( expr, &input_schema, + &input_dfschema, SortOptions { descending: !*asc, nulls_first: *nulls_first, @@ -283,7 +307,7 @@ impl DefaultPhysicalPlanner { }) .collect::>>()?; - Ok(Arc::new(SortExec::try_new(sort_expr, input)?)) + Ok(Arc::new(SortExec::try_new(sort_expr, physical_input)?)) } LogicalPlan::Join { left, @@ -292,37 +316,56 @@ impl DefaultPhysicalPlanner { join_type, .. } => { - let left = self.create_initial_plan(left, ctx_state)?; - let right = self.create_initial_plan(right, ctx_state)?; + let left_df_schema = left.schema(); + let physical_left = self.create_initial_plan(left, ctx_state)?; + let right_df_schema = right.schema(); + let physical_right = self.create_initial_plan(right, ctx_state)?; let physical_join_type = match join_type { JoinType::Inner => hash_utils::JoinType::Inner, JoinType::Left => hash_utils::JoinType::Left, JoinType::Right => hash_utils::JoinType::Right, }; + let join_on = keys + .iter() + .map(|(l, r)| { + Ok(( + Column::new(&l.name, left_df_schema.index_of_column(&l)?), + Column::new(&r.name, right_df_schema.index_of_column(&r)?), + )) + }) + .collect::>()?; + if ctx_state.config.concurrency > 1 && ctx_state.config.repartition_joins { - let left_expr = keys.iter().map(|x| col(&x.0)).collect(); - let right_expr = keys.iter().map(|x| col(&x.1)).collect(); + let (left_expr, right_expr) = join_on + .iter() + .map(|(l, r)| { + ( + Arc::new(l.clone()) as Arc, + Arc::new(r.clone()) as Arc, + ) + }) + .unzip(); // Use hash partition by defualt to parallelize hash joins Ok(Arc::new(HashJoinExec::try_new( Arc::new(RepartitionExec::try_new( - left, + physical_left, Partitioning::Hash(left_expr, ctx_state.config.concurrency), )?), Arc::new(RepartitionExec::try_new( - right, + physical_right, Partitioning::Hash(right_expr, ctx_state.config.concurrency), )?), - &keys, + join_on, &physical_join_type, PartitionMode::Partitioned, )?)) } else { Ok(Arc::new(HashJoinExec::try_new( - left, - right, - &keys, + physical_left, + physical_right, + join_on, &physical_join_type, PartitionMode::CollectLeft, )?)) @@ -406,10 +449,10 @@ impl DefaultPhysicalPlanner { "No installed planner was able to convert the custom node to an execution plan: {:?}", node )))?; - // Ensure the ExecutionPlan's schema matches the + // Ensure the ExecutionPlan's schema matches the // declared logical schema to catch and warn about // logic errors when creating user defined plans. - if plan.schema() != node.schema().as_ref().to_owned().into() { + if !node.schema().matches_arrow_schema(&plan.schema()) { Err(DataFusionError::Plan(format!( "Extension planner for {:?} created an ExecutionPlan with mismatched schema. \ LogicalPlan schema: {:?}, ExecutionPlan schema: {:?}", @@ -427,16 +470,19 @@ impl DefaultPhysicalPlanner { &self, e: &Expr, input_schema: &Schema, + input_dfschema: &DFSchema, ctx_state: &ExecutionContextState, ) -> Result> { match e { - Expr::Alias(expr, ..) => { - Ok(self.create_physical_expr(expr, input_schema, ctx_state)?) - } - Expr::Column(name) => { - // check that name exists - input_schema.field_with_name(&name)?; - Ok(Arc::new(Column::new(name))) + Expr::Alias(expr, ..) => Ok(self.create_physical_expr( + expr, + input_schema, + input_dfschema, + ctx_state, + )?), + Expr::Column(c) => { + let idx = input_dfschema.index_of_column(c)?; + Ok(Arc::new(Column::new(&c.name, idx))) } Expr::Literal(value) => Ok(Arc::new(Literal::new(value.clone()))), Expr::ScalarVariable(variable_names) => { @@ -465,8 +511,18 @@ impl DefaultPhysicalPlanner { } } Expr::BinaryExpr { left, op, right } => { - let lhs = self.create_physical_expr(left, input_schema, ctx_state)?; - let rhs = self.create_physical_expr(right, input_schema, ctx_state)?; + let lhs = self.create_physical_expr( + left, + input_schema, + input_dfschema, + ctx_state, + )?; + let rhs = self.create_physical_expr( + right, + input_schema, + input_dfschema, + ctx_state, + )?; binary(lhs, *op, rhs, input_schema) } Expr::Case { @@ -479,6 +535,7 @@ impl DefaultPhysicalPlanner { Some(self.create_physical_expr( e.as_ref(), input_schema, + input_dfschema, ctx_state, )?) } else { @@ -487,13 +544,23 @@ impl DefaultPhysicalPlanner { let when_expr = when_then_expr .iter() .map(|(w, _)| { - self.create_physical_expr(w.as_ref(), input_schema, ctx_state) + self.create_physical_expr( + w.as_ref(), + input_schema, + input_dfschema, + ctx_state, + ) }) .collect::>>()?; let then_expr = when_then_expr .iter() .map(|(_, t)| { - self.create_physical_expr(t.as_ref(), input_schema, ctx_state) + self.create_physical_expr( + t.as_ref(), + input_schema, + input_dfschema, + ctx_state, + ) }) .collect::>>()?; let when_then_expr: Vec<(Arc, Arc)> = @@ -507,6 +574,7 @@ impl DefaultPhysicalPlanner { Some(self.create_physical_expr( e.as_ref(), input_schema, + input_dfschema, ctx_state, )?) } else { @@ -519,35 +587,43 @@ impl DefaultPhysicalPlanner { )?)) } Expr::Cast { expr, data_type } => expressions::cast( - self.create_physical_expr(expr, input_schema, ctx_state)?, + self.create_physical_expr(expr, input_schema, input_dfschema, ctx_state)?, input_schema, data_type.clone(), ), Expr::TryCast { expr, data_type } => expressions::try_cast( - self.create_physical_expr(expr, input_schema, ctx_state)?, + self.create_physical_expr(expr, input_schema, input_dfschema, ctx_state)?, input_schema, data_type.clone(), ), Expr::Not(expr) => expressions::not( - self.create_physical_expr(expr, input_schema, ctx_state)?, + self.create_physical_expr(expr, input_schema, input_dfschema, ctx_state)?, input_schema, ), Expr::Negative(expr) => expressions::negative( - self.create_physical_expr(expr, input_schema, ctx_state)?, + self.create_physical_expr(expr, input_schema, input_dfschema, ctx_state)?, input_schema, ), Expr::IsNull(expr) => expressions::is_null(self.create_physical_expr( expr, input_schema, + input_dfschema, ctx_state, )?), Expr::IsNotNull(expr) => expressions::is_not_null( - self.create_physical_expr(expr, input_schema, ctx_state)?, + self.create_physical_expr(expr, input_schema, input_dfschema, ctx_state)?, ), Expr::ScalarFunction { fun, args } => { let physical_args = args .iter() - .map(|e| self.create_physical_expr(e, input_schema, ctx_state)) + .map(|e| { + self.create_physical_expr( + e, + input_schema, + input_dfschema, + ctx_state, + ) + }) .collect::>>()?; functions::create_physical_expr(fun, &physical_args, input_schema) } @@ -557,6 +633,7 @@ impl DefaultPhysicalPlanner { physical_args.push(self.create_physical_expr( e, input_schema, + input_dfschema, ctx_state, )?); } @@ -573,11 +650,24 @@ impl DefaultPhysicalPlanner { low, high, } => { - let value_expr = - self.create_physical_expr(expr, input_schema, ctx_state)?; - let low_expr = self.create_physical_expr(low, input_schema, ctx_state)?; - let high_expr = - self.create_physical_expr(high, input_schema, ctx_state)?; + let value_expr = self.create_physical_expr( + expr, + input_schema, + input_dfschema, + ctx_state, + )?; + let low_expr = self.create_physical_expr( + low, + input_schema, + input_dfschema, + ctx_state, + )?; + let high_expr = self.create_physical_expr( + high, + input_schema, + input_dfschema, + ctx_state, + )?; // rewrite the between into the two binary operators let binary_expr = binary( @@ -602,44 +692,54 @@ impl DefaultPhysicalPlanner { Ok(expressions::lit(ScalarValue::Boolean(None))) } _ => { - let value_expr = - self.create_physical_expr(expr, input_schema, ctx_state)?; + let value_expr = self.create_physical_expr( + expr, + input_schema, + input_dfschema, + ctx_state, + )?; let value_expr_data_type = value_expr.data_type(input_schema)?; - let list_exprs = - list.iter() - .map(|expr| match expr { - Expr::Literal(ScalarValue::Utf8(None)) => self - .create_physical_expr(expr, input_schema, ctx_state), - _ => { - let list_expr = self.create_physical_expr( - expr, + let list_exprs = list + .iter() + .map(|expr| match expr { + Expr::Literal(ScalarValue::Utf8(None)) => self + .create_physical_expr( + expr, + input_schema, + input_dfschema, + ctx_state, + ), + _ => { + let list_expr = self.create_physical_expr( + expr, + input_schema, + input_dfschema, + ctx_state, + )?; + let list_expr_data_type = + list_expr.data_type(input_schema)?; + + if list_expr_data_type == value_expr_data_type { + Ok(list_expr) + } else if can_cast_types( + &list_expr_data_type, + &value_expr_data_type, + ) { + expressions::cast( + list_expr, input_schema, - ctx_state, - )?; - let list_expr_data_type = - list_expr.data_type(input_schema)?; - - if list_expr_data_type == value_expr_data_type { - Ok(list_expr) - } else if can_cast_types( - &list_expr_data_type, - &value_expr_data_type, - ) { - expressions::cast( - list_expr, - input_schema, - value_expr.data_type(input_schema)?, - ) - } else { - Err(DataFusionError::Plan(format!( - "Unsupported CAST from {:?} to {:?}", - list_expr_data_type, value_expr_data_type - ))) - } + value_expr.data_type(input_schema)?, + ) + } else { + Err(DataFusionError::Plan(format!( + "Unsupported CAST from {:?} to {:?}", + list_expr_data_type, value_expr_data_type + ))) } - }) - .collect::>>()?; + } + }) + .collect::>>()?; expressions::in_list(value_expr, list_exprs, negated) } @@ -662,7 +762,7 @@ impl DefaultPhysicalPlanner { // unpack aliased logical expressions, e.g. "sum(col) as total" let (name, e) = match e { Expr::Alias(sub_expr, alias) => (alias.clone(), sub_expr.as_ref()), - _ => (e.name(logical_input_schema)?, e), + _ => (physical_name(e, logical_input_schema)?, e), }; match e { @@ -675,7 +775,12 @@ impl DefaultPhysicalPlanner { let args = args .iter() .map(|e| { - self.create_physical_expr(e, physical_input_schema, ctx_state) + self.create_physical_expr( + e, + physical_input_schema, + logical_input_schema, + ctx_state, + ) }) .collect::>>()?; aggregates::create_aggregate_expr( @@ -690,7 +795,12 @@ impl DefaultPhysicalPlanner { let args = args .iter() .map(|e| { - self.create_physical_expr(e, physical_input_schema, ctx_state) + self.create_physical_expr( + e, + physical_input_schema, + logical_input_schema, + ctx_state, + ) }) .collect::>>()?; @@ -708,11 +818,17 @@ impl DefaultPhysicalPlanner { &self, e: &Expr, input_schema: &Schema, + input_dfschema: &DFSchema, options: SortOptions, ctx_state: &ExecutionContextState, ) -> Result { Ok(PhysicalSortExpr { - expr: self.create_physical_expr(e, input_schema, ctx_state)?, + expr: self.create_physical_expr( + e, + input_schema, + input_dfschema, + ctx_state, + )?, options, }) } @@ -744,6 +860,7 @@ mod tests { use arrow::datatypes::{DataType, Field, SchemaRef}; use async_trait::async_trait; use fmt::Debug; + use std::convert::TryFrom; use std::{any::Any, collections::HashMap, fmt}; fn make_ctx_state() -> ExecutionContextState { @@ -781,7 +898,7 @@ mod tests { // verify that the plan correctly casts u8 to i64 // the cast here is implicit so has CastOptions with safe=true - let expected = "BinaryExpr { left: Column { name: \"c7\" }, op: Lt, right: TryCastExpr { expr: Literal { value: UInt8(5) }, cast_type: Int64 } }"; + let expected = "BinaryExpr { left: Column { name: \"c7\", index: 6 }, op: Lt, right: TryCastExpr { expr: Literal { value: UInt8(5) }, cast_type: Int64 } }"; assert!(format!("{:?}", plan).contains(expected)); Ok(()) @@ -790,12 +907,17 @@ mod tests { #[test] fn test_create_not() -> Result<()> { let schema = Schema::new(vec![Field::new("a", DataType::Boolean, true)]); + let dfschema = DFSchema::try_from(schema.clone())?; let planner = DefaultPhysicalPlanner::default(); - let expr = - planner.create_physical_expr(&col("a").not(), &schema, &make_ctx_state())?; - let expected = expressions::not(expressions::col("a"), &schema)?; + let expr = planner.create_physical_expr( + &col("a").not(), + &schema, + &dfschema, + &make_ctx_state(), + )?; + let expected = expressions::not(expressions::col("a", &schema)?, &schema)?; assert_eq!(format!("{:?}", expr), format!("{:?}", expected)); @@ -816,7 +938,7 @@ mod tests { // c12 is f64, c7 is u8 -> cast c7 to f64 // the cast here is implicit so has CastOptions with safe=true - let expected = "predicate: BinaryExpr { left: TryCastExpr { expr: Column { name: \"c7\" }, cast_type: Float64 }, op: Lt, right: Column { name: \"c12\" } }"; + let expected = "predicate: BinaryExpr { left: TryCastExpr { expr: Column { name: \"c7\", index: 6 }, cast_type: Float64 }, op: Lt, right: Column { name: \"c12\", index: 11 } }"; assert!(format!("{:?}", plan).contains(expected)); Ok(()) } @@ -941,8 +1063,7 @@ mod tests { .build()?; let execution_plan = plan(&logical_plan)?; // verify that the plan correctly adds cast from Int64(1) to Utf8 - let expected = "InListExpr { expr: Column { name: \"c1\" }, list: [Literal { value: Utf8(\"a\") }, CastExpr { expr: Literal { value: Int64(1) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }], negated: false }"; - println!("{:?}", execution_plan); + let expected = "InListExpr { expr: Column { name: \"c1\", index: 0 }, list: [Literal { value: Utf8(\"a\") }, CastExpr { expr: Literal { value: Int64(1) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }], negated: false }"; assert!(format!("{:?}", execution_plan).contains(expected)); // expression: "a in (true, 'a')" diff --git a/datafusion/src/physical_plan/projection.rs b/datafusion/src/physical_plan/projection.rs index a881beb453a0..cfe8ebe7b5f2 100644 --- a/datafusion/src/physical_plan/projection.rs +++ b/datafusion/src/physical_plan/projection.rs @@ -206,8 +206,10 @@ mod tests { )?; // pick column c1 and name it column c1 in the output schema - let projection = - ProjectionExec::try_new(vec![(col("c1"), "c1".to_string())], Arc::new(csv))?; + let projection = ProjectionExec::try_new( + vec![(col("c1", &schema)?, "c1".to_string())], + Arc::new(csv), + )?; let mut partition_count = 0; let mut row_count = 0; diff --git a/datafusion/src/physical_plan/repartition.rs b/datafusion/src/physical_plan/repartition.rs index 7243550127bd..68dca57c8f32 100644 --- a/datafusion/src/physical_plan/repartition.rs +++ b/datafusion/src/physical_plan/repartition.rs @@ -297,6 +297,7 @@ impl RecordBatchStream for RepartitionStream { #[cfg(test)] mod tests { use super::*; + use crate::physical_plan::expressions::col; use crate::physical_plan::memory::MemoryExec; use arrow::array::UInt32Array; use arrow::datatypes::{DataType, Field, Schema}; @@ -370,12 +371,7 @@ mod tests { let output_partitions = repartition( &schema, partitions, - Partitioning::Hash( - vec![Arc::new(crate::physical_plan::expressions::Column::new( - &"c0", - ))], - 8, - ), + Partitioning::Hash(vec![col("c0", &schema)?], 8), ) .await?; diff --git a/datafusion/src/physical_plan/sort.rs b/datafusion/src/physical_plan/sort.rs index 26855b354db0..4ee53bc0e22a 100644 --- a/datafusion/src/physical_plan/sort.rs +++ b/datafusion/src/physical_plan/sort.rs @@ -336,17 +336,17 @@ mod tests { vec![ // c1 string column PhysicalSortExpr { - expr: col("c1"), + expr: col("c1", &schema)?, options: SortOptions::default(), }, // c2 uin32 column PhysicalSortExpr { - expr: col("c2"), + expr: col("c2", &schema)?, options: SortOptions::default(), }, // c7 uin8 column PhysicalSortExpr { - expr: col("c7"), + expr: col("c7", &schema)?, options: SortOptions::default(), }, ], @@ -410,14 +410,14 @@ mod tests { let sort_exec = Arc::new(SortExec::try_new( vec![ PhysicalSortExpr { - expr: col("a"), + expr: col("a", &schema)?, options: SortOptions { descending: true, nulls_first: true, }, }, PhysicalSortExpr { - expr: col("b"), + expr: col("b", &schema)?, options: SortOptions { descending: false, nulls_first: false, diff --git a/datafusion/src/physical_plan/type_coercion.rs b/datafusion/src/physical_plan/type_coercion.rs index 24b51ba60695..41088da5814d 100644 --- a/datafusion/src/physical_plan/type_coercion.rs +++ b/datafusion/src/physical_plan/type_coercion.rs @@ -260,7 +260,9 @@ mod tests { let expressions = |t: Vec, schema| -> Result> { t.iter() .enumerate() - .map(|(i, t)| try_cast(col(&format!("c{}", i)), &schema, t.clone())) + .map(|(i, t)| { + try_cast(col(&format!("c{}", i), &schema)?, &schema, t.clone()) + }) .collect::>>() }; diff --git a/datafusion/src/prelude.rs b/datafusion/src/prelude.rs index 0edc82a98afb..cb358f7165d4 100644 --- a/datafusion/src/prelude.rs +++ b/datafusion/src/prelude.rs @@ -32,6 +32,6 @@ pub use crate::logical_plan::{ count, create_udf, in_list, initcap, left, length, lit, lower, lpad, ltrim, max, md5, min, octet_length, regexp_replace, repeat, replace, reverse, right, rpad, rtrim, sha224, sha256, sha384, sha512, split_part, starts_with, strpos, substr, sum, to_hex, - translate, trim, upper, JoinType, Partitioning, + translate, trim, upper, Column, JoinType, Partitioning, }; pub use crate::physical_plan::csv::CsvReadOptions; diff --git a/datafusion/src/sql/planner.rs b/datafusion/src/sql/planner.rs index f3cba232a23a..1b2146fa2075 100644 --- a/datafusion/src/sql/planner.rs +++ b/datafusion/src/sql/planner.rs @@ -17,6 +17,7 @@ //! SQL Query Planner (produces logical plan from SQL AST) +use std::collections::HashSet; use std::convert::TryInto; use std::str::FromStr; use std::sync::Arc; @@ -25,8 +26,8 @@ use crate::catalog::TableReference; use crate::datasource::TableProvider; use crate::logical_plan::Expr::Alias; use crate::logical_plan::{ - and, lit, DFSchema, Expr, LogicalPlan, LogicalPlanBuilder, Operator, PlanType, - StringifiedPlan, ToDFSchema, + and, lit, union_with_alias, Column, DFSchema, Expr, LogicalPlan, LogicalPlanBuilder, + Operator, PlanType, StringifiedPlan, ToDFSchema, }; use crate::scalar::ScalarValue; use crate::{ @@ -163,29 +164,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { (SetOperator::Union, true) => { let left_plan = self.set_expr_to_plan(left.as_ref(), None, ctes)?; let right_plan = self.set_expr_to_plan(right.as_ref(), None, ctes)?; - let inputs = vec![left_plan, right_plan] - .into_iter() - .flat_map(|p| match p { - LogicalPlan::Union { inputs, .. } => inputs, - x => vec![x], - }) - .collect::>(); - if inputs.is_empty() { - return Err(DataFusionError::Plan(format!( - "Empty UNION: {}", - set_expr - ))); - } - if !inputs.iter().all(|s| s.schema() == inputs[0].schema()) { - return Err(DataFusionError::Plan( - "UNION ALL schemas are expected to be the same".to_string(), - )); - } - Ok(LogicalPlan::Union { - schema: inputs[0].schema().clone(), - inputs, - alias, - }) + union_with_alias(left_plan, right_plan, alias) } _ => Err(DataFusionError::NotImplemented(format!( "Only UNION ALL is supported, found {}", @@ -371,7 +350,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ) -> Result { match constraint { JoinConstraint::On(sql_expr) => { - let mut keys: Vec<(String, String)> = vec![]; + let mut keys: Vec<(Column, Column)> = vec![]; let join_schema = left.schema().join(&right.schema())?; // parse ON expression @@ -379,20 +358,24 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // extract join keys extract_join_keys(&expr, &mut keys)?; - let left_keys: Vec<&str> = - keys.iter().map(|pair| pair.0.as_str()).collect(); - let right_keys: Vec<&str> = - keys.iter().map(|pair| pair.1.as_str()).collect(); + // TODO: avoid two iterations + let left_keys: Vec = + keys.iter().map(|pair| pair.0.clone()).collect(); + let right_keys: Vec = + keys.iter().map(|pair| pair.1.clone()).collect(); // return the logical plan representing the join LogicalPlanBuilder::from(&left) - .join(&right, join_type, &left_keys, &right_keys)? + .join(&right, join_type, left_keys, right_keys)? .build() } JoinConstraint::Using(idents) => { - let keys: Vec<&str> = idents.iter().map(|x| x.value.as_str()).collect(); + let keys: Vec = idents + .iter() + .map(|x| Column::from_name(x.value.clone())) + .collect(); LogicalPlanBuilder::from(&left) - .join(&right, join_type, &keys, &keys)? + .join_using(&right, join_type, keys)? .build() } JoinConstraint::Natural => { @@ -413,7 +396,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ctes: &mut HashMap, ) -> Result { match relation { - TableFactor::Table { name, .. } => { + TableFactor::Table { name, alias, .. } => { let table_name = name.to_string(); let cte = ctes.get(&table_name); match ( @@ -421,9 +404,16 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { self.schema_provider.get_table_provider(name.try_into()?), ) { (Some(cte_plan), _) => Ok(cte_plan.clone()), - (_, Some(provider)) => { - LogicalPlanBuilder::scan(&table_name, provider, None)?.build() - } + (_, Some(provider)) => LogicalPlanBuilder::scan( + // take alias into account to support `JOIN table1 as table2` + alias + .as_ref() + .map(|a| a.name.value.as_str()) + .or(Some(&table_name)), + provider, + None, + )? + .build(), (_, None) => Err(DataFusionError::Plan(format!( "Table or CTE with name '{}' not found", name @@ -471,21 +461,23 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let mut possible_join_keys = vec![]; extract_possible_join_keys(&filter_expr, &mut possible_join_keys)?; - let mut all_join_keys = vec![]; + let mut all_join_keys = HashSet::new(); let mut left = plans[0].clone(); for right in plans.iter().skip(1) { let left_schema = left.schema(); let right_schema = right.schema(); let mut join_keys = vec![]; for (l, r) in &possible_join_keys { - if left_schema.field_with_unqualified_name(l).is_ok() - && right_schema.field_with_unqualified_name(r).is_ok() + if left_schema.field_from_qualified_column(l).is_ok() + && right_schema.field_from_qualified_column(r).is_ok() { - join_keys.push((l.as_str(), r.as_str())); - } else if left_schema.field_with_unqualified_name(r).is_ok() - && right_schema.field_with_unqualified_name(l).is_ok() + // TODO: avoid clone here + join_keys.push((l.clone(), r.clone())); + } else if left_schema.field_from_qualified_column(r).is_ok() + && right_schema.field_from_qualified_column(l).is_ok() { - join_keys.push((r.as_str(), l.as_str())); + // TODO: avoid clone here + join_keys.push((r.clone(), l.clone())); } } if join_keys.is_empty() { @@ -493,16 +485,19 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { "Cartesian joins are not supported".to_string(), )); } else { - let left_keys: Vec<_> = - join_keys.iter().map(|(l, _)| *l).collect(); - let right_keys: Vec<_> = - join_keys.iter().map(|(_, r)| *r).collect(); + let left_keys: Vec = + // TODO: avoid clone here + join_keys.iter().map(|(l, _)| l.clone()).collect(); + let right_keys: Vec = + // TODO: avoid clone here + join_keys.iter().map(|(_, r)| r.clone()).collect(); let builder = LogicalPlanBuilder::from(&left); left = builder - .join(right, JoinType::Inner, &left_keys, &right_keys)? + .join(right, JoinType::Inner, left_keys, right_keys)? .build()?; } - all_join_keys.extend_from_slice(&join_keys); + + all_join_keys.extend(join_keys); } // remove join expressions from filter @@ -533,7 +528,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .having .as_ref() .map::, _>(|having_expr| { - let having_expr = self.sql_expr_to_logical_expr(having_expr)?; + // having clause may reference aliases defined in select projection + let projected_plan = self.project(&plan, select_exprs.clone(), false)?; + let mut combined_schema = (**projected_plan.schema()).clone(); + combined_schema.merge(plan.schema()); + let having_expr = + self.sql_expr_to_logical_expr(having_expr, &combined_schema)?; // This step "dereferences" any aliases in the HAVING clause. // @@ -562,7 +562,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // The outer expressions we will search through for // aggregates. Aggregates may be sourced from the SELECT... let mut aggr_expr_haystack = select_exprs.clone(); - // ... or from the HAVING. if let Some(having_expr) = &having_expr_opt { aggr_expr_haystack.push(having_expr.clone()); @@ -780,11 +779,15 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { find_column_exprs(exprs) .iter() .try_for_each(|col| match col { - Expr::Column(name) => { - schema.field_with_unqualified_name(&name).map_err(|_| { + Expr::Column(col) => { + match &col.relation { + Some(r) => schema.field_with_qualified_name(r, &col.name), + None => schema.field_with_unqualified_name(&col.name), + } + .map_err(|_| { DataFusionError::Plan(format!( "Invalid identifier '{}' for schema {}", - name, + col, schema.to_string() )) })?; @@ -811,19 +814,25 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { /// Generate a relational expression from a SQL expression pub fn sql_to_rex(&self, sql: &SQLExpr, schema: &DFSchema) -> Result { - let expr = self.sql_expr_to_logical_expr(sql)?; + let expr = self.sql_expr_to_logical_expr(sql, schema)?; self.validate_schema_satisfies_exprs(schema, &[expr.clone()])?; Ok(expr) } - fn sql_fn_arg_to_logical_expr(&self, sql: &FunctionArg) -> Result { + fn sql_fn_arg_to_logical_expr( + &self, + sql: &FunctionArg, + schema: &DFSchema, + ) -> Result { match sql { - FunctionArg::Named { name: _, arg } => self.sql_expr_to_logical_expr(arg), - FunctionArg::Unnamed(value) => self.sql_expr_to_logical_expr(value), + FunctionArg::Named { name: _, arg } => { + self.sql_expr_to_logical_expr(arg, schema) + } + FunctionArg::Unnamed(value) => self.sql_expr_to_logical_expr(value, schema), } } - fn sql_expr_to_logical_expr(&self, sql: &SQLExpr) -> Result { + fn sql_expr_to_logical_expr(&self, sql: &SQLExpr, schema: &DFSchema) -> Result { match sql { SQLExpr::Value(Value::Number(n, _)) => match n.parse::() { Ok(n) => Ok(lit(n)), @@ -838,7 +847,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { fun: functions::BuiltinScalarFunction::DatePart, args: vec![ Expr::Literal(ScalarValue::Utf8(Some(format!("{}", field)))), - self.sql_expr_to_logical_expr(expr)?, + self.sql_expr_to_logical_expr(expr, schema)?, ], }), @@ -861,7 +870,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let var_names = vec![id.value.clone()]; Ok(Expr::ScalarVariable(var_names)) } else { - Ok(Expr::Column(id.value.to_string())) + Ok(Expr::Column( + schema + .field_with_unqualified_name(&id.value)? + .qualified_column(), + )) } } @@ -872,6 +885,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } if &var_names[0][0..1] == "@" { Ok(Expr::ScalarVariable(var_names)) + } else if var_names.len() == 2 { + // table.column identifier + let name = var_names.pop().unwrap(); + let relation = Some(var_names.pop().unwrap()); + Ok(Expr::Column(Column { relation, name })) } else { Err(DataFusionError::NotImplemented(format!( "Unsupported compound identifier '{:?}'", @@ -889,20 +907,20 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { else_result, } => { let expr = if let Some(e) = operand { - Some(Box::new(self.sql_expr_to_logical_expr(e)?)) + Some(Box::new(self.sql_expr_to_logical_expr(e, schema)?)) } else { None }; let when_expr = conditions .iter() - .map(|e| self.sql_expr_to_logical_expr(e)) + .map(|e| self.sql_expr_to_logical_expr(e, schema)) .collect::>>()?; let then_expr = results .iter() - .map(|e| self.sql_expr_to_logical_expr(e)) + .map(|e| self.sql_expr_to_logical_expr(e, schema)) .collect::>>()?; let else_expr = if let Some(e) = else_result { - Some(Box::new(self.sql_expr_to_logical_expr(e)?)) + Some(Box::new(self.sql_expr_to_logical_expr(e, schema)?)) } else { None }; @@ -922,7 +940,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ref expr, ref data_type, } => Ok(Expr::Cast { - expr: Box::new(self.sql_expr_to_logical_expr(&expr)?), + expr: Box::new(self.sql_expr_to_logical_expr(&expr, schema)?), data_type: convert_data_type(data_type)?, }), @@ -930,7 +948,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ref expr, ref data_type, } => Ok(Expr::TryCast { - expr: Box::new(self.sql_expr_to_logical_expr(&expr)?), + expr: Box::new(self.sql_expr_to_logical_expr(&expr, schema)?), data_type: convert_data_type(data_type)?, }), @@ -942,19 +960,19 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { data_type: convert_data_type(data_type)?, }), - SQLExpr::IsNull(ref expr) => { - Ok(Expr::IsNull(Box::new(self.sql_expr_to_logical_expr(expr)?))) - } + SQLExpr::IsNull(ref expr) => Ok(Expr::IsNull(Box::new( + self.sql_expr_to_logical_expr(expr, schema)?, + ))), SQLExpr::IsNotNull(ref expr) => Ok(Expr::IsNotNull(Box::new( - self.sql_expr_to_logical_expr(expr)?, + self.sql_expr_to_logical_expr(expr, schema)?, ))), SQLExpr::UnaryOp { ref op, ref expr } => match op { - UnaryOperator::Not => { - Ok(Expr::Not(Box::new(self.sql_expr_to_logical_expr(expr)?))) - } - UnaryOperator::Plus => Ok(self.sql_expr_to_logical_expr(expr)?), + UnaryOperator::Not => Ok(Expr::Not(Box::new( + self.sql_expr_to_logical_expr(expr, schema)?, + ))), + UnaryOperator::Plus => Ok(self.sql_expr_to_logical_expr(expr, schema)?), UnaryOperator::Minus => { match expr.as_ref() { // optimization: if it's a number literal, we applly the negative operator @@ -970,7 +988,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { })?)), }, // not a literal, apply negative operator on expression - _ => Ok(Expr::Negative(Box::new(self.sql_expr_to_logical_expr(expr)?))), + _ => Ok(Expr::Negative(Box::new(self.sql_expr_to_logical_expr(expr, schema)?))), } } _ => Err(DataFusionError::NotImplemented(format!( @@ -985,10 +1003,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ref low, ref high, } => Ok(Expr::Between { - expr: Box::new(self.sql_expr_to_logical_expr(&expr)?), + expr: Box::new(self.sql_expr_to_logical_expr(&expr, schema)?), negated: *negated, - low: Box::new(self.sql_expr_to_logical_expr(&low)?), - high: Box::new(self.sql_expr_to_logical_expr(&high)?), + low: Box::new(self.sql_expr_to_logical_expr(&low, schema)?), + high: Box::new(self.sql_expr_to_logical_expr(&high, schema)?), }), SQLExpr::InList { @@ -998,11 +1016,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } => { let list_expr = list .iter() - .map(|e| self.sql_expr_to_logical_expr(e)) + .map(|e| self.sql_expr_to_logical_expr(e, schema)) .collect::>>()?; Ok(Expr::InList { - expr: Box::new(self.sql_expr_to_logical_expr(&expr)?), + expr: Box::new(self.sql_expr_to_logical_expr(&expr, schema)?), list: list_expr, negated: *negated, }) @@ -1036,9 +1054,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { }?; Ok(Expr::BinaryExpr { - left: Box::new(self.sql_expr_to_logical_expr(&left)?), + left: Box::new(self.sql_expr_to_logical_expr(&left, schema)?), op: operator, - right: Box::new(self.sql_expr_to_logical_expr(&right)?), + right: Box::new(self.sql_expr_to_logical_expr(&right, schema)?), }) } @@ -1062,7 +1080,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let args = function .args .iter() - .map(|a| self.sql_fn_arg_to_logical_expr(a)) + .map(|a| self.sql_fn_arg_to_logical_expr(a, schema)) .collect::>>()?; return Ok(Expr::ScalarFunction { fun, args }); @@ -1080,14 +1098,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { _, ))) => Ok(lit(1_u8)), FunctionArg::Unnamed(SQLExpr::Wildcard) => Ok(lit(1_u8)), - _ => self.sql_fn_arg_to_logical_expr(a), + _ => self.sql_fn_arg_to_logical_expr(a, schema), }) .collect::>>()? } else { function .args .iter() - .map(|a| self.sql_fn_arg_to_logical_expr(a)) + .map(|a| self.sql_fn_arg_to_logical_expr(a, schema)) .collect::>>()? }; @@ -1104,7 +1122,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let args = function .args .iter() - .map(|a| self.sql_fn_arg_to_logical_expr(a)) + .map(|a| self.sql_fn_arg_to_logical_expr(a, schema)) .collect::>>()?; Ok(Expr::ScalarUDF { fun: fm, args }) @@ -1114,7 +1132,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let args = function .args .iter() - .map(|a| self.sql_fn_arg_to_logical_expr(a)) + .map(|a| self.sql_fn_arg_to_logical_expr(a, schema)) .collect::>>()?; Ok(Expr::AggregateUDF { fun: fm, args }) @@ -1127,7 +1145,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } } - SQLExpr::Nested(e) => self.sql_expr_to_logical_expr(&e), + SQLExpr::Nested(e) => self.sql_expr_to_logical_expr(&e, schema), _ => Err(DataFusionError::NotImplemented(format!( "Unsupported ast node {:?} in sqltorel", @@ -1397,13 +1415,16 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { /// Remove join expressions from a filter expression fn remove_join_expressions( expr: &Expr, - join_columns: &[(&str, &str)], + join_columns: &HashSet<(Column, Column)>, ) -> Result> { match expr { Expr::BinaryExpr { left, op, right } => match op { Operator::Eq => match (left.as_ref(), right.as_ref()) { + // TODO: avoid clones (Expr::Column(l), Expr::Column(r)) => { - if join_columns.contains(&(l, r)) || join_columns.contains(&(r, l)) { + if join_columns.contains(&(l.clone(), r.clone())) + || join_columns.contains(&(r.clone(), l.clone())) + { Ok(None) } else { Ok(Some(expr.clone())) @@ -1434,12 +1455,12 @@ fn remove_join_expressions( /// foo = bar /// foo = bar AND bar = baz AND ... /// -fn extract_join_keys(expr: &Expr, accum: &mut Vec<(String, String)>) -> Result<()> { +fn extract_join_keys(expr: &Expr, accum: &mut Vec<(Column, Column)>) -> Result<()> { match expr { Expr::BinaryExpr { left, op, right } => match op { Operator::Eq => match (left.as_ref(), right.as_ref()) { (Expr::Column(l), Expr::Column(r)) => { - accum.push((l.to_owned(), r.to_owned())); + accum.push((l.clone(), r.clone())); Ok(()) } other => Err(DataFusionError::SQL(ParserError(format!( @@ -1466,13 +1487,13 @@ fn extract_join_keys(expr: &Expr, accum: &mut Vec<(String, String)>) -> Result<( /// Extract join keys from a WHERE clause fn extract_possible_join_keys( expr: &Expr, - accum: &mut Vec<(String, String)>, + accum: &mut Vec<(Column, Column)>, ) -> Result<()> { match expr { Expr::BinaryExpr { left, op, right } => match op { Operator::Eq => match (left.as_ref(), right.as_ref()) { (Expr::Column(l), Expr::Column(r)) => { - accum.push((l.to_owned(), r.to_owned())); + accum.push((l.clone(), r.clone())); Ok(()) } _ => Ok(()), @@ -1513,9 +1534,6 @@ mod tests { use crate::{logical_plan::create_udf, sql::parser::DFParser}; use functions::ScalarFunctionImplementation; - const PERSON_COLUMN_NAMES: &str = - "id, first_name, last_name, age, state, salary, birth_date"; - #[test] fn select_no_relation() { quick_test( @@ -1530,10 +1548,7 @@ mod tests { let sql = "SELECT doesnotexist FROM person"; let err = logical_plan(sql).expect_err("query should have failed"); assert_eq!( - format!( - "Plan(\"Invalid identifier \\\'doesnotexist\\\' for schema {}\")", - PERSON_COLUMN_NAMES - ), + "Plan(\"No field with unqualified name 'doesnotexist'\")", format!("{:?}", err) ); } @@ -1543,7 +1558,7 @@ mod tests { let sql = "SELECT age, age FROM person"; let err = logical_plan(sql).expect_err("query should have failed"); assert_eq!( - "Plan(\"Projections require unique expression names but the expression \\\"#age\\\" at position 0 and \\\"#age\\\" at position 1 have the same name. Consider aliasing (\\\"AS\\\") one of them.\")", + "Plan(\"Projections require unique expression names but the expression \\\"#person.age\\\" at position 0 and \\\"#person.age\\\" at position 1 have the same name. Consider aliasing (\\\"AS\\\") one of them.\")", format!("{:?}", err) ); } @@ -1553,7 +1568,7 @@ mod tests { let sql = "SELECT *, age FROM person"; let err = logical_plan(sql).expect_err("query should have failed"); assert_eq!( - "Plan(\"Projections require unique expression names but the expression \\\"#age\\\" at position 3 and \\\"#age\\\" at position 7 have the same name. Consider aliasing (\\\"AS\\\") one of them.\")", + "Plan(\"Projections require unique expression names but the expression \\\"#person.age\\\" at position 3 and \\\"#person.age\\\" at position 7 have the same name. Consider aliasing (\\\"AS\\\") one of them.\")", format!("{:?}", err) ); } @@ -1562,7 +1577,7 @@ mod tests { fn select_wildcard_with_repeated_column_but_is_aliased() { quick_test( "SELECT *, first_name AS fn from person", - "Projection: #id, #first_name, #last_name, #age, #state, #salary, #birth_date, #first_name AS fn\ + "Projection: #person.id, #person.first_name, #person.last_name, #person.age, #person.state, #person.salary, #person.birth_date, #person.first_name AS fn\ \n TableScan: person projection=None", ); } @@ -1580,8 +1595,8 @@ mod tests { fn select_simple_filter() { let sql = "SELECT id, first_name, last_name \ FROM person WHERE state = 'CO'"; - let expected = "Projection: #id, #first_name, #last_name\ - \n Filter: #state Eq Utf8(\"CO\")\ + let expected = "Projection: #person.id, #person.first_name, #person.last_name\ + \n Filter: #person.state Eq Utf8(\"CO\")\ \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -1591,10 +1606,7 @@ mod tests { let sql = "SELECT first_name FROM person WHERE doesnotexist = 'A'"; let err = logical_plan(sql).expect_err("query should have failed"); assert_eq!( - format!( - "Plan(\"Invalid identifier \\\'doesnotexist\\\' for schema {}\")", - PERSON_COLUMN_NAMES - ), + "Plan(\"No field with unqualified name 'doesnotexist'\")", format!("{:?}", err) ); } @@ -1604,10 +1616,7 @@ mod tests { let sql = "SELECT first_name AS x FROM person WHERE x = 'A'"; let err = logical_plan(sql).expect_err("query should have failed"); assert_eq!( - format!( - "Plan(\"Invalid identifier \\\'x\\\' for schema {}\")", - PERSON_COLUMN_NAMES - ), + "Plan(\"No field with unqualified name 'x'\")", format!("{:?}", err) ); } @@ -1616,8 +1625,8 @@ mod tests { fn select_neg_filter() { let sql = "SELECT id, first_name, last_name \ FROM person WHERE NOT state"; - let expected = "Projection: #id, #first_name, #last_name\ - \n Filter: NOT #state\ + let expected = "Projection: #person.id, #person.first_name, #person.last_name\ + \n Filter: NOT #person.state\ \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -1626,8 +1635,8 @@ mod tests { fn select_compound_filter() { let sql = "SELECT id, first_name, last_name \ FROM person WHERE state = 'CO' AND age >= 21 AND age <= 65"; - let expected = "Projection: #id, #first_name, #last_name\ - \n Filter: #state Eq Utf8(\"CO\") And #age GtEq Int64(21) And #age LtEq Int64(65)\ + let expected = "Projection: #person.id, #person.first_name, #person.last_name\ + \n Filter: #person.state Eq Utf8(\"CO\") And #person.age GtEq Int64(21) And #person.age LtEq Int64(65)\ \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -1637,8 +1646,8 @@ mod tests { let sql = "SELECT state FROM person WHERE birth_date < CAST (158412331400600000 as timestamp)"; - let expected = "Projection: #state\ - \n Filter: #birth_date Lt CAST(Int64(158412331400600000) AS Timestamp(Nanosecond, None))\ + let expected = "Projection: #person.state\ + \n Filter: #person.birth_date Lt CAST(Int64(158412331400600000) AS Timestamp(Nanosecond, None))\ \n TableScan: person projection=None"; quick_test(sql, expected); @@ -1649,8 +1658,8 @@ mod tests { let sql = "SELECT state FROM person WHERE birth_date < CAST ('2020-01-01' as date)"; - let expected = "Projection: #state\ - \n Filter: #birth_date Lt CAST(Utf8(\"2020-01-01\") AS Date32)\ + let expected = "Projection: #person.state\ + \n Filter: #person.birth_date Lt CAST(Utf8(\"2020-01-01\") AS Date32)\ \n TableScan: person projection=None"; quick_test(sql, expected); @@ -1666,13 +1675,13 @@ mod tests { AND age >= 21 \ AND age < 65 \ AND age <= 65"; - let expected = "Projection: #age, #first_name, #last_name\ - \n Filter: #age Eq Int64(21) \ - And #age NotEq Int64(21) \ - And #age Gt Int64(21) \ - And #age GtEq Int64(21) \ - And #age Lt Int64(65) \ - And #age LtEq Int64(65)\ + let expected = "Projection: #person.age, #person.first_name, #person.last_name\ + \n Filter: #person.age Eq Int64(21) \ + And #person.age NotEq Int64(21) \ + And #person.age Gt Int64(21) \ + And #person.age GtEq Int64(21) \ + And #person.age Lt Int64(65) \ + And #person.age LtEq Int64(65)\ \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -1680,8 +1689,8 @@ mod tests { #[test] fn select_between() { let sql = "SELECT state FROM person WHERE age BETWEEN 21 AND 65"; - let expected = "Projection: #state\ - \n Filter: #age BETWEEN Int64(21) AND Int64(65)\ + let expected = "Projection: #person.state\ + \n Filter: #person.age BETWEEN Int64(21) AND Int64(65)\ \n TableScan: person projection=None"; quick_test(sql, expected); @@ -1690,8 +1699,8 @@ mod tests { #[test] fn select_between_negated() { let sql = "SELECT state FROM person WHERE age NOT BETWEEN 21 AND 65"; - let expected = "Projection: #state\ - \n Filter: #age NOT BETWEEN Int64(21) AND Int64(65)\ + let expected = "Projection: #person.state\ + \n Filter: #person.age NOT BETWEEN Int64(21) AND Int64(65)\ \n TableScan: person projection=None"; quick_test(sql, expected); @@ -1707,9 +1716,9 @@ mod tests { FROM person ) )"; - let expected = "Projection: #fn2, #last_name\ - \n Projection: #fn1 AS fn2, #last_name, #birth_date\ - \n Projection: #first_name AS fn1, #last_name, #birth_date, #age\ + let expected = "Projection: #fn2, #person.last_name\ + \n Projection: #fn1 AS fn2, #person.last_name, #person.birth_date\ + \n Projection: #person.first_name AS fn1, #person.last_name, #person.birth_date, #person.age\ \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -1724,9 +1733,9 @@ mod tests { ) WHERE fn1 = 'X' AND age < 30"; - let expected = "Filter: #fn1 Eq Utf8(\"X\") And #age Lt Int64(30)\ - \n Projection: #first_name AS fn1, #age\ - \n Filter: #age Gt Int64(20)\ + let expected = "Filter: #fn1 Eq Utf8(\"X\") And #person.age Lt Int64(30)\ + \n Projection: #person.first_name AS fn1, #person.age\ + \n Filter: #person.age Gt Int64(20)\ \n TableScan: person projection=None"; quick_test(sql, expected); @@ -1737,8 +1746,8 @@ mod tests { let sql = "SELECT id, age FROM person HAVING age > 100 AND age < 200"; - let expected = "Projection: #id, #age\ - \n Filter: #age Gt Int64(100) And #age Lt Int64(200)\ + let expected = "Projection: #person.id, #person.age\ + \n Filter: #person.age Gt Int64(100) And #person.age Lt Int64(200)\ \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -1784,8 +1793,8 @@ mod tests { let sql = "SELECT MAX(age) FROM person HAVING MAX(age) < 30"; - let expected = "Filter: #MAX(age) Lt Int64(30)\ - \n Aggregate: groupBy=[[]], aggr=[[MAX(#age)]]\ + let expected = "Filter: #MAX(person.age) Lt Int64(30)\ + \n Aggregate: groupBy=[[]], aggr=[[MAX(#person.age)]]\ \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -1795,9 +1804,9 @@ mod tests { let sql = "SELECT MAX(age) FROM person HAVING MAX(first_name) > 'M'"; - let expected = "Projection: #MAX(age)\ - \n Filter: #MAX(first_name) Gt Utf8(\"M\")\ - \n Aggregate: groupBy=[[]], aggr=[[MAX(#age), MAX(#first_name)]]\ + let expected = "Projection: #MAX(person.age)\ + \n Filter: #MAX(person.first_name) Gt Utf8(\"M\")\ + \n Aggregate: groupBy=[[]], aggr=[[MAX(#person.age), MAX(#person.first_name)]]\ \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -1819,9 +1828,10 @@ mod tests { let sql = "SELECT MAX(age) as max_age FROM person HAVING max_age < 30"; - let expected = "Projection: #MAX(age) AS max_age\ - \n Filter: #MAX(age) Lt Int64(30)\ - \n Aggregate: groupBy=[[]], aggr=[[MAX(#age)]]\ + // FIXME: add test for having in execution + let expected = "Projection: #MAX(person.age) AS max_age\ + \n Filter: #MAX(person.age) Lt Int64(30)\ + \n Aggregate: groupBy=[[]], aggr=[[MAX(#person.age)]]\ \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -1831,9 +1841,9 @@ mod tests { let sql = "SELECT MAX(age) as max_age FROM person HAVING MAX(age) < 30"; - let expected = "Projection: #MAX(age) AS max_age\ - \n Filter: #MAX(age) Lt Int64(30)\ - \n Aggregate: groupBy=[[]], aggr=[[MAX(#age)]]\ + let expected = "Projection: #MAX(person.age) AS max_age\ + \n Filter: #MAX(person.age) Lt Int64(30)\ + \n Aggregate: groupBy=[[]], aggr=[[MAX(#person.age)]]\ \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -1844,8 +1854,8 @@ mod tests { FROM person GROUP BY first_name HAVING first_name = 'M'"; - let expected = "Filter: #first_name Eq Utf8(\"M\")\ - \n Aggregate: groupBy=[[#first_name]], aggr=[[MAX(#age)]]\ + let expected = "Filter: #person.first_name Eq Utf8(\"M\")\ + \n Aggregate: groupBy=[[#person.first_name]], aggr=[[MAX(#person.age)]]\ \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -1857,9 +1867,9 @@ mod tests { WHERE id > 5 GROUP BY first_name HAVING MAX(age) < 100"; - let expected = "Filter: #MAX(age) Lt Int64(100)\ - \n Aggregate: groupBy=[[#first_name]], aggr=[[MAX(#age)]]\ - \n Filter: #id Gt Int64(5)\ + let expected = "Filter: #MAX(person.age) Lt Int64(100)\ + \n Aggregate: groupBy=[[#person.first_name]], aggr=[[MAX(#person.age)]]\ + \n Filter: #person.id Gt Int64(5)\ \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -1872,9 +1882,9 @@ mod tests { WHERE id > 5 AND age > 18 GROUP BY first_name HAVING MAX(age) < 100"; - let expected = "Filter: #MAX(age) Lt Int64(100)\ - \n Aggregate: groupBy=[[#first_name]], aggr=[[MAX(#age)]]\ - \n Filter: #id Gt Int64(5) And #age Gt Int64(18)\ + let expected = "Filter: #MAX(person.age) Lt Int64(100)\ + \n Aggregate: groupBy=[[#person.first_name]], aggr=[[MAX(#person.age)]]\ + \n Filter: #person.id Gt Int64(5) And #person.age Gt Int64(18)\ \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -1885,9 +1895,9 @@ mod tests { FROM person GROUP BY first_name HAVING MAX(age) > 2 AND fn = 'M'"; - let expected = "Projection: #first_name AS fn, #MAX(age)\ - \n Filter: #MAX(age) Gt Int64(2) And #first_name Eq Utf8(\"M\")\ - \n Aggregate: groupBy=[[#first_name]], aggr=[[MAX(#age)]]\ + let expected = "Projection: #person.first_name AS fn, #MAX(person.age)\ + \n Filter: #MAX(person.age) Gt Int64(2) And #person.first_name Eq Utf8(\"M\")\ + \n Aggregate: groupBy=[[#person.first_name]], aggr=[[MAX(#person.age)]]\ \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -1899,9 +1909,9 @@ mod tests { FROM person GROUP BY first_name HAVING MAX(age) > 2 AND max_age < 5 AND first_name = 'M' AND fn = 'N'"; - let expected = "Projection: #first_name AS fn, #MAX(age) AS max_age\ - \n Filter: #MAX(age) Gt Int64(2) And #MAX(age) Lt Int64(5) And #first_name Eq Utf8(\"M\") And #first_name Eq Utf8(\"N\")\ - \n Aggregate: groupBy=[[#first_name]], aggr=[[MAX(#age)]]\ + let expected = "Projection: #person.first_name AS fn, #MAX(person.age) AS max_age\ + \n Filter: #MAX(person.age) Gt Int64(2) And #MAX(person.age) Lt Int64(5) And #person.first_name Eq Utf8(\"M\") And #person.first_name Eq Utf8(\"N\")\ + \n Aggregate: groupBy=[[#person.first_name]], aggr=[[MAX(#person.age)]]\ \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -1912,8 +1922,8 @@ mod tests { FROM person GROUP BY first_name HAVING MAX(age) > 100"; - let expected = "Filter: #MAX(age) Gt Int64(100)\ - \n Aggregate: groupBy=[[#first_name]], aggr=[[MAX(#age)]]\ + let expected = "Filter: #MAX(person.age) Gt Int64(100)\ + \n Aggregate: groupBy=[[#person.first_name]], aggr=[[MAX(#person.age)]]\ \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -1937,8 +1947,8 @@ mod tests { FROM person GROUP BY first_name HAVING MAX(age) > 100 AND MAX(age) < 200"; - let expected = "Filter: #MAX(age) Gt Int64(100) And #MAX(age) Lt Int64(200)\ - \n Aggregate: groupBy=[[#first_name]], aggr=[[MAX(#age)]]\ + let expected = "Filter: #MAX(person.age) Gt Int64(100) And #MAX(person.age) Lt Int64(200)\ + \n Aggregate: groupBy=[[#person.first_name]], aggr=[[MAX(#person.age)]]\ \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -1949,9 +1959,9 @@ mod tests { FROM person GROUP BY first_name HAVING MAX(age) > 100 AND MIN(id) < 50"; - let expected = "Projection: #first_name, #MAX(age)\ - \n Filter: #MAX(age) Gt Int64(100) And #MIN(id) Lt Int64(50)\ - \n Aggregate: groupBy=[[#first_name]], aggr=[[MAX(#age), MIN(#id)]]\ + let expected = "Projection: #person.first_name, #MAX(person.age)\ + \n Filter: #MAX(person.age) Gt Int64(100) And #MIN(person.id) Lt Int64(50)\ + \n Aggregate: groupBy=[[#person.first_name]], aggr=[[MAX(#person.age), MIN(#person.id)]]\ \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -1963,9 +1973,9 @@ mod tests { FROM person GROUP BY first_name HAVING max_age > 100"; - let expected = "Projection: #first_name, #MAX(age) AS max_age\ - \n Filter: #MAX(age) Gt Int64(100)\ - \n Aggregate: groupBy=[[#first_name]], aggr=[[MAX(#age)]]\ + let expected = "Projection: #person.first_name, #MAX(person.age) AS max_age\ + \n Filter: #MAX(person.age) Gt Int64(100)\ + \n Aggregate: groupBy=[[#person.first_name]], aggr=[[MAX(#person.age)]]\ \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -1978,9 +1988,9 @@ mod tests { GROUP BY first_name HAVING max_age_plus_one > 100"; let expected = - "Projection: #first_name, #MAX(age) Plus Int64(1) AS max_age_plus_one\ - \n Filter: #MAX(age) Plus Int64(1) Gt Int64(100)\ - \n Aggregate: groupBy=[[#first_name]], aggr=[[MAX(#age)]]\ + "Projection: #person.first_name, #MAX(person.age) Plus Int64(1) AS max_age_plus_one\ + \n Filter: #MAX(person.age) Plus Int64(1) Gt Int64(100)\ + \n Aggregate: groupBy=[[#person.first_name]], aggr=[[MAX(#person.age)]]\ \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -1992,9 +2002,9 @@ mod tests { FROM person GROUP BY first_name HAVING MAX(age) > 100 AND MIN(id - 2) < 50"; - let expected = "Projection: #first_name, #MAX(age)\ - \n Filter: #MAX(age) Gt Int64(100) And #MIN(id Minus Int64(2)) Lt Int64(50)\ - \n Aggregate: groupBy=[[#first_name]], aggr=[[MAX(#age), MIN(#id Minus Int64(2))]]\ + let expected = "Projection: #person.first_name, #MAX(person.age)\ + \n Filter: #MAX(person.age) Gt Int64(100) And #MIN(person.id Minus Int64(2)) Lt Int64(50)\ + \n Aggregate: groupBy=[[#person.first_name]], aggr=[[MAX(#person.age), MIN(#person.id Minus Int64(2))]]\ \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -2005,9 +2015,9 @@ mod tests { FROM person GROUP BY first_name HAVING MAX(age) > 100 AND COUNT(*) < 50"; - let expected = "Projection: #first_name, #MAX(age)\ - \n Filter: #MAX(age) Gt Int64(100) And #COUNT(UInt8(1)) Lt Int64(50)\ - \n Aggregate: groupBy=[[#first_name]], aggr=[[MAX(#age), COUNT(UInt8(1))]]\ + let expected = "Projection: #person.first_name, #MAX(person.age)\ + \n Filter: #MAX(person.age) Gt Int64(100) And #COUNT(UInt8(1)) Lt Int64(50)\ + \n Aggregate: groupBy=[[#person.first_name]], aggr=[[MAX(#person.age), COUNT(UInt8(1))]]\ \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -2015,7 +2025,7 @@ mod tests { #[test] fn select_binary_expr() { let sql = "SELECT age + salary from person"; - let expected = "Projection: #age Plus #salary\ + let expected = "Projection: #person.age Plus #person.salary\ \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -2023,7 +2033,7 @@ mod tests { #[test] fn select_binary_expr_nested() { let sql = "SELECT (age + salary)/2 from person"; - let expected = "Projection: #age Plus #salary Divide Int64(2)\ + let expected = "Projection: #person.age Plus #person.salary Divide Int64(2)\ \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -2032,13 +2042,13 @@ mod tests { fn select_wildcard_with_groupby() { quick_test( "SELECT * FROM person GROUP BY id, first_name, last_name, age, state, salary, birth_date", - "Aggregate: groupBy=[[#id, #first_name, #last_name, #age, #state, #salary, #birth_date]], aggr=[[]]\ + "Aggregate: groupBy=[[#person.id, #person.first_name, #person.last_name, #person.age, #person.state, #person.salary, #person.birth_date]], aggr=[[]]\ \n TableScan: person projection=None", ); quick_test( "SELECT * FROM (SELECT first_name, last_name FROM person) GROUP BY first_name, last_name", - "Aggregate: groupBy=[[#first_name, #last_name]], aggr=[[]]\ - \n Projection: #first_name, #last_name\ + "Aggregate: groupBy=[[#person.first_name, #person.last_name]], aggr=[[]]\ + \n Projection: #person.first_name, #person.last_name\ \n TableScan: person projection=None", ); } @@ -2047,7 +2057,7 @@ mod tests { fn select_simple_aggregate() { quick_test( "SELECT MIN(age) FROM person", - "Aggregate: groupBy=[[]], aggr=[[MIN(#age)]]\ + "Aggregate: groupBy=[[]], aggr=[[MIN(#person.age)]]\ \n TableScan: person projection=None", ); } @@ -2056,7 +2066,7 @@ mod tests { fn test_sum_aggregate() { quick_test( "SELECT SUM(age) from person", - "Aggregate: groupBy=[[]], aggr=[[SUM(#age)]]\ + "Aggregate: groupBy=[[]], aggr=[[SUM(#person.age)]]\ \n TableScan: person projection=None", ); } @@ -2066,10 +2076,7 @@ mod tests { let sql = "SELECT MIN(doesnotexist) FROM person"; let err = logical_plan(sql).expect_err("query should have failed"); assert_eq!( - format!( - "Plan(\"Invalid identifier \\\'doesnotexist\\\' for schema {}\")", - PERSON_COLUMN_NAMES - ), + "Plan(\"No field with unqualified name 'doesnotexist'\")", format!("{:?}", err) ); } @@ -2079,7 +2086,7 @@ mod tests { let sql = "SELECT MIN(age), MIN(age) FROM person"; let err = logical_plan(sql).expect_err("query should have failed"); assert_eq!( - "Plan(\"Projections require unique expression names but the expression \\\"#MIN(age)\\\" at position 0 and \\\"#MIN(age)\\\" at position 1 have the same name. Consider aliasing (\\\"AS\\\") one of them.\")", + "Plan(\"Projections require unique expression names but the expression \\\"#MIN(person.age)\\\" at position 0 and \\\"#MIN(person.age)\\\" at position 1 have the same name. Consider aliasing (\\\"AS\\\") one of them.\")", format!("{:?}", err) ); } @@ -2088,8 +2095,8 @@ mod tests { fn select_simple_aggregate_repeated_aggregate_with_single_alias() { quick_test( "SELECT MIN(age), MIN(age) AS a FROM person", - "Projection: #MIN(age), #MIN(age) AS a\ - \n Aggregate: groupBy=[[]], aggr=[[MIN(#age)]]\ + "Projection: #MIN(person.age), #MIN(person.age) AS a\ + \n Aggregate: groupBy=[[]], aggr=[[MIN(#person.age)]]\ \n TableScan: person projection=None", ); } @@ -2098,8 +2105,8 @@ mod tests { fn select_simple_aggregate_repeated_aggregate_with_unique_aliases() { quick_test( "SELECT MIN(age) AS a, MIN(age) AS b FROM person", - "Projection: #MIN(age) AS a, #MIN(age) AS b\ - \n Aggregate: groupBy=[[]], aggr=[[MIN(#age)]]\ + "Projection: #MIN(person.age) AS a, #MIN(person.age) AS b\ + \n Aggregate: groupBy=[[]], aggr=[[MIN(#person.age)]]\ \n TableScan: person projection=None", ); } @@ -2109,7 +2116,7 @@ mod tests { let sql = "SELECT MIN(age) AS a, MIN(age) AS a FROM person"; let err = logical_plan(sql).expect_err("query should have failed"); assert_eq!( - "Plan(\"Projections require unique expression names but the expression \\\"#MIN(age) AS a\\\" at position 0 and \\\"#MIN(age) AS a\\\" at position 1 have the same name. Consider aliasing (\\\"AS\\\") one of them.\")", + "Plan(\"Projections require unique expression names but the expression \\\"#MIN(person.age) AS a\\\" at position 0 and \\\"#MIN(person.age) AS a\\\" at position 1 have the same name. Consider aliasing (\\\"AS\\\") one of them.\")", format!("{:?}", err) ); } @@ -2118,7 +2125,7 @@ mod tests { fn select_simple_aggregate_with_groupby() { quick_test( "SELECT state, MIN(age), MAX(age) FROM person GROUP BY state", - "Aggregate: groupBy=[[#state]], aggr=[[MIN(#age), MAX(#age)]]\ + "Aggregate: groupBy=[[#person.state]], aggr=[[MIN(#person.age), MAX(#person.age)]]\ \n TableScan: person projection=None", ); } @@ -2127,8 +2134,8 @@ mod tests { fn select_simple_aggregate_with_groupby_with_aliases() { quick_test( "SELECT state AS a, MIN(age) AS b FROM person GROUP BY state", - "Projection: #state AS a, #MIN(age) AS b\ - \n Aggregate: groupBy=[[#state]], aggr=[[MIN(#age)]]\ + "Projection: #person.state AS a, #MIN(person.age) AS b\ + \n Aggregate: groupBy=[[#person.state]], aggr=[[MIN(#person.age)]]\ \n TableScan: person projection=None", ); } @@ -2138,7 +2145,7 @@ mod tests { let sql = "SELECT state AS a, MIN(age) AS a FROM person GROUP BY state"; let err = logical_plan(sql).expect_err("query should have failed"); assert_eq!( - "Plan(\"Projections require unique expression names but the expression \\\"#state AS a\\\" at position 0 and \\\"#MIN(age) AS a\\\" at position 1 have the same name. Consider aliasing (\\\"AS\\\") one of them.\")", + "Plan(\"Projections require unique expression names but the expression \\\"#person.state AS a\\\" at position 0 and \\\"#MIN(person.age) AS a\\\" at position 1 have the same name. Consider aliasing (\\\"AS\\\") one of them.\")", format!("{:?}", err) ); } @@ -2147,8 +2154,8 @@ mod tests { fn select_simple_aggregate_with_groupby_column_unselected() { quick_test( "SELECT MIN(age), MAX(age) FROM person GROUP BY state", - "Projection: #MIN(age), #MAX(age)\ - \n Aggregate: groupBy=[[#state]], aggr=[[MIN(#age), MAX(#age)]]\ + "Projection: #MIN(person.age), #MAX(person.age)\ + \n Aggregate: groupBy=[[#person.state]], aggr=[[MIN(#person.age), MAX(#person.age)]]\ \n TableScan: person projection=None", ); } @@ -2158,10 +2165,7 @@ mod tests { let sql = "SELECT SUM(age) FROM person GROUP BY doesnotexist"; let err = logical_plan(sql).expect_err("query should have failed"); assert_eq!( - format!( - "Plan(\"Invalid identifier \\\'doesnotexist\\\' for schema {}\")", - PERSON_COLUMN_NAMES - ), + "Plan(\"No field with unqualified name 'doesnotexist'\")", format!("{:?}", err) ); } @@ -2171,10 +2175,7 @@ mod tests { let sql = "SELECT SUM(doesnotexist) FROM person GROUP BY first_name"; let err = logical_plan(sql).expect_err("query should have failed"); assert_eq!( - format!( - "Plan(\"Invalid identifier \\\'doesnotexist\\\' for schema {}\")", - PERSON_COLUMN_NAMES - ), + "Plan(\"No field with unqualified name 'doesnotexist'\")", format!("{:?}", err) ); } @@ -2194,7 +2195,7 @@ mod tests { let sql = "SELECT INTERVAL '1 year 1 day'"; let err = logical_plan(sql).expect_err("query should have failed"); assert_eq!( - "NotImplemented(\"DF does not support intervals that have both a Year/Month part as well as Days/Hours/Mins/Seconds: \\\"1 year 1 day\\\". Hint: try breaking the interval into two parts, one with Year/Month and the other with Days/Hours/Mins/Seconds - e.g. (NOW() + INTERVAL \\\'1 year\\\') + INTERVAL \\\'1 day\\\'\")", + "NotImplemented(\"DF does not support intervals that have both a Year/Month part as well as Days/Hours/Mins/Seconds: \\\"1 year 1 day\\\". Hint: try breaking the interval into two parts, one with Year/Month and the other with Days/Hours/Mins/Seconds - e.g. (NOW() + INTERVAL '1 year') + INTERVAL '1 day'\")", format!("{:?}", err) ); } @@ -2203,8 +2204,8 @@ mod tests { fn select_simple_aggregate_with_groupby_and_column_is_in_aggregate_and_groupby() { quick_test( "SELECT MAX(first_name) FROM person GROUP BY first_name", - "Projection: #MAX(first_name)\ - \n Aggregate: groupBy=[[#first_name]], aggr=[[MAX(#first_name)]]\ + "Projection: #MAX(person.first_name)\ + \n Aggregate: groupBy=[[#person.first_name]], aggr=[[MAX(#person.first_name)]]\ \n TableScan: person projection=None", ); } @@ -2214,10 +2215,7 @@ mod tests { let sql = "SELECT state AS x, MAX(age) FROM person GROUP BY x"; let err = logical_plan(sql).expect_err("query should have failed"); assert_eq!( - format!( - "Plan(\"Invalid identifier \\\'x\\\' for schema {}\")", - PERSON_COLUMN_NAMES - ), + "Plan(\"No field with unqualified name 'x'\")", format!("{:?}", err) ); } @@ -2227,7 +2225,7 @@ mod tests { let sql = "SELECT state, MIN(age), MIN(age) FROM person GROUP BY state"; let err = logical_plan(sql).expect_err("query should have failed"); assert_eq!( - "Plan(\"Projections require unique expression names but the expression \\\"#MIN(age)\\\" at position 1 and \\\"#MIN(age)\\\" at position 2 have the same name. Consider aliasing (\\\"AS\\\") one of them.\")", + "Plan(\"Projections require unique expression names but the expression \\\"#MIN(person.age)\\\" at position 1 and \\\"#MIN(person.age)\\\" at position 2 have the same name. Consider aliasing (\\\"AS\\\") one of them.\")", format!("{:?}", err) ); } @@ -2236,8 +2234,8 @@ mod tests { fn select_simple_aggregate_with_groupby_aggregate_repeated_and_one_has_alias() { quick_test( "SELECT state, MIN(age), MIN(age) AS ma FROM person GROUP BY state", - "Projection: #state, #MIN(age), #MIN(age) AS ma\ - \n Aggregate: groupBy=[[#state]], aggr=[[MIN(#age)]]\ + "Projection: #person.state, #MIN(person.age), #MIN(person.age) AS ma\ + \n Aggregate: groupBy=[[#person.state]], aggr=[[MIN(#person.age)]]\ \n TableScan: person projection=None", ) } @@ -2245,8 +2243,8 @@ mod tests { fn select_simple_aggregate_with_groupby_non_column_expression_unselected() { quick_test( "SELECT MIN(first_name) FROM person GROUP BY age + 1", - "Projection: #MIN(first_name)\ - \n Aggregate: groupBy=[[#age Plus Int64(1)]], aggr=[[MIN(#first_name)]]\ + "Projection: #MIN(person.first_name)\ + \n Aggregate: groupBy=[[#person.age Plus Int64(1)]], aggr=[[MIN(#person.first_name)]]\ \n TableScan: person projection=None", ); } @@ -2256,13 +2254,13 @@ mod tests { ) { quick_test( "SELECT age + 1, MIN(first_name) FROM person GROUP BY age + 1", - "Aggregate: groupBy=[[#age Plus Int64(1)]], aggr=[[MIN(#first_name)]]\ + "Aggregate: groupBy=[[#person.age Plus Int64(1)]], aggr=[[MIN(#person.first_name)]]\ \n TableScan: person projection=None", ); quick_test( "SELECT MIN(first_name), age + 1 FROM person GROUP BY age + 1", - "Projection: #MIN(first_name), #age Plus Int64(1)\ - \n Aggregate: groupBy=[[#age Plus Int64(1)]], aggr=[[MIN(#first_name)]]\ + "Projection: #MIN(person.first_name), #person.age Plus Int64(1)\ + \n Aggregate: groupBy=[[#person.age Plus Int64(1)]], aggr=[[MIN(#person.first_name)]]\ \n TableScan: person projection=None", ); } @@ -2272,8 +2270,8 @@ mod tests { { quick_test( "SELECT ((age + 1) / 2) * (age + 1), MIN(first_name) FROM person GROUP BY age + 1", - "Projection: #age Plus Int64(1) Divide Int64(2) Multiply #age Plus Int64(1), #MIN(first_name)\ - \n Aggregate: groupBy=[[#age Plus Int64(1)]], aggr=[[MIN(#first_name)]]\ + "Projection: #person.age Plus Int64(1) Divide Int64(2) Multiply #person.age Plus Int64(1), #MIN(person.first_name)\ + \n Aggregate: groupBy=[[#person.age Plus Int64(1)]], aggr=[[MIN(#person.first_name)]]\ \n TableScan: person projection=None", ); } @@ -2306,8 +2304,8 @@ mod tests { fn select_simple_aggregate_nested_in_binary_expr_with_groupby() { quick_test( "SELECT state, MIN(age) < 10 FROM person GROUP BY state", - "Projection: #state, #MIN(age) Lt Int64(10)\ - \n Aggregate: groupBy=[[#state]], aggr=[[MIN(#age)]]\ + "Projection: #person.state, #MIN(person.age) Lt Int64(10)\ + \n Aggregate: groupBy=[[#person.state]], aggr=[[MIN(#person.age)]]\ \n TableScan: person projection=None", ); } @@ -2316,8 +2314,8 @@ mod tests { fn select_simple_aggregate_and_nested_groupby_column() { quick_test( "SELECT age + 1, MAX(first_name) FROM person GROUP BY age", - "Projection: #age Plus Int64(1), #MAX(first_name)\ - \n Aggregate: groupBy=[[#age]], aggr=[[MAX(#first_name)]]\ + "Projection: #person.age Plus Int64(1), #MAX(person.first_name)\ + \n Aggregate: groupBy=[[#person.age]], aggr=[[MAX(#person.first_name)]]\ \n TableScan: person projection=None", ); } @@ -2326,8 +2324,8 @@ mod tests { fn select_aggregate_compounded_with_groupby_column() { quick_test( "SELECT age + MIN(salary) FROM person GROUP BY age", - "Projection: #age Plus #MIN(salary)\ - \n Aggregate: groupBy=[[#age]], aggr=[[MIN(#salary)]]\ + "Projection: #person.age Plus #MIN(person.salary)\ + \n Aggregate: groupBy=[[#person.age]], aggr=[[MIN(#person.salary)]]\ \n TableScan: person projection=None", ); } @@ -2336,7 +2334,7 @@ mod tests { fn select_aggregate_with_non_column_inner_expression_with_groupby() { quick_test( "SELECT state, MIN(age + 1) FROM person GROUP BY state", - "Aggregate: groupBy=[[#state]], aggr=[[MIN(#age Plus Int64(1))]]\ + "Aggregate: groupBy=[[#person.state]], aggr=[[MIN(#person.age Plus Int64(1))]]\ \n TableScan: person projection=None", ); } @@ -2345,7 +2343,7 @@ mod tests { fn test_wildcard() { quick_test( "SELECT * from person", - "Projection: #id, #first_name, #last_name, #age, #state, #salary, #birth_date\ + "Projection: #person.id, #person.first_name, #person.last_name, #person.age, #person.state, #person.salary, #person.birth_date\ \n TableScan: person projection=None", ); } @@ -2361,7 +2359,7 @@ mod tests { #[test] fn select_count_column() { let sql = "SELECT COUNT(id) FROM person"; - let expected = "Aggregate: groupBy=[[]], aggr=[[COUNT(#id)]]\ + let expected = "Aggregate: groupBy=[[]], aggr=[[COUNT(#person.id)]]\ \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -2369,15 +2367,15 @@ mod tests { #[test] fn select_scalar_func() { let sql = "SELECT sqrt(age) FROM person"; - let expected = "Projection: sqrt(#age)\ + let expected = "Projection: sqrt(#person.age)\ \n TableScan: person projection=None"; quick_test(sql, expected); } #[test] fn select_aliased_scalar_func() { - let sql = "SELECT sqrt(age) AS square_people FROM person"; - let expected = "Projection: sqrt(#age) AS square_people\ + let sql = "SELECT sqrt(person.age) AS square_people FROM person"; + let expected = "Projection: sqrt(#person.age) AS square_people\ \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -2386,8 +2384,8 @@ mod tests { fn select_where_nullif_division() { let sql = "SELECT c3/(c4+c5) \ FROM aggregate_test_100 WHERE c3/nullif(c4+c5, 0) > 0.1"; - let expected = "Projection: #c3 Divide #c4 Plus #c5\ - \n Filter: #c3 Divide nullif(#c4 Plus #c5, Int64(0)) Gt Float64(0.1)\ + let expected = "Projection: #aggregate_test_100.c3 Divide #aggregate_test_100.c4 Plus #aggregate_test_100.c5\ + \n Filter: #aggregate_test_100.c3 Divide nullif(#aggregate_test_100.c4 Plus #aggregate_test_100.c5, Int64(0)) Gt Float64(0.1)\ \n TableScan: aggregate_test_100 projection=None"; quick_test(sql, expected); } @@ -2395,8 +2393,8 @@ mod tests { #[test] fn select_where_with_negative_operator() { let sql = "SELECT c3 FROM aggregate_test_100 WHERE c3 > -0.1 AND -c4 > 0"; - let expected = "Projection: #c3\ - \n Filter: #c3 Gt Float64(-0.1) And (- #c4) Gt Int64(0)\ + let expected = "Projection: #aggregate_test_100.c3\ + \n Filter: #aggregate_test_100.c3 Gt Float64(-0.1) And (- #aggregate_test_100.c4) Gt Int64(0)\ \n TableScan: aggregate_test_100 projection=None"; quick_test(sql, expected); } @@ -2404,8 +2402,8 @@ mod tests { #[test] fn select_where_with_positive_operator() { let sql = "SELECT c3 FROM aggregate_test_100 WHERE c3 > +0.1 AND +c4 > 0"; - let expected = "Projection: #c3\ - \n Filter: #c3 Gt Float64(0.1) And #c4 Gt Int64(0)\ + let expected = "Projection: #aggregate_test_100.c3\ + \n Filter: #aggregate_test_100.c3 Gt Float64(0.1) And #aggregate_test_100.c4 Gt Int64(0)\ \n TableScan: aggregate_test_100 projection=None"; quick_test(sql, expected); } @@ -2413,8 +2411,8 @@ mod tests { #[test] fn select_order_by() { let sql = "SELECT id FROM person ORDER BY id"; - let expected = "Sort: #id ASC NULLS FIRST\ - \n Projection: #id\ + let expected = "Sort: #person.id ASC NULLS FIRST\ + \n Projection: #person.id\ \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -2422,8 +2420,8 @@ mod tests { #[test] fn select_order_by_desc() { let sql = "SELECT id FROM person ORDER BY id DESC"; - let expected = "Sort: #id DESC NULLS FIRST\ - \n Projection: #id\ + let expected = "Sort: #person.id DESC NULLS FIRST\ + \n Projection: #person.id\ \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -2432,15 +2430,15 @@ mod tests { fn select_order_by_nulls_last() { quick_test( "SELECT id FROM person ORDER BY id DESC NULLS LAST", - "Sort: #id DESC NULLS LAST\ - \n Projection: #id\ + "Sort: #person.id DESC NULLS LAST\ + \n Projection: #person.id\ \n TableScan: person projection=None", ); quick_test( "SELECT id FROM person ORDER BY id NULLS LAST", - "Sort: #id ASC NULLS LAST\ - \n Projection: #id\ + "Sort: #person.id ASC NULLS LAST\ + \n Projection: #person.id\ \n TableScan: person projection=None", ); } @@ -2448,7 +2446,7 @@ mod tests { #[test] fn select_group_by() { let sql = "SELECT state FROM person GROUP BY state"; - let expected = "Aggregate: groupBy=[[#state]], aggr=[[]]\ + let expected = "Aggregate: groupBy=[[#person.state]], aggr=[[]]\ \n TableScan: person projection=None"; quick_test(sql, expected); @@ -2457,8 +2455,8 @@ mod tests { #[test] fn select_group_by_columns_not_in_select() { let sql = "SELECT MAX(age) FROM person GROUP BY state"; - let expected = "Projection: #MAX(age)\ - \n Aggregate: groupBy=[[#state]], aggr=[[MAX(#age)]]\ + let expected = "Projection: #MAX(person.age)\ + \n Aggregate: groupBy=[[#person.state]], aggr=[[MAX(#person.age)]]\ \n TableScan: person projection=None"; quick_test(sql, expected); @@ -2467,7 +2465,7 @@ mod tests { #[test] fn select_group_by_count_star() { let sql = "SELECT state, COUNT(*) FROM person GROUP BY state"; - let expected = "Aggregate: groupBy=[[#state]], aggr=[[COUNT(UInt8(1))]]\ + let expected = "Aggregate: groupBy=[[#person.state]], aggr=[[COUNT(UInt8(1))]]\ \n TableScan: person projection=None"; quick_test(sql, expected); @@ -2477,8 +2475,8 @@ mod tests { fn select_group_by_needs_projection() { let sql = "SELECT COUNT(state), state FROM person GROUP BY state"; let expected = "\ - Projection: #COUNT(state), #state\ - \n Aggregate: groupBy=[[#state]], aggr=[[COUNT(#state)]]\ + Projection: #COUNT(person.state), #person.state\ + \n Aggregate: groupBy=[[#person.state]], aggr=[[COUNT(#person.state)]]\ \n TableScan: person projection=None"; quick_test(sql, expected); @@ -2487,8 +2485,8 @@ mod tests { #[test] fn select_7480_1() { let sql = "SELECT c1, MIN(c12) FROM aggregate_test_100 GROUP BY c1, c13"; - let expected = "Projection: #c1, #MIN(c12)\ - \n Aggregate: groupBy=[[#c1, #c13]], aggr=[[MIN(#c12)]]\ + let expected = "Projection: #aggregate_test_100.c1, #MIN(aggregate_test_100.c12)\ + \n Aggregate: groupBy=[[#aggregate_test_100.c1, #aggregate_test_100.c13]], aggr=[[MIN(#aggregate_test_100.c12)]]\ \n TableScan: aggregate_test_100 projection=None"; quick_test(sql, expected); } @@ -2544,22 +2542,49 @@ mod tests { FROM person \ JOIN orders \ ON id = customer_id"; - let expected = "Projection: #id, #order_id\ - \n Join: id = customer_id\ + let expected = "Projection: #person.id, #orders.order_id\ + \n Join: #person.id = #orders.customer_id\ \n TableScan: person projection=None\ \n TableScan: orders projection=None"; quick_test(sql, expected); } + #[test] + fn join_with_table_name() { + let sql = "SELECT id, order_id \ + FROM person \ + JOIN orders \ + ON person.id = orders.customer_id"; + let expected = "Projection: #person.id, #orders.order_id\ + \n Join: #person.id = #orders.customer_id\ + \n TableScan: person projection=None\ + \n TableScan: orders projection=None"; + quick_test(sql, expected); + } + + #[test] + fn join_with_using() { + let sql = "SELECT person.first_name, id \ + FROM person \ + JOIN person as person2 \ + USING (id)"; + let expected = "Projection: #person.first_name, #person.id\ + \n Join: #person.id = #person2.id\ + \n TableScan: person projection=None\ + \n TableScan: person2 projection=None"; + quick_test(sql, expected); + } + #[test] fn equijoin_explicit_syntax_3_tables() { let sql = "SELECT id, order_id, l_description \ FROM person \ JOIN orders ON id = customer_id \ JOIN lineitem ON o_item_id = l_item_id"; - let expected = "Projection: #id, #order_id, #l_description\ - \n Join: o_item_id = l_item_id\ - \n Join: id = customer_id\ + let expected = + "Projection: #person.id, #orders.order_id, #lineitem.l_description\ + \n Join: #orders.o_item_id = #lineitem.l_item_id\ + \n Join: #person.id = #orders.customer_id\ \n TableScan: person projection=None\ \n TableScan: orders projection=None\ \n TableScan: lineitem projection=None"; @@ -2571,8 +2596,8 @@ mod tests { let sql = "SELECT order_id \ FROM orders \ WHERE delivered = false OR delivered = true"; - let expected = "Projection: #order_id\ - \n Filter: #delivered Eq Boolean(false) Or #delivered Eq Boolean(true)\ + let expected = "Projection: #orders.order_id\ + \n Filter: #orders.delivered Eq Boolean(false) Or #orders.delivered Eq Boolean(true)\ \n TableScan: orders projection=None"; quick_test(sql, expected); } @@ -2581,9 +2606,9 @@ mod tests { fn union() { let sql = "SELECT order_id from orders UNION ALL SELECT order_id FROM orders"; let expected = "Union\ - \n Projection: #order_id\ + \n Projection: #orders.order_id\ \n TableScan: orders projection=None\ - \n Projection: #order_id\ + \n Projection: #orders.order_id\ \n TableScan: orders projection=None"; quick_test(sql, expected); } @@ -2595,13 +2620,13 @@ mod tests { UNION ALL SELECT order_id FROM orders UNION ALL SELECT order_id FROM orders"; let expected = "Union\ - \n Projection: #order_id\ + \n Projection: #orders.order_id\ \n TableScan: orders projection=None\ - \n Projection: #order_id\ + \n Projection: #orders.order_id\ \n TableScan: orders projection=None\ - \n Projection: #order_id\ + \n Projection: #orders.order_id\ \n TableScan: orders projection=None\ - \n Projection: #order_id\ + \n Projection: #orders.order_id\ \n TableScan: orders projection=None"; quick_test(sql, expected); } @@ -2644,7 +2669,7 @@ mod tests { /// Create logical plan, write with formatter, compare to expected output fn quick_test(sql: &str, expected: &str) { let plan = logical_plan(sql).unwrap(); - assert_eq!(expected, format!("{:?}", plan)); + assert_eq!(format!("{:?}", plan), expected); } struct MockContextProvider {} @@ -2679,6 +2704,7 @@ mod tests { "lineitem" => Some(Schema::new(vec![ Field::new("l_item_id", DataType::UInt32, false), Field::new("l_description", DataType::Utf8, false), + Field::new("price", DataType::Float64, false), ])), "aggregate_test_100" => Some(Schema::new(vec![ Field::new("c1", DataType::Utf8, false), diff --git a/datafusion/src/sql/utils.rs b/datafusion/src/sql/utils.rs index f41643d2ab44..b6f245e0fb57 100644 --- a/datafusion/src/sql/utils.rs +++ b/datafusion/src/sql/utils.rs @@ -18,7 +18,7 @@ use crate::logical_plan::{DFSchema, Expr, LogicalPlan}; use crate::{ error::{DataFusionError, Result}, - logical_plan::{ExpressionVisitor, Recursion}, + logical_plan::{Column, ExpressionVisitor, Recursion}, }; use std::collections::HashMap; @@ -28,7 +28,7 @@ pub(crate) fn expand_wildcard(expr: &Expr, schema: &DFSchema) -> Vec { Expr::Wildcard => schema .fields() .iter() - .map(|f| Expr::Column(f.name().to_string())) + .map(|f| Expr::Column(f.qualified_column())) .collect::>(), _ => vec![expr.clone()], } @@ -127,7 +127,7 @@ where pub(crate) fn expr_as_column_expr(expr: &Expr, plan: &LogicalPlan) -> Result { match expr { Expr::Column(_) => Ok(expr.clone()), - _ => Ok(Expr::Column(expr.name(&plan.schema())?)), + _ => Ok(Expr::Column(Column::from_name(expr.name(&plan.schema())?))), } } @@ -335,7 +335,7 @@ where asc: *asc, nulls_first: *nulls_first, }), - Expr::Column(_) | Expr::Literal(_) | Expr::ScalarVariable(_) => { + Expr::Column { .. } | Expr::Literal(_) | Expr::ScalarVariable(_) => { Ok(expr.clone()) } Expr::Wildcard => Ok(Expr::Wildcard), @@ -364,8 +364,8 @@ pub(crate) fn resolve_aliases_to_exprs( aliases: &HashMap, ) -> Result { clone_with_replacement(expr, &|nested_expr| match nested_expr { - Expr::Column(name) => { - if let Some(aliased_expr) = aliases.get(name) { + Expr::Column(c) if c.relation.is_none() => { + if let Some(aliased_expr) = aliases.get(&c.name) { Ok(Some(aliased_expr.clone())) } else { Ok(None) diff --git a/datafusion/src/test/mod.rs b/datafusion/src/test/mod.rs index 926a69226169..8b7dc9e20970 100644 --- a/datafusion/src/test/mod.rs +++ b/datafusion/src/test/mod.rs @@ -117,7 +117,7 @@ pub fn test_table_scan() -> Result { Field::new("b", DataType::UInt32, false), Field::new("c", DataType::UInt32, false), ]); - LogicalPlanBuilder::scan_empty("test", &schema, None)?.build() + LogicalPlanBuilder::scan_empty(Some("test"), &schema, None)?.build() } pub fn assert_fields_eq(plan: &LogicalPlan, expected: Vec<&str>) { diff --git a/datafusion/tests/dataframe.rs b/datafusion/tests/dataframe.rs index b93e21f4abab..c56ac5a29ff9 100644 --- a/datafusion/tests/dataframe.rs +++ b/datafusion/tests/dataframe.rs @@ -24,6 +24,7 @@ use arrow::{ }; use datafusion::error::Result; +use datafusion::logical_plan::Column; use datafusion::{datasource::MemTable, prelude::JoinType}; use datafusion::execution::context::ExecutionContext; @@ -69,7 +70,12 @@ async fn join() -> Result<()> { let df2 = ctx.table("aaa")?; - let a = df1.join(df2, JoinType::Inner, &["a"], &["a"])?; + let a = df1.join( + df2, + JoinType::Inner, + vec![Column::from_name("a".to_string())], + vec![Column::from_name("a".to_string())], + )?; let batches = a.collect().await?; diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs index f4d4e65f3a4e..8185ba97f2c2 100644 --- a/datafusion/tests/sql.rs +++ b/datafusion/tests/sql.rs @@ -1391,12 +1391,12 @@ async fn csv_explain() { register_aggregate_csv_by_sql(&mut ctx).await; let sql = "EXPLAIN SELECT c1 FROM aggregate_test_100 where c2 > 10"; let actual = execute(&mut ctx, sql).await; - let expected = vec![ - vec![ - "logical_plan", - "Projection: #c1\n Filter: #c2 Gt Int64(10)\n TableScan: aggregate_test_100 projection=None" - ] - ]; + let expected = vec![vec![ + "logical_plan", + "Projection: #aggregate_test_100.c1\ + \n Filter: #aggregate_test_100.c2 Gt Int64(10)\ + \n TableScan: aggregate_test_100 projection=None", + ]]; assert_eq!(expected, actual); // Also, expect same result with lowercase explain @@ -1420,7 +1420,11 @@ async fn csv_explain_verbose() { // pain). Instead just check for a few key pieces. assert!(actual.contains("logical_plan"), "Actual: '{}'", actual); assert!(actual.contains("physical_plan"), "Actual: '{}'", actual); - assert!(actual.contains("#c2 Gt Int64(10)"), "Actual: '{}'", actual); + assert!( + actual.contains("#aggregate_test_100.c2 Gt Int64(10)"), + "Actual: '{}'", + actual + ); } fn aggr_test_schema() -> SchemaRef { diff --git a/datafusion/tests/user_defined_plan.rs b/datafusion/tests/user_defined_plan.rs index f9f24430104c..094fefa5a599 100644 --- a/datafusion/tests/user_defined_plan.rs +++ b/datafusion/tests/user_defined_plan.rs @@ -161,11 +161,9 @@ async fn topk_query() -> Result<()> { async fn topk_plan() -> Result<()> { let mut ctx = setup_table(make_topk_context()).await?; - let expected = vec![ - "| logical_plan after topk | TopK: k=3 |", - "| | Projection: #customer_id, #revenue |", - "| | TableScan: sales projection=Some([0, 1]) |", - ].join("\n"); + let expected = "| logical_plan after topk | TopK: k=3 |\ + \n| | Projection: #sales.customer_id, #sales.revenue |\ + \n| | TableScan: sales projection=Some([0, 1]) |"; let explain_query = format!("EXPLAIN VERBOSE {}", QUERY); let actual_output = exec_sql(&mut ctx, &explain_query).await?; @@ -173,7 +171,18 @@ async fn topk_plan() -> Result<()> { // normalize newlines (output on windows uses \r\n) let actual_output = actual_output.replace("\r\n", "\n"); - assert!(actual_output.contains(&expected) , "Expected output not present in actual output\nExpected:\n---------\n{}\nActual:\n--------\n{}", expected, actual_output); + assert!( + actual_output.contains(&expected), + "Expected output not present in actual output\ + \nExpected:\ + \n---------\ + \n{}\ + \nActual:\ + \n--------\ + \n{}", + expected, + actual_output + ); Ok(()) } From 696f8e04178fef877a658d0ef08c3c4ceebfcfdf Mon Sep 17 00:00:00 2001 From: Qingping Hou Date: Sun, 25 Apr 2021 13:17:08 -0700 Subject: [PATCH 02/25] handle coalesced hash join partition in HashJoinStream --- .../physical_optimizer/coalesce_batches.rs | 14 +------ datafusion/src/physical_plan/hash_join.rs | 39 ++++++++++++++++++- 2 files changed, 40 insertions(+), 13 deletions(-) diff --git a/datafusion/src/physical_optimizer/coalesce_batches.rs b/datafusion/src/physical_optimizer/coalesce_batches.rs index a25b6e2e0ba7..9adee1ce2f2c 100644 --- a/datafusion/src/physical_optimizer/coalesce_batches.rs +++ b/datafusion/src/physical_optimizer/coalesce_batches.rs @@ -23,7 +23,7 @@ use crate::{ error::Result, physical_plan::{ coalesce_batches::CoalesceBatchesExec, filter::FilterExec, - hash_join::HashJoinExec, repartition::RepartitionExec, Partitioning, + hash_join::HashJoinExec, repartition::RepartitionExec, }, }; use std::sync::Arc; @@ -59,17 +59,7 @@ impl PhysicalOptimizerRule for CoalesceBatches { // See https://issues.apache.org/jira/browse/ARROW-11068 let wrap_in_coalesce = plan_any.downcast_ref::().is_some() || plan_any.downcast_ref::().is_some() - || { - match plan_any.downcast_ref::() { - Some(p) => match p.partitioning() { - // do not coalesce hash partitions since other plans like partitioned hash - // join depends on it empty batches for outter joins - Partitioning::Hash(_, _) => false, - _ => true, - }, - None => false, - } - }; + || plan_any.downcast_ref::().is_some(); //TODO we should also do this for HashAggregateExec but we need to update tests // as part of this work - see https://issues.apache.org/jira/browse/ARROW-11068 diff --git a/datafusion/src/physical_plan/hash_join.rs b/datafusion/src/physical_plan/hash_join.rs index 053a00f4e6f2..4318a80e0b3d 100644 --- a/datafusion/src/physical_plan/hash_join.rs +++ b/datafusion/src/physical_plan/hash_join.rs @@ -432,6 +432,9 @@ struct HashJoinStream { join_time: usize, /// Random state used for hashing initialization random_state: RandomState, + /// Whether we should create rows for left side data when right side partition is coalesced to + /// None + preserve_left: bool, } impl HashJoinStream { @@ -445,6 +448,10 @@ impl HashJoinStream { column_indices: Vec, random_state: RandomState, ) -> Self { + let preserve_left = match join_type { + JoinType::Left => left_data.0.len() > 0, + JoinType::Inner | JoinType::Right => false, + }; HashJoinStream { schema, on_left, @@ -459,6 +466,7 @@ impl HashJoinStream { num_output_rows: 0, join_time: 0, random_state, + preserve_left, } } } @@ -871,7 +879,6 @@ impl Stream for HashJoinStream { .map(|maybe_batch| match maybe_batch { Some(Ok(batch)) => { let start = Instant::now(); - let result = build_batch( &batch, &self.left_data, @@ -891,6 +898,36 @@ impl Stream for HashJoinStream { } Some(result) } + // If maybe_batch is None and num_output_rows is 0, that means right side batch was + // empty and has been coalesced to None. Fill right side with Null if preserve_left + // is true. + None if self.preserve_left && self.num_output_rows == 0 => { + let start = Instant::now(); + let num_rows = self.left_data.1.num_rows(); + let mut columns: Vec> = + Vec::with_capacity(self.schema.fields().len()); + for (idx, column_index) in self.column_indices.iter().enumerate() { + let array = if column_index.is_left { + let array = self.left_data.1.column(column_index.index); + array.clone() + } else { + let datatype = self.schema.field(idx).data_type(); + arrow::array::new_null_array(datatype, num_rows) + }; + + columns.push(array); + } + let result = RecordBatch::try_new(self.schema.clone(), columns); + + self.num_input_batches += 1; + self.num_input_rows += num_rows; + if let Ok(ref batch) = result { + self.join_time += start.elapsed().as_millis() as usize; + self.num_output_batches += 1; + self.num_output_rows += batch.num_rows(); + } + Some(result) + } other => { // End of right batch, print stats in debug mode debug!( From cdc5fb76291acf6148699fd39da1455b47960338 Mon Sep 17 00:00:00 2001 From: Qingping Hou Date: Sun, 25 Apr 2021 15:47:24 -0700 Subject: [PATCH 03/25] implement Into for &str --- datafusion/src/dataframe.rs | 15 ++++++--------- datafusion/src/execution/dataframe_impl.rs | 22 +++++++++++----------- datafusion/src/logical_plan/builder.rs | 14 +++++++------- datafusion/src/logical_plan/expr.rs | 10 +++++++++- datafusion/tests/dataframe.rs | 8 +------- 5 files changed, 34 insertions(+), 35 deletions(-) diff --git a/datafusion/src/dataframe.rs b/datafusion/src/dataframe.rs index 59f9fa23cbf8..507a79861cd5 100644 --- a/datafusion/src/dataframe.rs +++ b/datafusion/src/dataframe.rs @@ -20,7 +20,7 @@ use crate::arrow::record_batch::RecordBatch; use crate::error::Result; use crate::logical_plan::{ - Column, DFSchema, Expr, FunctionRegistry, JoinType, LogicalPlan, Partitioning, + DFSchema, Expr, FunctionRegistry, JoinType, LogicalPlan, Partitioning, }; use std::sync::Arc; @@ -175,12 +175,7 @@ pub trait DataFrame: Send + Sync { /// col("a").alias("a2"), /// col("b").alias("b2"), /// col("c").alias("c2")])?; - /// let join = left.join( - /// right, - /// JoinType::Inner, - /// vec![Column::from_name("a".to_string()), Column::from_name("b".to_string())], - /// vec![Column::from_name("a2".to_string()), Column::from_name("b2".to_string())], - /// )?; + /// let join = left.join(right, JoinType::Inner, &["a", "b"], &["a2", "b2"])?; /// let batches = join.collect().await?; /// # Ok(()) /// # } @@ -189,10 +184,12 @@ pub trait DataFrame: Send + Sync { &self, right: Arc, join_type: JoinType, - left_cols: Vec, - right_cols: Vec, + left_cols: &[&str], + right_cols: &[&str], ) -> Result>; + // TODO: add join_using + /// Repartition a DataFrame based on a logical partitioning scheme. /// /// ``` diff --git a/datafusion/src/execution/dataframe_impl.rs b/datafusion/src/execution/dataframe_impl.rs index f178fc4776e7..ccd103e94d02 100644 --- a/datafusion/src/execution/dataframe_impl.rs +++ b/datafusion/src/execution/dataframe_impl.rs @@ -23,8 +23,8 @@ use crate::arrow::record_batch::RecordBatch; use crate::error::Result; use crate::execution::context::{ExecutionContext, ExecutionContextState}; use crate::logical_plan::{ - col, Column, DFSchema, Expr, FunctionRegistry, JoinType, LogicalPlan, - LogicalPlanBuilder, Partitioning, + col, DFSchema, Expr, FunctionRegistry, JoinType, LogicalPlan, LogicalPlanBuilder, + Partitioning, }; use crate::{ dataframe::*, @@ -106,11 +106,16 @@ impl DataFrame for DataFrameImpl { &self, right: Arc, join_type: JoinType, - left_cols: Vec, - right_cols: Vec, + left_cols: &[&str], + right_cols: &[&str], ) -> Result> { let plan = LogicalPlanBuilder::from(&self.plan) - .join(&right.to_logical_plan(), join_type, left_cols, right_cols)? + .join( + &right.to_logical_plan(), + join_type, + left_cols.to_vec(), + right_cols.to_vec(), + )? .build()?; Ok(Arc::new(DataFrameImpl::new(self.ctx_state.clone(), &plan))) } @@ -252,12 +257,7 @@ mod tests { let right = test_table()?.select_columns(&["c1", "c3"])?; let left_rows = left.collect().await?; let right_rows = right.collect().await?; - let join = left.join( - right, - JoinType::Inner, - vec![Column::from_name("c1".to_string())], - vec![Column::from_name("c1".to_string())], - )?; + let join = left.join(right, JoinType::Inner, &["c1"], &["c1"])?; let join_rows = join.collect().await?; assert_eq!(100, left_rows.iter().map(|x| x.num_rows()).sum::()); assert_eq!(100, right_rows.iter().map(|x| x.num_rows()).sum::()); diff --git a/datafusion/src/logical_plan/builder.rs b/datafusion/src/logical_plan/builder.rs index 8b90dae52791..abe82e49ea36 100644 --- a/datafusion/src/logical_plan/builder.rs +++ b/datafusion/src/logical_plan/builder.rs @@ -274,8 +274,8 @@ impl LogicalPlanBuilder { &self, right: &LogicalPlan, join_type: JoinType, - left_keys: Vec, - right_keys: Vec, + left_keys: Vec>, + right_keys: Vec>, ) -> Result { if left_keys.len() != right_keys.len() { return Err(DataFusionError::Plan( @@ -285,12 +285,12 @@ impl LogicalPlanBuilder { let left_keys: Vec = left_keys .into_iter() - .map(|c| c.normalize(&self.plan.all_schemas())) + .map(|c| c.into().normalize(&self.plan.all_schemas())) .collect::>()?; let right_keys: Vec = right_keys .into_iter() // FIXME: write a test for this - .map(|c| c.normalize(&right.all_schemas())) + .map(|c| c.into().normalize(&right.all_schemas())) .collect::>()?; let on: Vec<(_, _)> = left_keys.into_iter().zip(right_keys.into_iter()).collect(); let join_schema = build_join_schema( @@ -315,16 +315,16 @@ impl LogicalPlanBuilder { &self, right: &LogicalPlan, join_type: JoinType, - using_keys: Vec, + using_keys: Vec + Clone>, ) -> Result { let left_keys: Vec = using_keys .clone() .into_iter() - .map(|c| c.normalize(&self.plan.all_schemas())) + .map(|c| c.into().normalize(&self.plan.all_schemas())) .collect::>()?; let right_keys: Vec = using_keys .into_iter() - .map(|c| c.normalize(&right.all_schemas())) + .map(|c| c.into().normalize(&right.all_schemas())) .collect::>()?; let on: Vec<(_, _)> = left_keys.into_iter().zip(right_keys.into_iter()).collect(); diff --git a/datafusion/src/logical_plan/expr.rs b/datafusion/src/logical_plan/expr.rs index d69a5a123c92..fe8e9af2857c 100644 --- a/datafusion/src/logical_plan/expr.rs +++ b/datafusion/src/logical_plan/expr.rs @@ -53,6 +53,7 @@ impl Column { } } + /// Deserialize a flat name string into a column pub fn from_flat_name(flat_name: &str) -> Self { use sqlparser::tokenizer::Token; @@ -75,6 +76,7 @@ impl Column { } } + /// Serialize column into a flat name string pub fn flat_name(&self) -> String { match &self.relation { Some(r) => format!("{}.{}", r, self.name), @@ -101,6 +103,12 @@ impl Column { } } +impl From<&str> for Column { + fn from(c: &str) -> Self { + Self::from_flat_name(c) + } +} + impl fmt::Display for Column { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match &self.relation { @@ -1020,7 +1028,7 @@ pub fn or(left: Expr, right: Expr) -> Expr { /// Create a column expression based on a qualified or unqualified column name pub fn col(ident: &str) -> Expr { - Expr::Column(Column::from_flat_name(ident)) + Expr::Column(ident.into()) } /// Recursively normalize all Column expressions in a given expression tree diff --git a/datafusion/tests/dataframe.rs b/datafusion/tests/dataframe.rs index c56ac5a29ff9..b93e21f4abab 100644 --- a/datafusion/tests/dataframe.rs +++ b/datafusion/tests/dataframe.rs @@ -24,7 +24,6 @@ use arrow::{ }; use datafusion::error::Result; -use datafusion::logical_plan::Column; use datafusion::{datasource::MemTable, prelude::JoinType}; use datafusion::execution::context::ExecutionContext; @@ -70,12 +69,7 @@ async fn join() -> Result<()> { let df2 = ctx.table("aaa")?; - let a = df1.join( - df2, - JoinType::Inner, - vec![Column::from_name("a".to_string())], - vec![Column::from_name("a".to_string())], - )?; + let a = df1.join(df2, JoinType::Inner, &["a"], &["a"])?; let batches = a.collect().await?; From 723ee5d56b38b8a1e6880964f0b0762ed6e1eef3 Mon Sep 17 00:00:00 2001 From: Qingping Hou Date: Sun, 25 Apr 2021 15:52:30 -0700 Subject: [PATCH 04/25] add todo for ARROW-10971 --- datafusion/src/physical_plan/hash_join.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/datafusion/src/physical_plan/hash_join.rs b/datafusion/src/physical_plan/hash_join.rs index 4318a80e0b3d..9d9c1b3b1f45 100644 --- a/datafusion/src/physical_plan/hash_join.rs +++ b/datafusion/src/physical_plan/hash_join.rs @@ -901,6 +901,9 @@ impl Stream for HashJoinStream { // If maybe_batch is None and num_output_rows is 0, that means right side batch was // empty and has been coalesced to None. Fill right side with Null if preserve_left // is true. + // + // TODO: generalize this to keep track of unmatched left rows across batches, see + // https://issues.apache.org/jira/browse/ARROW-10971 None if self.preserve_left && self.num_output_rows == 0 => { let start = Instant::now(); let num_rows = self.left_data.1.num_rows(); From 9cf494f23dfcaf67a6dca0822216088492a4d5e0 Mon Sep 17 00:00:00 2001 From: Qingping Hou Date: Mon, 26 Apr 2021 22:46:21 -0700 Subject: [PATCH 05/25] fix cross join handling in production push down optimizer When a projection is pushed down to cross join inputs, fields from resulting plan's schema need to be trimmed to only contain projected fields. --- datafusion/src/execution/context.rs | 14 +++++++------- datafusion/src/logical_plan/builder.rs | 2 +- datafusion/src/optimizer/utils.rs | 14 +++++++------- datafusion/src/physical_plan/cross_join.rs | 2 +- datafusion/src/physical_plan/hash_utils.rs | 5 ----- datafusion/tests/sql.rs | 2 -- 6 files changed, 16 insertions(+), 23 deletions(-) diff --git a/datafusion/src/execution/context.rs b/datafusion/src/execution/context.rs index 9ead784b08ff..2e70f576bd26 100644 --- a/datafusion/src/execution/context.rs +++ b/datafusion/src/execution/context.rs @@ -1678,13 +1678,13 @@ mod tests { .expect("ran plan correctly"); let expected = vec![ - "+-----+------------+", - "| str | COUNT(val) |", - "+-----+------------+", - "| A | 4 |", - "| B | 1 |", - "| C | 1 |", - "+-----+------------+", + "+-----+--------------+", + "| str | COUNT(t.val) |", + "+-----+--------------+", + "| A | 4 |", + "| B | 1 |", + "| C | 1 |", + "+-----+--------------+", ]; assert_batches_sorted_eq!(expected, &results); } diff --git a/datafusion/src/logical_plan/builder.rs b/datafusion/src/logical_plan/builder.rs index e3cb665e2182..2f54f198d1e9 100644 --- a/datafusion/src/logical_plan/builder.rs +++ b/datafusion/src/logical_plan/builder.rs @@ -344,10 +344,10 @@ impl LogicalPlanBuilder { schema: DFSchemaRef::new(join_schema), })) } + /// Apply a cross join pub fn cross_join(&self, right: &LogicalPlan) -> Result { let schema = self.plan.schema().join(right.schema())?; - Ok(Self::from(&LogicalPlan::CrossJoin { left: Arc::new(self.plan.clone()), right: Arc::new(right.clone()), diff --git a/datafusion/src/optimizer/utils.rs b/datafusion/src/optimizer/utils.rs index 72fa584b9b44..ae5a9847e7d5 100644 --- a/datafusion/src/optimizer/utils.rs +++ b/datafusion/src/optimizer/utils.rs @@ -23,8 +23,8 @@ use arrow::datatypes::Schema; use super::optimizer::OptimizerRule; use crate::logical_plan::{ - Column, Expr, LogicalPlan, Operator, Partitioning, PlanType, Recursion, - StringifiedPlan, ToDFSchema, + Column, Expr, LogicalPlan, LogicalPlanBuilder, Operator, Partitioning, PlanType, + Recursion, StringifiedPlan, ToDFSchema, }; use crate::prelude::lit; use crate::scalar::ScalarValue; @@ -208,11 +208,11 @@ pub fn from_plan( on: on.clone(), schema: schema.clone(), }), - LogicalPlan::CrossJoin { schema, .. } => Ok(LogicalPlan::CrossJoin { - left: Arc::new(inputs[0].clone()), - right: Arc::new(inputs[1].clone()), - schema: schema.clone(), - }), + LogicalPlan::CrossJoin { .. } => { + let left = &inputs[0]; + let right = &inputs[1]; + LogicalPlanBuilder::from(left).cross_join(right)?.build() + } LogicalPlan::Limit { n, .. } => Ok(LogicalPlan::Limit { n: *n, input: Arc::new(inputs[0].clone()), diff --git a/datafusion/src/physical_plan/cross_join.rs b/datafusion/src/physical_plan/cross_join.rs index 4372352d6ecf..c22fca845852 100644 --- a/datafusion/src/physical_plan/cross_join.rs +++ b/datafusion/src/physical_plan/cross_join.rs @@ -67,7 +67,7 @@ impl CrossJoinExec { ) -> Result { let left_schema = left.schema(); let right_schema = right.schema(); - check_join_is_valid(&left_schema, &right_schema, &[])?; + check_join_is_valid(&left_schema, &right_schema, &vec![])?; let left_schema = left.schema(); let left_fields = left_schema.fields().iter(); diff --git a/datafusion/src/physical_plan/hash_utils.rs b/datafusion/src/physical_plan/hash_utils.rs index 6060bd0b5bb5..57be7f8217dd 100644 --- a/datafusion/src/physical_plan/hash_utils.rs +++ b/datafusion/src/physical_plan/hash_utils.rs @@ -63,11 +63,6 @@ fn check_join_set_is_valid( right: &HashSet, on: &[(Column, Column)], ) -> Result<()> { - if on.is_empty() { - return Err(DataFusionError::Plan( - "The 'on' clause of a join cannot be empty".to_string(), - )); - } let on_left = &on.iter().map(|on| on.0.clone()).collect::>(); let left_missing = on_left.difference(left).collect::>(); diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs index 4421d62a3430..f9828bed7670 100644 --- a/datafusion/tests/sql.rs +++ b/datafusion/tests/sql.rs @@ -1580,8 +1580,6 @@ async fn execute(ctx: &mut ExecutionContext, sql: &str) -> Vec> { let msg = format!("Executing physical plan for '{}': {:?}", sql, plan); let results = collect(plan).await.expect(&msg); - assert_eq!(logical_schema.as_ref(), optimized_logical_schema.as_ref()); - result_vec(&results) } From fff2e1dadba31ea0860fd31c3594ae27285330e8 Mon Sep 17 00:00:00 2001 From: Qingping Hou Date: Sun, 9 May 2021 19:05:16 -0700 Subject: [PATCH 06/25] maintain field order during plan optimization using projections --- .../src/optimizer/hash_build_probe_order.rs | 21 ++++++++++++------- datafusion/tests/sql.rs | 2 ++ 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/datafusion/src/optimizer/hash_build_probe_order.rs b/datafusion/src/optimizer/hash_build_probe_order.rs index 23b5a0e44988..578a20dca8e5 100644 --- a/datafusion/src/optimizer/hash_build_probe_order.rs +++ b/datafusion/src/optimizer/hash_build_probe_order.rs @@ -22,7 +22,7 @@ use std::sync::Arc; -use crate::logical_plan::LogicalPlan; +use crate::logical_plan::{Expr, LogicalPlan, LogicalPlanBuilder}; use crate::optimizer::optimizer::OptimizerRule; use crate::{error::Result, prelude::JoinType}; @@ -147,12 +147,19 @@ impl OptimizerRule for HashBuildProbeOrder { let left = self.optimize(left)?; let right = self.optimize(right)?; if should_swap_join_order(&left, &right) { - // Swap left and right - Ok(LogicalPlan::CrossJoin { - left: Arc::new(right), - right: Arc::new(left), - schema: schema.clone(), - }) + let swapped = LogicalPlanBuilder::from(&right).cross_join(&left)?; + // wrap plan with projection to maintain column order + let left_cols = left + .schema() + .fields() + .iter() + .map(|f| Expr::Column(f.qualified_column())); + let right_cols = right + .schema() + .fields() + .iter() + .map(|f| Expr::Column(f.qualified_column())); + swapped.project(left_cols.chain(right_cols))?.build() } else { // Keep join as is Ok(LogicalPlan::CrossJoin { diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs index f9828bed7670..4421d62a3430 100644 --- a/datafusion/tests/sql.rs +++ b/datafusion/tests/sql.rs @@ -1580,6 +1580,8 @@ async fn execute(ctx: &mut ExecutionContext, sql: &str) -> Vec> { let msg = format!("Executing physical plan for '{}': {:?}", sql, plan); let results = collect(plan).await.expect(&msg); + assert_eq!(logical_schema.as_ref(), optimized_logical_schema.as_ref()); + result_vec(&results) } From eaf1edce5c1bc6ae6ece025bcd108332fe1020aa Mon Sep 17 00:00:00 2001 From: Qingping Hou Date: Fri, 14 May 2021 21:54:32 -0700 Subject: [PATCH 07/25] change TableScane name from Option to String --- datafusion/src/execution/context.rs | 61 +++++----- datafusion/src/execution/dataframe_impl.rs | 18 +-- datafusion/src/logical_plan/builder.rs | 52 ++++----- datafusion/src/logical_plan/dfschema.rs | 1 + datafusion/src/logical_plan/expr.rs | 22 +++- datafusion/src/logical_plan/mod.rs | 10 +- datafusion/src/logical_plan/plan.rs | 12 +- datafusion/src/optimizer/filter_push_down.rs | 8 +- .../src/optimizer/hash_build_probe_order.rs | 4 +- .../src/optimizer/projection_push_down.rs | 2 +- datafusion/src/physical_plan/planner.rs | 107 +++++++++++++++++- datafusion/src/sql/planner.rs | 2 +- datafusion/tests/custom_sources.rs | 11 +- 13 files changed, 205 insertions(+), 105 deletions(-) diff --git a/datafusion/src/execution/context.rs b/datafusion/src/execution/context.rs index c5aebfbdb7b9..5788f5938ffa 100644 --- a/datafusion/src/execution/context.rs +++ b/datafusion/src/execution/context.rs @@ -52,7 +52,7 @@ use crate::datasource::TableProvider; use crate::error::{DataFusionError, Result}; use crate::execution::dataframe_impl::DataFrameImpl; use crate::logical_plan::{ - FunctionRegistry, LogicalPlan, LogicalPlanBuilder, ToDFSchema, + FunctionRegistry, LogicalPlan, LogicalPlanBuilder, UNNAMED_TABLE, }; use crate::optimizer::constant_folding::ConstantFolding; use crate::optimizer::filter_push_down::FilterPushDown; @@ -294,19 +294,9 @@ impl ExecutionContext { &mut self, provider: Arc, ) -> Result> { - // FIMXE: add table name method to table provider? - let schema = provider.schema(); - let table_scan = LogicalPlan::TableScan { - table_name: None, - source: provider, - projected_schema: schema.to_dfschema_ref()?, - projection: None, - filters: vec![], - limit: None, - }; Ok(Arc::new(DataFrameImpl::new( self.state.clone(), - &LogicalPlanBuilder::from(&table_scan).build()?, + &LogicalPlanBuilder::scan(UNNAMED_TABLE, provider, None)?.build()?, ))) } @@ -411,7 +401,7 @@ impl ExecutionContext { match schema.table(table_ref.table()) { Some(ref provider) => { let plan = LogicalPlanBuilder::scan( - Some(table_ref.table()), + &table_ref.table(), Arc::clone(provider), None, )? @@ -1116,8 +1106,11 @@ mod tests { _ => panic!("expect optimized_plan to be projection"), } - let expected = "Projection: #b\ - \n TableScan: projection=Some([1])"; + let expected = format!( + "Projection: #{}.b\ + \n TableScan: {} projection=Some([1])", + UNNAMED_TABLE, UNNAMED_TABLE + ); assert_eq!(format!("{:?}", optimized_plan), expected); let physical_plan = ctx.create_physical_plan(&optimized_plan)?; @@ -1928,7 +1921,7 @@ mod tests { let plan = LogicalPlanBuilder::scan_empty(None, schema.as_ref(), None)? .aggregate(vec![col("c1")], vec![sum(col("c2"))])? - .project(vec![col("c1"), col("SUM(c2)").alias("total_salary")])? + .project(vec![col("c1"), sum(col("c2")).alias("total_salary")])? .build()?; let plan = ctx.optimize(&plan)?; @@ -2024,11 +2017,11 @@ mod tests { .unwrap(); let expected = vec![ - "+-----------+", - "| sqrt(t.i) |", - "+-----------+", - "| 1 |", - "+-----------+", + "+---------+", + "| sqrt(i) |", + "+---------+", + "| 1 |", + "+---------+", ]; let results = plan_and_collect(&mut ctx, "SELECT sqrt(i) FROM t") @@ -2086,11 +2079,11 @@ mod tests { let result = plan_and_collect(&mut ctx, "SELECT \"MY_FUNC\"(i) FROM t").await?; let expected = vec![ - "+--------------+", - "| MY_FUNC(t.i) |", - "+--------------+", - "| 1 |", - "+--------------+", + "+------------+", + "| MY_FUNC(i) |", + "+------------+", + "| 1 |", + "+------------+", ]; assert_batches_eq!(expected, &result); @@ -2385,14 +2378,14 @@ mod tests { let result = collect(plan).await?; let expected = vec![ - "+-----+-----+-----------------+", - "| a | b | my_add(t.a,t.b) |", - "+-----+-----+-----------------+", - "| 1 | 2 | 3 |", - "| 10 | 12 | 22 |", - "| 10 | 12 | 22 |", - "| 100 | 120 | 220 |", - "+-----+-----+-----------------+", + "+-----+-----+-------------+", + "| a | b | my_add(a,b) |", + "+-----+-----+-------------+", + "| 1 | 2 | 3 |", + "| 10 | 12 | 22 |", + "| 10 | 12 | 22 |", + "| 100 | 120 | 220 |", + "+-----+-----+-------------+", ]; assert_batches_eq!(expected, &result); diff --git a/datafusion/src/execution/dataframe_impl.rs b/datafusion/src/execution/dataframe_impl.rs index b15a1b09cbd3..776bce762d00 100644 --- a/datafusion/src/execution/dataframe_impl.rs +++ b/datafusion/src/execution/dataframe_impl.rs @@ -241,15 +241,15 @@ mod tests { assert_batches_sorted_eq!( vec![ - "+----+-----------------------------+-----------------------------+-----------------------------+-----------------------------+-------------------------------+----------------------------------------+", - "| c1 | MIN(aggregate_test_100.c12) | MAX(aggregate_test_100.c12) | AVG(aggregate_test_100.c12) | SUM(aggregate_test_100.c12) | COUNT(aggregate_test_100.c12) | COUNT(DISTINCT aggregate_test_100.c12) |", - "+----+-----------------------------+-----------------------------+-----------------------------+-----------------------------+-------------------------------+----------------------------------------+", - "| a | 0.02182578039211991 | 0.9800193410444061 | 0.48754517466109415 | 10.238448667882977 | 21 | 21 |", - "| b | 0.04893135681998029 | 0.9185813970744787 | 0.41040709263815384 | 7.797734760124923 | 19 | 19 |", - "| c | 0.0494924465469434 | 0.991517828651004 | 0.6600456536439784 | 13.860958726523545 | 21 | 21 |", - "| d | 0.061029375346466685 | 0.9748360509016578 | 0.48855379387549824 | 8.793968289758968 | 18 | 18 |", - "| e | 0.01479305307777301 | 0.9965400387585364 | 0.48600669271341534 | 10.206140546981722 | 21 | 21 |", - "+----+-----------------------------+-----------------------------+-----------------------------+-----------------------------+-------------------------------+----------------------------------------+", + "+----+----------------------+--------------------+---------------------+--------------------+------------+---------------------+", + "| c1 | MIN(c12) | MAX(c12) | AVG(c12) | SUM(c12) | COUNT(c12) | COUNT(DISTINCT c12) |", + "+----+----------------------+--------------------+---------------------+--------------------+------------+---------------------+", + "| a | 0.02182578039211991 | 0.9800193410444061 | 0.48754517466109415 | 10.238448667882977 | 21 | 21 |", + "| b | 0.04893135681998029 | 0.9185813970744787 | 0.41040709263815384 | 7.797734760124923 | 19 | 19 |", + "| c | 0.0494924465469434 | 0.991517828651004 | 0.6600456536439784 | 13.860958726523545 | 21 | 21 |", + "| d | 0.061029375346466685 | 0.9748360509016578 | 0.48855379387549824 | 8.793968289758968 | 18 | 18 |", + "| e | 0.01479305307777301 | 0.9965400387585364 | 0.48600669271341534 | 10.206140546981722 | 21 | 21 |", + "+----+----------------------+--------------------+---------------------+--------------------+------------+---------------------+", ], &df ); diff --git a/datafusion/src/logical_plan/builder.rs b/datafusion/src/logical_plan/builder.rs index 2f54f198d1e9..6c113b51bb84 100644 --- a/datafusion/src/logical_plan/builder.rs +++ b/datafusion/src/logical_plan/builder.rs @@ -34,10 +34,14 @@ use crate::{ use super::dfschema::ToDFSchema; use super::{exprlist_to_fields, Expr, JoinType, LogicalPlan, PlanType, StringifiedPlan}; use crate::logical_plan::{ - normalize_col, normalize_cols, Column, DFField, DFSchema, DFSchemaRef, Partitioning, + columnize_expr, normalize_col, normalize_cols, Column, DFField, DFSchema, + DFSchemaRef, Partitioning, }; use std::collections::HashSet; +/// Default table name for unnamed table +pub const UNNAMED_TABLE: &str = "?table?"; + pub enum JoinConstraint { On, Using, @@ -108,7 +112,7 @@ impl LogicalPlanBuilder { projection: Option>, ) -> Result { let provider = Arc::new(MemTable::try_new(schema, partitions)?); - Self::scan(None, provider, projection) + Self::scan(UNNAMED_TABLE, provider, projection) } /// Scan a CSV data source @@ -118,7 +122,7 @@ impl LogicalPlanBuilder { projection: Option>, ) -> Result { let provider = Arc::new(CsvFile::try_new(path, options)?); - Self::scan(None, provider, projection) + Self::scan(path, provider, projection) } /// Scan a Parquet data source @@ -128,7 +132,7 @@ impl LogicalPlanBuilder { max_concurrency: usize, ) -> Result { let provider = Arc::new(ParquetTable::try_new(path, max_concurrency)?); - Self::scan(None, provider, projection) + Self::scan(path, provider, projection) } /// Scan an empty data source, mainly used in tests @@ -139,21 +143,19 @@ impl LogicalPlanBuilder { ) -> Result { let table_schema = Arc::new(table_schema.clone()); let provider = Arc::new(EmptyTable::new(table_schema)); - Self::scan(name, provider, projection) + Self::scan(name.unwrap_or(UNNAMED_TABLE), provider, projection) } /// Convert a table provider into a builder with a TableScan pub fn scan( - table_name: Option<&str>, + table_name: &str, provider: Arc, projection: Option>, ) -> Result { - if let Some(name) = table_name { - if name.is_empty() { - return Err(DataFusionError::Plan( - "table_name cannot be empty".to_string(), - )); - } + if table_name.is_empty() { + return Err(DataFusionError::Plan( + "table_name cannot be empty".to_string(), + )); } let schema = provider.schema(); @@ -163,33 +165,17 @@ impl LogicalPlanBuilder { .map(|p| DFSchema { fields: p .iter() - .map(|i| match table_name { - // FIXME: move if check outside - Some(name) => { - DFField::from_qualified(name, schema.field(*i).clone()) - } - None => DFField::from(schema.field(*i).clone()), + .map(|i| { + DFField::from_qualified(table_name, schema.field(*i).clone()) }) .collect(), }) .unwrap_or_else(|| { - // FIXME: remove unwrap - match table_name { - Some(name) => DFSchema::try_from_qualified(name, &schema).unwrap(), - None => DFSchema::new( - schema - .fields() - .iter() - .map(|f| DFField::from(f.clone())) - .collect(), - ) - .unwrap(), - } + DFSchema::try_from_qualified(table_name, &schema).unwrap() }); - // FIXME: check for empty table name let table_scan = LogicalPlan::TableScan { - table_name: table_name.clone().map(|s| s.to_string()), + table_name: table_name.to_string(), source: provider, projected_schema: Arc::new(projected_schema), projection, @@ -219,7 +205,7 @@ impl LogicalPlanBuilder { .push(Expr::Column(input_schema.field(i).qualified_column())) }); } - _ => projected_expr.push(normalized_e), + _ => projected_expr.push(columnize_expr(normalized_e, input_schema)), } } diff --git a/datafusion/src/logical_plan/dfschema.rs b/datafusion/src/logical_plan/dfschema.rs index d39697595559..517c02a58b71 100644 --- a/datafusion/src/logical_plan/dfschema.rs +++ b/datafusion/src/logical_plan/dfschema.rs @@ -227,6 +227,7 @@ impl DFSchema { } } + /// Check to see if unqualified field names matches field names in Arrow schema pub fn matches_arrow_schema(&self, arrow_schema: &Schema) -> bool { self.fields .iter() diff --git a/datafusion/src/logical_plan/expr.rs b/datafusion/src/logical_plan/expr.rs index 49d0aec020cc..383ada0854a5 100644 --- a/datafusion/src/logical_plan/expr.rs +++ b/datafusion/src/logical_plan/expr.rs @@ -418,7 +418,7 @@ impl Expr { } } - /// Returns the name of this expression based on [arrow::datatypes::Schema]. + /// Returns the name of this expression based on [crate::logical_plan::DFSchema]. /// /// This represents how a column with this expression is named when no alias is chosen pub fn name(&self, input_schema: &DFSchema) -> Result { @@ -1031,6 +1031,24 @@ pub fn col(ident: &str) -> Expr { Expr::Column(ident.into()) } +/// Convert an expression into Column expression if it's already provided as input +pub fn columnize_expr(e: Expr, input_schema: &DFSchema) -> Expr { + match e { + Expr::Column(_) => e, + Expr::Alias(inner_expr, name) => { + Expr::Alias(Box::new(columnize_expr(*inner_expr, input_schema)), name) + } + _ => match e.name(input_schema) { + Ok(name) => match input_schema.field_with_unqualified_name(&name) { + Ok(field) => Expr::Column(field.qualified_column()), + // expression not provided as input, do not convert to a column reference + Err(_) => e, + }, + Err(_) => e, + }, + } +} + /// Recursively normalize all Column expressions in a given expression tree pub fn normalize_col(e: Expr, schemas: &[&DFSchemaRef]) -> Result { struct ColumnNormalizer<'a, 'b> { @@ -1519,7 +1537,7 @@ fn create_name(e: &Expr, input_schema: &DFSchema) -> Result { } } other => Err(DataFusionError::NotImplemented(format!( - "Physical plan does not support logical expression {:?}", + "Logical plan does not support logical expression {:?}", other ))), } diff --git a/datafusion/src/logical_plan/mod.rs b/datafusion/src/logical_plan/mod.rs index ba4426d0351c..5bcf8ec765be 100644 --- a/datafusion/src/logical_plan/mod.rs +++ b/datafusion/src/logical_plan/mod.rs @@ -29,15 +29,15 @@ mod extension; mod operators; mod plan; mod registry; -pub use builder::{union_with_alias, LogicalPlanBuilder}; +pub use builder::{union_with_alias, LogicalPlanBuilder, UNNAMED_TABLE}; pub use dfschema::{DFField, DFSchema, DFSchemaRef, ToDFSchema}; pub use display::display_schema; pub use expr::{ abs, acos, and, array, ascii, asin, atan, avg, binary_expr, bit_length, btrim, case, - ceil, character_length, chr, col, combine_filters, concat, concat_ws, cos, count, - count_distinct, create_udaf, create_udf, exp, exprlist_to_fields, floor, in_list, - initcap, left, length, lit, ln, log10, log2, lower, lpad, ltrim, max, md5, min, - normalize_col, normalize_cols, octet_length, or, regexp_match, regexp_replace, + ceil, character_length, chr, col, columnize_expr, combine_filters, concat, concat_ws, + cos, count, count_distinct, create_udaf, create_udf, exp, exprlist_to_fields, floor, + in_list, initcap, left, length, lit, ln, log10, log2, lower, lpad, ltrim, max, md5, + min, normalize_col, normalize_cols, octet_length, or, regexp_match, regexp_replace, repeat, replace, reverse, right, round, rpad, rtrim, sha224, sha256, sha384, sha512, signum, sin, split_part, sqrt, starts_with, strpos, substr, sum, tan, to_hex, translate, trim, trunc, upper, when, Column, Expr, ExprRewriter, ExpressionVisitor, diff --git a/datafusion/src/logical_plan/plan.rs b/datafusion/src/logical_plan/plan.rs index a7df24f0da74..4685aca67120 100644 --- a/datafusion/src/logical_plan/plan.rs +++ b/datafusion/src/logical_plan/plan.rs @@ -137,7 +137,7 @@ pub enum LogicalPlan { /// Produces rows from a table provider by reference or from the context TableScan { /// The name of the table - table_name: Option, + table_name: String, /// The source of the table source: Arc, /// Optional column indices to use as a projection @@ -629,16 +629,10 @@ impl LogicalPlan { ref limit, .. } => { - let sep = match table_name { - Some(_) => " ", - None => "", - }; write!( f, - "TableScan: {}{}projection={:?}", - table_name.as_ref().map(|s| s.as_str()).unwrap_or(""), - sep, - projection + "TableScan: {} projection={:?}", + table_name, projection )?; if !filters.is_empty() { diff --git a/datafusion/src/optimizer/filter_push_down.rs b/datafusion/src/optimizer/filter_push_down.rs index 027a7360a322..adf7e289ad55 100644 --- a/datafusion/src/optimizer/filter_push_down.rs +++ b/datafusion/src/optimizer/filter_push_down.rs @@ -1013,7 +1013,7 @@ mod tests { let test_provider = PushDownProvider { filter_support }; let table_scan = LogicalPlan::TableScan { - table_name: None, + table_name: "test".to_string(), filters: vec![], projected_schema: Arc::new(DFSchema::try_from( (*test_provider.schema()).clone(), @@ -1033,7 +1033,7 @@ mod tests { let plan = table_scan_with_pushdown_provider(TableProviderFilterPushDown::Exact)?; let expected = "\ - TableScan: projection=None, filters=[#a Eq Int64(1)]"; + TableScan: test projection=None, filters=[#a Eq Int64(1)]"; assert_optimized_plan_eq(&plan, expected); Ok(()) } @@ -1045,7 +1045,7 @@ mod tests { let expected = "\ Filter: #a Eq Int64(1)\ - \n TableScan: projection=None, filters=[#a Eq Int64(1)]"; + \n TableScan: test projection=None, filters=[#a Eq Int64(1)]"; assert_optimized_plan_eq(&plan, expected); Ok(()) } @@ -1057,7 +1057,7 @@ mod tests { let expected = "\ Filter: #a Eq Int64(1)\ - \n TableScan: projection=None"; + \n TableScan: test projection=None"; assert_optimized_plan_eq(&plan, expected); Ok(()) } diff --git a/datafusion/src/optimizer/hash_build_probe_order.rs b/datafusion/src/optimizer/hash_build_probe_order.rs index 578a20dca8e5..eeca46a7b3cb 100644 --- a/datafusion/src/optimizer/hash_build_probe_order.rs +++ b/datafusion/src/optimizer/hash_build_probe_order.rs @@ -265,7 +265,7 @@ mod tests { #[test] fn test_swap_order() { let lp_left = LogicalPlan::TableScan { - table_name: Some("left".to_string()), + table_name: "left".to_string(), projection: None, source: Arc::new(TestTableProvider { num_rows: 1000 }), projected_schema: Arc::new(DFSchema::empty()), @@ -274,7 +274,7 @@ mod tests { }; let lp_right = LogicalPlan::TableScan { - table_name: Some("right".to_string()), + table_name: "right".to_string(), projection: None, source: Arc::new(TestTableProvider { num_rows: 100 }), projected_schema: Arc::new(DFSchema::empty()), diff --git a/datafusion/src/optimizer/projection_push_down.rs b/datafusion/src/optimizer/projection_push_down.rs index 5eb1e3be6680..9b9be1572255 100644 --- a/datafusion/src/optimizer/projection_push_down.rs +++ b/datafusion/src/optimizer/projection_push_down.rs @@ -257,7 +257,7 @@ fn optimize_plan( .. } => { let (projection, projected_schema) = get_projected_schema( - table_name.as_ref(), + Some(&table_name), &source.schema(), required_columns, has_projection, diff --git a/datafusion/src/physical_plan/planner.rs b/datafusion/src/physical_plan/planner.rs index 7e38980baa21..2ee43914db88 100644 --- a/datafusion/src/physical_plan/planner.rs +++ b/datafusion/src/physical_plan/planner.rs @@ -52,11 +52,114 @@ use arrow::datatypes::{Schema, SchemaRef}; use expressions::col; use log::debug; +fn create_function_physical_name( + fun: &str, + distinct: bool, + args: &[Expr], + input_schema: &DFSchema, +) -> Result { + let names: Vec = args + .iter() + .map(|e| physical_name(e, input_schema)) + .collect::>()?; + let distinct_str = match distinct { + true => "DISTINCT ", + false => "", + }; + Ok(format!("{}({}{})", fun, distinct_str, names.join(","))) +} + fn physical_name(e: &Expr, input_schema: &DFSchema) -> Result { - // FIXME: finish this match e { Expr::Column(c) => Ok(c.name.clone()), - _ => e.name(&input_schema), + Expr::Alias(_, name) => Ok(name.clone()), + Expr::ScalarVariable(variable_names) => Ok(variable_names.join(".")), + Expr::Literal(value) => Ok(format!("{:?}", value)), + Expr::BinaryExpr { left, op, right } => { + let left = physical_name(left, input_schema)?; + let right = physical_name(right, input_schema)?; + Ok(format!("{} {:?} {}", left, op, right)) + } + Expr::Case { + expr, + when_then_expr, + else_expr, + } => { + let mut name = "CASE ".to_string(); + if let Some(e) = expr { + name += &format!("{:?} ", e); + } + for (w, t) in when_then_expr { + name += &format!("WHEN {:?} THEN {:?} ", w, t); + } + if let Some(e) = else_expr { + name += &format!("ELSE {:?} ", e); + } + name += "END"; + Ok(name) + } + Expr::Cast { expr, data_type } => { + let expr = physical_name(expr, input_schema)?; + Ok(format!("CAST({} AS {:?})", expr, data_type)) + } + Expr::TryCast { expr, data_type } => { + let expr = physical_name(expr, input_schema)?; + Ok(format!("TRY_CAST({} AS {:?})", expr, data_type)) + } + Expr::Not(expr) => { + let expr = physical_name(expr, input_schema)?; + Ok(format!("NOT {}", expr)) + } + Expr::Negative(expr) => { + let expr = physical_name(expr, input_schema)?; + Ok(format!("(- {})", expr)) + } + Expr::IsNull(expr) => { + let expr = physical_name(expr, input_schema)?; + Ok(format!("{} IS NULL", expr)) + } + Expr::IsNotNull(expr) => { + let expr = physical_name(expr, input_schema)?; + Ok(format!("{} IS NOT NULL", expr)) + } + Expr::ScalarFunction { fun, args, .. } => { + create_function_physical_name(&fun.to_string(), false, args, input_schema) + } + Expr::ScalarUDF { fun, args, .. } => { + create_function_physical_name(&fun.name, false, args, input_schema) + } + Expr::AggregateFunction { + fun, + distinct, + args, + .. + } => { + create_function_physical_name(&fun.to_string(), *distinct, args, input_schema) + } + Expr::AggregateUDF { fun, args } => { + let mut names = Vec::with_capacity(args.len()); + for e in args { + names.push(physical_name(e, input_schema)?); + } + Ok(format!("{}({})", fun.name, names.join(","))) + } + Expr::InList { + expr, + list, + negated, + } => { + let expr = physical_name(expr, input_schema)?; + let list = list.iter().map(|expr| physical_name(expr, input_schema)); + if *negated { + Ok(format!("{} NOT IN ({:?})", expr, list)) + } else { + Ok(format!("{} IN ({:?})", expr, list)) + } + } + other => Err(DataFusionError::NotImplemented(format!( + "Physical plan does not support logical expression {:?}", + other + ))), } } diff --git a/datafusion/src/sql/planner.rs b/datafusion/src/sql/planner.rs index 34eef538a3d2..7efbfccd5bfb 100644 --- a/datafusion/src/sql/planner.rs +++ b/datafusion/src/sql/planner.rs @@ -417,7 +417,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { alias .as_ref() .map(|a| a.name.value.as_str()) - .or(Some(&table_name)), + .unwrap_or(&table_name), provider, None, )? diff --git a/datafusion/tests/custom_sources.rs b/datafusion/tests/custom_sources.rs index a00dd6ac2821..3567c019eb94 100644 --- a/datafusion/tests/custom_sources.rs +++ b/datafusion/tests/custom_sources.rs @@ -27,7 +27,9 @@ use datafusion::{ }; use datafusion::execution::context::ExecutionContext; -use datafusion::logical_plan::{col, Expr, LogicalPlan, LogicalPlanBuilder}; +use datafusion::logical_plan::{ + col, Expr, LogicalPlan, LogicalPlanBuilder, UNNAMED_TABLE, +}; use datafusion::physical_plan::{ ExecutionPlan, Partitioning, RecordBatchStream, SendableRecordBatchStream, }; @@ -181,8 +183,11 @@ async fn custom_source_dataframe() -> Result<()> { _ => panic!("expect optimized_plan to be projection"), } - let expected = "Projection: #c2\ - \n TableScan: projection=Some([1])"; + let expected = format!( + "Projection: #{}.c2\ + \n TableScan: {} projection=Some([1])", + UNNAMED_TABLE, UNNAMED_TABLE + ); assert_eq!(format!("{:?}", optimized_plan), expected); let physical_plan = ctx.create_physical_plan(&optimized_plan)?; From 7f253c76a9ec1145ea30b6d8731fdb49d2dc056d Mon Sep 17 00:00:00 2001 From: Qingping Hou Date: Sun, 16 May 2021 16:27:45 -0700 Subject: [PATCH 08/25] WIP: fix ballista --- ballista/rust/core/proto/ballista.proto | 41 ++++++++---- .../core/src/serde/logical_plan/from_proto.rs | 67 +++++++++++++------ .../rust/core/src/serde/logical_plan/mod.rs | 14 ++-- .../core/src/serde/logical_plan/to_proto.rs | 54 +++++++++++---- ballista/rust/core/src/serde/mod.rs | 11 +++ .../src/serde/physical_plan/from_proto.rs | 23 +++++-- .../rust/core/src/serde/physical_plan/mod.rs | 32 ++++----- .../core/src/serde/physical_plan/to_proto.rs | 17 +++-- 8 files changed, 181 insertions(+), 78 deletions(-) diff --git a/ballista/rust/core/proto/ballista.proto b/ballista/rust/core/proto/ballista.proto index 07419d09b7a9..873454516e6f 100644 --- a/ballista/rust/core/proto/ballista.proto +++ b/ballista/rust/core/proto/ballista.proto @@ -28,11 +28,29 @@ option java_outer_classname = "BallistaProto"; // Ballista Logical Plan /////////////////////////////////////////////////////////////////////////////////////////////////// +message ColumnRelation { + string relation = 1; +} + +message Column { + string name = 1; + ColumnRelation relation = 2; +} + +message DfField{ + Field field = 1; + ColumnRelation qualifier = 2; +} + +message DfSchema { + repeated DfField columns = 1; +} + // logical expressions message LogicalExprNode { oneof ExprType { // column references - string column_name = 1; + Column column = 1; // alias AliasNode alias = 2; @@ -263,7 +281,7 @@ message CreateExternalTableNode{ string location = 2; FileType file_type = 3; bool has_header = 4; - Schema schema = 5; + DfSchema schema = 5; } enum FileType{ @@ -277,11 +295,6 @@ message ExplainNode{ bool verbose = 2; } -message DfField{ - string qualifier = 2; - Field field = 1; -} - message AggregateNode { LogicalPlanNode input = 1; repeated LogicalExprNode group_expr = 2; @@ -299,8 +312,8 @@ message JoinNode { LogicalPlanNode left = 1; LogicalPlanNode right = 2; JoinType join_type = 3; - repeated string left_join_column = 4; - repeated string right_join_column = 5; + repeated Column left_join_column = 4; + repeated Column right_join_column = 5; } message LimitNode { @@ -376,11 +389,15 @@ message HashJoinExecNode { } -message JoinOn { - string left = 1; - string right = 2; +message PhysicalColumn { + string name = 1; + uint32 index = 2; } +message JoinOn { + PhysicalColumn left = 1; + PhysicalColumn right = 2; +} message EmptyExecNode { bool produce_one_row = 1; diff --git a/ballista/rust/core/src/serde/logical_plan/from_proto.rs b/ballista/rust/core/src/serde/logical_plan/from_proto.rs index 6987035394c6..9353e9aca75a 100644 --- a/ballista/rust/core/src/serde/logical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/from_proto.rs @@ -19,6 +19,7 @@ use std::{ convert::{From, TryInto}, + sync::Arc, unimplemented, }; @@ -29,7 +30,8 @@ use crate::{convert_box_required, convert_required}; use arrow::datatypes::{DataType, Field, Schema}; use datafusion::logical_plan::{ abs, acos, asin, atan, ceil, cos, exp, floor, ln, log10, log2, round, signum, sin, - sqrt, tan, trunc, Expr, JoinType, LogicalPlan, LogicalPlanBuilder, Operator, + sqrt, tan, trunc, Column, DFField, DFSchema, Expr, JoinType, LogicalPlan, + LogicalPlanBuilder, Operator, }; use datafusion::physical_plan::aggregates::AggregateFunction; use datafusion::physical_plan::csv::CsvReadOptions; @@ -101,8 +103,8 @@ impl TryInto for &protobuf::LogicalPlanNode { .has_header(scan.has_header); let mut projection = None; - if let Some(column_names) = &scan.projection { - let column_indices = column_names + if let Some(columns) = &scan.projection { + let column_indices = columns .columns .iter() .map(|name| schema.index_of(name)) @@ -220,10 +222,10 @@ impl TryInto for &protobuf::LogicalPlanNode { .map_err(|e| e.into()) } LogicalPlanType::Join(join) => { - let left_keys: Vec<&str> = - join.left_join_column.iter().map(|i| i.as_str()).collect(); - let right_keys: Vec<&str> = - join.right_join_column.iter().map(|i| i.as_str()).collect(); + let left_keys: Vec = + join.left_join_column.iter().map(|i| i.into()).collect(); + let right_keys: Vec = + join.right_join_column.iter().map(|i| i.into()).collect(); let join_type = protobuf::JoinType::from_i32(join.join_type).ok_or_else(|| { proto_error(format!( @@ -241,8 +243,8 @@ impl TryInto for &protobuf::LogicalPlanNode { .join( &convert_box_required!(join.right)?, join_type, - &left_keys, - &right_keys, + left_keys, + right_keys, )? .build() .map_err(|e| e.into()) @@ -251,22 +253,47 @@ impl TryInto for &protobuf::LogicalPlanNode { } } -impl TryInto for protobuf::Schema { +impl From<&protobuf::Column> for Column { + fn from(c: &protobuf::Column) -> Column { + Column { + relation: c.relation.map(|r| r.relation), + name: c.name, + } + } +} + +impl TryInto for &protobuf::DfSchema { type Error = BallistaError; - fn try_into(self) -> Result { - let schema: Schema = (&self).try_into()?; - schema.try_into().map_err(BallistaError::DataFusionError) + + fn try_into(self) -> Result { + let fields = self + .columns + .iter() + .map(|c| c.try_into()) + .collect::, _>>()?; + Ok(DFSchema::new(fields)?) } } -impl TryInto for protobuf::Schema { +impl TryInto for protobuf::DfSchema { type Error = BallistaError; + fn try_into(self) -> Result { - use datafusion::logical_plan::ToDFSchema; - let schema: Schema = (&self).try_into()?; - schema - .to_dfschema_ref() - .map_err(BallistaError::DataFusionError) + let dfschema: DFSchema = (&self).try_into()?; + Ok(Arc::new(dfschema)) + } +} + +impl TryInto for &protobuf::DfField { + type Error = BallistaError; + + fn try_into(self) -> Result { + let field: Field = convert_required!(self.field)?; + + Ok(match self.qualifier { + Some(q) => DFField::from_qualified(&q.relation, field), + None => DFField::from(field), + }) } } @@ -883,7 +910,7 @@ impl TryInto for &protobuf::LogicalExprNode { op: from_proto_binary_op(&binary_expr.op)?, right: Box::new(parse_required_expr(&binary_expr.r)?), }), - ExprType::ColumnName(column_name) => Ok(Expr::Column(column_name.to_owned())), + ExprType::Column(column) => Ok(Expr::Column(column.into())), ExprType::Literal(literal) => { use datafusion::scalar::ScalarValue; let scalar_value: datafusion::scalar::ScalarValue = literal.try_into()?; diff --git a/ballista/rust/core/src/serde/logical_plan/mod.rs b/ballista/rust/core/src/serde/logical_plan/mod.rs index 48dd96c4d3f3..2684c3122189 100644 --- a/ballista/rust/core/src/serde/logical_plan/mod.rs +++ b/ballista/rust/core/src/serde/logical_plan/mod.rs @@ -28,7 +28,7 @@ mod roundtrip_tests { use core::panic; use datafusion::physical_plan::functions::BuiltinScalarFunction::Sqrt; use datafusion::{ - logical_plan::{Expr, LogicalPlan, LogicalPlanBuilder}, + logical_plan::{col, Expr, LogicalPlan, LogicalPlanBuilder}, physical_plan::csv::CsvReadOptions, prelude::*, scalar::ScalarValue, @@ -63,10 +63,8 @@ mod roundtrip_tests { let test_batch_sizes = [usize::MIN, usize::MAX, 43256]; - let test_expr: Vec = vec![ - Expr::Column("c1".to_string()) + Expr::Column("c2".to_string()), - Expr::Literal((4.0).into()), - ]; + let test_expr: Vec = + vec![col("c1") + col("c2"), Expr::Literal((4.0).into())]; let schema = Schema::new(vec![ Field::new("id", DataType::Int32, false), @@ -719,7 +717,7 @@ mod roundtrip_tests { CsvReadOptions::new().schema(&schema).has_header(true), Some(vec![3, 4]), ) - .and_then(|plan| plan.join(&scan_plan, JoinType::Inner, &["id"], &["id"])) + .and_then(|plan| plan.join(&scan_plan, JoinType::Inner, vec!["id"], vec!["id"])) .and_then(|plan| plan.build()) .map_err(BallistaError::DataFusionError)?; @@ -806,7 +804,7 @@ mod roundtrip_tests { #[test] fn roundtrip_is_null() -> Result<()> { - let test_expr = Expr::IsNull(Box::new(Expr::Column("id".into()))); + let test_expr = Expr::IsNull(Box::new(col("id"))); roundtrip_test!(test_expr, protobuf::LogicalExprNode, Expr); @@ -816,7 +814,7 @@ mod roundtrip_tests { #[test] fn roundtrip_is_not_null() -> Result<()> { - let test_expr = Expr::IsNotNull(Box::new(Expr::Column("id".into()))); + let test_expr = Expr::IsNotNull(Box::new(col("id"))); roundtrip_test!(test_expr, protobuf::LogicalExprNode, Expr); diff --git a/ballista/rust/core/src/serde/logical_plan/to_proto.rs b/ballista/rust/core/src/serde/logical_plan/to_proto.rs index 01b669d26446..b5b012d70203 100644 --- a/ballista/rust/core/src/serde/logical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/to_proto.rs @@ -29,7 +29,7 @@ use crate::serde::{protobuf, BallistaError}; use arrow::datatypes::{DataType, Schema}; use datafusion::datasource::CsvFile; -use datafusion::logical_plan::{Expr, JoinType, LogicalPlan}; +use datafusion::logical_plan::{Column, Expr, JoinType, LogicalPlan}; use datafusion::physical_plan::aggregates::AggregateFunction; use datafusion::{datasource::parquet::ParquetTable, logical_plan::exprlist_to_fields}; use protobuf::{ @@ -810,8 +810,8 @@ impl TryInto for &LogicalPlan { JoinType::Right => protobuf::JoinType::Right, JoinType::Full => protobuf::JoinType::Full, }; - let left_join_column = on.iter().map(|on| on.0.to_owned()).collect(); - let right_join_column = on.iter().map(|on| on.1.to_owned()).collect(); + let left_join_column = on.iter().map(|on| on.0.into()).collect(); + let right_join_column = on.iter().map(|on| on.1.into()).collect(); Ok(protobuf::LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Join(Box::new( protobuf::JoinNode { @@ -902,13 +902,6 @@ impl TryInto for &LogicalPlan { schema: df_schema, } => { use datafusion::sql::parser::FileType; - let schema: Schema = df_schema.as_ref().clone().into(); - let pb_schema: protobuf::Schema = (&schema).try_into().map_err(|e| { - BallistaError::General(format!( - "Could not convert schema into protobuf: {:?}", - e - )) - })?; let pb_file_type: protobuf::FileType = match file_type { FileType::NdJson => protobuf::FileType::NdJson, @@ -923,7 +916,7 @@ impl TryInto for &LogicalPlan { location: location.clone(), file_type: pb_file_type as i32, has_header: *has_header, - schema: Some(pb_schema), + schema: Some(df_schema.into()), }, )), }) @@ -965,9 +958,9 @@ impl TryInto for &Expr { use datafusion::scalar::ScalarValue; use protobuf::scalar_value::Value; match self { - Expr::Column(name) => { + Expr::Column(c) => { let expr = protobuf::LogicalExprNode { - expr_type: Some(ExprType::ColumnName(name.clone())), + expr_type: Some(ExprType::Column(c.into())), }; Ok(expr) } @@ -1165,6 +1158,23 @@ impl TryInto for &Expr { } } +impl From for protobuf::Column { + fn from(c: Column) -> protobuf::Column { + protobuf::Column { + relation: c + .relation + .map(|relation| protobuf::ColumnRelation { relation }), + name: c.name, + } + } +} + +impl From<&Column> for protobuf::Column { + fn from(c: &Column) -> protobuf::Column { + c.clone().into() + } +} + #[allow(clippy::from_over_into)] impl Into for &Schema { fn into(self) -> protobuf::Schema { @@ -1178,6 +1188,24 @@ impl Into for &Schema { } } +impl From<&datafusion::logical_plan::DFField> for protobuf::DfField { + fn from(f: &datafusion::logical_plan::DFField) -> protobuf::DfField { + protobuf::DfField { + field: Some(f.field().into()), + qualifier: f.qualifier().map(|r| protobuf::ColumnRelation { + relation: r.to_string(), + }), + } + } +} + +impl From<&datafusion::logical_plan::DFSchemaRef> for protobuf::DfSchema { + fn from(s: &datafusion::logical_plan::DFSchemaRef) -> protobuf::DfSchema { + let columns = s.fields().iter().map(|f| f.into()).collect::>(); + protobuf::DfSchema { columns } + } +} + impl TryFrom<&arrow::datatypes::DataType> for protobuf::ScalarType { type Error = BallistaError; fn try_from(value: &arrow::datatypes::DataType) -> Result { diff --git a/ballista/rust/core/src/serde/mod.rs b/ballista/rust/core/src/serde/mod.rs index b96163999f39..58a0c18707ee 100644 --- a/ballista/rust/core/src/serde/mod.rs +++ b/ballista/rust/core/src/serde/mod.rs @@ -57,6 +57,17 @@ macro_rules! convert_required { }}; } +#[macro_export] +macro_rules! into_required { + ($PB:expr) => {{ + if let Some(field) = $PB.as_ref() { + Ok(field.into()) + } else { + Err(proto_error("Missing required field in protobuf")) + } + }}; +} + #[macro_export] macro_rules! convert_box_required { ($PB:expr) => {{ diff --git a/ballista/rust/core/src/serde/physical_plan/from_proto.rs b/ballista/rust/core/src/serde/physical_plan/from_proto.rs index 9c35c9d88941..69b48600e8b2 100644 --- a/ballista/rust/core/src/serde/physical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/physical_plan/from_proto.rs @@ -27,7 +27,7 @@ use crate::serde::protobuf::repartition_exec_node::PartitionMethod; use crate::serde::protobuf::LogicalExprNode; use crate::serde::scheduler::PartitionLocation; use crate::serde::{proto_error, protobuf}; -use crate::{convert_box_required, convert_required}; +use crate::{convert_box_required, convert_required, into_required}; use arrow::datatypes::{DataType, Schema, SchemaRef}; use datafusion::catalog::catalog::{ @@ -284,11 +284,15 @@ impl TryInto> for &protobuf::PhysicalPlanNode { let left: Arc = convert_box_required!(hashjoin.left)?; let right: Arc = convert_box_required!(hashjoin.right)?; - let on: Vec<(String, String)> = hashjoin + let on: Vec<(Column, Column)> = hashjoin .on .iter() - .map(|col| (col.left.clone(), col.right.clone())) - .collect(); + .map(|col| { + let left = into_required!(col.left)?; + let right = into_required!(col.right)?; + Ok((left, right)) + }) + .collect::>()?; let join_type = protobuf::JoinType::from_i32(hashjoin.join_type) .ok_or_else(|| { proto_error(format!( @@ -305,7 +309,7 @@ impl TryInto> for &protobuf::PhysicalPlanNode { Ok(Arc::new(HashJoinExec::try_new( left, right, - &on, + on, &join_type, PartitionMode::CollectLeft, )?)) @@ -382,6 +386,15 @@ impl TryInto> for &protobuf::PhysicalPlanNode { } } +impl From<&protobuf::PhysicalColumn> for Column { + fn from(c: &protobuf::PhysicalColumn) -> Column { + Column { + index: c.index as usize, + name: c.name, + } + } +} + fn compile_expr( expr: &protobuf::LogicalExprNode, schema: &Schema, diff --git a/ballista/rust/core/src/serde/physical_plan/mod.rs b/ballista/rust/core/src/serde/physical_plan/mod.rs index e7985cc84a9a..78826027dc62 100644 --- a/ballista/rust/core/src/serde/physical_plan/mod.rs +++ b/ballista/rust/core/src/serde/physical_plan/mod.rs @@ -27,7 +27,7 @@ mod roundtrip_tests { use datafusion::physical_plan::ColumnarValue; use datafusion::physical_plan::{ empty::EmptyExec, - expressions::{Avg, Column, PhysicalSortExpr}, + expressions::{col, Avg, Column, PhysicalSortExpr}, hash_aggregate::{AggregateMode, HashAggregateExec}, hash_join::HashJoinExec, limit::{GlobalLimitExec, LocalLimitExec}, @@ -79,36 +79,36 @@ mod roundtrip_tests { let field_a = Field::new("col", DataType::Int64, false); let schema_left = Schema::new(vec![field_a.clone()]); let schema_right = Schema::new(vec![field_a]); + let on = vec![( + Column::new("col", schema_left.index_of("col")?), + Column::new("col", schema_right.index_of("col")?), + )]; roundtrip_test(Arc::new(HashJoinExec::try_new( Arc::new(EmptyExec::new(false, Arc::new(schema_left))), Arc::new(EmptyExec::new(false, Arc::new(schema_right))), - &[("col".to_string(), "col".to_string())], + on, &JoinType::Inner, PartitionMode::CollectLeft, )?)) } - fn col(name: &str) -> Arc { - Arc::new(Column::new(name)) - } - #[test] fn rountrip_hash_aggregate() -> Result<()> { use arrow::datatypes::{DataType, Field, Schema}; + let field_a = Field::new("a", DataType::Int64, false); + let field_b = Field::new("b", DataType::Int64, false); + let schema = Arc::new(Schema::new(vec![field_a, field_b])); + let groups: Vec<(Arc, String)> = - vec![(col("a"), "unused".to_string())]; + vec![(col("a", &schema)?, "unused".to_string())]; let aggregates: Vec> = vec![Arc::new(Avg::new( - col("b"), + col("b", &schema)?, "AVG(b)".to_string(), DataType::Float64, ))]; - let field_a = Field::new("a", DataType::Int64, false); - let field_b = Field::new("b", DataType::Int64, false); - let schema = Arc::new(Schema::new(vec![field_a, field_b])); - roundtrip_test(Arc::new(HashAggregateExec::try_new( AggregateMode::Final, groups.clone(), @@ -131,9 +131,9 @@ mod roundtrip_tests { let field_b = Field::new("b", DataType::Int64, false); let field_c = Field::new("c", DataType::Int64, false); let schema = Arc::new(Schema::new(vec![field_a, field_b, field_c])); - let not = Arc::new(NotExpr::new(col("a"))); + let not = Arc::new(NotExpr::new(col("a", &schema)?)); let in_list = Arc::new(InListExpr::new( - col("b"), + col("b", &schema)?, vec![ lit(ScalarValue::Int64(Some(1))), lit(ScalarValue::Int64(Some(2))), @@ -156,14 +156,14 @@ mod roundtrip_tests { let schema = Arc::new(Schema::new(vec![field_a, field_b])); let sort_exprs = vec![ PhysicalSortExpr { - expr: col("a"), + expr: col("a", &schema)?, options: SortOptions { descending: true, nulls_first: false, }, }, PhysicalSortExpr { - expr: col("b"), + expr: col("b", &schema)?, options: SortOptions { descending: false, nulls_first: true, diff --git a/ballista/rust/core/src/serde/physical_plan/to_proto.rs b/ballista/rust/core/src/serde/physical_plan/to_proto.rs index 8a5fd71083f7..d771845022d2 100644 --- a/ballista/rust/core/src/serde/physical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/physical_plan/to_proto.rs @@ -124,8 +124,14 @@ impl TryInto for Arc { .on() .iter() .map(|tuple| protobuf::JoinOn { - left: tuple.0.to_owned(), - right: tuple.1.to_owned(), + left: Some(protobuf::PhysicalColumn { + name: tuple.0.name().to_string(), + index: tuple.0.index() as u32, + }), + right: Some(protobuf::PhysicalColumn { + name: tuple.1.name().to_string(), + index: tuple.1.index() as u32, + }), }) .collect(); let join_type = match exec.join_type() { @@ -404,8 +410,11 @@ impl TryFrom> for protobuf::LogicalExprNode { if let Some(expr) = expr.downcast_ref::() { Ok(protobuf::LogicalExprNode { - expr_type: Some(protobuf::logical_expr_node::ExprType::ColumnName( - expr.name().to_owned(), + expr_type: Some(protobuf::logical_expr_node::ExprType::Column( + protobuf::Column { + name: expr.name().to_owned(), + relation: None, + }, )), }) } else if let Some(expr) = expr.downcast_ref::() { From 5c413ddf65faea0282c3ffe6f40ab34cc493152e Mon Sep 17 00:00:00 2001 From: Qingping Hou Date: Sun, 23 May 2021 19:30:59 -0700 Subject: [PATCH 09/25] separate logical and physical expressions in proto, fix ballista build --- ballista/rust/core/proto/ballista.proto | 120 ++++++- .../core/src/serde/logical_plan/from_proto.rs | 183 +--------- .../core/src/serde/logical_plan/to_proto.rs | 4 +- ballista/rust/core/src/serde/mod.rs | 182 ++++++++++ .../src/serde/physical_plan/from_proto.rs | 310 ++++++++++++----- .../core/src/serde/physical_plan/to_proto.rs | 151 ++++---- .../src/physical_plan/expressions/mod.rs | 2 +- datafusion/src/physical_plan/functions.rs | 329 +++++++++--------- 8 files changed, 773 insertions(+), 508 deletions(-) diff --git a/ballista/rust/core/proto/ballista.proto b/ballista/rust/core/proto/ballista.proto index 873454516e6f..efc8bbce2b9d 100644 --- a/ballista/rust/core/proto/ballista.proto +++ b/ballista/rust/core/proto/ballista.proto @@ -350,6 +350,107 @@ message PhysicalPlanNode { } } +// physical expressions +message PhysicalExprNode { + oneof ExprType { + // column references + PhysicalColumn column = 1; + + ScalarValue literal = 2; + + // binary expressions + PhysicalBinaryExprNode binary_expr = 3; + + // aggregate expressions + PhysicalAggregateExprNode aggregate_expr = 4; + + // null checks + PhysicalIsNull is_null_expr = 5; + PhysicalIsNotNull is_not_null_expr = 6; + PhysicalNot not_expr = 7; + + PhysicalCaseNode case_ = 8; + PhysicalCastNode cast = 9; + PhysicalSortExprNode sort = 10; + PhysicalNegativeNode negative = 11; + PhysicalInListNode in_list = 12; + PhysicalScalarFunctionNode scalar_function = 13; + PhysicalTryCastNode try_cast = 14; + } +} + +message PhysicalAggregateExprNode { + AggregateFunction aggr_function = 1; + PhysicalExprNode expr = 2; +} + +message PhysicalIsNull { + PhysicalExprNode expr = 1; +} + +message PhysicalIsNotNull { + PhysicalExprNode expr = 1; +} + +message PhysicalNot { + PhysicalExprNode expr = 1; +} + +message PhysicalAliasNode { + PhysicalExprNode expr = 1; + string alias = 2; +} + +message PhysicalBinaryExprNode { + PhysicalExprNode l = 1; + PhysicalExprNode r = 2; + string op = 3; +} + +message PhysicalSortExprNode { + PhysicalExprNode expr = 1; + bool asc = 2; + bool nulls_first = 3; +} + +message PhysicalWhenThen { + PhysicalExprNode when_expr = 1; + PhysicalExprNode then_expr = 2; +} + +message PhysicalInListNode { + PhysicalExprNode expr = 1; + repeated PhysicalExprNode list = 2; + bool negated = 3; +} + +message PhysicalCaseNode { + PhysicalExprNode expr = 1; + repeated PhysicalWhenThen when_then_expr = 2; + PhysicalExprNode else_expr = 3; +} + +message PhysicalScalarFunctionNode { + string name = 1; + ScalarFunction fun = 2; + repeated PhysicalExprNode args = 3; + ArrowType return_type = 4; +} + +message PhysicalTryCastNode { + PhysicalExprNode expr = 1; + ArrowType arrow_type = 2; +} + +message PhysicalCastNode { + PhysicalExprNode expr = 1; + ArrowType arrow_type = 2; +} + +message PhysicalNegativeNode { + PhysicalExprNode expr = 1; +} + message UnresolvedShuffleExecNode { repeated uint32 query_stage_ids = 1; Schema schema = 2; @@ -358,7 +459,7 @@ message UnresolvedShuffleExecNode { message FilterExecNode { PhysicalPlanNode input = 1; - LogicalExprNode expr = 2; + PhysicalExprNode expr = 2; } message ParquetScanExecNode { @@ -406,7 +507,7 @@ message EmptyExecNode { message ProjectionExecNode { PhysicalPlanNode input = 1; - repeated LogicalExprNode expr = 2; + repeated PhysicalExprNode expr = 2; repeated string expr_name = 3; } @@ -416,8 +517,8 @@ enum AggregateMode { } message HashAggregateExecNode { - repeated LogicalExprNode group_expr = 1; - repeated LogicalExprNode aggr_expr = 2; + repeated PhysicalExprNode group_expr = 1; + repeated PhysicalExprNode aggr_expr = 2; AggregateMode mode = 3; PhysicalPlanNode input = 4; repeated string group_expr_name = 5; @@ -443,7 +544,7 @@ message LocalLimitExecNode { message SortExecNode { PhysicalPlanNode input = 1; - repeated LogicalExprNode expr = 2; + repeated PhysicalExprNode expr = 2; } message CoalesceBatchesExecNode { @@ -455,11 +556,16 @@ message MergeExecNode { PhysicalPlanNode input = 1; } +message PhysicalHashRepartition { + repeated PhysicalExprNode hash_expr = 1; + uint64 partition_count = 2; +} + message RepartitionExecNode{ PhysicalPlanNode input = 1; oneof partition_method { uint64 round_robin = 2; - HashRepartition hash = 3; + PhysicalHashRepartition hash = 3; uint64 unknown = 4; } } @@ -736,7 +842,7 @@ message ScalarListValue{ message ScalarValue{ - oneof value{ + oneof value { bool bool_value = 1; string utf8_value = 2; string large_utf8_value = 3; diff --git a/ballista/rust/core/src/serde/logical_plan/from_proto.rs b/ballista/rust/core/src/serde/logical_plan/from_proto.rs index 9353e9aca75a..b1fff46c7611 100644 --- a/ballista/rust/core/src/serde/logical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/from_proto.rs @@ -24,14 +24,14 @@ use std::{ }; use crate::error::BallistaError; -use crate::serde::{proto_error, protobuf}; +use crate::serde::{from_proto_binary_op, proto_error, protobuf}; use crate::{convert_box_required, convert_required}; use arrow::datatypes::{DataType, Field, Schema}; use datafusion::logical_plan::{ abs, acos, asin, atan, ceil, cos, exp, floor, ln, log10, log2, round, signum, sin, sqrt, tan, trunc, Column, DFField, DFSchema, Expr, JoinType, LogicalPlan, - LogicalPlanBuilder, Operator, + LogicalPlanBuilder, }; use datafusion::physical_plan::aggregates::AggregateFunction; use datafusion::physical_plan::csv::CsvReadOptions; @@ -255,6 +255,7 @@ impl TryInto for &protobuf::LogicalPlanNode { impl From<&protobuf::Column> for Column { fn from(c: &protobuf::Column) -> Column { + let c = c.clone(); Column { relation: c.relation.map(|r| r.relation), name: c.name, @@ -290,7 +291,7 @@ impl TryInto for &protobuf::DfField { fn try_into(self) -> Result { let field: Field = convert_required!(self.field)?; - Ok(match self.qualifier { + Ok(match &self.qualifier { Some(q) => DFField::from_qualified(&q.relation, field), None => DFField::from(field), }) @@ -349,151 +350,6 @@ impl TryInto for &protobuf::scalar_type::Datatype { } } -impl TryInto for &protobuf::arrow_type::ArrowTypeEnum { - type Error = BallistaError; - fn try_into(self) -> Result { - use arrow::datatypes::DataType; - use protobuf::arrow_type; - Ok(match self { - arrow_type::ArrowTypeEnum::None(_) => DataType::Null, - arrow_type::ArrowTypeEnum::Bool(_) => DataType::Boolean, - arrow_type::ArrowTypeEnum::Uint8(_) => DataType::UInt8, - arrow_type::ArrowTypeEnum::Int8(_) => DataType::Int8, - arrow_type::ArrowTypeEnum::Uint16(_) => DataType::UInt16, - arrow_type::ArrowTypeEnum::Int16(_) => DataType::Int16, - arrow_type::ArrowTypeEnum::Uint32(_) => DataType::UInt32, - arrow_type::ArrowTypeEnum::Int32(_) => DataType::Int32, - arrow_type::ArrowTypeEnum::Uint64(_) => DataType::UInt64, - arrow_type::ArrowTypeEnum::Int64(_) => DataType::Int64, - arrow_type::ArrowTypeEnum::Float16(_) => DataType::Float16, - arrow_type::ArrowTypeEnum::Float32(_) => DataType::Float32, - arrow_type::ArrowTypeEnum::Float64(_) => DataType::Float64, - arrow_type::ArrowTypeEnum::Utf8(_) => DataType::Utf8, - arrow_type::ArrowTypeEnum::LargeUtf8(_) => DataType::LargeUtf8, - arrow_type::ArrowTypeEnum::Binary(_) => DataType::Binary, - arrow_type::ArrowTypeEnum::FixedSizeBinary(size) => { - DataType::FixedSizeBinary(*size) - } - arrow_type::ArrowTypeEnum::LargeBinary(_) => DataType::LargeBinary, - arrow_type::ArrowTypeEnum::Date32(_) => DataType::Date32, - arrow_type::ArrowTypeEnum::Date64(_) => DataType::Date64, - arrow_type::ArrowTypeEnum::Duration(time_unit) => { - DataType::Duration(protobuf::TimeUnit::from_i32_to_arrow(*time_unit)?) - } - arrow_type::ArrowTypeEnum::Timestamp(protobuf::Timestamp { - time_unit, - timezone, - }) => DataType::Timestamp( - protobuf::TimeUnit::from_i32_to_arrow(*time_unit)?, - match timezone.len() { - 0 => None, - _ => Some(timezone.to_owned()), - }, - ), - arrow_type::ArrowTypeEnum::Time32(time_unit) => { - DataType::Time32(protobuf::TimeUnit::from_i32_to_arrow(*time_unit)?) - } - arrow_type::ArrowTypeEnum::Time64(time_unit) => { - DataType::Time64(protobuf::TimeUnit::from_i32_to_arrow(*time_unit)?) - } - arrow_type::ArrowTypeEnum::Interval(interval_unit) => DataType::Interval( - protobuf::IntervalUnit::from_i32_to_arrow(*interval_unit)?, - ), - arrow_type::ArrowTypeEnum::Decimal(protobuf::Decimal { - whole, - fractional, - }) => DataType::Decimal(*whole as usize, *fractional as usize), - arrow_type::ArrowTypeEnum::List(list) => { - let list_type: &protobuf::Field = list - .as_ref() - .field_type - .as_ref() - .ok_or_else(|| proto_error("Protobuf deserialization error: List message missing required field 'field_type'"))? - .as_ref(); - DataType::List(Box::new(list_type.try_into()?)) - } - arrow_type::ArrowTypeEnum::LargeList(list) => { - let list_type: &protobuf::Field = list - .as_ref() - .field_type - .as_ref() - .ok_or_else(|| proto_error("Protobuf deserialization error: List message missing required field 'field_type'"))? - .as_ref(); - DataType::LargeList(Box::new(list_type.try_into()?)) - } - arrow_type::ArrowTypeEnum::FixedSizeList(list) => { - let list_type: &protobuf::Field = list - .as_ref() - .field_type - .as_ref() - .ok_or_else(|| proto_error("Protobuf deserialization error: List message missing required field 'field_type'"))? - .as_ref(); - let list_size = list.list_size; - DataType::FixedSizeList(Box::new(list_type.try_into()?), list_size) - } - arrow_type::ArrowTypeEnum::Struct(strct) => DataType::Struct( - strct - .sub_field_types - .iter() - .map(|field| field.try_into()) - .collect::, _>>()?, - ), - arrow_type::ArrowTypeEnum::Union(union) => DataType::Union( - union - .union_types - .iter() - .map(|field| field.try_into()) - .collect::, _>>()?, - ), - arrow_type::ArrowTypeEnum::Dictionary(dict) => { - let pb_key_datatype = dict - .as_ref() - .key - .as_ref() - .ok_or_else(|| proto_error("Protobuf deserialization error: Dictionary message missing required field 'key'"))?; - let pb_value_datatype = dict - .as_ref() - .value - .as_ref() - .ok_or_else(|| proto_error("Protobuf deserialization error: Dictionary message missing required field 'key'"))?; - let key_datatype: DataType = pb_key_datatype.as_ref().try_into()?; - let value_datatype: DataType = pb_value_datatype.as_ref().try_into()?; - DataType::Dictionary(Box::new(key_datatype), Box::new(value_datatype)) - } - }) - } -} - -#[allow(clippy::from_over_into)] -impl Into for protobuf::PrimitiveScalarType { - fn into(self) -> arrow::datatypes::DataType { - use arrow::datatypes::DataType; - match self { - protobuf::PrimitiveScalarType::Bool => DataType::Boolean, - protobuf::PrimitiveScalarType::Uint8 => DataType::UInt8, - protobuf::PrimitiveScalarType::Int8 => DataType::Int8, - protobuf::PrimitiveScalarType::Uint16 => DataType::UInt16, - protobuf::PrimitiveScalarType::Int16 => DataType::Int16, - protobuf::PrimitiveScalarType::Uint32 => DataType::UInt32, - protobuf::PrimitiveScalarType::Int32 => DataType::Int32, - protobuf::PrimitiveScalarType::Uint64 => DataType::UInt64, - protobuf::PrimitiveScalarType::Int64 => DataType::Int64, - protobuf::PrimitiveScalarType::Float32 => DataType::Float32, - protobuf::PrimitiveScalarType::Float64 => DataType::Float64, - protobuf::PrimitiveScalarType::Utf8 => DataType::Utf8, - protobuf::PrimitiveScalarType::LargeUtf8 => DataType::LargeUtf8, - protobuf::PrimitiveScalarType::Date32 => DataType::Date32, - protobuf::PrimitiveScalarType::TimeMicrosecond => { - DataType::Time64(arrow::datatypes::TimeUnit::Microsecond) - } - protobuf::PrimitiveScalarType::TimeNanosecond => { - DataType::Time64(arrow::datatypes::TimeUnit::Nanosecond) - } - protobuf::PrimitiveScalarType::Null => DataType::Null, - } - } -} - //Does not typecheck lists fn typechecked_scalar_value_conversion( tested_type: &protobuf::scalar_value::Value, @@ -925,16 +781,9 @@ impl TryInto for &protobuf::LogicalExprNode { expr.aggr_function )) })?; - let fun = match aggr_function { - protobuf::AggregateFunction::Min => AggregateFunction::Min, - protobuf::AggregateFunction::Max => AggregateFunction::Max, - protobuf::AggregateFunction::Sum => AggregateFunction::Sum, - protobuf::AggregateFunction::Avg => AggregateFunction::Avg, - protobuf::AggregateFunction::Count => AggregateFunction::Count, - }; Ok(Expr::AggregateFunction { - fun, + fun: aggr_function.into(), args: vec![parse_required_expr(&expr.expr)?], distinct: false, //TODO }) @@ -1106,28 +955,6 @@ impl TryInto for &protobuf::LogicalExprNode { } } -fn from_proto_binary_op(op: &str) -> Result { - match op { - "And" => Ok(Operator::And), - "Or" => Ok(Operator::Or), - "Eq" => Ok(Operator::Eq), - "NotEq" => Ok(Operator::NotEq), - "LtEq" => Ok(Operator::LtEq), - "Lt" => Ok(Operator::Lt), - "Gt" => Ok(Operator::Gt), - "GtEq" => Ok(Operator::GtEq), - "Plus" => Ok(Operator::Plus), - "Minus" => Ok(Operator::Minus), - "Multiply" => Ok(Operator::Multiply), - "Divide" => Ok(Operator::Divide), - "Like" => Ok(Operator::Like), - other => Err(proto_error(format!( - "Unsupported binary operator '{:?}'", - other - ))), - } -} - impl TryInto for &protobuf::ScalarType { type Error = BallistaError; fn try_into(self) -> Result { diff --git a/ballista/rust/core/src/serde/logical_plan/to_proto.rs b/ballista/rust/core/src/serde/logical_plan/to_proto.rs index b5b012d70203..5654be6f498a 100644 --- a/ballista/rust/core/src/serde/logical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/to_proto.rs @@ -810,8 +810,8 @@ impl TryInto for &LogicalPlan { JoinType::Right => protobuf::JoinType::Right, JoinType::Full => protobuf::JoinType::Full, }; - let left_join_column = on.iter().map(|on| on.0.into()).collect(); - let right_join_column = on.iter().map(|on| on.1.into()).collect(); + let (left_join_column, right_join_column) = + on.iter().map(|(l, r)| (l.into(), r.into())).unzip(); Ok(protobuf::LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Join(Box::new( protobuf::JoinNode { diff --git a/ballista/rust/core/src/serde/mod.rs b/ballista/rust/core/src/serde/mod.rs index 58a0c18707ee..e92ca2cba885 100644 --- a/ballista/rust/core/src/serde/mod.rs +++ b/ballista/rust/core/src/serde/mod.rs @@ -20,6 +20,9 @@ use std::{convert::TryInto, io::Cursor}; +use datafusion::logical_plan::Operator; +use datafusion::physical_plan::aggregates::AggregateFunction; + use crate::{error::BallistaError, serde::scheduler::Action as BallistaAction}; use prost::Message; @@ -78,3 +81,182 @@ macro_rules! convert_box_required { } }}; } + +pub(crate) fn from_proto_binary_op(op: &str) -> Result { + match op { + "And" => Ok(Operator::And), + "Or" => Ok(Operator::Or), + "Eq" => Ok(Operator::Eq), + "NotEq" => Ok(Operator::NotEq), + "LtEq" => Ok(Operator::LtEq), + "Lt" => Ok(Operator::Lt), + "Gt" => Ok(Operator::Gt), + "GtEq" => Ok(Operator::GtEq), + "Plus" => Ok(Operator::Plus), + "Minus" => Ok(Operator::Minus), + "Multiply" => Ok(Operator::Multiply), + "Divide" => Ok(Operator::Divide), + "Like" => Ok(Operator::Like), + other => Err(proto_error(format!( + "Unsupported binary operator '{:?}'", + other + ))), + } +} + +impl From for AggregateFunction { + fn from(agg_fun: protobuf::AggregateFunction) -> AggregateFunction { + match agg_fun { + protobuf::AggregateFunction::Min => AggregateFunction::Min, + protobuf::AggregateFunction::Max => AggregateFunction::Max, + protobuf::AggregateFunction::Sum => AggregateFunction::Sum, + protobuf::AggregateFunction::Avg => AggregateFunction::Avg, + protobuf::AggregateFunction::Count => AggregateFunction::Count, + } + } +} + +impl TryInto for &protobuf::arrow_type::ArrowTypeEnum { + type Error = BallistaError; + fn try_into(self) -> Result { + use arrow::datatypes::DataType; + use protobuf::arrow_type; + Ok(match self { + arrow_type::ArrowTypeEnum::None(_) => DataType::Null, + arrow_type::ArrowTypeEnum::Bool(_) => DataType::Boolean, + arrow_type::ArrowTypeEnum::Uint8(_) => DataType::UInt8, + arrow_type::ArrowTypeEnum::Int8(_) => DataType::Int8, + arrow_type::ArrowTypeEnum::Uint16(_) => DataType::UInt16, + arrow_type::ArrowTypeEnum::Int16(_) => DataType::Int16, + arrow_type::ArrowTypeEnum::Uint32(_) => DataType::UInt32, + arrow_type::ArrowTypeEnum::Int32(_) => DataType::Int32, + arrow_type::ArrowTypeEnum::Uint64(_) => DataType::UInt64, + arrow_type::ArrowTypeEnum::Int64(_) => DataType::Int64, + arrow_type::ArrowTypeEnum::Float16(_) => DataType::Float16, + arrow_type::ArrowTypeEnum::Float32(_) => DataType::Float32, + arrow_type::ArrowTypeEnum::Float64(_) => DataType::Float64, + arrow_type::ArrowTypeEnum::Utf8(_) => DataType::Utf8, + arrow_type::ArrowTypeEnum::LargeUtf8(_) => DataType::LargeUtf8, + arrow_type::ArrowTypeEnum::Binary(_) => DataType::Binary, + arrow_type::ArrowTypeEnum::FixedSizeBinary(size) => { + DataType::FixedSizeBinary(*size) + } + arrow_type::ArrowTypeEnum::LargeBinary(_) => DataType::LargeBinary, + arrow_type::ArrowTypeEnum::Date32(_) => DataType::Date32, + arrow_type::ArrowTypeEnum::Date64(_) => DataType::Date64, + arrow_type::ArrowTypeEnum::Duration(time_unit) => { + DataType::Duration(protobuf::TimeUnit::from_i32_to_arrow(*time_unit)?) + } + arrow_type::ArrowTypeEnum::Timestamp(protobuf::Timestamp { + time_unit, + timezone, + }) => DataType::Timestamp( + protobuf::TimeUnit::from_i32_to_arrow(*time_unit)?, + match timezone.len() { + 0 => None, + _ => Some(timezone.to_owned()), + }, + ), + arrow_type::ArrowTypeEnum::Time32(time_unit) => { + DataType::Time32(protobuf::TimeUnit::from_i32_to_arrow(*time_unit)?) + } + arrow_type::ArrowTypeEnum::Time64(time_unit) => { + DataType::Time64(protobuf::TimeUnit::from_i32_to_arrow(*time_unit)?) + } + arrow_type::ArrowTypeEnum::Interval(interval_unit) => DataType::Interval( + protobuf::IntervalUnit::from_i32_to_arrow(*interval_unit)?, + ), + arrow_type::ArrowTypeEnum::Decimal(protobuf::Decimal { + whole, + fractional, + }) => DataType::Decimal(*whole as usize, *fractional as usize), + arrow_type::ArrowTypeEnum::List(list) => { + let list_type: &protobuf::Field = list + .as_ref() + .field_type + .as_ref() + .ok_or_else(|| proto_error("Protobuf deserialization error: List message missing required field 'field_type'"))? + .as_ref(); + DataType::List(Box::new(list_type.try_into()?)) + } + arrow_type::ArrowTypeEnum::LargeList(list) => { + let list_type: &protobuf::Field = list + .as_ref() + .field_type + .as_ref() + .ok_or_else(|| proto_error("Protobuf deserialization error: List message missing required field 'field_type'"))? + .as_ref(); + DataType::LargeList(Box::new(list_type.try_into()?)) + } + arrow_type::ArrowTypeEnum::FixedSizeList(list) => { + let list_type: &protobuf::Field = list + .as_ref() + .field_type + .as_ref() + .ok_or_else(|| proto_error("Protobuf deserialization error: List message missing required field 'field_type'"))? + .as_ref(); + let list_size = list.list_size; + DataType::FixedSizeList(Box::new(list_type.try_into()?), list_size) + } + arrow_type::ArrowTypeEnum::Struct(strct) => DataType::Struct( + strct + .sub_field_types + .iter() + .map(|field| field.try_into()) + .collect::, _>>()?, + ), + arrow_type::ArrowTypeEnum::Union(union) => DataType::Union( + union + .union_types + .iter() + .map(|field| field.try_into()) + .collect::, _>>()?, + ), + arrow_type::ArrowTypeEnum::Dictionary(dict) => { + let pb_key_datatype = dict + .as_ref() + .key + .as_ref() + .ok_or_else(|| proto_error("Protobuf deserialization error: Dictionary message missing required field 'key'"))?; + let pb_value_datatype = dict + .as_ref() + .value + .as_ref() + .ok_or_else(|| proto_error("Protobuf deserialization error: Dictionary message missing required field 'key'"))?; + let key_datatype: DataType = pb_key_datatype.as_ref().try_into()?; + let value_datatype: DataType = pb_value_datatype.as_ref().try_into()?; + DataType::Dictionary(Box::new(key_datatype), Box::new(value_datatype)) + } + }) + } +} + +#[allow(clippy::from_over_into)] +impl Into for protobuf::PrimitiveScalarType { + fn into(self) -> arrow::datatypes::DataType { + use arrow::datatypes::DataType; + match self { + protobuf::PrimitiveScalarType::Bool => DataType::Boolean, + protobuf::PrimitiveScalarType::Uint8 => DataType::UInt8, + protobuf::PrimitiveScalarType::Int8 => DataType::Int8, + protobuf::PrimitiveScalarType::Uint16 => DataType::UInt16, + protobuf::PrimitiveScalarType::Int16 => DataType::Int16, + protobuf::PrimitiveScalarType::Uint32 => DataType::UInt32, + protobuf::PrimitiveScalarType::Int32 => DataType::Int32, + protobuf::PrimitiveScalarType::Uint64 => DataType::UInt64, + protobuf::PrimitiveScalarType::Int64 => DataType::Int64, + protobuf::PrimitiveScalarType::Float32 => DataType::Float32, + protobuf::PrimitiveScalarType::Float64 => DataType::Float64, + protobuf::PrimitiveScalarType::Utf8 => DataType::Utf8, + protobuf::PrimitiveScalarType::LargeUtf8 => DataType::LargeUtf8, + protobuf::PrimitiveScalarType::Date32 => DataType::Date32, + protobuf::PrimitiveScalarType::TimeMicrosecond => { + DataType::Time64(arrow::datatypes::TimeUnit::Microsecond) + } + protobuf::PrimitiveScalarType::TimeNanosecond => { + DataType::Time64(arrow::datatypes::TimeUnit::Nanosecond) + } + protobuf::PrimitiveScalarType::Null => DataType::Null, + } + } +} diff --git a/ballista/rust/core/src/serde/physical_plan/from_proto.rs b/ballista/rust/core/src/serde/physical_plan/from_proto.rs index 69b48600e8b2..299b9db4acd9 100644 --- a/ballista/rust/core/src/serde/physical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/physical_plan/from_proto.rs @@ -18,15 +18,14 @@ //! Serde code to convert from protocol buffers to Rust data structures. use std::collections::HashMap; -use std::convert::TryInto; +use std::convert::{TryFrom, TryInto}; use std::sync::Arc; use crate::error::BallistaError; use crate::execution_plans::{ShuffleReaderExec, UnresolvedShuffleExec}; use crate::serde::protobuf::repartition_exec_node::PartitionMethod; -use crate::serde::protobuf::LogicalExprNode; use crate::serde::scheduler::PartitionLocation; -use crate::serde::{proto_error, protobuf}; +use crate::serde::{from_proto_binary_op, proto_error, protobuf}; use crate::{convert_box_required, convert_required, into_required}; use arrow::datatypes::{DataType, Schema, SchemaRef}; @@ -36,9 +35,7 @@ use datafusion::catalog::catalog::{ use datafusion::execution::context::{ ExecutionConfig, ExecutionContextState, ExecutionProps, }; -use datafusion::logical_plan::{DFSchema, Expr}; use datafusion::physical_plan::aggregates::{create_aggregate_expr, AggregateFunction}; -use datafusion::physical_plan::expressions::col; use datafusion::physical_plan::hash_aggregate::{AggregateMode, HashAggregateExec}; use datafusion::physical_plan::hash_join::PartitionMode; use datafusion::physical_plan::merge::MergeExec; @@ -47,8 +44,13 @@ use datafusion::physical_plan::{ coalesce_batches::CoalesceBatchesExec, csv::CsvExec, empty::EmptyExec, - expressions::{Avg, Column, PhysicalSortExpr}, + expressions::{ + col, Avg, BinaryExpr, CaseExpr, CastExpr, Column, InListExpr, IsNotNullExpr, + IsNullExpr, Literal, NegativeExpr, NotExpr, PhysicalSortExpr, TryCastExpr, + DEFAULT_DATAFUSION_CAST_OPTIONS, + }, filter::FilterExec, + functions::{self, BuiltinScalarFunction, ScalarFunctionExpr}, hash_join::HashJoinExec, hash_utils::JoinType, limit::{GlobalLimitExec, LocalLimitExec}, @@ -61,7 +63,7 @@ use datafusion::physical_plan::{ use datafusion::physical_plan::{AggregateExpr, ExecutionPlan, PhysicalExpr}; use datafusion::prelude::CsvReadOptions; use log::debug; -use protobuf::logical_expr_node::ExprType; +use protobuf::physical_expr_node::ExprType; use protobuf::physical_plan_node::PhysicalPlanType; impl TryInto> for &protobuf::PhysicalPlanNode { @@ -82,23 +84,23 @@ impl TryInto> for &protobuf::PhysicalPlanNode { .expr .iter() .zip(projection.expr_name.iter()) - .map(|(expr, name)| { - compile_expr(expr, &input.schema()).map(|e| (e, name.to_string())) - }) - .collect::, _>>()?; + .map(|(expr, name)| Ok((expr.try_into()?, name.to_string()))) + .collect::, String)>, Self::Error>>( + )?; Ok(Arc::new(ProjectionExec::try_new(exprs, input)?)) } PhysicalPlanType::Filter(filter) => { let input: Arc = convert_box_required!(filter.input)?; - let predicate = compile_expr( - filter.expr.as_ref().ok_or_else(|| { + let predicate = filter + .expr + .as_ref() + .ok_or_else(|| { BallistaError::General( "filter (FilterExecNode) in PhysicalPlanNode is missing." .to_owned(), ) - })?, - &input.schema(), - )?; + })? + .try_into()?; Ok(Arc::new(FilterExec::try_new(predicate, input)?)) } PhysicalPlanType::CsvScan(scan) => { @@ -149,7 +151,7 @@ impl TryInto> for &protobuf::PhysicalPlanNode { let expr = hash_part .hash_expr .iter() - .map(|e| compile_expr(e, &input.schema())) + .map(|e| e.try_into()) .collect::>, _>>()?; Ok(Arc::new(RepartitionExec::try_new( @@ -208,29 +210,10 @@ impl TryInto> for &protobuf::PhysicalPlanNode { .iter() .zip(hash_agg.group_expr_name.iter()) .map(|(expr, name)| { - compile_expr(expr, &input.schema()).map(|e| (e, name.to_string())) + expr.try_into().map(|expr| (expr, name.to_string())) }) .collect::, _>>()?; - let logical_agg_expr: Vec<(Expr, String)> = hash_agg - .aggr_expr - .iter() - .zip(hash_agg.aggr_expr_name.iter()) - .map(|(expr, name)| expr.try_into().map(|expr| (expr, name.clone()))) - .collect::, _>>()?; - - let df_planner = DefaultPhysicalPlanner::default(); - let catalog_list = - Arc::new(MemoryCatalogList::new()) as Arc; - let ctx_state = ExecutionContextState { - catalog_list, - scalar_functions: Default::default(), - var_provider: Default::default(), - aggregate_functions: Default::default(), - config: ExecutionConfig::new(), - execution_props: ExecutionProps::new(), - }; - let input_schema = hash_agg .input_schema .as_ref() @@ -243,35 +226,46 @@ impl TryInto> for &protobuf::PhysicalPlanNode { let physical_schema: SchemaRef = SchemaRef::new((&input_schema).try_into()?); - let mut physical_aggr_expr = vec![]; + let physical_aggr_expr: Vec> = hash_agg + .aggr_expr + .iter() + .zip(hash_agg.aggr_expr_name.iter()) + .map(|(expr, name)| { + let expr_type = expr.expr_type.as_ref().ok_or_else(|| { + proto_error("Unexpected empty aggregate physical expression") + })?; + + match expr_type { + ExprType::AggregateExpr(agg_node) => { + let aggr_function = + protobuf::AggregateFunction::from_i32( + agg_node.aggr_function, + ) + .ok_or_else( + || { + proto_error(format!( + "Received an unknown aggregate function: {}", + agg_node.aggr_function + )) + }, + )?; - for (expr, name) in &logical_agg_expr { - match expr { - Expr::AggregateFunction { fun, args, .. } => { - let arg = df_planner - .create_physical_expr( - &args[0], + Ok(create_aggregate_expr( + &aggr_function.into(), + false, + &[convert_box_required!(agg_node.expr)?], &physical_schema, - &ctx_state, - ) - .map_err(|e| { - BallistaError::General(format!("{:?}", e)) - })?; - physical_aggr_expr.push(create_aggregate_expr( - &fun, - false, - &[arg], - &physical_schema, - name.to_string(), - )?); + name.to_string(), + )?) + } + _ => Err(BallistaError::General( + "Invalid aggregate expression for HashAggregateExec" + .to_string(), + )), } - _ => { - return Err(BallistaError::General( - "Invalid expression for HashAggregateExec".to_string(), - )) - } - } - } + }) + .collect::, _>>()?; + Ok(Arc::new(HashAggregateExec::try_new( agg_mode, group, @@ -292,7 +286,7 @@ impl TryInto> for &protobuf::PhysicalPlanNode { let right = into_required!(col.right)?; Ok((left, right)) }) - .collect::>()?; + .collect::>()?; let join_type = protobuf::JoinType::from_i32(hashjoin.join_type) .ok_or_else(|| { proto_error(format!( @@ -341,7 +335,7 @@ impl TryInto> for &protobuf::PhysicalPlanNode { self )) })?; - if let protobuf::logical_expr_node::ExprType::Sort(sort_expr) = expr { + if let protobuf::physical_expr_node::ExprType::Sort(sort_expr) = expr { let expr = sort_expr .expr .as_ref() @@ -353,7 +347,7 @@ impl TryInto> for &protobuf::PhysicalPlanNode { })? .as_ref(); Ok(PhysicalSortExpr { - expr: compile_expr(expr, &input.schema())?, + expr: expr.try_into()?, options: SortOptions { descending: !sort_expr.asc, nulls_first: sort_expr.nulls_first, @@ -388,29 +382,171 @@ impl TryInto> for &protobuf::PhysicalPlanNode { impl From<&protobuf::PhysicalColumn> for Column { fn from(c: &protobuf::PhysicalColumn) -> Column { - Column { - index: c.index as usize, - name: c.name, + Column::new(&c.name, c.index as usize) + } +} + +impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { + fn from(f: &protobuf::ScalarFunction) -> BuiltinScalarFunction { + use protobuf::ScalarFunction; + match f { + ScalarFunction::Sqrt => BuiltinScalarFunction::Sqrt, + ScalarFunction::Sin => BuiltinScalarFunction::Sin, + ScalarFunction::Cos => BuiltinScalarFunction::Cos, + ScalarFunction::Tan => BuiltinScalarFunction::Tan, + ScalarFunction::Asin => BuiltinScalarFunction::Asin, + ScalarFunction::Acos => BuiltinScalarFunction::Acos, + ScalarFunction::Atan => BuiltinScalarFunction::Atan, + ScalarFunction::Exp => BuiltinScalarFunction::Exp, + ScalarFunction::Log => BuiltinScalarFunction::Log, + ScalarFunction::Log2 => BuiltinScalarFunction::Log2, + ScalarFunction::Log10 => BuiltinScalarFunction::Log10, + ScalarFunction::Floor => BuiltinScalarFunction::Floor, + ScalarFunction::Ceil => BuiltinScalarFunction::Ceil, + ScalarFunction::Round => BuiltinScalarFunction::Round, + ScalarFunction::Trunc => BuiltinScalarFunction::Trunc, + ScalarFunction::Abs => BuiltinScalarFunction::Abs, + ScalarFunction::Signum => BuiltinScalarFunction::Signum, + ScalarFunction::Octetlength => BuiltinScalarFunction::OctetLength, + ScalarFunction::Concat => BuiltinScalarFunction::Concat, + ScalarFunction::Lower => BuiltinScalarFunction::Lower, + ScalarFunction::Upper => BuiltinScalarFunction::Upper, + ScalarFunction::Trim => BuiltinScalarFunction::Trim, + ScalarFunction::Ltrim => BuiltinScalarFunction::Ltrim, + ScalarFunction::Rtrim => BuiltinScalarFunction::Rtrim, + ScalarFunction::Totimestamp => BuiltinScalarFunction::ToTimestamp, + ScalarFunction::Array => BuiltinScalarFunction::Array, + ScalarFunction::Nullif => BuiltinScalarFunction::NullIf, + ScalarFunction::Datetrunc => BuiltinScalarFunction::DateTrunc, + ScalarFunction::Md5 => BuiltinScalarFunction::MD5, + ScalarFunction::Sha224 => BuiltinScalarFunction::SHA224, + ScalarFunction::Sha256 => BuiltinScalarFunction::SHA256, + ScalarFunction::Sha384 => BuiltinScalarFunction::SHA384, + ScalarFunction::Sha512 => BuiltinScalarFunction::SHA512, + ScalarFunction::Ln => BuiltinScalarFunction::Ln, } } } -fn compile_expr( - expr: &protobuf::LogicalExprNode, - schema: &Schema, -) -> Result, BallistaError> { - let df_planner = DefaultPhysicalPlanner::default(); - let catalog_list = Arc::new(MemoryCatalogList::new()) as Arc; - let state = ExecutionContextState { - catalog_list, - scalar_functions: HashMap::new(), - var_provider: HashMap::new(), - aggregate_functions: HashMap::new(), - config: ExecutionConfig::new(), - execution_props: ExecutionProps::new(), - }; - let expr: Expr = expr.try_into()?; - df_planner - .create_physical_expr(&expr, schema, &state) - .map_err(|e| BallistaError::General(format!("{:?}", e))) +impl TryFrom<&protobuf::PhysicalExprNode> for Arc { + type Error = BallistaError; + + fn try_from(expr: &protobuf::PhysicalExprNode) -> Result { + let expr_type = expr + .expr_type + .as_ref() + .ok_or_else(|| proto_error("Unexpected empty physical expression"))?; + + let pexpr: Arc = match expr_type { + ExprType::Column(c) => { + let pcol: Column = c.into(); + Arc::new(pcol) + } + ExprType::Literal(scalar) => { + Arc::new(Literal::new(convert_required!(scalar.value)?)) + } + ExprType::BinaryExpr(binary_expr) => Arc::new(BinaryExpr::new( + convert_box_required!(&binary_expr.l)?, + from_proto_binary_op(&binary_expr.op)?, + convert_box_required!(&binary_expr.r)?, + )), + ExprType::AggregateExpr(_) => { + return Err(BallistaError::General( + "Cannot convert aggregate expr node to physical expression" + .to_owned(), + )); + } + ExprType::Sort(_) => { + return Err(BallistaError::General( + "Cannot convert sort expr node to physical expression".to_owned(), + )); + } + ExprType::IsNullExpr(e) => { + Arc::new(IsNullExpr::new(convert_box_required!(e.expr)?)) + } + ExprType::IsNotNullExpr(e) => { + Arc::new(IsNotNullExpr::new(convert_box_required!(e.expr)?)) + } + ExprType::NotExpr(e) => { + Arc::new(NotExpr::new(convert_box_required!(e.expr)?)) + } + ExprType::Negative(e) => { + Arc::new(NegativeExpr::new(convert_box_required!(e.expr)?)) + } + ExprType::InList(e) => Arc::new(InListExpr::new( + convert_box_required!(e.expr)?, + e.list + .iter() + .map(|x| x.try_into()) + .collect::, _>>()?, + e.negated, + )), + ExprType::Case(e) => Arc::new(CaseExpr::try_new( + e.expr.as_ref().map(|e| e.as_ref().try_into()).transpose()?, + e.when_then_expr + .iter() + .map(|e| { + Ok(( + convert_required!(e.when_expr)?, + convert_required!(e.then_expr)?, + )) + }) + .collect::, BallistaError>>()? + .as_slice(), + e.else_expr + .as_ref() + .map(|e| e.as_ref().try_into()) + .transpose()?, + )?), + ExprType::Cast(e) => Arc::new(CastExpr::new( + convert_box_required!(e.expr)?, + convert_required!(e.arrow_type)?, + DEFAULT_DATAFUSION_CAST_OPTIONS, + )), + ExprType::TryCast(e) => Arc::new(TryCastExpr::new( + convert_box_required!(e.expr)?, + convert_required!(e.arrow_type)?, + )), + ExprType::ScalarFunction(e) => { + let scalar_function = protobuf::ScalarFunction::from_i32(e.fun) + .ok_or_else(|| { + proto_error(format!( + "Received an unknown scalar function: {}", + e.fun, + )) + })?; + + let args = e + .args + .iter() + .map(|x| x.try_into()) + .collect::, _>>()?; + + let catalog_list = + Arc::new(MemoryCatalogList::new()) as Arc; + let ctx_state = ExecutionContextState { + catalog_list, + scalar_functions: Default::default(), + var_provider: Default::default(), + aggregate_functions: Default::default(), + config: ExecutionConfig::new(), + execution_props: ExecutionProps::new(), + }; + + let fun_expr = functions::create_physical_fun( + &(&scalar_function).into(), + &ctx_state, + )?; + + Arc::new(ScalarFunctionExpr::new( + &e.name, + fun_expr, + args, + &convert_required!(e.return_type)?, + )) + } + }; + + Ok(pexpr) + } } diff --git a/ballista/rust/core/src/serde/physical_plan/to_proto.rs b/ballista/rust/core/src/serde/physical_plan/to_proto.rs index d771845022d2..5a01e4d1559e 100644 --- a/ballista/rust/core/src/serde/physical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/physical_plan/to_proto.rs @@ -297,7 +297,7 @@ impl TryInto for Arc { let pb_partition_method = match exec.partitioning() { Partitioning::Hash(exprs, partition_count) => { - PartitionMethod::Hash(protobuf::HashRepartition { + PartitionMethod::Hash(protobuf::PhysicalHashRepartition { hash_expr: exprs .iter() .map(|expr| expr.clone().try_into()) @@ -327,13 +327,13 @@ impl TryInto for Arc { .expr() .iter() .map(|expr| { - let sort_expr = Box::new(protobuf::SortExprNode { + let sort_expr = Box::new(protobuf::PhysicalSortExprNode { expr: Some(Box::new(expr.expr.to_owned().try_into()?)), asc: !expr.options.descending, nulls_first: expr.options.nulls_first, }); - Ok(protobuf::LogicalExprNode { - expr_type: Some(protobuf::logical_expr_node::ExprType::Sort( + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::Sort( sort_expr, )), }) @@ -370,10 +370,10 @@ impl TryInto for Arc { } } -impl TryInto for Arc { +impl TryInto for Arc { type Error = BallistaError; - fn try_into(self) -> Result { + fn try_into(self) -> Result { let aggr_function = if self.as_any().downcast_ref::().is_some() { Ok(protobuf::AggregateFunction::Avg.into()) } else if self.as_any().downcast_ref::().is_some() { @@ -386,14 +386,14 @@ impl TryInto for Arc { self ))) }?; - let expressions: Vec = self + let expressions: Vec = self .expressions() .iter() .map(|e| e.clone().try_into()) .collect::, BallistaError>>()?; - Ok(protobuf::LogicalExprNode { - expr_type: Some(protobuf::logical_expr_node::ExprType::AggregateExpr( - Box::new(protobuf::AggregateExprNode { + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::AggregateExpr( + Box::new(protobuf::PhysicalAggregateExprNode { aggr_function, expr: Some(Box::new(expressions[0].clone())), }), @@ -402,93 +402,100 @@ impl TryInto for Arc { } } -impl TryFrom> for protobuf::LogicalExprNode { +impl TryFrom> for protobuf::PhysicalExprNode { type Error = BallistaError; fn try_from(value: Arc) -> Result { let expr = value.as_any(); if let Some(expr) = expr.downcast_ref::() { - Ok(protobuf::LogicalExprNode { - expr_type: Some(protobuf::logical_expr_node::ExprType::Column( - protobuf::Column { - name: expr.name().to_owned(), - relation: None, + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::Column( + protobuf::PhysicalColumn { + name: expr.name().to_string(), + index: expr.index() as u32, }, )), }) } else if let Some(expr) = expr.downcast_ref::() { - let binary_expr = Box::new(protobuf::BinaryExprNode { + let binary_expr = Box::new(protobuf::PhysicalBinaryExprNode { l: Some(Box::new(expr.left().to_owned().try_into()?)), r: Some(Box::new(expr.right().to_owned().try_into()?)), op: format!("{:?}", expr.op()), }); - Ok(protobuf::LogicalExprNode { - expr_type: Some(protobuf::logical_expr_node::ExprType::BinaryExpr( + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::BinaryExpr( binary_expr, )), }) } else if let Some(expr) = expr.downcast_ref::() { - Ok(protobuf::LogicalExprNode { - expr_type: Some(protobuf::logical_expr_node::ExprType::Case(Box::new( - protobuf::CaseNode { - expr: expr - .expr() - .as_ref() - .map(|exp| exp.clone().try_into().map(Box::new)) - .transpose()?, - when_then_expr: expr - .when_then_expr() - .iter() - .map(|(when_expr, then_expr)| { - try_parse_when_then_expr(when_expr, then_expr) - }) - .collect::, Self::Error>>()?, - else_expr: expr - .else_expr() - .map(|a| a.clone().try_into().map(Box::new)) - .transpose()?, - }, - ))), + Ok(protobuf::PhysicalExprNode { + expr_type: Some( + protobuf::physical_expr_node::ExprType::Case( + Box::new( + protobuf::PhysicalCaseNode { + expr: expr + .expr() + .as_ref() + .map(|exp| exp.clone().try_into().map(Box::new)) + .transpose()?, + when_then_expr: expr + .when_then_expr() + .iter() + .map(|(when_expr, then_expr)| { + try_parse_when_then_expr(when_expr, then_expr) + }) + .collect::, + Self::Error, + >>()?, + else_expr: expr + .else_expr() + .map(|a| a.clone().try_into().map(Box::new)) + .transpose()?, + }, + ), + ), + ), }) } else if let Some(expr) = expr.downcast_ref::() { - Ok(protobuf::LogicalExprNode { - expr_type: Some(protobuf::logical_expr_node::ExprType::NotExpr( - Box::new(protobuf::Not { + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::NotExpr( + Box::new(protobuf::PhysicalNot { expr: Some(Box::new(expr.arg().to_owned().try_into()?)), }), )), }) } else if let Some(expr) = expr.downcast_ref::() { - Ok(protobuf::LogicalExprNode { - expr_type: Some(protobuf::logical_expr_node::ExprType::IsNullExpr( - Box::new(protobuf::IsNull { + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::IsNullExpr( + Box::new(protobuf::PhysicalIsNull { expr: Some(Box::new(expr.arg().to_owned().try_into()?)), }), )), }) } else if let Some(expr) = expr.downcast_ref::() { - Ok(protobuf::LogicalExprNode { - expr_type: Some(protobuf::logical_expr_node::ExprType::IsNotNullExpr( - Box::new(protobuf::IsNotNull { + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::IsNotNullExpr( + Box::new(protobuf::PhysicalIsNotNull { expr: Some(Box::new(expr.arg().to_owned().try_into()?)), }), )), }) } else if let Some(expr) = expr.downcast_ref::() { - Ok(protobuf::LogicalExprNode { + Ok(protobuf::PhysicalExprNode { expr_type: Some( - protobuf::logical_expr_node::ExprType::InList( + protobuf::physical_expr_node::ExprType::InList( Box::new( - protobuf::InListNode { + protobuf::PhysicalInListNode { expr: Some(Box::new(expr.expr().to_owned().try_into()?)), list: expr .list() .iter() .map(|a| a.clone().try_into()) .collect::, + Vec, Self::Error, >>()?, negated: expr.negated(), @@ -498,32 +505,32 @@ impl TryFrom> for protobuf::LogicalExprNode { ), }) } else if let Some(expr) = expr.downcast_ref::() { - Ok(protobuf::LogicalExprNode { - expr_type: Some(protobuf::logical_expr_node::ExprType::Negative( - Box::new(protobuf::NegativeNode { + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::Negative( + Box::new(protobuf::PhysicalNegativeNode { expr: Some(Box::new(expr.arg().to_owned().try_into()?)), }), )), }) } else if let Some(lit) = expr.downcast_ref::() { - Ok(protobuf::LogicalExprNode { - expr_type: Some(protobuf::logical_expr_node::ExprType::Literal( + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::Literal( lit.value().try_into()?, )), }) } else if let Some(cast) = expr.downcast_ref::() { - Ok(protobuf::LogicalExprNode { - expr_type: Some(protobuf::logical_expr_node::ExprType::Cast(Box::new( - protobuf::CastNode { + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::Cast(Box::new( + protobuf::PhysicalCastNode { expr: Some(Box::new(cast.expr().clone().try_into()?)), arrow_type: Some(cast.cast_type().into()), }, ))), }) } else if let Some(cast) = expr.downcast_ref::() { - Ok(protobuf::LogicalExprNode { - expr_type: Some(protobuf::logical_expr_node::ExprType::TryCast( - Box::new(protobuf::TryCastNode { + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::TryCast( + Box::new(protobuf::PhysicalTryCastNode { expr: Some(Box::new(cast.expr().clone().try_into()?)), arrow_type: Some(cast.cast_type().into()), }), @@ -533,16 +540,18 @@ impl TryFrom> for protobuf::LogicalExprNode { let fun: BuiltinScalarFunction = BuiltinScalarFunction::from_str(expr.name())?; let fun: protobuf::ScalarFunction = (&fun).try_into()?; - let expr: Vec = expr + let args: Vec = expr .args() .iter() .map(|e| e.to_owned().try_into()) .collect::, _>>()?; - Ok(protobuf::LogicalExprNode { - expr_type: Some(protobuf::logical_expr_node::ExprType::ScalarFunction( - protobuf::ScalarFunctionNode { + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::ScalarFunction( + protobuf::PhysicalScalarFunctionNode { + name: expr.name().to_string(), fun: fun.into(), - expr, + args, + return_type: Some(expr.return_type().into()), }, )), }) @@ -558,8 +567,8 @@ impl TryFrom> for protobuf::LogicalExprNode { fn try_parse_when_then_expr( when_expr: &Arc, then_expr: &Arc, -) -> Result { - Ok(protobuf::WhenThen { +) -> Result { + Ok(protobuf::PhysicalWhenThen { when_expr: Some(when_expr.clone().try_into()?), then_expr: Some(then_expr.clone().try_into()?), }) diff --git a/datafusion/src/physical_plan/expressions/mod.rs b/datafusion/src/physical_plan/expressions/mod.rs index 3d4d81efe35f..6f7669bafaa9 100644 --- a/datafusion/src/physical_plan/expressions/mod.rs +++ b/datafusion/src/physical_plan/expressions/mod.rs @@ -47,7 +47,7 @@ mod try_cast; pub use average::{avg_return_type, Avg, AvgAccumulator}; pub use binary::{binary, binary_operator_data_type, BinaryExpr}; pub use case::{case, CaseExpr}; -pub use cast::{cast, cast_with_options, CastExpr}; +pub use cast::{cast, cast_with_options, CastExpr, DEFAULT_DATAFUSION_CAST_OPTIONS}; pub use column::{col, Column}; pub use count::Count; pub use in_list::{in_list, InListExpr}; diff --git a/datafusion/src/physical_plan/functions.rs b/datafusion/src/physical_plan/functions.rs index 16e9b9e6f764..b01a8cda3d52 100644 --- a/datafusion/src/physical_plan/functions.rs +++ b/datafusion/src/physical_plan/functions.rs @@ -704,38 +704,35 @@ macro_rules! invoke_if_unicode_expressions_feature_flag { }; } -/// Create a physical (function) expression. -/// This function errors when `args`' can't be coerced to a valid argument type of the function. -pub fn create_physical_expr( +/// Create a physical scalar function. +pub fn create_physical_fun( fun: &BuiltinScalarFunction, - args: &[Arc], - input_schema: &Schema, ctx_state: &ExecutionContextState, -) -> Result> { - let fun_expr: ScalarFunctionImplementation = Arc::new(match fun { +) -> Result { + Ok(match fun { // math functions - BuiltinScalarFunction::Abs => math_expressions::abs, - BuiltinScalarFunction::Acos => math_expressions::acos, - BuiltinScalarFunction::Asin => math_expressions::asin, - BuiltinScalarFunction::Atan => math_expressions::atan, - BuiltinScalarFunction::Ceil => math_expressions::ceil, - BuiltinScalarFunction::Cos => math_expressions::cos, - BuiltinScalarFunction::Exp => math_expressions::exp, - BuiltinScalarFunction::Floor => math_expressions::floor, - BuiltinScalarFunction::Log => math_expressions::log10, - BuiltinScalarFunction::Ln => math_expressions::ln, - BuiltinScalarFunction::Log10 => math_expressions::log10, - BuiltinScalarFunction::Log2 => math_expressions::log2, - BuiltinScalarFunction::Round => math_expressions::round, - BuiltinScalarFunction::Signum => math_expressions::signum, - BuiltinScalarFunction::Sin => math_expressions::sin, - BuiltinScalarFunction::Sqrt => math_expressions::sqrt, - BuiltinScalarFunction::Tan => math_expressions::tan, - BuiltinScalarFunction::Trunc => math_expressions::trunc, + BuiltinScalarFunction::Abs => Arc::new(math_expressions::abs), + BuiltinScalarFunction::Acos => Arc::new(math_expressions::acos), + BuiltinScalarFunction::Asin => Arc::new(math_expressions::asin), + BuiltinScalarFunction::Atan => Arc::new(math_expressions::atan), + BuiltinScalarFunction::Ceil => Arc::new(math_expressions::ceil), + BuiltinScalarFunction::Cos => Arc::new(math_expressions::cos), + BuiltinScalarFunction::Exp => Arc::new(math_expressions::exp), + BuiltinScalarFunction::Floor => Arc::new(math_expressions::floor), + BuiltinScalarFunction::Log => Arc::new(math_expressions::log10), + BuiltinScalarFunction::Ln => Arc::new(math_expressions::ln), + BuiltinScalarFunction::Log10 => Arc::new(math_expressions::log10), + BuiltinScalarFunction::Log2 => Arc::new(math_expressions::log2), + BuiltinScalarFunction::Round => Arc::new(math_expressions::round), + BuiltinScalarFunction::Signum => Arc::new(math_expressions::signum), + BuiltinScalarFunction::Sin => Arc::new(math_expressions::sin), + BuiltinScalarFunction::Sqrt => Arc::new(math_expressions::sqrt), + BuiltinScalarFunction::Tan => Arc::new(math_expressions::tan), + BuiltinScalarFunction::Trunc => Arc::new(math_expressions::trunc), // string functions - BuiltinScalarFunction::Array => array_expressions::array, - BuiltinScalarFunction::Ascii => |args| match args[0].data_type() { + BuiltinScalarFunction::Array => Arc::new(array_expressions::array), + BuiltinScalarFunction::Ascii => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { make_scalar_function(string_expressions::ascii::)(args) } @@ -746,8 +743,8 @@ pub fn create_physical_expr( "Unsupported data type {:?} for function ascii", other, ))), - }, - BuiltinScalarFunction::BitLength => |args| match &args[0] { + }), + BuiltinScalarFunction::BitLength => Arc::new(|args| match &args[0] { ColumnarValue::Array(v) => Ok(ColumnarValue::Array(bit_length(v.as_ref())?)), ColumnarValue::Scalar(v) => match v { ScalarValue::Utf8(v) => Ok(ColumnarValue::Scalar(ScalarValue::Int32( @@ -758,8 +755,8 @@ pub fn create_physical_expr( )), _ => unreachable!(), }, - }, - BuiltinScalarFunction::Btrim => |args| match args[0].data_type() { + }), + BuiltinScalarFunction::Btrim => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { make_scalar_function(string_expressions::btrim::)(args) } @@ -770,55 +767,47 @@ pub fn create_physical_expr( "Unsupported data type {:?} for function btrim", other, ))), - }, - BuiltinScalarFunction::CharacterLength => |args| match args[0].data_type() { - DataType::Utf8 => { - let func = invoke_if_unicode_expressions_feature_flag!( - character_length, - Int32Type, - "character_length" - ); - make_scalar_function(func)(args) - } - DataType::LargeUtf8 => { - let func = invoke_if_unicode_expressions_feature_flag!( - character_length, - Int64Type, - "character_length" - ); - make_scalar_function(func)(args) - } - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {:?} for function character_length", - other, - ))), - }, + }), + BuiltinScalarFunction::CharacterLength => { + Arc::new(|args| match args[0].data_type() { + DataType::Utf8 => { + let func = invoke_if_unicode_expressions_feature_flag!( + character_length, + Int32Type, + "character_length" + ); + make_scalar_function(func)(args) + } + DataType::LargeUtf8 => { + let func = invoke_if_unicode_expressions_feature_flag!( + character_length, + Int64Type, + "character_length" + ); + make_scalar_function(func)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function character_length", + other, + ))), + }) + } BuiltinScalarFunction::Chr => { - |args| make_scalar_function(string_expressions::chr)(args) + Arc::new(|args| make_scalar_function(string_expressions::chr)(args)) } - BuiltinScalarFunction::Concat => string_expressions::concat, + BuiltinScalarFunction::Concat => Arc::new(string_expressions::concat), BuiltinScalarFunction::ConcatWithSeparator => { - |args| make_scalar_function(string_expressions::concat_ws)(args) + Arc::new(|args| make_scalar_function(string_expressions::concat_ws)(args)) } - BuiltinScalarFunction::DatePart => datetime_expressions::date_part, - BuiltinScalarFunction::DateTrunc => datetime_expressions::date_trunc, + BuiltinScalarFunction::DatePart => Arc::new(datetime_expressions::date_part), + BuiltinScalarFunction::DateTrunc => Arc::new(datetime_expressions::date_trunc), BuiltinScalarFunction::Now => { // bind value for now at plan time - let fun_expr = Arc::new(datetime_expressions::make_now( + Arc::new(datetime_expressions::make_now( ctx_state.execution_props.query_execution_start_time, - )); - - // TODO refactor code to not return here, but instead fall through below - let args = vec![]; - let arg_types = vec![]; // has no args - return Ok(Arc::new(ScalarFunctionExpr::new( - &format!("{}", fun), - fun_expr, - args, - &return_type(&fun, &arg_types)?, - ))); + )) } - BuiltinScalarFunction::InitCap => |args| match args[0].data_type() { + BuiltinScalarFunction::InitCap => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { make_scalar_function(string_expressions::initcap::)(args) } @@ -829,8 +818,8 @@ pub fn create_physical_expr( "Unsupported data type {:?} for function initcap", other, ))), - }, - BuiltinScalarFunction::Left => |args| match args[0].data_type() { + }), + BuiltinScalarFunction::Left => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { let func = invoke_if_unicode_expressions_feature_flag!(left, i32, "left"); make_scalar_function(func)(args) @@ -843,9 +832,9 @@ pub fn create_physical_expr( "Unsupported data type {:?} for function left", other, ))), - }, - BuiltinScalarFunction::Lower => string_expressions::lower, - BuiltinScalarFunction::Lpad => |args| match args[0].data_type() { + }), + BuiltinScalarFunction::Lower => Arc::new(string_expressions::lower), + BuiltinScalarFunction::Lpad => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { let func = invoke_if_unicode_expressions_feature_flag!(lpad, i32, "lpad"); make_scalar_function(func)(args) @@ -858,8 +847,8 @@ pub fn create_physical_expr( "Unsupported data type {:?} for function lpad", other, ))), - }, - BuiltinScalarFunction::Ltrim => |args| match args[0].data_type() { + }), + BuiltinScalarFunction::Ltrim => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { make_scalar_function(string_expressions::ltrim::)(args) } @@ -870,12 +859,12 @@ pub fn create_physical_expr( "Unsupported data type {:?} for function ltrim", other, ))), - }, + }), BuiltinScalarFunction::MD5 => { - invoke_if_crypto_expressions_feature_flag!(md5, "md5") + Arc::new(invoke_if_crypto_expressions_feature_flag!(md5, "md5")) } - BuiltinScalarFunction::NullIf => nullif_func, - BuiltinScalarFunction::OctetLength => |args| match &args[0] { + BuiltinScalarFunction::NullIf => Arc::new(nullif_func), + BuiltinScalarFunction::OctetLength => Arc::new(|args| match &args[0] { ColumnarValue::Array(v) => Ok(ColumnarValue::Array(length(v.as_ref())?)), ColumnarValue::Scalar(v) => match v { ScalarValue::Utf8(v) => Ok(ColumnarValue::Scalar(ScalarValue::Int32( @@ -886,52 +875,56 @@ pub fn create_physical_expr( )), _ => unreachable!(), }, - }, - BuiltinScalarFunction::RegexpMatch => |args| match args[0].data_type() { - DataType::Utf8 => { - let func = invoke_if_regex_expressions_feature_flag!( - regexp_match, - i32, - "regexp_match" - ); - make_scalar_function(func)(args) - } - DataType::LargeUtf8 => { - let func = invoke_if_regex_expressions_feature_flag!( - regexp_match, - i64, - "regexp_match" - ); - make_scalar_function(func)(args) - } - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {:?} for function regexp_match", - other - ))), - }, - BuiltinScalarFunction::RegexpReplace => |args| match args[0].data_type() { - DataType::Utf8 => { - let func = invoke_if_regex_expressions_feature_flag!( - regexp_replace, - i32, - "regexp_replace" - ); - make_scalar_function(func)(args) - } - DataType::LargeUtf8 => { - let func = invoke_if_regex_expressions_feature_flag!( - regexp_replace, - i64, - "regexp_replace" - ); - make_scalar_function(func)(args) - } - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {:?} for function regexp_replace", - other, - ))), - }, - BuiltinScalarFunction::Repeat => |args| match args[0].data_type() { + }), + BuiltinScalarFunction::RegexpMatch => { + Arc::new(|args| match args[0].data_type() { + DataType::Utf8 => { + let func = invoke_if_regex_expressions_feature_flag!( + regexp_match, + i32, + "regexp_match" + ); + make_scalar_function(func)(args) + } + DataType::LargeUtf8 => { + let func = invoke_if_regex_expressions_feature_flag!( + regexp_match, + i64, + "regexp_match" + ); + make_scalar_function(func)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function regexp_match", + other + ))), + }) + } + BuiltinScalarFunction::RegexpReplace => { + Arc::new(|args| match args[0].data_type() { + DataType::Utf8 => { + let func = invoke_if_regex_expressions_feature_flag!( + regexp_replace, + i32, + "regexp_replace" + ); + make_scalar_function(func)(args) + } + DataType::LargeUtf8 => { + let func = invoke_if_regex_expressions_feature_flag!( + regexp_replace, + i64, + "regexp_replace" + ); + make_scalar_function(func)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function regexp_replace", + other, + ))), + }) + } + BuiltinScalarFunction::Repeat => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { make_scalar_function(string_expressions::repeat::)(args) } @@ -942,8 +935,8 @@ pub fn create_physical_expr( "Unsupported data type {:?} for function repeat", other, ))), - }, - BuiltinScalarFunction::Replace => |args| match args[0].data_type() { + }), + BuiltinScalarFunction::Replace => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { make_scalar_function(string_expressions::replace::)(args) } @@ -954,8 +947,8 @@ pub fn create_physical_expr( "Unsupported data type {:?} for function replace", other, ))), - }, - BuiltinScalarFunction::Reverse => |args| match args[0].data_type() { + }), + BuiltinScalarFunction::Reverse => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { let func = invoke_if_unicode_expressions_feature_flag!(reverse, i32, "reverse"); @@ -970,8 +963,8 @@ pub fn create_physical_expr( "Unsupported data type {:?} for function reverse", other, ))), - }, - BuiltinScalarFunction::Right => |args| match args[0].data_type() { + }), + BuiltinScalarFunction::Right => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { let func = invoke_if_unicode_expressions_feature_flag!(right, i32, "right"); @@ -986,8 +979,8 @@ pub fn create_physical_expr( "Unsupported data type {:?} for function right", other, ))), - }, - BuiltinScalarFunction::Rpad => |args| match args[0].data_type() { + }), + BuiltinScalarFunction::Rpad => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { let func = invoke_if_unicode_expressions_feature_flag!(rpad, i32, "rpad"); make_scalar_function(func)(args) @@ -1000,8 +993,8 @@ pub fn create_physical_expr( "Unsupported data type {:?} for function rpad", other, ))), - }, - BuiltinScalarFunction::Rtrim => |args| match args[0].data_type() { + }), + BuiltinScalarFunction::Rtrim => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { make_scalar_function(string_expressions::rtrim::)(args) } @@ -1012,20 +1005,20 @@ pub fn create_physical_expr( "Unsupported data type {:?} for function rtrim", other, ))), - }, + }), BuiltinScalarFunction::SHA224 => { - invoke_if_crypto_expressions_feature_flag!(sha224, "sha224") + Arc::new(invoke_if_crypto_expressions_feature_flag!(sha224, "sha224")) } BuiltinScalarFunction::SHA256 => { - invoke_if_crypto_expressions_feature_flag!(sha256, "sha256") + Arc::new(invoke_if_crypto_expressions_feature_flag!(sha256, "sha256")) } BuiltinScalarFunction::SHA384 => { - invoke_if_crypto_expressions_feature_flag!(sha384, "sha384") + Arc::new(invoke_if_crypto_expressions_feature_flag!(sha384, "sha384")) } BuiltinScalarFunction::SHA512 => { - invoke_if_crypto_expressions_feature_flag!(sha512, "sha512") + Arc::new(invoke_if_crypto_expressions_feature_flag!(sha512, "sha512")) } - BuiltinScalarFunction::SplitPart => |args| match args[0].data_type() { + BuiltinScalarFunction::SplitPart => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { make_scalar_function(string_expressions::split_part::)(args) } @@ -1036,8 +1029,8 @@ pub fn create_physical_expr( "Unsupported data type {:?} for function split_part", other, ))), - }, - BuiltinScalarFunction::StartsWith => |args| match args[0].data_type() { + }), + BuiltinScalarFunction::StartsWith => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { make_scalar_function(string_expressions::starts_with::)(args) } @@ -1048,8 +1041,8 @@ pub fn create_physical_expr( "Unsupported data type {:?} for function starts_with", other, ))), - }, - BuiltinScalarFunction::Strpos => |args| match args[0].data_type() { + }), + BuiltinScalarFunction::Strpos => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { let func = invoke_if_unicode_expressions_feature_flag!( strpos, Int32Type, "strpos" @@ -1066,8 +1059,8 @@ pub fn create_physical_expr( "Unsupported data type {:?} for function strpos", other, ))), - }, - BuiltinScalarFunction::Substr => |args| match args[0].data_type() { + }), + BuiltinScalarFunction::Substr => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { let func = invoke_if_unicode_expressions_feature_flag!(substr, i32, "substr"); @@ -1082,8 +1075,8 @@ pub fn create_physical_expr( "Unsupported data type {:?} for function substr", other, ))), - }, - BuiltinScalarFunction::ToHex => |args| match args[0].data_type() { + }), + BuiltinScalarFunction::ToHex => Arc::new(|args| match args[0].data_type() { DataType::Int32 => { make_scalar_function(string_expressions::to_hex::)(args) } @@ -1094,9 +1087,11 @@ pub fn create_physical_expr( "Unsupported data type {:?} for function to_hex", other, ))), - }, - BuiltinScalarFunction::ToTimestamp => datetime_expressions::to_timestamp, - BuiltinScalarFunction::Translate => |args| match args[0].data_type() { + }), + BuiltinScalarFunction::ToTimestamp => { + Arc::new(datetime_expressions::to_timestamp) + } + BuiltinScalarFunction::Translate => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { let func = invoke_if_unicode_expressions_feature_flag!( translate, @@ -1117,8 +1112,8 @@ pub fn create_physical_expr( "Unsupported data type {:?} for function translate", other, ))), - }, - BuiltinScalarFunction::Trim => |args| match args[0].data_type() { + }), + BuiltinScalarFunction::Trim => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { make_scalar_function(string_expressions::btrim::)(args) } @@ -1129,10 +1124,20 @@ pub fn create_physical_expr( "Unsupported data type {:?} for function trim", other, ))), - }, - BuiltinScalarFunction::Upper => string_expressions::upper, - }); - // coerce + }), + BuiltinScalarFunction::Upper => Arc::new(string_expressions::upper), + }) +} + +/// Create a physical (function) expression. +/// This function errors when `args`' can't be coerced to a valid argument type of the function. +pub fn create_physical_expr( + fun: &BuiltinScalarFunction, + args: &[Arc], + input_schema: &Schema, + ctx_state: &ExecutionContextState, +) -> Result> { + let fun_expr = create_physical_fun(fun, ctx_state)?; let args = coerce(args, input_schema, &signature(fun))?; let arg_types = args From 841159ff18c18d83934d346d0e184643eb1c29dd Mon Sep 17 00:00:00 2001 From: Qingping Hou Date: Sun, 23 May 2021 20:38:05 -0700 Subject: [PATCH 10/25] fix join schema handling in production push down optimizer schema needs to be recalculated based on newly optimized inputs --- datafusion/src/logical_plan/builder.rs | 20 ++--- datafusion/src/logical_plan/mod.rs | 7 +- datafusion/src/logical_plan/plan.rs | 11 +++ .../src/optimizer/hash_build_probe_order.rs | 3 + .../src/optimizer/projection_push_down.rs | 90 ++++++++++++++----- datafusion/src/optimizer/utils.rs | 30 ++++--- 6 files changed, 119 insertions(+), 42 deletions(-) diff --git a/datafusion/src/logical_plan/builder.rs b/datafusion/src/logical_plan/builder.rs index b42b8bfadd07..0b7fc3e84221 100644 --- a/datafusion/src/logical_plan/builder.rs +++ b/datafusion/src/logical_plan/builder.rs @@ -32,7 +32,10 @@ use crate::{ }; use super::dfschema::ToDFSchema; -use super::{exprlist_to_fields, Expr, JoinType, LogicalPlan, PlanType, StringifiedPlan}; +use super::{ + exprlist_to_fields, Expr, JoinConstraint, JoinType, LogicalPlan, PlanType, + StringifiedPlan, +}; use crate::logical_plan::{ columnize_expr, normalize_col, normalize_cols, Column, DFField, DFSchema, DFSchemaRef, Partitioning, @@ -42,11 +45,6 @@ use std::collections::HashSet; /// Default table name for unnamed table pub const UNNAMED_TABLE: &str = "?table?"; -pub enum JoinConstraint { - On, - Using, -} - /// Builder for logical plans /// /// ``` @@ -284,7 +282,7 @@ impl LogicalPlanBuilder { right.schema(), &on, &join_type, - JoinConstraint::On, + &JoinConstraint::On, )?; Ok(Self::from(&LogicalPlan::Join { @@ -292,6 +290,7 @@ impl LogicalPlanBuilder { right: Arc::new(right.clone()), on, join_type, + join_constraint: JoinConstraint::On, schema: DFSchemaRef::new(join_schema), })) } @@ -319,7 +318,7 @@ impl LogicalPlanBuilder { right.schema(), &on, &join_type, - JoinConstraint::Using, + &JoinConstraint::Using, )?; Ok(Self::from(&LogicalPlan::Join { @@ -327,6 +326,7 @@ impl LogicalPlanBuilder { right: Arc::new(right.clone()), on, join_type, + join_constraint: JoinConstraint::Using, schema: DFSchemaRef::new(join_schema), })) } @@ -400,12 +400,12 @@ impl LogicalPlanBuilder { /// Creates a schema for a join operation. /// The fields from the left side are first -fn build_join_schema( +pub fn build_join_schema( left: &DFSchema, right: &DFSchema, on: &[(Column, Column)], join_type: &JoinType, - join_constraint: JoinConstraint, + join_constraint: &JoinConstraint, ) -> Result { let fields: Vec = match join_type { JoinType::Inner | JoinType::Left | JoinType::Full => { diff --git a/datafusion/src/logical_plan/mod.rs b/datafusion/src/logical_plan/mod.rs index 5bcf8ec765be..d259003c5b34 100644 --- a/datafusion/src/logical_plan/mod.rs +++ b/datafusion/src/logical_plan/mod.rs @@ -29,7 +29,9 @@ mod extension; mod operators; mod plan; mod registry; -pub use builder::{union_with_alias, LogicalPlanBuilder, UNNAMED_TABLE}; +pub use builder::{ + build_join_schema, union_with_alias, LogicalPlanBuilder, UNNAMED_TABLE, +}; pub use dfschema::{DFField, DFSchema, DFSchemaRef, ToDFSchema}; pub use display::display_schema; pub use expr::{ @@ -46,6 +48,7 @@ pub use expr::{ pub use extension::UserDefinedLogicalNode; pub use operators::Operator; pub use plan::{ - JoinType, LogicalPlan, Partitioning, PlanType, PlanVisitor, StringifiedPlan, + JoinConstraint, JoinType, LogicalPlan, Partitioning, PlanType, PlanVisitor, + StringifiedPlan, }; pub use registry::FunctionRegistry; diff --git a/datafusion/src/logical_plan/plan.rs b/datafusion/src/logical_plan/plan.rs index 28ffefd471c5..1ae37a190e76 100644 --- a/datafusion/src/logical_plan/plan.rs +++ b/datafusion/src/logical_plan/plan.rs @@ -45,6 +45,15 @@ pub enum JoinType { Full, } +/// Join constraint +#[derive(Debug, Clone, Copy)] +pub enum JoinConstraint { + /// Join ON + On, + /// Join USING + Using, +} + /// A LogicalPlan represents the different types of relational /// operators (such as Projection, Filter, etc) and can be created by /// the SQL query planner and the DataFrame API. @@ -108,6 +117,8 @@ pub enum LogicalPlan { on: Vec<(Column, Column)>, /// Join type join_type: JoinType, + /// Join constraint + join_constraint: JoinConstraint, /// The output schema, containing fields from the left and right inputs schema: DFSchemaRef, }, diff --git a/datafusion/src/optimizer/hash_build_probe_order.rs b/datafusion/src/optimizer/hash_build_probe_order.rs index 642a7ef4d936..fb0d95fc1e1d 100644 --- a/datafusion/src/optimizer/hash_build_probe_order.rs +++ b/datafusion/src/optimizer/hash_build_probe_order.rs @@ -120,6 +120,7 @@ impl OptimizerRule for HashBuildProbeOrder { right, on, join_type, + join_constraint, schema, } => { let left = self.optimize(left, execution_props)?; @@ -131,6 +132,7 @@ impl OptimizerRule for HashBuildProbeOrder { right: Arc::new(left), on: on.iter().map(|(l, r)| (r.clone(), l.clone())).collect(), join_type: swap_join_type(*join_type), + join_constraint: *join_constraint, schema: schema.clone(), }) } else { @@ -140,6 +142,7 @@ impl OptimizerRule for HashBuildProbeOrder { right: Arc::new(right), on: on.clone(), join_type: *join_type, + join_constraint: *join_constraint, schema: schema.clone(), }) } diff --git a/datafusion/src/optimizer/projection_push_down.rs b/datafusion/src/optimizer/projection_push_down.rs index c12109980707..5eef39064b57 100644 --- a/datafusion/src/optimizer/projection_push_down.rs +++ b/datafusion/src/optimizer/projection_push_down.rs @@ -21,7 +21,7 @@ use crate::error::Result; use crate::execution::context::ExecutionProps; use crate::logical_plan::{ - Column, DFField, DFSchema, DFSchemaRef, LogicalPlan, ToDFSchema, + build_join_schema, Column, DFField, DFSchema, DFSchemaRef, LogicalPlan, ToDFSchema, }; use crate::optimizer::optimizer::OptimizerRule; use crate::optimizer::utils; @@ -182,31 +182,45 @@ fn optimize_plan( right, on, join_type, - schema, + join_constraint, + .. } => { for (l, r) in on { new_required_columns.insert(l.clone()); new_required_columns.insert(r.clone()); } - Ok(LogicalPlan::Join { - left: Arc::new(optimize_plan( - optimizer, - &left, - &new_required_columns, - true, - execution_props, - )?), - right: Arc::new(optimize_plan( - optimizer, - &right, - &new_required_columns, - true, - execution_props, - )?), + let optimized_left = Arc::new(optimize_plan( + optimizer, + &left, + &new_required_columns, + true, + execution_props, + )?); + + let optimized_right = Arc::new(optimize_plan( + optimizer, + &right, + &new_required_columns, + true, + execution_props, + )?); + + let schema = build_join_schema( + &optimized_left.schema(), + &optimized_right.schema(), + on, + join_type, + join_constraint, + )?; + + Ok(LogicalPlan::Join { + left: optimized_left, + right: optimized_right, join_type: *join_type, + join_constraint: *join_constraint, on: on.clone(), - schema: schema.clone(), + schema: DFSchemaRef::new(schema), }) } LogicalPlan::Aggregate { @@ -382,8 +396,7 @@ fn optimize_plan( mod tests { use super::*; - use crate::logical_plan::{col, lit}; - use crate::logical_plan::{max, min, Expr, LogicalPlanBuilder}; + use crate::logical_plan::{col, lit, max, min, Expr, JoinType, LogicalPlanBuilder}; use crate::test::*; use arrow::datatypes::DataType; @@ -437,6 +450,43 @@ mod tests { Ok(()) } + #[test] + fn join_schema_trim() -> Result<()> { + let table_scan = test_table_scan()?; + + let schema = Schema::new(vec![Field::new("c1", DataType::UInt32, false)]); + let table2_scan = + LogicalPlanBuilder::scan_empty(Some("test2"), &schema, None)?.build()?; + + let plan = LogicalPlanBuilder::from(&table_scan) + .join(&table2_scan, JoinType::Left, vec!["a"], vec!["c1"])? + .project(vec![col("a"), col("b"), col("c1")])? + .build()?; + + // make sure projections are pushed down to table scan + let expected = "Projection: #test.a, #test.b, #test2.c1\ + \n Join: #test.a = #test2.c1\ + \n TableScan: test projection=Some([0, 1])\ + \n TableScan: test2 projection=Some([0])"; + + let optimized_plan = optimize(&plan)?; + let formatted_plan = format!("{:?}", optimized_plan); + assert_eq!(formatted_plan, expected); + + // make sure schema for join node doesn't include c1 column + let optimized_join = optimized_plan.inputs()[0]; + assert_eq!( + **optimized_join.schema(), + DFSchema::new(vec![ + DFField::new(Some("test"), "a", DataType::UInt32, false), + DFField::new(Some("test"), "b", DataType::UInt32, false), + DFField::new(Some("test2"), "c1", DataType::UInt32, false), + ])?, + ); + + Ok(()) + } + #[test] fn cast() -> Result<()> { let table_scan = test_table_scan()?; diff --git a/datafusion/src/optimizer/utils.rs b/datafusion/src/optimizer/utils.rs index 5080db82fb97..1ac357db7b13 100644 --- a/datafusion/src/optimizer/utils.rs +++ b/datafusion/src/optimizer/utils.rs @@ -24,8 +24,8 @@ use arrow::datatypes::Schema; use super::optimizer::OptimizerRule; use crate::execution::context::ExecutionProps; use crate::logical_plan::{ - Column, Expr, LogicalPlan, LogicalPlanBuilder, Operator, Partitioning, PlanType, - Recursion, StringifiedPlan, ToDFSchema, + build_join_schema, Column, DFSchemaRef, Expr, LogicalPlan, LogicalPlanBuilder, + Operator, Partitioning, PlanType, Recursion, StringifiedPlan, ToDFSchema, }; use crate::prelude::lit; use crate::scalar::ScalarValue; @@ -202,16 +202,26 @@ pub fn from_plan( }), LogicalPlan::Join { join_type, + join_constraint, on, - schema, .. - } => Ok(LogicalPlan::Join { - left: Arc::new(inputs[0].clone()), - right: Arc::new(inputs[1].clone()), - join_type: *join_type, - on: on.clone(), - schema: schema.clone(), - }), + } => { + let schema = build_join_schema( + inputs[0].schema(), + inputs[1].schema(), + on, + join_type, + join_constraint, + )?; + Ok(LogicalPlan::Join { + left: Arc::new(inputs[0].clone()), + right: Arc::new(inputs[1].clone()), + join_type: *join_type, + join_constraint: *join_constraint, + on: on.clone(), + schema: DFSchemaRef::new(schema), + }) + } LogicalPlan::CrossJoin { .. } => { let left = &inputs[0]; let right = &inputs[1]; From 9ab47115eb0b5a119c934cc962c5368557606b71 Mon Sep 17 00:00:00 2001 From: Qingping Hou Date: Sun, 23 May 2021 20:38:57 -0700 Subject: [PATCH 11/25] tpch 7 & 8 are now passing! --- benchmarks/src/bin/tpch.rs | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/benchmarks/src/bin/tpch.rs b/benchmarks/src/bin/tpch.rs index cee555fe675e..e90bf597697f 100644 --- a/benchmarks/src/bin/tpch.rs +++ b/benchmarks/src/bin/tpch.rs @@ -661,6 +661,16 @@ mod tests { run_query(6).await } + #[tokio::test] + async fn run_q7() -> Result<()> { + run_query(7).await + } + + #[tokio::test] + async fn run_q8() -> Result<()> { + run_query(8).await + } + #[tokio::test] async fn run_q9() -> Result<()> { run_query(9).await From babb252e5d1b7002cecbb62585a95dfe2f766665 Mon Sep 17 00:00:00 2001 From: Qingping Hou Date: Sun, 23 May 2021 21:29:58 -0700 Subject: [PATCH 12/25] fix roundtrip_join test --- ballista/rust/core/src/serde/logical_plan/mod.rs | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/ballista/rust/core/src/serde/logical_plan/mod.rs b/ballista/rust/core/src/serde/logical_plan/mod.rs index 2684c3122189..7619f1951aa7 100644 --- a/ballista/rust/core/src/serde/logical_plan/mod.rs +++ b/ballista/rust/core/src/serde/logical_plan/mod.rs @@ -709,13 +709,18 @@ mod roundtrip_tests { Field::new("salary", DataType::Int32, false), ]); - let scan_plan = LogicalPlanBuilder::empty(false) - .build() - .map_err(BallistaError::DataFusionError)?; + let scan_plan = LogicalPlanBuilder::scan_csv( + "employee1", + CsvReadOptions::new().schema(&schema).has_header(true), + Some(vec![0, 3, 4]), + )? + .build() + .map_err(BallistaError::DataFusionError)?; + let plan = LogicalPlanBuilder::scan_csv( - "employee.csv", + "employee2", CsvReadOptions::new().schema(&schema).has_header(true), - Some(vec![3, 4]), + Some(vec![0, 3, 4]), ) .and_then(|plan| plan.join(&scan_plan, JoinType::Inner, vec!["id"], vec!["id"])) .and_then(|plan| plan.build()) From 6aaa148b60465add5ae8c7417b050541fc9bbfc9 Mon Sep 17 00:00:00 2001 From: Qingping Hou Date: Tue, 1 Jun 2021 21:40:43 -0700 Subject: [PATCH 13/25] fix clippy warnings --- datafusion/src/logical_plan/dfschema.rs | 2 +- datafusion/src/logical_plan/expr.rs | 6 +++--- datafusion/src/optimizer/filter_push_down.rs | 2 +- datafusion/src/physical_plan/cross_join.rs | 2 +- datafusion/src/physical_plan/hash_join.rs | 4 ++-- datafusion/src/physical_plan/hash_utils.rs | 6 ++++-- 6 files changed, 12 insertions(+), 10 deletions(-) diff --git a/datafusion/src/logical_plan/dfschema.rs b/datafusion/src/logical_plan/dfschema.rs index 517c02a58b71..583c0917dfaf 100644 --- a/datafusion/src/logical_plan/dfschema.rs +++ b/datafusion/src/logical_plan/dfschema.rs @@ -265,7 +265,7 @@ impl DFSchema { .into_iter() .map(|f| { DFField::new( - Some(qualifer.clone()), + Some(qualifer), f.name(), f.data_type().to_owned(), f.is_nullable(), diff --git a/datafusion/src/logical_plan/expr.rs b/datafusion/src/logical_plan/expr.rs index a8879c94a2f4..89662f4ea65f 100644 --- a/datafusion/src/logical_plan/expr.rs +++ b/datafusion/src/logical_plan/expr.rs @@ -87,7 +87,7 @@ impl Column { /// Normalize Column with qualifier based on provided dataframe schemas. pub fn normalize(self, schemas: &[&DFSchemaRef]) -> Result { - if !self.relation.is_none() { + if self.relation.is_some() { return Ok(self); } @@ -445,7 +445,7 @@ impl Expr { pub fn to_field(&self, input_schema: &DFSchema) -> Result { match self { Expr::Column(c) => Ok(DFField::new( - c.relation.as_ref().map(|s| s.as_str()), + c.relation.as_deref(), &c.name, self.get_type(input_schema)?, self.nullable(input_schema)?, @@ -1110,7 +1110,7 @@ pub fn normalize_cols( ) -> Result> { exprs .into_iter() - .map(|e| normalize_col(e.clone(), schemas)) + .map(|e| normalize_col(e, schemas)) .collect() } diff --git a/datafusion/src/optimizer/filter_push_down.rs b/datafusion/src/optimizer/filter_push_down.rs index 76c6ac234b66..37056c7281ac 100644 --- a/datafusion/src/optimizer/filter_push_down.rs +++ b/datafusion/src/optimizer/filter_push_down.rs @@ -278,7 +278,7 @@ fn optimize(plan: &LogicalPlan, mut state: State) -> Result { expr => expr.clone(), }; - projection.insert(field.qualified_name().clone(), expr); + projection.insert(field.qualified_name(), expr); }); // re-write all filters based on this projection diff --git a/datafusion/src/physical_plan/cross_join.rs b/datafusion/src/physical_plan/cross_join.rs index 179ccb3ded86..f6f5da4cf8db 100644 --- a/datafusion/src/physical_plan/cross_join.rs +++ b/datafusion/src/physical_plan/cross_join.rs @@ -68,7 +68,7 @@ impl CrossJoinExec { ) -> Result { let left_schema = left.schema(); let right_schema = right.schema(); - check_join_is_valid(&left_schema, &right_schema, &vec![])?; + check_join_is_valid(&left_schema, &right_schema, &[])?; let left_schema = left.schema(); let left_fields = left_schema.fields().iter(); diff --git a/datafusion/src/physical_plan/hash_join.rs b/datafusion/src/physical_plan/hash_join.rs index 3968061b1dfe..f66303cc318a 100644 --- a/datafusion/src/physical_plan/hash_join.rs +++ b/datafusion/src/physical_plan/hash_join.rs @@ -476,6 +476,7 @@ struct HashJoinStream { is_exhausted: bool, } +#[allow(clippy::too_many_arguments)] impl HashJoinStream { fn new( schema: Arc, @@ -697,11 +698,10 @@ fn build_join_indexes( &keys_values, )? { left_indices.append_value(i)?; - right_indices.append_value(row as u32)?; } else { left_indices.append_null()?; - right_indices.append_value(row as u32)?; } + right_indices.append_value(row as u32)?; } } None => { diff --git a/datafusion/src/physical_plan/hash_utils.rs b/datafusion/src/physical_plan/hash_utils.rs index a2127217b42b..e2fde46960fd 100644 --- a/datafusion/src/physical_plan/hash_utils.rs +++ b/datafusion/src/physical_plan/hash_utils.rs @@ -38,10 +38,12 @@ pub enum JoinType { /// The on clause of the join, as vector of (left, right) columns. pub type JoinOn = Vec<(Column, Column)>; +/// Reference for JoinOn. +pub type JoinOnRef<'a> = &'a [(Column, Column)]; /// Checks whether the schemas "left" and "right" and columns "on" represent a valid join. /// They are valid whenever their columns' intersection equals the set `on` -pub fn check_join_is_valid(left: &Schema, right: &Schema, on: &JoinOn) -> Result<()> { +pub fn check_join_is_valid(left: &Schema, right: &Schema, on: JoinOnRef) -> Result<()> { let left: HashSet = left .fields() .iter() @@ -101,7 +103,7 @@ fn check_join_set_is_valid( pub fn build_join_schema( left: &Schema, right: &Schema, - on: &JoinOn, + on: JoinOnRef, join_type: &JoinType, ) -> Schema { let fields: Vec = match join_type { From 2ef668d4035ad928addf1e525614836560c44279 Mon Sep 17 00:00:00 2001 From: Qingping Hou Date: Sat, 5 Jun 2021 12:28:27 -0700 Subject: [PATCH 14/25] fix sql planner test error checking with matches `format("{:?}", err)` yields different results between stable and nightly rust. --- datafusion/src/sql/planner.rs | 64 +++++++++++++++++------------------ 1 file changed, 32 insertions(+), 32 deletions(-) diff --git a/datafusion/src/sql/planner.rs b/datafusion/src/sql/planner.rs index 64875e4667da..d292a5f674c2 100644 --- a/datafusion/src/sql/planner.rs +++ b/datafusion/src/sql/planner.rs @@ -1618,10 +1618,10 @@ mod tests { fn select_column_does_not_exist() { let sql = "SELECT doesnotexist FROM person"; let err = logical_plan(sql).expect_err("query should have failed"); - assert_eq!( - "Plan(\"No field with unqualified name 'doesnotexist'\")", - format!("{:?}", err) - ); + assert!(matches!( + err, + DataFusionError::Plan(msg) if msg == "No field with unqualified name 'doesnotexist'", + )); } #[test] @@ -1676,20 +1676,20 @@ mod tests { fn select_filter_column_does_not_exist() { let sql = "SELECT first_name FROM person WHERE doesnotexist = 'A'"; let err = logical_plan(sql).expect_err("query should have failed"); - assert_eq!( - "Plan(\"No field with unqualified name 'doesnotexist'\")", - format!("{:?}", err) - ); + assert!(matches!( + err, + DataFusionError::Plan(msg) if msg == "No field with unqualified name 'doesnotexist'", + )); } #[test] fn select_filter_cannot_use_alias() { let sql = "SELECT first_name AS x FROM person WHERE x = 'A'"; let err = logical_plan(sql).expect_err("query should have failed"); - assert_eq!( - "Plan(\"No field with unqualified name 'x'\")", - format!("{:?}", err) - ); + assert!(matches!( + err, + DataFusionError::Plan(msg) if msg == "No field with unqualified name 'x'", + )); } #[test] @@ -2157,10 +2157,10 @@ mod tests { fn select_simple_aggregate_column_does_not_exist() { let sql = "SELECT MIN(doesnotexist) FROM person"; let err = logical_plan(sql).expect_err("query should have failed"); - assert_eq!( - "Plan(\"No field with unqualified name 'doesnotexist'\")", - format!("{:?}", err) - ); + assert!(matches!( + err, + DataFusionError::Plan(msg) if msg == "No field with unqualified name 'doesnotexist'", + )); } #[test] @@ -2247,20 +2247,20 @@ mod tests { fn select_simple_aggregate_with_groupby_and_column_in_group_by_does_not_exist() { let sql = "SELECT SUM(age) FROM person GROUP BY doesnotexist"; let err = logical_plan(sql).expect_err("query should have failed"); - assert_eq!( - "Plan(\"No field with unqualified name 'doesnotexist'\")", - format!("{:?}", err) - ); + assert!(matches!( + err, + DataFusionError::Plan(msg) if msg == "No field with unqualified name 'doesnotexist'", + )); } #[test] fn select_simple_aggregate_with_groupby_and_column_in_aggregate_does_not_exist() { let sql = "SELECT SUM(doesnotexist) FROM person GROUP BY first_name"; let err = logical_plan(sql).expect_err("query should have failed"); - assert_eq!( - "Plan(\"No field with unqualified name 'doesnotexist'\")", - format!("{:?}", err) - ); + assert!(matches!( + err, + DataFusionError::Plan(msg) if msg == "No field with unqualified name 'doesnotexist'", + )); } #[test] @@ -2277,10 +2277,10 @@ mod tests { fn select_unsupported_complex_interval() { let sql = "SELECT INTERVAL '1 year 1 day'"; let err = logical_plan(sql).expect_err("query should have failed"); - assert_eq!( - "NotImplemented(\"DF does not support intervals that have both a Year/Month part as well as Days/Hours/Mins/Seconds: \\\"1 year 1 day\\\". Hint: try breaking the interval into two parts, one with Year/Month and the other with Days/Hours/Mins/Seconds - e.g. (NOW() + INTERVAL '1 year') + INTERVAL '1 day'\")", - format!("{:?}", err) - ); + assert!(matches!( + err, + DataFusionError::NotImplemented(msg) if msg == "DF does not support intervals that have both a Year/Month part as well as Days/Hours/Mins/Seconds: \"1 year 1 day\". Hint: try breaking the interval into two parts, one with Year/Month and the other with Days/Hours/Mins/Seconds - e.g. (NOW() + INTERVAL '1 year') + INTERVAL '1 day'", + )); } #[test] @@ -2297,10 +2297,10 @@ mod tests { fn select_simple_aggregate_with_groupby_cannot_use_alias() { let sql = "SELECT state AS x, MAX(age) FROM person GROUP BY x"; let err = logical_plan(sql).expect_err("query should have failed"); - assert_eq!( - "Plan(\"No field with unqualified name 'x'\")", - format!("{:?}", err) - ); + assert!(matches!( + err, + DataFusionError::Plan(msg) if msg == "No field with unqualified name 'x'", + )); } #[test] From 7b70f04195a291146f094db10b1577199a881136 Mon Sep 17 00:00:00 2001 From: Qingping Hou Date: Sat, 5 Jun 2021 22:19:50 -0700 Subject: [PATCH 15/25] address FIXMEs --- datafusion/src/logical_plan/builder.rs | 3 +- datafusion/src/logical_plan/dfschema.rs | 20 ++++---- datafusion/src/logical_plan/expr.rs | 54 ++++++++------------ datafusion/src/physical_optimizer/pruning.rs | 3 +- datafusion/src/physical_plan/parquet.rs | 1 - 5 files changed, 33 insertions(+), 48 deletions(-) diff --git a/datafusion/src/logical_plan/builder.rs b/datafusion/src/logical_plan/builder.rs index 05319f04e6fc..d80aeb912be8 100644 --- a/datafusion/src/logical_plan/builder.rs +++ b/datafusion/src/logical_plan/builder.rs @@ -169,7 +169,7 @@ impl LogicalPlanBuilder { .collect(), }) .unwrap_or_else(|| { - DFSchema::try_from_qualified(table_name, &schema).unwrap() + DFSchema::try_from_qualified_schema(table_name, &schema).unwrap() }); let table_scan = LogicalPlan::TableScan { @@ -273,7 +273,6 @@ impl LogicalPlanBuilder { .collect::>()?; let right_keys: Vec = right_keys .into_iter() - // FIXME: write a test for this .map(|c| c.into().normalize(&right.all_schemas())) .collect::>()?; let on: Vec<(_, _)> = left_keys.into_iter().zip(right_keys.into_iter()).collect(); diff --git a/datafusion/src/logical_plan/dfschema.rs b/datafusion/src/logical_plan/dfschema.rs index 583c0917dfaf..1d85e0ef0a4e 100644 --- a/datafusion/src/logical_plan/dfschema.rs +++ b/datafusion/src/logical_plan/dfschema.rs @@ -89,8 +89,7 @@ impl DFSchema { } /// Create a `DFSchema` from an Arrow schema - // FIXME: change to a better name? - pub fn try_from_qualified(qualifier: &str, schema: &Schema) -> Result { + pub fn try_from_qualified_schema(qualifier: &str, schema: &Schema) -> Result { Self::new( schema .fields() @@ -323,7 +322,6 @@ impl Into for &DFSchema { /// Create a `DFSchema` from an Arrow schema impl TryFrom for DFSchema { type Error = DataFusionError; - // FIXME: change this to reference of schema fn try_from(schema: Schema) -> std::result::Result { Self::new( schema @@ -509,14 +507,14 @@ mod tests { #[test] fn from_qualified_schema() -> Result<()> { - let schema = DFSchema::try_from_qualified("t1", &test_schema_1())?; + let schema = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; assert_eq!("t1.c0, t1.c1", schema.to_string()); Ok(()) } #[test] fn from_qualified_schema_into_arrow_schema() -> Result<()> { - let schema = DFSchema::try_from_qualified("t1", &test_schema_1())?; + let schema = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; let arrow_schema: Schema = schema.into(); let expected = "Field { name: \"c0\", data_type: Boolean, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: None }, \ Field { name: \"c1\", data_type: Boolean, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: None }"; @@ -526,8 +524,8 @@ mod tests { #[test] fn join_qualified() -> Result<()> { - let left = DFSchema::try_from_qualified("t1", &test_schema_1())?; - let right = DFSchema::try_from_qualified("t2", &test_schema_1())?; + let left = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; + let right = DFSchema::try_from_qualified_schema("t2", &test_schema_1())?; let join = left.join(&right)?; assert_eq!("t1.c0, t1.c1, t2.c0, t2.c1", join.to_string()); // test valid access @@ -542,8 +540,8 @@ mod tests { #[test] fn join_qualified_duplicate() -> Result<()> { - let left = DFSchema::try_from_qualified("t1", &test_schema_1())?; - let right = DFSchema::try_from_qualified("t1", &test_schema_1())?; + let left = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; + let right = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; let join = left.join(&right); assert!(join.is_err()); assert_eq!( @@ -570,7 +568,7 @@ mod tests { #[test] fn join_mixed() -> Result<()> { - let left = DFSchema::try_from_qualified("t1", &test_schema_1())?; + let left = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; let right = DFSchema::try_from(test_schema_2())?; let join = left.join(&right)?; assert_eq!("t1.c0, t1.c1, c100, c101", join.to_string()); @@ -588,7 +586,7 @@ mod tests { #[test] fn join_mixed_duplicate() -> Result<()> { - let left = DFSchema::try_from_qualified("t1", &test_schema_1())?; + let left = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; let right = DFSchema::try_from(test_schema_1())?; let join = left.join(&right); assert!(join.is_err()); diff --git a/datafusion/src/logical_plan/expr.rs b/datafusion/src/logical_plan/expr.rs index bfe494daaa96..4da08c4d2143 100644 --- a/datafusion/src/logical_plan/expr.rs +++ b/datafusion/src/logical_plan/expr.rs @@ -54,27 +54,28 @@ impl Column { } } - /// Deserialize a flat name string into a column - pub fn from_flat_name(flat_name: &str) -> Self { + /// Deserialize a fully qualified name string into a column + pub fn from_qualified_name(flat_name: &str) -> Self { use sqlparser::tokenizer::Token; let dialect = sqlparser::dialect::GenericDialect {}; let mut tokenizer = sqlparser::tokenizer::Tokenizer::new(&dialect, flat_name); - // FIXME: remove unwrap - let tokens = tokenizer.tokenize().unwrap(); - - // any expression that's not in the form of foo.bar will be treated as unqualified - // column name - match tokens.as_slice() { - [Token::Word(relation), Token::Period, Token::Word(name)] => Column { - relation: Some(relation.value.clone()), - name: name.value.clone(), - }, - _ => Column { - relation: None, - name: String::from(flat_name), - }, + if let Ok(tokens) = tokenizer.tokenize() { + if let [Token::Word(relation), Token::Period, Token::Word(name)] = + tokens.as_slice() + { + return Column { + relation: Some(relation.value.clone()), + name: name.value.clone(), + }; + } } + // any expression that's not in the form of `foo.bar` will be treated as unqualified column + // name + return Column { + relation: None, + name: String::from(flat_name), + }; } /// Serialize column into a flat name string @@ -106,7 +107,7 @@ impl Column { impl From<&str> for Column { fn from(c: &str) -> Self { - Self::from_flat_name(c) + Self::from_qualified_name(c) } } @@ -1093,22 +1094,11 @@ pub fn normalize_col(e: Expr, schemas: &[&DFSchemaRef]) -> Result { impl<'a, 'b> ExprRewriter for ColumnNormalizer<'a, 'b> { fn mutate(&mut self, expr: Expr) -> Result { - if let Expr::Column(ref c) = expr { - // FIXME: reuse ColumnNormalizer::normalize? - if c.relation.is_none() { - for schema in self.schemas { - if let Ok(field) = schema.field_with_unqualified_name(&c.name) { - return Ok(Expr::Column(field.qualified_column())); - } - } - return Err(DataFusionError::Plan(format!( - "Column {} not found in provided schemas", - c.name, - ))); - } + if let Expr::Column(c) = expr { + Ok(Expr::Column(c.normalize(self.schemas)?)) + } else { + Ok(expr) } - - Ok(expr) } } diff --git a/datafusion/src/physical_optimizer/pruning.rs b/datafusion/src/physical_optimizer/pruning.rs index 116274175021..0035296be17f 100644 --- a/datafusion/src/physical_optimizer/pruning.rs +++ b/datafusion/src/physical_optimizer/pruning.rs @@ -199,7 +199,7 @@ impl PruningPredicate { #[derive(Debug, Default, Clone)] struct RequiredStatColumns { /// The statistics required to evaluate this predicate: - /// * The column name in the input schema + /// * The unqualified column in the input schema /// * Statistics type (e.g. Min or Max) /// * The field the statistics value should be placed in for /// pruning predicate evaluation @@ -465,7 +465,6 @@ fn build_single_column_expr( required_columns: &mut RequiredStatColumns, is_not: bool, // if true, treat as !col ) -> Option { - // FIXME(houqp): change logical column to physical column? let field = schema.field_with_name(&column.name).ok()?; if matches!(field.data_type(), &DataType::Boolean) { diff --git a/datafusion/src/physical_plan/parquet.rs b/datafusion/src/physical_plan/parquet.rs index a63139e0ac93..3d20a9bf98c1 100644 --- a/datafusion/src/physical_plan/parquet.rs +++ b/datafusion/src/physical_plan/parquet.rs @@ -497,7 +497,6 @@ macro_rules! get_statistic { // Extract the min or max value calling `func` or `bytes_func` on the ParquetStatistics as appropriate macro_rules! get_min_max_values { ($self:expr, $column:expr, $func:ident, $bytes_func:ident) => {{ - // FIXME: change to use physical column? let (column_index, field) = if let Some((v, f)) = $self.parquet_schema.column_with_name(&$column.name) { (v, f) } else { From fd3005fedc2f8538f697c4a9a05c9858b1796b6e Mon Sep 17 00:00:00 2001 From: Qingping Hou Date: Sat, 12 Jun 2021 01:10:44 -0700 Subject: [PATCH 16/25] honor datafusion field name semantic strip qualifer name in physical field names --- datafusion/src/execution/context.rs | 318 ++++++++++++------------ datafusion/src/lib.rs | 10 +- datafusion/src/logical_plan/builder.rs | 8 +- datafusion/src/physical_plan/planner.rs | 47 +++- datafusion/tests/sql.rs | 2 +- 5 files changed, 212 insertions(+), 173 deletions(-) diff --git a/datafusion/src/execution/context.rs b/datafusion/src/execution/context.rs index 847ad42129fd..029bae42adbf 100644 --- a/datafusion/src/execution/context.rs +++ b/datafusion/src/execution/context.rs @@ -1269,15 +1269,15 @@ mod tests { assert_eq!(results.len(), 1); let expected = vec![ - "+----+----+--------------+----------------+--------------+--------------+--------------+", - "| c1 | c2 | SUM(test.c2) | COUNT(test.c2) | MAX(test.c2) | MIN(test.c2) | AVG(test.c2) |", - "+----+----+--------------+----------------+--------------+--------------+--------------+", - "| 0 | 1 | 220 | 40 | 10 | 1 | 5.5 |", - "| 0 | 2 | 220 | 40 | 10 | 1 | 5.5 |", - "| 0 | 3 | 220 | 40 | 10 | 1 | 5.5 |", - "| 0 | 4 | 220 | 40 | 10 | 1 | 5.5 |", - "| 0 | 5 | 220 | 40 | 10 | 1 | 5.5 |", - "+----+----+--------------+----------------+--------------+--------------+--------------+", + "+----+----+---------+-----------+---------+---------+---------+", + "| c1 | c2 | SUM(c2) | COUNT(c2) | MAX(c2) | MIN(c2) | AVG(c2) |", + "+----+----+---------+-----------+---------+---------+---------+", + "| 0 | 1 | 220 | 40 | 10 | 1 | 5.5 |", + "| 0 | 2 | 220 | 40 | 10 | 1 | 5.5 |", + "| 0 | 3 | 220 | 40 | 10 | 1 | 5.5 |", + "| 0 | 4 | 220 | 40 | 10 | 1 | 5.5 |", + "| 0 | 5 | 220 | 40 | 10 | 1 | 5.5 |", + "+----+----+---------+-----------+---------+---------+---------+", ]; // window function shall respect ordering @@ -1291,11 +1291,11 @@ mod tests { assert_eq!(results.len(), 1); let expected = vec![ - "+--------------+--------------+", - "| SUM(test.c1) | SUM(test.c2) |", - "+--------------+--------------+", - "| 60 | 220 |", - "+--------------+--------------+", + "+---------+---------+", + "| SUM(c1) | SUM(c2) |", + "+---------+---------+", + "| 60 | 220 |", + "+---------+---------+", ]; assert_batches_sorted_eq!(expected, &results); @@ -1312,11 +1312,11 @@ mod tests { assert_eq!(results.len(), 1); let expected = vec![ - "+--------------+--------------+", - "| SUM(test.c1) | SUM(test.c2) |", - "+--------------+--------------+", - "| | |", - "+--------------+--------------+", + "+---------+---------+", + "| SUM(c1) | SUM(c2) |", + "+---------+---------+", + "| | |", + "+---------+---------+", ]; assert_batches_sorted_eq!(expected, &results); @@ -1329,11 +1329,11 @@ mod tests { assert_eq!(results.len(), 1); let expected = vec![ - "+--------------+--------------+", - "| AVG(test.c1) | AVG(test.c2) |", - "+--------------+--------------+", - "| 1.5 | 5.5 |", - "+--------------+--------------+", + "+---------+---------+", + "| AVG(c1) | AVG(c2) |", + "+---------+---------+", + "| 1.5 | 5.5 |", + "+---------+---------+", ]; assert_batches_sorted_eq!(expected, &results); @@ -1346,11 +1346,11 @@ mod tests { assert_eq!(results.len(), 1); let expected = vec![ - "+--------------+--------------+", - "| MAX(test.c1) | MAX(test.c2) |", - "+--------------+--------------+", - "| 3 | 10 |", - "+--------------+--------------+", + "+---------+---------+", + "| MAX(c1) | MAX(c2) |", + "+---------+---------+", + "| 3 | 10 |", + "+---------+---------+", ]; assert_batches_sorted_eq!(expected, &results); @@ -1363,11 +1363,11 @@ mod tests { assert_eq!(results.len(), 1); let expected = vec![ - "+--------------+--------------+", - "| MIN(test.c1) | MIN(test.c2) |", - "+--------------+--------------+", - "| 0 | 1 |", - "+--------------+--------------+", + "+---------+---------+", + "| MIN(c1) | MIN(c2) |", + "+---------+---------+", + "| 0 | 1 |", + "+---------+---------+", ]; assert_batches_sorted_eq!(expected, &results); @@ -1379,14 +1379,14 @@ mod tests { let results = execute("SELECT c1, SUM(c2) FROM test GROUP BY c1", 4).await?; let expected = vec![ - "+----+--------------+", - "| c1 | SUM(test.c2) |", - "+----+--------------+", - "| 0 | 55 |", - "| 1 | 55 |", - "| 2 | 55 |", - "| 3 | 55 |", - "+----+--------------+", + "+----+---------+", + "| c1 | SUM(c2) |", + "+----+---------+", + "| 0 | 55 |", + "| 1 | 55 |", + "| 2 | 55 |", + "| 3 | 55 |", + "+----+---------+", ]; assert_batches_sorted_eq!(expected, &results); @@ -1398,14 +1398,14 @@ mod tests { let results = execute("SELECT c1, AVG(c2) FROM test GROUP BY c1", 4).await?; let expected = vec![ - "+----+--------------+", - "| c1 | AVG(test.c2) |", - "+----+--------------+", - "| 0 | 5.5 |", - "| 1 | 5.5 |", - "| 2 | 5.5 |", - "| 3 | 5.5 |", - "+----+--------------+", + "+----+---------+", + "| c1 | AVG(c2) |", + "+----+---------+", + "| 0 | 5.5 |", + "| 1 | 5.5 |", + "| 2 | 5.5 |", + "| 3 | 5.5 |", + "+----+---------+", ]; assert_batches_sorted_eq!(expected, &results); @@ -1449,14 +1449,14 @@ mod tests { let results = execute("SELECT c1, MAX(c2) FROM test GROUP BY c1", 4).await?; let expected = vec![ - "+----+--------------+", - "| c1 | MAX(test.c2) |", - "+----+--------------+", - "| 0 | 10 |", - "| 1 | 10 |", - "| 2 | 10 |", - "| 3 | 10 |", - "+----+--------------+", + "+----+---------+", + "| c1 | MAX(c2) |", + "+----+---------+", + "| 0 | 10 |", + "| 1 | 10 |", + "| 2 | 10 |", + "| 3 | 10 |", + "+----+---------+", ]; assert_batches_sorted_eq!(expected, &results); @@ -1468,14 +1468,14 @@ mod tests { let results = execute("SELECT c1, MIN(c2) FROM test GROUP BY c1", 4).await?; let expected = vec![ - "+----+--------------+", - "| c1 | MIN(test.c2) |", - "+----+--------------+", - "| 0 | 1 |", - "| 1 | 1 |", - "| 2 | 1 |", - "| 3 | 1 |", - "+----+--------------+", + "+----+---------+", + "| c1 | MIN(c2) |", + "+----+---------+", + "| 0 | 1 |", + "| 1 | 1 |", + "| 2 | 1 |", + "| 3 | 1 |", + "+----+---------+", ]; assert_batches_sorted_eq!(expected, &results); @@ -1516,11 +1516,11 @@ mod tests { .unwrap(); let expected = vec![ - "+----------------+-----------------+-----------------+---------------+", - "| COUNT(t.nanos) | COUNT(t.micros) | COUNT(t.millis) | COUNT(t.secs) |", - "+----------------+-----------------+-----------------+---------------+", - "| 3 | 3 | 3 | 3 |", - "+----------------+-----------------+-----------------+---------------+", + "+--------------+---------------+---------------+-------------+", + "| COUNT(nanos) | COUNT(micros) | COUNT(millis) | COUNT(secs) |", + "+--------------+---------------+---------------+-------------+", + "| 3 | 3 | 3 | 3 |", + "+--------------+---------------+---------------+-------------+", ]; assert_batches_sorted_eq!(expected, &results); @@ -1543,7 +1543,7 @@ mod tests { let expected = vec![ "+----------------------------+----------------------------+-------------------------+---------------------+", - "| MIN(t.nanos) | MIN(t.micros) | MIN(t.millis) | MIN(t.secs) |", + "| MIN(nanos) | MIN(micros) | MIN(millis) | MIN(secs) |", "+----------------------------+----------------------------+-------------------------+---------------------+", "| 2011-12-13 11:13:10.123450 | 2011-12-13 11:13:10.123450 | 2011-12-13 11:13:10.123 | 2011-12-13 11:13:10 |", "+----------------------------+----------------------------+-------------------------+---------------------+", @@ -1569,7 +1569,7 @@ mod tests { let expected = vec![ "+-------------------------+-------------------------+-------------------------+---------------------+", - "| MAX(t.nanos) | MAX(t.micros) | MAX(t.millis) | MAX(t.secs) |", + "| MAX(nanos) | MAX(micros) | MAX(millis) | MAX(secs) |", "+-------------------------+-------------------------+-------------------------+---------------------+", "| 2021-01-01 05:11:10.432 | 2021-01-01 05:11:10.432 | 2021-01-01 05:11:10.432 | 2021-01-01 05:11:10 |", "+-------------------------+-------------------------+-------------------------+---------------------+", @@ -1620,11 +1620,11 @@ mod tests { assert_eq!(results.len(), 1); let expected = vec![ - "+----------------+----------------+", - "| COUNT(test.c1) | COUNT(test.c2) |", - "+----------------+----------------+", - "| 10 | 10 |", - "+----------------+----------------+", + "+-----------+-----------+", + "| COUNT(c1) | COUNT(c2) |", + "+-----------+-----------+", + "| 10 | 10 |", + "+-----------+-----------+", ]; assert_batches_sorted_eq!(expected, &results); Ok(()) @@ -1636,11 +1636,11 @@ mod tests { assert_eq!(results.len(), 1); let expected = vec![ - "+----------------+----------------+", - "| COUNT(test.c1) | COUNT(test.c2) |", - "+----------------+----------------+", - "| 40 | 40 |", - "+----------------+----------------+", + "+-----------+-----------+", + "| COUNT(c1) | COUNT(c2) |", + "+-----------+-----------+", + "| 40 | 40 |", + "+-----------+-----------+", ]; assert_batches_sorted_eq!(expected, &results); Ok(()) @@ -1651,14 +1651,14 @@ mod tests { let results = execute("SELECT c1, COUNT(c2) FROM test GROUP BY c1", 4).await?; let expected = vec![ - "+----+----------------+", - "| c1 | COUNT(test.c2) |", - "+----+----------------+", - "| 0 | 10 |", - "| 1 | 10 |", - "| 2 | 10 |", - "| 3 | 10 |", - "+----+----------------+", + "+----+-----------+", + "| c1 | COUNT(c2) |", + "+----+-----------+", + "| 0 | 10 |", + "| 1 | 10 |", + "| 2 | 10 |", + "| 3 | 10 |", + "+----+-----------+", ]; assert_batches_sorted_eq!(expected, &results); Ok(()) @@ -1702,12 +1702,12 @@ mod tests { ).await?; let expected = vec![ - "+---------------------+--------------+", - "| week | SUM(test.c2) |", - "+---------------------+--------------+", - "| 2020-12-07 00:00:00 | 24 |", - "| 2020-12-14 00:00:00 | 156 |", - "+---------------------+--------------+", + "+---------------------+---------+", + "| week | SUM(c2) |", + "+---------------------+---------+", + "| 2020-12-07 00:00:00 | 24 |", + "| 2020-12-14 00:00:00 | 156 |", + "+---------------------+---------+", ]; assert_batches_sorted_eq!(expected, &results); @@ -1753,13 +1753,13 @@ mod tests { .expect("ran plan correctly"); let expected = vec![ - "+-----+--------------+", - "| str | COUNT(t.val) |", - "+-----+--------------+", - "| A | 4 |", - "| B | 1 |", - "| C | 1 |", - "+-----+--------------+", + "+-----+------------+", + "| str | COUNT(val) |", + "+-----+------------+", + "| A | 4 |", + "| B | 1 |", + "| C | 1 |", + "+-----+------------+", ]; assert_batches_sorted_eq!(expected, &results); } @@ -1804,13 +1804,13 @@ mod tests { .expect("ran plan correctly"); let expected = vec![ - "+------+--------------+", - "| dict | COUNT(t.val) |", - "+------+--------------+", - "| A | 4 |", - "| B | 1 |", - "| C | 1 |", - "+------+--------------+", + "+------+------------+", + "| dict | COUNT(val) |", + "+------+------------+", + "| A | 4 |", + "| B | 1 |", + "| C | 1 |", + "+------+------------+", ]; assert_batches_sorted_eq!(expected, &results); @@ -1821,13 +1821,13 @@ mod tests { .expect("ran plan correctly"); let expected = vec![ - "+-----+---------------+", - "| val | COUNT(t.dict) |", - "+-----+---------------+", - "| 1 | 3 |", - "| 2 | 2 |", - "| 4 | 1 |", - "+-----+---------------+", + "+-----+-------------+", + "| val | COUNT(dict) |", + "+-----+-------------+", + "| 1 | 3 |", + "| 2 | 2 |", + "| 4 | 1 |", + "+-----+-------------+", ]; assert_batches_sorted_eq!(expected, &results); @@ -1840,13 +1840,13 @@ mod tests { .expect("ran plan correctly"); let expected = vec![ - "+-----+------------------------+", - "| val | COUNT(DISTINCT t.dict) |", - "+-----+------------------------+", - "| 1 | 2 |", - "| 2 | 2 |", - "| 4 | 1 |", - "+-----+------------------------+", + "+-----+----------------------+", + "| val | COUNT(DISTINCT dict) |", + "+-----+----------------------+", + "| 1 | 2 |", + "| 2 | 2 |", + "| 4 | 1 |", + "+-----+----------------------+", ]; assert_batches_sorted_eq!(expected, &results); } @@ -1945,13 +1945,13 @@ mod tests { let results = run_count_distinct_integers_aggregated_scenario(partitions).await?; let expected = vec![ - "+---------+----------------------+-----------------------------+------------------------------+------------------------------+------------------------------+------------------------------+-------------------------------+-------------------------------+-------------------------------+", - "| c_group | COUNT(test.c_uint64) | COUNT(DISTINCT test.c_int8) | COUNT(DISTINCT test.c_int16) | COUNT(DISTINCT test.c_int32) | COUNT(DISTINCT test.c_int64) | COUNT(DISTINCT test.c_uint8) | COUNT(DISTINCT test.c_uint16) | COUNT(DISTINCT test.c_uint32) | COUNT(DISTINCT test.c_uint64) |", - "+---------+----------------------+-----------------------------+------------------------------+------------------------------+------------------------------+------------------------------+-------------------------------+-------------------------------+-------------------------------+", - "| a | 3 | 2 | 2 | 2 | 2 | 2 | 2 | 2 | 2 |", - "| b | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 |", - "| c | 3 | 2 | 2 | 2 | 2 | 2 | 2 | 2 | 2 |", - "+---------+----------------------+-----------------------------+------------------------------+------------------------------+------------------------------+------------------------------+-------------------------------+-------------------------------+-------------------------------+", + "+---------+-----------------+------------------------+-------------------------+-------------------------+-------------------------+-------------------------+--------------------------+--------------------------+--------------------------+", + "| c_group | COUNT(c_uint64) | COUNT(DISTINCT c_int8) | COUNT(DISTINCT c_int16) | COUNT(DISTINCT c_int32) | COUNT(DISTINCT c_int64) | COUNT(DISTINCT c_uint8) | COUNT(DISTINCT c_uint16) | COUNT(DISTINCT c_uint32) | COUNT(DISTINCT c_uint64) |", + "+---------+-----------------+------------------------+-------------------------+-------------------------+-------------------------+-------------------------+--------------------------+--------------------------+--------------------------+", + "| a | 3 | 2 | 2 | 2 | 2 | 2 | 2 | 2 | 2 |", + "| b | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 |", + "| c | 3 | 2 | 2 | 2 | 2 | 2 | 2 | 2 | 2 |", + "+---------+-----------------+------------------------+-------------------------+-------------------------+-------------------------+-------------------------+--------------------------+--------------------------+--------------------------+", ]; assert_batches_sorted_eq!(expected, &results); @@ -1971,13 +1971,13 @@ mod tests { let results = run_count_distinct_integers_aggregated_scenario(partitions).await?; let expected = vec![ - "+---------+----------------------+-----------------------------+------------------------------+------------------------------+------------------------------+------------------------------+-------------------------------+-------------------------------+-------------------------------+", - "| c_group | COUNT(test.c_uint64) | COUNT(DISTINCT test.c_int8) | COUNT(DISTINCT test.c_int16) | COUNT(DISTINCT test.c_int32) | COUNT(DISTINCT test.c_int64) | COUNT(DISTINCT test.c_uint8) | COUNT(DISTINCT test.c_uint16) | COUNT(DISTINCT test.c_uint32) | COUNT(DISTINCT test.c_uint64) |", - "+---------+----------------------+-----------------------------+------------------------------+------------------------------+------------------------------+------------------------------+-------------------------------+-------------------------------+-------------------------------+", - "| a | 5 | 3 | 3 | 3 | 3 | 3 | 3 | 3 | 3 |", - "| b | 5 | 4 | 4 | 4 | 4 | 4 | 4 | 4 | 4 |", - "| c | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 |", - "+---------+----------------------+-----------------------------+------------------------------+------------------------------+------------------------------+------------------------------+-------------------------------+-------------------------------+-------------------------------+", + "+---------+-----------------+------------------------+-------------------------+-------------------------+-------------------------+-------------------------+--------------------------+--------------------------+--------------------------+", + "| c_group | COUNT(c_uint64) | COUNT(DISTINCT c_int8) | COUNT(DISTINCT c_int16) | COUNT(DISTINCT c_int32) | COUNT(DISTINCT c_int64) | COUNT(DISTINCT c_uint8) | COUNT(DISTINCT c_uint16) | COUNT(DISTINCT c_uint32) | COUNT(DISTINCT c_uint64) |", + "+---------+-----------------+------------------------+-------------------------+-------------------------+-------------------------+-------------------------+--------------------------+--------------------------+--------------------------+", + "| a | 5 | 3 | 3 | 3 | 3 | 3 | 3 | 3 | 3 |", + "| b | 5 | 4 | 4 | 4 | 4 | 4 | 4 | 4 | 4 |", + "| c | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 |", + "+---------+-----------------+------------------------+-------------------------+-------------------------+-------------------------+-------------------------+--------------------------+--------------------------+--------------------------+", ]; assert_batches_sorted_eq!(expected, &results); @@ -2172,11 +2172,11 @@ mod tests { .unwrap(); let expected = vec![ - "+----------+", - "| MAX(t.i) |", - "+----------+", - "| 1 |", - "+----------+", + "+--------+", + "| MAX(i) |", + "+--------+", + "| 1 |", + "+--------+", ]; let results = plan_and_collect(&mut ctx, "SELECT max(i) FROM t") @@ -2235,11 +2235,11 @@ mod tests { let result = plan_and_collect(&mut ctx, "SELECT \"MY_AVG\"(i) FROM t").await?; let expected = vec![ - "+-------------+", - "| MY_AVG(t.i) |", - "+-------------+", - "| 1 |", - "+-------------+", + "+-----------+", + "| MY_AVG(i) |", + "+-----------+", + "| 1 |", + "+-----------+", ]; assert_batches_eq!(expected, &result); @@ -2335,11 +2335,11 @@ mod tests { assert_eq!(results.len(), 1); let expected = vec![ - "+--------------+--------------+-----------------+", - "| SUM(test.c1) | SUM(test.c2) | COUNT(UInt8(1)) |", - "+--------------+--------------+-----------------+", - "| 10 | 110 | 20 |", - "+--------------+--------------+-----------------+", + "+---------+---------+-----------------+", + "| SUM(c1) | SUM(c2) | COUNT(UInt8(1)) |", + "+---------+---------+-----------------+", + "| 10 | 110 | 20 |", + "+---------+---------+-----------------+", ]; assert_batches_eq!(expected, &results); @@ -2564,11 +2564,11 @@ mod tests { let result = plan_and_collect(&mut ctx, "SELECT MY_AVG(a) FROM t").await?; let expected = vec![ - "+-------------+", - "| my_avg(t.a) |", - "+-------------+", - "| 3 |", - "+-------------+", + "+-----------+", + "| my_avg(a) |", + "+-----------+", + "| 3 |", + "+-----------+", ]; assert_batches_eq!(expected, &result); diff --git a/datafusion/src/lib.rs b/datafusion/src/lib.rs index 56f01eb6a3ba..e4501a78ada4 100644 --- a/datafusion/src/lib.rs +++ b/datafusion/src/lib.rs @@ -95,11 +95,11 @@ //! let pretty_results = arrow::util::pretty::pretty_format_batches(&results)?; //! //! let expected = vec![ -//! "+---+----------------+", -//! "| a | MIN(example.b) |", -//! "+---+----------------+", -//! "| 1 | 2 |", -//! "+---+----------------+" +//! "+---+--------+", +//! "| a | MIN(b) |", +//! "+---+--------+", +//! "| 1 | 2 |", +//! "+---+--------+" //! ]; //! //! assert_eq!(pretty_results.trim().lines().collect::>(), expected); diff --git a/datafusion/src/logical_plan/builder.rs b/datafusion/src/logical_plan/builder.rs index d80aeb912be8..59e49175f6ba 100644 --- a/datafusion/src/logical_plan/builder.rs +++ b/datafusion/src/logical_plan/builder.rs @@ -195,15 +195,17 @@ impl LogicalPlanBuilder { let all_schemas = self.plan.all_schemas(); let mut projected_expr = vec![]; for e in expr { - let normalized_e = normalize_col(e, &all_schemas)?; - match normalized_e { + match e { Expr::Wildcard => { (0..input_schema.fields().len()).for_each(|i| { projected_expr .push(Expr::Column(input_schema.field(i).qualified_column())) }); } - _ => projected_expr.push(columnize_expr(normalized_e, input_schema)), + _ => projected_expr.push(columnize_expr( + normalize_col(e, &all_schemas)?, + input_schema, + )), } } diff --git a/datafusion/src/physical_plan/planner.rs b/datafusion/src/physical_plan/planner.rs index 11161da2facf..58c6f0a57306 100644 --- a/datafusion/src/physical_plan/planner.rs +++ b/datafusion/src/physical_plan/planner.rs @@ -67,6 +67,7 @@ fn create_function_physical_name( .iter() .map(|e| physical_name(e, input_schema)) .collect::>()?; + let distinct_str = match distinct { true => "DISTINCT ", false => "", @@ -133,6 +134,9 @@ fn physical_name(e: &Expr, input_schema: &DFSchema) -> Result { Expr::ScalarUDF { fun, args, .. } => { create_function_physical_name(&fun.name, false, args, input_schema) } + Expr::WindowFunction { fun, args, .. } => { + create_function_physical_name(&fun.to_string(), false, args, input_schema) + } Expr::AggregateFunction { fun, distinct, @@ -162,7 +166,7 @@ fn physical_name(e: &Expr, input_schema: &DFSchema) -> Result { } } other => Err(DataFusionError::NotImplemented(format!( - "Physical plan does not support logical expression {:?}", + "Cannot derive physical field name for logical expression {:?}", other ))), } @@ -382,9 +386,38 @@ impl DefaultPhysicalPlanner { LogicalPlan::Projection { input, expr, .. } => { let input_exec = self.create_initial_plan(input, ctx_state)?; let input_schema = input.as_ref().schema(); - let runtime_expr = expr + + let physical_exprs = expr .iter() .map(|e| { + // For projections, SQL planner and logical plan builder may convert user + // provided expressions into logical Column expressions if their results + // are already provided from the input plans. Because we work with + // qualified columns in logical plane, derived columns involve operators or + // functions will contain qualifers as well. This will result in logical + // columns with names like `SUM(t1.c1)`, `t1.c1 + t1.c2`, etc. + // + // If we run these logical columns through physical_name function, we will + // get physical names with column qualifiers, which violates Datafusion's + // field name semantics. To account for this, we need to derive the + // physical name from physical input instead. + // + // This depends on the invariant that logical schema field index MUST match + // with physical schema field index. + let physical_name = if let Expr::Column(col) = e { + match input_schema.index_of_column(&col) { + Ok(idx) => { + // index physical field using logical field index + Ok(input_exec.schema().field(idx).name().to_string()) + } + // logical column is not a derived column, safe to pass along to + // physical_name + Err(_) => physical_name(e, &input_schema), + } + } else { + physical_name(e, &input_schema) + }; + tuple_err(( self.create_physical_expr( e, @@ -392,11 +425,15 @@ impl DefaultPhysicalPlanner { input_schema, &ctx_state, ), - physical_name(e, &input_schema), + physical_name, )) }) .collect::>>()?; - Ok(Arc::new(ProjectionExec::try_new(runtime_expr, input_exec)?)) + + Ok(Arc::new(ProjectionExec::try_new( + physical_exprs, + input_exec, + )?)) } LogicalPlan::Filter { input, predicate, .. @@ -946,7 +983,7 @@ impl DefaultPhysicalPlanner { // unpack aliased logical expressions, e.g. "sum(col) over () as total" let (name, e) = match e { Expr::Alias(sub_expr, alias) => (alias.clone(), sub_expr.as_ref()), - _ => (e.name(logical_input_schema)?, e), + _ => (physical_name(e, logical_input_schema)?, e), }; match e { diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs index 6f99fafb2179..98c732441448 100644 --- a/datafusion/tests/sql.rs +++ b/datafusion/tests/sql.rs @@ -3371,7 +3371,7 @@ async fn test_physical_plan_display_indent() { "GlobalLimitExec: limit=10", " SortExec: [the_min@2 DESC]", " MergeExec", - " ProjectionExec: expr=[c1@0 as c1, MAX(aggregate_test_100.c12)@1 as MAX(aggregate_test_100.c12), MIN(aggregate_test_100.c12)@2 as the_min]", + " ProjectionExec: expr=[c1@0 as c1, MAX(aggregate_test_100.c12)@1 as MAX(c12), MIN(aggregate_test_100.c12)@2 as the_min]", " HashAggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[MAX(c12), MIN(c12)]", " CoalesceBatchesExec: target_batch_size=4096", " RepartitionExec: partitioning=Hash([Column { name: \"c1\", index: 0 }], 3)", From 071f86b7612291137bbd9d3c00c32228334f12ba Mon Sep 17 00:00:00 2001 From: Qingping Hou Date: Sat, 12 Jun 2021 14:16:48 -0700 Subject: [PATCH 17/25] add more comment --- datafusion/src/logical_plan/expr.rs | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/datafusion/src/logical_plan/expr.rs b/datafusion/src/logical_plan/expr.rs index f568a3100225..eae542f92cf1 100644 --- a/datafusion/src/logical_plan/expr.rs +++ b/datafusion/src/logical_plan/expr.rs @@ -1081,7 +1081,21 @@ pub fn col(ident: &str) -> Expr { Expr::Column(ident.into()) } -/// Convert an expression into Column expression if it's already provided as input +/// Convert an expression into Column expression if it's already provided as input plan. +/// +/// For example, it rewrites: +/// +/// ```ignore +/// .aggregate(vec![col("c1")], vec![sum(col("c2"))])? +/// .project(vec![col("c1"), sum(col("c2"))? +/// ``` +/// +/// Into: +/// +/// ```ignore +/// .aggregate(vec![col("c1")], vec![sum(col("c2"))])? +/// .project(vec![col("c1"), col("SUM(#c2)")? +/// ``` pub fn columnize_expr(e: Expr, input_schema: &DFSchema) -> Expr { match e { Expr::Column(_) => e, From 80a51688de3eeba87d4636eab243230e1272f242 Mon Sep 17 00:00:00 2001 From: Qingping Hou Date: Sat, 12 Jun 2021 14:38:28 -0700 Subject: [PATCH 18/25] enable more queries in benchmark/run.sh --- benchmarks/run.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/run.sh b/benchmarks/run.sh index 8e36424da89f..21633d39c23a 100755 --- a/benchmarks/run.sh +++ b/benchmarks/run.sh @@ -20,7 +20,7 @@ set -e # This bash script is meant to be run inside the docker-compose environment. Check the README for instructions cd / -for query in 1 3 5 6 10 12 +for query in 1 3 5 6 7 8 9 10 12 do /tpch benchmark ballista --host ballista-scheduler --port 50050 --query $query --path /data --format tbl --iterations 1 --debug done From 713fbe1b74321f83d478bbc0ad86450428b55eaa Mon Sep 17 00:00:00 2001 From: Qingping Hou Date: Sat, 12 Jun 2021 16:49:27 -0700 Subject: [PATCH 19/25] use unzip to avoid unnecessary iterators --- datafusion/src/sql/planner.rs | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/datafusion/src/sql/planner.rs b/datafusion/src/sql/planner.rs index 6d49e8c3000b..e2b89aa0e86b 100644 --- a/datafusion/src/sql/planner.rs +++ b/datafusion/src/sql/planner.rs @@ -370,12 +370,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // extract join keys extract_join_keys(&expr, &mut keys)?; - // TODO: avoid two iterations - let left_keys: Vec = - keys.iter().map(|pair| pair.0.clone()).collect(); - let right_keys: Vec = - keys.iter().map(|pair| pair.1.clone()).collect(); + let (left_keys, right_keys): (Vec, Vec) = + keys.into_iter().unzip(); // return the logical plan representing the join LogicalPlanBuilder::from(&left) .join(&right, join_type, left_keys, right_keys)? @@ -483,12 +480,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { if left_schema.field_from_qualified_column(l).is_ok() && right_schema.field_from_qualified_column(r).is_ok() { - // TODO: avoid clone here join_keys.push((l.clone(), r.clone())); } else if left_schema.field_from_qualified_column(r).is_ok() && right_schema.field_from_qualified_column(l).is_ok() { - // TODO: avoid clone here join_keys.push((r.clone(), l.clone())); } } @@ -497,10 +492,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { LogicalPlanBuilder::from(&left).cross_join(right)?.build()?; } else { let left_keys: Vec = - // TODO: avoid clone here join_keys.iter().map(|(l, _)| l.clone()).collect(); let right_keys: Vec = - // TODO: avoid clone here join_keys.iter().map(|(_, r)| r.clone()).collect(); let builder = LogicalPlanBuilder::from(&left); left = builder @@ -1542,7 +1535,6 @@ fn remove_join_expressions( match expr { Expr::BinaryExpr { left, op, right } => match op { Operator::Eq => match (left.as_ref(), right.as_ref()) { - // TODO: avoid clones (Expr::Column(l), Expr::Column(r)) => { if join_columns.contains(&(l.clone(), r.clone())) || join_columns.contains(&(r.clone(), l.clone())) From e4677b99c570741d2a45b2fe9ec9ef5ceca5140c Mon Sep 17 00:00:00 2001 From: Qingping Hou Date: Sat, 12 Jun 2021 17:08:08 -0700 Subject: [PATCH 20/25] reduce diff by discarding style related changes --- datafusion/src/optimizer/utils.rs | 12 ++++--- .../physical_optimizer/coalesce_batches.rs | 1 - datafusion/src/physical_optimizer/pruning.rs | 32 ++++++------------- 3 files changed, 17 insertions(+), 28 deletions(-) diff --git a/datafusion/src/optimizer/utils.rs b/datafusion/src/optimizer/utils.rs index ff6a9acbc499..c9447e526981 100644 --- a/datafusion/src/optimizer/utils.rs +++ b/datafusion/src/optimizer/utils.rs @@ -305,8 +305,10 @@ pub fn expr_sub_expressions(expr: &Expr) -> Result> { } Expr::Cast { expr, .. } => Ok(vec![expr.as_ref().to_owned()]), Expr::TryCast { expr, .. } => Ok(vec![expr.as_ref().to_owned()]), - Expr::Column(_) | Expr::Literal(_) | Expr::ScalarVariable(_) => Ok(vec![]), + Expr::Column(_) => Ok(vec![]), Expr::Alias(expr, ..) => Ok(vec![expr.as_ref().to_owned()]), + Expr::Literal(_) => Ok(vec![]), + Expr::ScalarVariable(_) => Ok(vec![]), Expr::Not(expr) => Ok(vec![expr.as_ref().to_owned()]), Expr::Negative(expr) => Ok(vec![expr.as_ref().to_owned()]), Expr::Sort { expr, .. } => Ok(vec![expr.as_ref().to_owned()]), @@ -451,6 +453,9 @@ pub fn rewrite_expression(expr: &Expr, expressions: &[Expr]) -> Result { } Expr::Not(_) => Ok(Expr::Not(Box::new(expressions[0].clone()))), Expr::Negative(_) => Ok(Expr::Negative(Box::new(expressions[0].clone()))), + Expr::Column(_) => Ok(expr.clone()), + Expr::Literal(_) => Ok(expr.clone()), + Expr::ScalarVariable(_) => Ok(expr.clone()), Expr::Sort { asc, nulls_first, .. } => Ok(Expr::Sort { @@ -479,13 +484,10 @@ pub fn rewrite_expression(expr: &Expr, expressions: &[Expr]) -> Result { Ok(expr) } } + Expr::InList { .. } => Ok(expr.clone()), Expr::Wildcard { .. } => Err(DataFusionError::Internal( "Wildcard expressions are not valid in a logical query plan".to_owned(), )), - Expr::InList { .. } - | Expr::Column(_) - | Expr::Literal(_) - | Expr::ScalarVariable(_) => Ok(expr.clone()), } } diff --git a/datafusion/src/physical_optimizer/coalesce_batches.rs b/datafusion/src/physical_optimizer/coalesce_batches.rs index 9adee1ce2f2c..9af8911062df 100644 --- a/datafusion/src/physical_optimizer/coalesce_batches.rs +++ b/datafusion/src/physical_optimizer/coalesce_batches.rs @@ -37,7 +37,6 @@ impl CoalesceBatches { Self {} } } - impl PhysicalOptimizerRule for CoalesceBatches { fn optimize( &self, diff --git a/datafusion/src/physical_optimizer/pruning.rs b/datafusion/src/physical_optimizer/pruning.rs index d5d21e225591..d792dcfbf0cd 100644 --- a/datafusion/src/physical_optimizer/pruning.rs +++ b/datafusion/src/physical_optimizer/pruning.rs @@ -734,25 +734,25 @@ mod tests { let required_columns = RequiredStatColumns::from(vec![ // min of original column s1, named s1_min ( - Column::from_name("s1".to_string()), + "s1".into(), StatisticsType::Min, Field::new("s1_min", DataType::Int32, true), ), // max of original column s2, named s2_max ( - Column::from_name("s2".to_string()), + "s2".into(), StatisticsType::Max, Field::new("s2_max", DataType::Int32, true), ), // max of original column s3, named s3_max ( - Column::from_name("s3".to_string()), + "s3".into(), StatisticsType::Max, Field::new("s3_max", DataType::Utf8, true), ), // min of original column s3, named s3_min ( - Column::from_name("s3".to_string()), + "s3".into(), StatisticsType::Min, Field::new("s3_min", DataType::Utf8, true), ), @@ -804,7 +804,7 @@ mod tests { // Request a record batch with of s1_min as a timestamp let required_columns = RequiredStatColumns::from(vec![( - Column::from_name("s1".to_string()), + "s3".into(), StatisticsType::Min, Field::new( "s1_min", @@ -858,7 +858,7 @@ mod tests { // Request a record batch with of s1_min as a timestamp let required_columns = RequiredStatColumns::from(vec![( - Column::from_name("s1".to_string()), + "s3".into(), StatisticsType::Min, Field::new("s1_min", DataType::Utf8, true), )]); @@ -887,7 +887,7 @@ mod tests { fn test_build_statistics_inconsistent_length() { // return an inconsistent length to the actual statistics arrays let required_columns = RequiredStatColumns::from(vec![( - Column::from_name("s1".to_string()), + "s1".into(), StatisticsType::Min, Field::new("s1_min", DataType::Int64, true), )]); @@ -1114,30 +1114,18 @@ mod tests { let c1_min_field = Field::new("c1_min", DataType::Int32, false); assert_eq!( required_columns.columns[0], - ( - Column::from_name("c1".to_string()), - StatisticsType::Min, - c1_min_field - ) + ("c1".into(), StatisticsType::Min, c1_min_field) ); // c2 = 2 should add c2_min and c2_max let c2_min_field = Field::new("c2_min", DataType::Int32, false); assert_eq!( required_columns.columns[1], - ( - Column::from_name("c2".to_string()), - StatisticsType::Min, - c2_min_field - ) + ("c2".into(), StatisticsType::Min, c2_min_field) ); let c2_max_field = Field::new("c2_max", DataType::Int32, false); assert_eq!( required_columns.columns[2], - ( - Column::from_name("c2".to_string()), - StatisticsType::Max, - c2_max_field - ) + ("c2".into(), StatisticsType::Max, c2_max_field) ); // c2 = 3 shouldn't add any new statistics fields assert_eq!(required_columns.columns.len(), 3); From 6f6ecdfc584602d691ce7b6e54c956133a91ce0b Mon Sep 17 00:00:00 2001 From: Qingping Hou Date: Sat, 12 Jun 2021 17:49:31 -0700 Subject: [PATCH 21/25] simplify hash_join tests --- .../src/physical_plan/expressions/column.rs | 7 +- datafusion/src/physical_plan/hash_join.rs | 77 ++++++++++--------- 2 files changed, 45 insertions(+), 39 deletions(-) diff --git a/datafusion/src/physical_plan/expressions/column.rs b/datafusion/src/physical_plan/expressions/column.rs index 0fd2f3e903cf..d6eafbb05384 100644 --- a/datafusion/src/physical_plan/expressions/column.rs +++ b/datafusion/src/physical_plan/expressions/column.rs @@ -43,6 +43,11 @@ impl Column { } } + /// Create a new column expression based on column name and schema + pub fn new_with_schema(name: &str, schema: &Schema) -> Result { + Ok(Column::new(name, schema.index_of(name)?)) + } + /// Get the column name pub fn name(&self) -> &str { &self.name @@ -84,5 +89,5 @@ impl PhysicalExpr for Column { /// Create a column expression pub fn col(name: &str, schema: &Schema) -> Result> { - Ok(Arc::new(Column::new(name, schema.index_of(name)?))) + Ok(Arc::new(Column::new_with_schema(name, schema)?)) } diff --git a/datafusion/src/physical_plan/hash_join.rs b/datafusion/src/physical_plan/hash_join.rs index 53ade5f4fe4f..3b64d83be22f 100644 --- a/datafusion/src/physical_plan/hash_join.rs +++ b/datafusion/src/physical_plan/hash_join.rs @@ -1387,9 +1387,10 @@ mod tests { ("b1", &vec![4, 5, 6]), ("c2", &vec![70, 80, 90]), ); + let on = vec![( - Column::new("b1", left.schema().index_of("b1")?), - Column::new("b1", right.schema().index_of("b1")?), + Column::new_with_schema("b1", &left.schema())?, + Column::new_with_schema("b1", &right.schema())?, )]; let (columns, batches) = @@ -1425,8 +1426,8 @@ mod tests { ("c2", &vec![70, 80, 90]), ); let on = vec![( - Column::new("b1", left.schema().index_of("b1")?), - Column::new("b1", right.schema().index_of("b1")?), + Column::new_with_schema("b1", &left.schema())?, + Column::new_with_schema("b1", &right.schema())?, )]; let (columns, batches) = partitioned_join_collect( @@ -1466,8 +1467,8 @@ mod tests { ("c2", &vec![70, 80, 90]), ); let on = vec![( - Column::new("b1", left.schema().index_of("b1")?), - Column::new("b2", right.schema().index_of("b2")?), + Column::new_with_schema("b1", &left.schema())?, + Column::new_with_schema("b2", &right.schema())?, )]; let (columns, batches) = join_collect(left, right, on, &JoinType::Inner).await?; @@ -1503,12 +1504,12 @@ mod tests { ); let on = vec![ ( - Column::new("a1", left.schema().index_of("a1")?), - Column::new("a1", right.schema().index_of("a1")?), + Column::new_with_schema("a1", &left.schema())?, + Column::new_with_schema("a1", &right.schema())?, ), ( - Column::new("b2", left.schema().index_of("b2")?), - Column::new("b2", right.schema().index_of("b2")?), + Column::new_with_schema("b2", &left.schema())?, + Column::new_with_schema("b2", &right.schema())?, ), ]; @@ -1555,12 +1556,12 @@ mod tests { ); let on = vec![ ( - Column::new("a1", left.schema().index_of("a1")?), - Column::new("a1", right.schema().index_of("a1")?), + Column::new_with_schema("a1", &left.schema())?, + Column::new_with_schema("a1", &right.schema())?, ), ( - Column::new("b2", left.schema().index_of("b2")?), - Column::new("b2", right.schema().index_of("b2")?), + Column::new_with_schema("b2", &left.schema())?, + Column::new_with_schema("b2", &right.schema())?, ), ]; @@ -1607,8 +1608,8 @@ mod tests { ); let on = vec![( - Column::new("b1", left.schema().index_of("b1")?), - Column::new("b1", right.schema().index_of("b1")?), + Column::new_with_schema("b1", &left.schema())?, + Column::new_with_schema("b1", &right.schema())?, )]; let join = join(left, right, on, &JoinType::Inner)?; @@ -1673,8 +1674,8 @@ mod tests { ("c2", &vec![70, 80, 90]), ); let on = vec![( - Column::new("b1", left.schema().index_of("b1").unwrap()), - Column::new("b1", right.schema().index_of("b1").unwrap()), + Column::new_with_schema("b1", &left.schema()).unwrap(), + Column::new_with_schema("b1", &right.schema()).unwrap(), )]; let join = join(left, right, on, &JoinType::Left).unwrap(); @@ -1714,8 +1715,8 @@ mod tests { ("c2", &vec![70, 80, 90]), ); let on = vec![( - Column::new("b1", left.schema().index_of("b1").unwrap()), - Column::new("b2", right.schema().index_of("b2").unwrap()), + Column::new_with_schema("b1", &left.schema()).unwrap(), + Column::new_with_schema("b2", &right.schema()).unwrap(), )]; let join = join(left, right, on, &JoinType::Full).unwrap(); @@ -1752,8 +1753,8 @@ mod tests { ); let right = build_table_i32(("a2", &vec![]), ("b1", &vec![]), ("c2", &vec![])); let on = vec![( - Column::new("b1", left.schema().index_of("b1").unwrap()), - Column::new("b1", right.schema().index_of("b1").unwrap()), + Column::new_with_schema("b1", &left.schema()).unwrap(), + Column::new_with_schema("b1", &right.schema()).unwrap(), )]; let schema = right.schema(); let right = Arc::new(MemoryExec::try_new(&[vec![right]], schema, None).unwrap()); @@ -1787,8 +1788,8 @@ mod tests { ); let right = build_table_i32(("a2", &vec![]), ("b2", &vec![]), ("c2", &vec![])); let on = vec![( - Column::new("b1", left.schema().index_of("b1").unwrap()), - Column::new("b2", right.schema().index_of("b2").unwrap()), + Column::new_with_schema("b1", &left.schema()).unwrap(), + Column::new_with_schema("b2", &right.schema()).unwrap(), )]; let schema = right.schema(); let right = Arc::new(MemoryExec::try_new(&[vec![right]], schema, None).unwrap()); @@ -1826,8 +1827,8 @@ mod tests { ("c2", &vec![70, 80, 90]), ); let on = vec![( - Column::new("b1", left.schema().index_of("b1")?), - Column::new("b1", right.schema().index_of("b1")?), + Column::new_with_schema("b1", &left.schema())?, + Column::new_with_schema("b1", &right.schema())?, )]; let (columns, batches) = @@ -1862,8 +1863,8 @@ mod tests { ("c2", &vec![70, 80, 90]), ); let on = vec![( - Column::new("b1", left.schema().index_of("b1")?), - Column::new("b1", right.schema().index_of("b1")?), + Column::new_with_schema("b1", &left.schema())?, + Column::new_with_schema("b1", &right.schema())?, )]; let (columns, batches) = partitioned_join_collect( @@ -1902,8 +1903,8 @@ mod tests { ("c2", &vec![70, 80, 90, 100]), ); let on = vec![( - Column::new("b1", left.schema().index_of("b1")?), - Column::new("b1", right.schema().index_of("b1")?), + Column::new_with_schema("b1", &left.schema())?, + Column::new_with_schema("b1", &right.schema())?, )]; let join = join(left, right, on, &JoinType::Semi)?; @@ -1941,8 +1942,8 @@ mod tests { ("c2", &vec![70, 80, 90, 100]), ); let on = vec![( - Column::new("b1", left.schema().index_of("b1")?), - Column::new("b1", right.schema().index_of("b1")?), + Column::new_with_schema("b1", &left.schema())?, + Column::new_with_schema("b1", &right.schema())?, )]; let join = join(left, right, on, &JoinType::Anti)?; @@ -1978,8 +1979,8 @@ mod tests { ("c2", &vec![70, 80, 90]), ); let on = vec![( - Column::new("b1", left.schema().index_of("b1")?), - Column::new("b1", right.schema().index_of("b1")?), + Column::new_with_schema("b1", &left.schema())?, + Column::new_with_schema("b1", &right.schema())?, )]; let (columns, batches) = join_collect(left, right, on, &JoinType::Right).await?; @@ -2014,8 +2015,8 @@ mod tests { ("c2", &vec![70, 80, 90]), ); let on = vec![( - Column::new("b1", left.schema().index_of("b1")?), - Column::new("b1", right.schema().index_of("b1")?), + Column::new_with_schema("b1", &left.schema())?, + Column::new_with_schema("b1", &right.schema())?, )]; let (columns, batches) = @@ -2051,8 +2052,8 @@ mod tests { ("c2", &vec![70, 80, 90]), ); let on = vec![( - Column::new("b1", left.schema().index_of("b1").unwrap()), - Column::new("b2", right.schema().index_of("b2").unwrap()), + Column::new_with_schema("b1", &left.schema()).unwrap(), + Column::new_with_schema("b2", &right.schema()).unwrap(), )]; let join = join(left, right, on, &JoinType::Full)?; From 16436170d6ae74276755b81f3ac399d78f6fb319 Mon Sep 17 00:00:00 2001 From: Qingping Hou Date: Sat, 12 Jun 2021 18:39:15 -0700 Subject: [PATCH 22/25] reduce diff for easier revuew --- datafusion/src/logical_plan/builder.rs | 19 ++++++++++--------- datafusion/src/logical_plan/dfschema.rs | 2 +- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/datafusion/src/logical_plan/builder.rs b/datafusion/src/logical_plan/builder.rs index 59e49175f6ba..4b4ed0fb9d41 100644 --- a/datafusion/src/logical_plan/builder.rs +++ b/datafusion/src/logical_plan/builder.rs @@ -160,17 +160,18 @@ impl LogicalPlanBuilder { let projected_schema = projection .as_ref() - .map(|p| DFSchema { - fields: p - .iter() - .map(|i| { - DFField::from_qualified(table_name, schema.field(*i).clone()) - }) - .collect(), + .map(|p| { + DFSchema::new( + p.iter() + .map(|i| { + DFField::from_qualified(table_name, schema.field(*i).clone()) + }) + .collect(), + ) }) .unwrap_or_else(|| { - DFSchema::try_from_qualified_schema(table_name, &schema).unwrap() - }); + DFSchema::try_from_qualified_schema(table_name, &schema) + })?; let table_scan = LogicalPlan::TableScan { table_name: table_name.to_string(), diff --git a/datafusion/src/logical_plan/dfschema.rs b/datafusion/src/logical_plan/dfschema.rs index 1d85e0ef0a4e..ce1424365ab3 100644 --- a/datafusion/src/logical_plan/dfschema.rs +++ b/datafusion/src/logical_plan/dfschema.rs @@ -35,7 +35,7 @@ pub type DFSchemaRef = Arc; #[derive(Debug, Clone, PartialEq, Eq)] pub struct DFSchema { /// Fields - pub(crate) fields: Vec, + fields: Vec, } impl DFSchema { From cad0d5ec5515e48e6a3da6731c08947622e5494f Mon Sep 17 00:00:00 2001 From: Qingping Hou Date: Sun, 13 Jun 2021 22:57:20 -0700 Subject: [PATCH 23/25] fix unnecessary reference clippy error --- datafusion/src/execution/context.rs | 2 +- .../src/optimizer/projection_push_down.rs | 6 +++--- datafusion/src/physical_plan/hash_join.rs | 2 +- datafusion/src/physical_plan/planner.rs | 18 +++++++++--------- datafusion/src/sql/planner.rs | 4 ++-- 5 files changed, 16 insertions(+), 16 deletions(-) diff --git a/datafusion/src/execution/context.rs b/datafusion/src/execution/context.rs index 8f71c9d1476c..76bdea744a19 100644 --- a/datafusion/src/execution/context.rs +++ b/datafusion/src/execution/context.rs @@ -404,7 +404,7 @@ impl ExecutionContext { match schema.table(table_ref.table()) { Some(ref provider) => { let plan = LogicalPlanBuilder::scan( - &table_ref.table(), + table_ref.table(), Arc::clone(provider), None, )? diff --git a/datafusion/src/optimizer/projection_push_down.rs b/datafusion/src/optimizer/projection_push_down.rs index 58e14e5d0398..17186972d7db 100644 --- a/datafusion/src/optimizer/projection_push_down.rs +++ b/datafusion/src/optimizer/projection_push_down.rs @@ -208,8 +208,8 @@ fn optimize_plan( )?); let schema = build_join_schema( - &optimized_left.schema(), - &optimized_right.schema(), + optimized_left.schema(), + optimized_right.schema(), on, join_type, join_constraint, @@ -336,7 +336,7 @@ fn optimize_plan( .. } => { let (projection, projected_schema) = get_projected_schema( - Some(&table_name), + Some(table_name), &source.schema(), required_columns, has_projection, diff --git a/datafusion/src/physical_plan/hash_join.rs b/datafusion/src/physical_plan/hash_join.rs index 116969411218..4c4cb59ec27e 100644 --- a/datafusion/src/physical_plan/hash_join.rs +++ b/datafusion/src/physical_plan/hash_join.rs @@ -134,7 +134,7 @@ impl HashJoinExec { ) -> Result { let left_schema = left.schema(); let right_schema = right.schema(); - check_join_is_valid(&left_schema, &right_schema, on)?; + check_join_is_valid(&left_schema, &right_schema, &on)?; let schema = Arc::new(build_join_schema( &left_schema, diff --git a/datafusion/src/physical_plan/planner.rs b/datafusion/src/physical_plan/planner.rs index 16cb0c3287ff..775ec14f6301 100644 --- a/datafusion/src/physical_plan/planner.rs +++ b/datafusion/src/physical_plan/planner.rs @@ -300,7 +300,7 @@ impl DefaultPhysicalPlanner { tuple_err(( self.create_physical_expr( e, - &logical_input_schema, + logical_input_schema, &physical_input_schema, ctx_state, ), @@ -398,17 +398,17 @@ impl DefaultPhysicalPlanner { // This depends on the invariant that logical schema field index MUST match // with physical schema field index. let physical_name = if let Expr::Column(col) = e { - match input_schema.index_of_column(&col) { + match input_schema.index_of_column(col) { Ok(idx) => { // index physical field using logical field index Ok(input_exec.schema().field(idx).name().to_string()) } // logical column is not a derived column, safe to pass along to // physical_name - Err(_) => physical_name(e, &input_schema), + Err(_) => physical_name(e, input_schema), } } else { - physical_name(e, &input_schema) + physical_name(e, input_schema) }; tuple_err(( @@ -466,7 +466,7 @@ impl DefaultPhysicalPlanner { .map(|e| { self.create_physical_expr( e, - &input_dfschema, + input_dfschema, &input_schema, ctx_state, ) @@ -494,7 +494,7 @@ impl DefaultPhysicalPlanner { nulls_first, } => self.create_physical_sort_expr( expr, - &input_dfschema, + input_dfschema, &input_schema, SortOptions { descending: !*asc, @@ -533,8 +533,8 @@ impl DefaultPhysicalPlanner { .iter() .map(|(l, r)| { Ok(( - Column::new(&l.name, left_df_schema.index_of_column(&l)?), - Column::new(&r.name, right_df_schema.index_of_column(&r)?), + Column::new(&l.name, left_df_schema.index_of_column(l)?), + Column::new(&r.name, right_df_schema.index_of_column(r)?), )) }) .collect::>()?; @@ -691,7 +691,7 @@ impl DefaultPhysicalPlanner { )?), Expr::Column(c) => { let idx = input_dfschema.index_of_column(c)?; - Ok(Arc::new(Column::new(c.name, idx))) + Ok(Arc::new(Column::new(&c.name, idx))) } Expr::Literal(value) => Ok(Arc::new(Literal::new(value.clone()))), Expr::ScalarVariable(variable_names) => { diff --git a/datafusion/src/sql/planner.rs b/datafusion/src/sql/planner.rs index 79fdb81e623e..ee124f93de7e 100644 --- a/datafusion/src/sql/planner.rs +++ b/datafusion/src/sql/planner.rs @@ -831,8 +831,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .try_for_each(|col| match col { Expr::Column(col) => { match &col.relation { - Some(r) => schema.field_with_qualified_name(r, col.name), - None => schema.field_with_unqualified_name(col.name), + Some(r) => schema.field_with_qualified_name(r, &col.name), + None => schema.field_with_unqualified_name(&col.name), } .map_err(|_| { DataFusionError::Plan(format!( From 593785aaa891f0afc2ae030130db24aff03a9458 Mon Sep 17 00:00:00 2001 From: Qingping Hou Date: Mon, 14 Jun 2021 23:18:42 -0700 Subject: [PATCH 24/25] incorporate code review feedback --- datafusion/src/logical_plan/dfschema.rs | 17 +---------------- datafusion/src/logical_plan/expr.rs | 2 +- 2 files changed, 2 insertions(+), 17 deletions(-) diff --git a/datafusion/src/logical_plan/dfschema.rs b/datafusion/src/logical_plan/dfschema.rs index 2d442bf58c25..c993e64b1a42 100644 --- a/datafusion/src/logical_plan/dfschema.rs +++ b/datafusion/src/logical_plan/dfschema.rs @@ -300,22 +300,7 @@ impl Into for DFSchema { impl Into for &DFSchema { /// Convert a schema into a DFSchema fn into(self) -> Schema { - Schema::new( - self.fields - .iter() - .map(|f| { - if f.qualifier().is_some() { - Field::new( - f.name().as_str(), - f.data_type().to_owned(), - f.is_nullable(), - ) - } else { - f.field.clone() - } - }) - .collect(), - ) + Schema::new(self.fields.iter().map(|f| f.field.clone()).collect()) } } diff --git a/datafusion/src/logical_plan/expr.rs b/datafusion/src/logical_plan/expr.rs index eae542f92cf1..1c5cc770c94f 100644 --- a/datafusion/src/logical_plan/expr.rs +++ b/datafusion/src/logical_plan/expr.rs @@ -33,7 +33,7 @@ use std::collections::HashSet; use std::fmt; use std::sync::Arc; -/// A named reference to a qualified filed in a schema. +/// A named reference to a qualified field in a schema. #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Column { /// relation/table name. From d26b54c4036117f3c5f2841008d1c0968641c5e1 Mon Sep 17 00:00:00 2001 From: Qingping Hou Date: Sun, 20 Jun 2021 14:20:30 -0700 Subject: [PATCH 25/25] fix window schema handling in projection pushdown optimizer --- .../src/optimizer/projection_push_down.rs | 32 +++++++------------ datafusion/src/physical_plan/planner.rs | 6 ++-- datafusion/src/physical_plan/windows.rs | 2 +- integration-tests/test_psql_parity.py | 2 +- 4 files changed, 15 insertions(+), 27 deletions(-) diff --git a/datafusion/src/optimizer/projection_push_down.rs b/datafusion/src/optimizer/projection_push_down.rs index 17186972d7db..2544d89d0492 100644 --- a/datafusion/src/optimizer/projection_push_down.rs +++ b/datafusion/src/optimizer/projection_push_down.rs @@ -21,7 +21,8 @@ use crate::error::Result; use crate::execution::context::ExecutionProps; use crate::logical_plan::{ - build_join_schema, Column, DFField, DFSchema, DFSchemaRef, LogicalPlan, ToDFSchema, + build_join_schema, Column, DFField, DFSchema, DFSchemaRef, LogicalPlan, + LogicalPlanBuilder, ToDFSchema, }; use crate::optimizer::optimizer::OptimizerRule; use crate::optimizer::utils; @@ -253,26 +254,15 @@ fn optimize_plan( &mut new_required_columns, )?; - let new_schema = DFSchema::new( - schema - .fields() - .iter() - .filter(|f| new_required_columns.contains(&f.qualified_column())) - .cloned() - .collect(), - )?; - - Ok(LogicalPlan::Window { - window_expr: new_window_expr, - input: Arc::new(optimize_plan( - optimizer, - input, - &new_required_columns, - true, - execution_props, - )?), - schema: DFSchemaRef::new(new_schema), - }) + LogicalPlanBuilder::from(&optimize_plan( + optimizer, + input, + &new_required_columns, + true, + execution_props, + )?) + .window(new_window_expr)? + .build() } LogicalPlan::Aggregate { schema, diff --git a/datafusion/src/physical_plan/planner.rs b/datafusion/src/physical_plan/planner.rs index 40719c000a5b..902b122e9b4d 100644 --- a/datafusion/src/physical_plan/planner.rs +++ b/datafusion/src/physical_plan/planner.rs @@ -265,10 +265,8 @@ impl DefaultPhysicalPlanner { } let input_exec = self.create_initial_plan(input, ctx_state)?; - let input_schema = input_exec.schema(); - + let physical_input_schema = input_exec.schema(); let logical_input_schema = input.as_ref().schema(); - let physical_input_schema = input_exec.as_ref().schema(); let window_expr = window_expr .iter() @@ -285,7 +283,7 @@ impl DefaultPhysicalPlanner { Ok(Arc::new(WindowAggExec::try_new( window_expr, input_exec.clone(), - input_schema, + physical_input_schema, )?)) } LogicalPlan::Aggregate { diff --git a/datafusion/src/physical_plan/windows.rs b/datafusion/src/physical_plan/windows.rs index fdea92c922f5..cb1dda738bd4 100644 --- a/datafusion/src/physical_plan/windows.rs +++ b/datafusion/src/physical_plan/windows.rs @@ -320,7 +320,7 @@ impl WindowAggExec { input: Arc, input_schema: SchemaRef, ) -> Result { - let schema = create_schema(&input.schema(), &window_expr)?; + let schema = create_schema(&input_schema, &window_expr)?; let schema = Arc::new(schema); Ok(WindowAggExec { input, diff --git a/integration-tests/test_psql_parity.py b/integration-tests/test_psql_parity.py index 4e0878c24b81..10ff5055e6f7 100644 --- a/integration-tests/test_psql_parity.py +++ b/integration-tests/test_psql_parity.py @@ -83,7 +83,7 @@ def test_parity(self): psql_output = pd.read_csv(io.BytesIO(generate_csv_from_psql(fname))) self.assertTrue( np.allclose(datafusion_output, psql_output), - msg=f"data fusion output={datafusion_output}, psql_output={psql_output}", + msg=f"datafusion output=\n{datafusion_output}, psql_output=\n{psql_output}", )