diff --git a/ballista/rust/core/proto/ballista.proto b/ballista/rust/core/proto/ballista.proto index 5aafd00cf1b0..d75cbaa73efe 100644 --- a/ballista/rust/core/proto/ballista.proto +++ b/ballista/rust/core/proto/ballista.proto @@ -28,11 +28,29 @@ option java_outer_classname = "BallistaProto"; // Ballista Logical Plan /////////////////////////////////////////////////////////////////////////////////////////////////// +message ColumnRelation { + string relation = 1; +} + +message Column { + string name = 1; + ColumnRelation relation = 2; +} + +message DfField{ + Field field = 1; + ColumnRelation qualifier = 2; +} + +message DfSchema { + repeated DfField columns = 1; +} + // logical expressions message LogicalExprNode { oneof ExprType { // column references - string column_name = 1; + Column column = 1; // alias AliasNode alias = 2; @@ -295,7 +313,7 @@ message CreateExternalTableNode{ string location = 2; FileType file_type = 3; bool has_header = 4; - Schema schema = 5; + DfSchema schema = 5; } enum FileType{ @@ -309,11 +327,6 @@ message ExplainNode{ bool verbose = 2; } -message DfField{ - string qualifier = 2; - Field field = 1; -} - message AggregateNode { LogicalPlanNode input = 1; repeated LogicalExprNode group_expr = 2; @@ -369,8 +382,8 @@ message JoinNode { LogicalPlanNode left = 1; LogicalPlanNode right = 2; JoinType join_type = 3; - repeated string left_join_column = 4; - repeated string right_join_column = 5; + repeated Column left_join_column = 4; + repeated Column right_join_column = 5; } message LimitNode { @@ -408,6 +421,119 @@ message PhysicalPlanNode { } } +// physical expressions +message PhysicalExprNode { + oneof ExprType { + // column references + PhysicalColumn column = 1; + + ScalarValue literal = 2; + + // binary expressions + PhysicalBinaryExprNode binary_expr = 3; + + // aggregate expressions + PhysicalAggregateExprNode aggregate_expr = 4; + + // null checks + PhysicalIsNull is_null_expr = 5; + PhysicalIsNotNull is_not_null_expr = 6; + PhysicalNot not_expr = 7; + + PhysicalCaseNode case_ = 8; + PhysicalCastNode cast = 9; + PhysicalSortExprNode sort = 10; + PhysicalNegativeNode negative = 11; + PhysicalInListNode in_list = 12; + PhysicalScalarFunctionNode scalar_function = 13; + PhysicalTryCastNode try_cast = 14; + + // window expressions + PhysicalWindowExprNode window_expr = 15; + } +} + +message PhysicalAggregateExprNode { + AggregateFunction aggr_function = 1; + PhysicalExprNode expr = 2; +} + +message PhysicalWindowExprNode { + oneof window_function { + AggregateFunction aggr_function = 1; + BuiltInWindowFunction built_in_function = 2; + // udaf = 3 + } + PhysicalExprNode expr = 4; +} + +message PhysicalIsNull { + PhysicalExprNode expr = 1; +} + +message PhysicalIsNotNull { + PhysicalExprNode expr = 1; +} + +message PhysicalNot { + PhysicalExprNode expr = 1; +} + +message PhysicalAliasNode { + PhysicalExprNode expr = 1; + string alias = 2; +} + +message PhysicalBinaryExprNode { + PhysicalExprNode l = 1; + PhysicalExprNode r = 2; + string op = 3; +} + +message PhysicalSortExprNode { + PhysicalExprNode expr = 1; + bool asc = 2; + bool nulls_first = 3; +} + +message PhysicalWhenThen { + PhysicalExprNode when_expr = 1; + PhysicalExprNode then_expr = 2; +} + +message PhysicalInListNode { + PhysicalExprNode expr = 1; + repeated PhysicalExprNode list = 2; + bool negated = 3; +} + +message PhysicalCaseNode { + PhysicalExprNode expr = 1; + repeated PhysicalWhenThen when_then_expr = 2; + PhysicalExprNode else_expr = 3; +} + +message PhysicalScalarFunctionNode { + string name = 1; + ScalarFunction fun = 2; + repeated PhysicalExprNode args = 3; + ArrowType return_type = 4; +} + +message PhysicalTryCastNode { + PhysicalExprNode expr = 1; + ArrowType arrow_type = 2; +} + +message PhysicalCastNode { + PhysicalExprNode expr = 1; + ArrowType arrow_type = 2; +} + +message PhysicalNegativeNode { + PhysicalExprNode expr = 1; +} + message UnresolvedShuffleExecNode { repeated uint32 query_stage_ids = 1; Schema schema = 2; @@ -416,7 +542,7 @@ message UnresolvedShuffleExecNode { message FilterExecNode { PhysicalPlanNode input = 1; - LogicalExprNode expr = 2; + PhysicalExprNode expr = 2; } message ParquetScanExecNode { @@ -447,11 +573,15 @@ message HashJoinExecNode { } -message JoinOn { - string left = 1; - string right = 2; +message PhysicalColumn { + string name = 1; + uint32 index = 2; } +message JoinOn { + PhysicalColumn left = 1; + PhysicalColumn right = 2; +} message EmptyExecNode { bool produce_one_row = 1; @@ -460,7 +590,7 @@ message EmptyExecNode { message ProjectionExecNode { PhysicalPlanNode input = 1; - repeated LogicalExprNode expr = 2; + repeated PhysicalExprNode expr = 2; repeated string expr_name = 3; } @@ -472,14 +602,14 @@ enum AggregateMode { message WindowAggExecNode { PhysicalPlanNode input = 1; - repeated LogicalExprNode window_expr = 2; + repeated PhysicalExprNode window_expr = 2; repeated string window_expr_name = 3; Schema input_schema = 4; } message HashAggregateExecNode { - repeated LogicalExprNode group_expr = 1; - repeated LogicalExprNode aggr_expr = 2; + repeated PhysicalExprNode group_expr = 1; + repeated PhysicalExprNode aggr_expr = 2; AggregateMode mode = 3; PhysicalPlanNode input = 4; repeated string group_expr_name = 5; @@ -510,7 +640,7 @@ message LocalLimitExecNode { message SortExecNode { PhysicalPlanNode input = 1; - repeated LogicalExprNode expr = 2; + repeated PhysicalExprNode expr = 2; } message CoalesceBatchesExecNode { @@ -522,11 +652,16 @@ message MergeExecNode { PhysicalPlanNode input = 1; } +message PhysicalHashRepartition { + repeated PhysicalExprNode hash_expr = 1; + uint64 partition_count = 2; +} + message RepartitionExecNode{ PhysicalPlanNode input = 1; oneof partition_method { uint64 round_robin = 2; - HashRepartition hash = 3; + PhysicalHashRepartition hash = 3; uint64 unknown = 4; } } @@ -803,7 +938,7 @@ message ScalarListValue{ message ScalarValue{ - oneof value{ + oneof value { bool bool_value = 1; string utf8_value = 2; string large_utf8_value = 3; diff --git a/ballista/rust/core/src/serde/logical_plan/from_proto.rs b/ballista/rust/core/src/serde/logical_plan/from_proto.rs index c2c1001b939c..1b7deb7b7126 100644 --- a/ballista/rust/core/src/serde/logical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/from_proto.rs @@ -18,7 +18,7 @@ //! Serde code to convert from protocol buffers to Rust data structures. use crate::error::BallistaError; -use crate::serde::{proto_error, protobuf}; +use crate::serde::{from_proto_binary_op, proto_error, protobuf}; use crate::{convert_box_required, convert_required}; use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit}; use datafusion::logical_plan::window_frames::{ @@ -26,7 +26,8 @@ use datafusion::logical_plan::window_frames::{ }; use datafusion::logical_plan::{ abs, acos, asin, atan, ceil, cos, exp, floor, ln, log10, log2, round, signum, sin, - sqrt, tan, trunc, Expr, JoinType, LogicalPlan, LogicalPlanBuilder, Operator, + sqrt, tan, trunc, Column, DFField, DFSchema, Expr, JoinType, LogicalPlan, + LogicalPlanBuilder, Operator, }; use datafusion::physical_plan::aggregates::AggregateFunction; use datafusion::physical_plan::csv::CsvReadOptions; @@ -36,6 +37,7 @@ use protobuf::logical_plan_node::LogicalPlanType; use protobuf::{logical_expr_node::ExprType, scalar_type}; use std::{ convert::{From, TryInto}, + sync::Arc, unimplemented, }; @@ -115,8 +117,8 @@ impl TryInto for &protobuf::LogicalPlanNode { .has_header(scan.has_header); let mut projection = None; - if let Some(column_names) = &scan.projection { - let column_indices = column_names + if let Some(columns) = &scan.projection { + let column_indices = columns .columns .iter() .map(|name| schema.index_of(name)) @@ -234,10 +236,10 @@ impl TryInto for &protobuf::LogicalPlanNode { .map_err(|e| e.into()) } LogicalPlanType::Join(join) => { - let left_keys: Vec<&str> = - join.left_join_column.iter().map(|i| i.as_str()).collect(); - let right_keys: Vec<&str> = - join.right_join_column.iter().map(|i| i.as_str()).collect(); + let left_keys: Vec = + join.left_join_column.iter().map(|i| i.into()).collect(); + let right_keys: Vec = + join.right_join_column.iter().map(|i| i.into()).collect(); let join_type = protobuf::JoinType::from_i32(join.join_type).ok_or_else(|| { proto_error(format!( @@ -257,8 +259,8 @@ impl TryInto for &protobuf::LogicalPlanNode { .join( &convert_box_required!(join.right)?, join_type, - &left_keys, - &right_keys, + left_keys, + right_keys, )? .build() .map_err(|e| e.into()) @@ -267,22 +269,48 @@ impl TryInto for &protobuf::LogicalPlanNode { } } -impl TryInto for protobuf::Schema { +impl From<&protobuf::Column> for Column { + fn from(c: &protobuf::Column) -> Column { + let c = c.clone(); + Column { + relation: c.relation.map(|r| r.relation), + name: c.name, + } + } +} + +impl TryInto for &protobuf::DfSchema { type Error = BallistaError; - fn try_into(self) -> Result { - let schema: Schema = (&self).try_into()?; - schema.try_into().map_err(BallistaError::DataFusionError) + + fn try_into(self) -> Result { + let fields = self + .columns + .iter() + .map(|c| c.try_into()) + .collect::, _>>()?; + Ok(DFSchema::new(fields)?) } } -impl TryInto for protobuf::Schema { +impl TryInto for protobuf::DfSchema { type Error = BallistaError; + fn try_into(self) -> Result { - use datafusion::logical_plan::ToDFSchema; - let schema: Schema = (&self).try_into()?; - schema - .to_dfschema_ref() - .map_err(BallistaError::DataFusionError) + let dfschema: DFSchema = (&self).try_into()?; + Ok(Arc::new(dfschema)) + } +} + +impl TryInto for &protobuf::DfField { + type Error = BallistaError; + + fn try_into(self) -> Result { + let field: Field = convert_required!(self.field)?; + + Ok(match &self.qualifier { + Some(q) => DFField::from_qualified(&q.relation, field), + None => DFField::from(field), + }) } } @@ -339,149 +367,6 @@ impl TryInto for &protobuf::scalar_type::Datatype { } } -impl TryInto for &protobuf::arrow_type::ArrowTypeEnum { - type Error = BallistaError; - fn try_into(self) -> Result { - use protobuf::arrow_type; - Ok(match self { - arrow_type::ArrowTypeEnum::None(_) => DataType::Null, - arrow_type::ArrowTypeEnum::Bool(_) => DataType::Boolean, - arrow_type::ArrowTypeEnum::Uint8(_) => DataType::UInt8, - arrow_type::ArrowTypeEnum::Int8(_) => DataType::Int8, - arrow_type::ArrowTypeEnum::Uint16(_) => DataType::UInt16, - arrow_type::ArrowTypeEnum::Int16(_) => DataType::Int16, - arrow_type::ArrowTypeEnum::Uint32(_) => DataType::UInt32, - arrow_type::ArrowTypeEnum::Int32(_) => DataType::Int32, - arrow_type::ArrowTypeEnum::Uint64(_) => DataType::UInt64, - arrow_type::ArrowTypeEnum::Int64(_) => DataType::Int64, - arrow_type::ArrowTypeEnum::Float16(_) => DataType::Float16, - arrow_type::ArrowTypeEnum::Float32(_) => DataType::Float32, - arrow_type::ArrowTypeEnum::Float64(_) => DataType::Float64, - arrow_type::ArrowTypeEnum::Utf8(_) => DataType::Utf8, - arrow_type::ArrowTypeEnum::LargeUtf8(_) => DataType::LargeUtf8, - arrow_type::ArrowTypeEnum::Binary(_) => DataType::Binary, - arrow_type::ArrowTypeEnum::FixedSizeBinary(size) => { - DataType::FixedSizeBinary(*size) - } - arrow_type::ArrowTypeEnum::LargeBinary(_) => DataType::LargeBinary, - arrow_type::ArrowTypeEnum::Date32(_) => DataType::Date32, - arrow_type::ArrowTypeEnum::Date64(_) => DataType::Date64, - arrow_type::ArrowTypeEnum::Duration(time_unit) => { - DataType::Duration(protobuf::TimeUnit::from_i32_to_arrow(*time_unit)?) - } - arrow_type::ArrowTypeEnum::Timestamp(protobuf::Timestamp { - time_unit, - timezone, - }) => DataType::Timestamp( - protobuf::TimeUnit::from_i32_to_arrow(*time_unit)?, - match timezone.len() { - 0 => None, - _ => Some(timezone.to_owned()), - }, - ), - arrow_type::ArrowTypeEnum::Time32(time_unit) => { - DataType::Time32(protobuf::TimeUnit::from_i32_to_arrow(*time_unit)?) - } - arrow_type::ArrowTypeEnum::Time64(time_unit) => { - DataType::Time64(protobuf::TimeUnit::from_i32_to_arrow(*time_unit)?) - } - arrow_type::ArrowTypeEnum::Interval(interval_unit) => DataType::Interval( - protobuf::IntervalUnit::from_i32_to_arrow(*interval_unit)?, - ), - arrow_type::ArrowTypeEnum::Decimal(protobuf::Decimal { - whole, - fractional, - }) => DataType::Decimal(*whole as usize, *fractional as usize), - arrow_type::ArrowTypeEnum::List(list) => { - let list_type: &protobuf::Field = list - .as_ref() - .field_type - .as_ref() - .ok_or_else(|| proto_error("Protobuf deserialization error: List message missing required field 'field_type'"))? - .as_ref(); - DataType::List(Box::new(list_type.try_into()?)) - } - arrow_type::ArrowTypeEnum::LargeList(list) => { - let list_type: &protobuf::Field = list - .as_ref() - .field_type - .as_ref() - .ok_or_else(|| proto_error("Protobuf deserialization error: List message missing required field 'field_type'"))? - .as_ref(); - DataType::LargeList(Box::new(list_type.try_into()?)) - } - arrow_type::ArrowTypeEnum::FixedSizeList(list) => { - let list_type: &protobuf::Field = list - .as_ref() - .field_type - .as_ref() - .ok_or_else(|| proto_error("Protobuf deserialization error: List message missing required field 'field_type'"))? - .as_ref(); - let list_size = list.list_size; - DataType::FixedSizeList(Box::new(list_type.try_into()?), list_size) - } - arrow_type::ArrowTypeEnum::Struct(strct) => DataType::Struct( - strct - .sub_field_types - .iter() - .map(|field| field.try_into()) - .collect::, _>>()?, - ), - arrow_type::ArrowTypeEnum::Union(union) => DataType::Union( - union - .union_types - .iter() - .map(|field| field.try_into()) - .collect::, _>>()?, - ), - arrow_type::ArrowTypeEnum::Dictionary(dict) => { - let pb_key_datatype = dict - .as_ref() - .key - .as_ref() - .ok_or_else(|| proto_error("Protobuf deserialization error: Dictionary message missing required field 'key'"))?; - let pb_value_datatype = dict - .as_ref() - .value - .as_ref() - .ok_or_else(|| proto_error("Protobuf deserialization error: Dictionary message missing required field 'key'"))?; - let key_datatype: DataType = pb_key_datatype.as_ref().try_into()?; - let value_datatype: DataType = pb_value_datatype.as_ref().try_into()?; - DataType::Dictionary(Box::new(key_datatype), Box::new(value_datatype)) - } - }) - } -} - -#[allow(clippy::from_over_into)] -impl Into for protobuf::PrimitiveScalarType { - fn into(self) -> DataType { - match self { - protobuf::PrimitiveScalarType::Bool => DataType::Boolean, - protobuf::PrimitiveScalarType::Uint8 => DataType::UInt8, - protobuf::PrimitiveScalarType::Int8 => DataType::Int8, - protobuf::PrimitiveScalarType::Uint16 => DataType::UInt16, - protobuf::PrimitiveScalarType::Int16 => DataType::Int16, - protobuf::PrimitiveScalarType::Uint32 => DataType::UInt32, - protobuf::PrimitiveScalarType::Int32 => DataType::Int32, - protobuf::PrimitiveScalarType::Uint64 => DataType::UInt64, - protobuf::PrimitiveScalarType::Int64 => DataType::Int64, - protobuf::PrimitiveScalarType::Float32 => DataType::Float32, - protobuf::PrimitiveScalarType::Float64 => DataType::Float64, - protobuf::PrimitiveScalarType::Utf8 => DataType::Utf8, - protobuf::PrimitiveScalarType::LargeUtf8 => DataType::LargeUtf8, - protobuf::PrimitiveScalarType::Date32 => DataType::Date32, - protobuf::PrimitiveScalarType::TimeMicrosecond => { - DataType::Time64(TimeUnit::Microsecond) - } - protobuf::PrimitiveScalarType::TimeNanosecond => { - DataType::Time64(TimeUnit::Nanosecond) - } - protobuf::PrimitiveScalarType::Null => DataType::Null, - } - } -} - //Does not typecheck lists fn typechecked_scalar_value_conversion( tested_type: &protobuf::scalar_value::Value, @@ -899,7 +784,7 @@ impl TryInto for &protobuf::LogicalExprNode { op: from_proto_binary_op(&binary_expr.op)?, right: Box::new(parse_required_expr(&binary_expr.r)?), }), - ExprType::ColumnName(column_name) => Ok(Expr::Column(column_name.to_owned())), + ExprType::Column(column) => Ok(Expr::Column(column.into())), ExprType::Literal(literal) => { use datafusion::scalar::ScalarValue; let scalar_value: datafusion::scalar::ScalarValue = literal.try_into()?; @@ -1164,28 +1049,6 @@ impl TryInto for &protobuf::LogicalExprNode { } } -fn from_proto_binary_op(op: &str) -> Result { - match op { - "And" => Ok(Operator::And), - "Or" => Ok(Operator::Or), - "Eq" => Ok(Operator::Eq), - "NotEq" => Ok(Operator::NotEq), - "LtEq" => Ok(Operator::LtEq), - "Lt" => Ok(Operator::Lt), - "Gt" => Ok(Operator::Gt), - "GtEq" => Ok(Operator::GtEq), - "Plus" => Ok(Operator::Plus), - "Minus" => Ok(Operator::Minus), - "Multiply" => Ok(Operator::Multiply), - "Divide" => Ok(Operator::Divide), - "Like" => Ok(Operator::Like), - other => Err(proto_error(format!( - "Unsupported binary operator '{:?}'", - other - ))), - } -} - impl TryInto for &protobuf::ScalarType { type Error = BallistaError; fn try_into(self) -> Result { @@ -1361,43 +1224,3 @@ impl TryFrom for WindowFrame { }) } } - -impl From for AggregateFunction { - fn from(aggr_function: protobuf::AggregateFunction) -> Self { - match aggr_function { - protobuf::AggregateFunction::Min => AggregateFunction::Min, - protobuf::AggregateFunction::Max => AggregateFunction::Max, - protobuf::AggregateFunction::Sum => AggregateFunction::Sum, - protobuf::AggregateFunction::Avg => AggregateFunction::Avg, - protobuf::AggregateFunction::Count => AggregateFunction::Count, - } - } -} - -impl From for BuiltInWindowFunction { - fn from(built_in_function: protobuf::BuiltInWindowFunction) -> Self { - match built_in_function { - protobuf::BuiltInWindowFunction::RowNumber => { - BuiltInWindowFunction::RowNumber - } - protobuf::BuiltInWindowFunction::Rank => BuiltInWindowFunction::Rank, - protobuf::BuiltInWindowFunction::PercentRank => { - BuiltInWindowFunction::PercentRank - } - protobuf::BuiltInWindowFunction::DenseRank => { - BuiltInWindowFunction::DenseRank - } - protobuf::BuiltInWindowFunction::Lag => BuiltInWindowFunction::Lag, - protobuf::BuiltInWindowFunction::Lead => BuiltInWindowFunction::Lead, - protobuf::BuiltInWindowFunction::FirstValue => { - BuiltInWindowFunction::FirstValue - } - protobuf::BuiltInWindowFunction::CumeDist => BuiltInWindowFunction::CumeDist, - protobuf::BuiltInWindowFunction::Ntile => BuiltInWindowFunction::Ntile, - protobuf::BuiltInWindowFunction::NthValue => BuiltInWindowFunction::NthValue, - protobuf::BuiltInWindowFunction::LastValue => { - BuiltInWindowFunction::LastValue - } - } - } -} diff --git a/ballista/rust/core/src/serde/logical_plan/mod.rs b/ballista/rust/core/src/serde/logical_plan/mod.rs index d2792b09fa16..0d27c58ac292 100644 --- a/ballista/rust/core/src/serde/logical_plan/mod.rs +++ b/ballista/rust/core/src/serde/logical_plan/mod.rs @@ -26,7 +26,9 @@ mod roundtrip_tests { use core::panic; use datafusion::{ arrow::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUnit}, - logical_plan::{Expr, LogicalPlan, LogicalPlanBuilder, Partitioning, ToDFSchema}, + logical_plan::{ + col, Expr, LogicalPlan, LogicalPlanBuilder, Partitioning, ToDFSchema, + }, physical_plan::{csv::CsvReadOptions, functions::BuiltinScalarFunction::Sqrt}, prelude::*, scalar::ScalarValue, @@ -61,10 +63,8 @@ mod roundtrip_tests { let test_batch_sizes = [usize::MIN, usize::MAX, 43256]; - let test_expr: Vec = vec![ - Expr::Column("c1".to_string()) + Expr::Column("c2".to_string()), - Expr::Literal((4.0).into()), - ]; + let test_expr: Vec = + vec![col("c1") + col("c2"), Expr::Literal((4.0).into())]; let schema = Schema::new(vec![ Field::new("id", DataType::Int32, false), @@ -688,15 +688,20 @@ mod roundtrip_tests { Field::new("salary", DataType::Int32, false), ]); - let scan_plan = LogicalPlanBuilder::empty(false) - .build() - .map_err(BallistaError::DataFusionError)?; + let scan_plan = LogicalPlanBuilder::scan_csv( + "employee1", + CsvReadOptions::new().schema(&schema).has_header(true), + Some(vec![0, 3, 4]), + )? + .build() + .map_err(BallistaError::DataFusionError)?; + let plan = LogicalPlanBuilder::scan_csv( - "employee.csv", + "employee2", CsvReadOptions::new().schema(&schema).has_header(true), - Some(vec![3, 4]), + Some(vec![0, 3, 4]), ) - .and_then(|plan| plan.join(&scan_plan, JoinType::Inner, &["id"], &["id"])) + .and_then(|plan| plan.join(&scan_plan, JoinType::Inner, vec!["id"], vec!["id"])) .and_then(|plan| plan.build()) .map_err(BallistaError::DataFusionError)?; @@ -779,7 +784,7 @@ mod roundtrip_tests { #[test] fn roundtrip_is_null() -> Result<()> { - let test_expr = Expr::IsNull(Box::new(Expr::Column("id".into()))); + let test_expr = Expr::IsNull(Box::new(col("id"))); roundtrip_test!(test_expr, protobuf::LogicalExprNode, Expr); @@ -788,7 +793,7 @@ mod roundtrip_tests { #[test] fn roundtrip_is_not_null() -> Result<()> { - let test_expr = Expr::IsNotNull(Box::new(Expr::Column("id".into()))); + let test_expr = Expr::IsNotNull(Box::new(col("id"))); roundtrip_test!(test_expr, protobuf::LogicalExprNode, Expr); diff --git a/ballista/rust/core/src/serde/logical_plan/to_proto.rs b/ballista/rust/core/src/serde/logical_plan/to_proto.rs index c454d03257f0..24e2b56bad86 100644 --- a/ballista/rust/core/src/serde/logical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/to_proto.rs @@ -26,7 +26,7 @@ use datafusion::arrow::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUn use datafusion::datasource::CsvFile; use datafusion::logical_plan::{ window_frames::{WindowFrame, WindowFrameBound, WindowFrameUnits}, - Expr, JoinType, LogicalPlan, + Column, Expr, JoinType, LogicalPlan, }; use datafusion::physical_plan::aggregates::AggregateFunction; use datafusion::physical_plan::functions::BuiltinScalarFunction; @@ -816,8 +816,8 @@ impl TryInto for &LogicalPlan { JoinType::Semi => protobuf::JoinType::Semi, JoinType::Anti => protobuf::JoinType::Anti, }; - let left_join_column = on.iter().map(|on| on.0.to_owned()).collect(); - let right_join_column = on.iter().map(|on| on.1.to_owned()).collect(); + let (left_join_column, right_join_column) = + on.iter().map(|(l, r)| (l.into(), r.into())).unzip(); Ok(protobuf::LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Join(Box::new( protobuf::JoinNode { @@ -908,13 +908,6 @@ impl TryInto for &LogicalPlan { schema: df_schema, } => { use datafusion::sql::parser::FileType; - let schema: Schema = df_schema.as_ref().clone().into(); - let pb_schema: protobuf::Schema = (&schema).try_into().map_err(|e| { - BallistaError::General(format!( - "Could not convert schema into protobuf: {:?}", - e - )) - })?; let pb_file_type: protobuf::FileType = match file_type { FileType::NdJson => protobuf::FileType::NdJson, @@ -929,7 +922,7 @@ impl TryInto for &LogicalPlan { location: location.clone(), file_type: pb_file_type as i32, has_header: *has_header, - schema: Some(pb_schema), + schema: Some(df_schema.into()), }, )), }) @@ -971,9 +964,9 @@ impl TryInto for &Expr { use datafusion::scalar::ScalarValue; use protobuf::scalar_value::Value; match self { - Expr::Column(name) => { + Expr::Column(c) => { let expr = protobuf::LogicalExprNode { - expr_type: Some(ExprType::ColumnName(name.clone())), + expr_type: Some(ExprType::Column(c.into())), }; Ok(expr) } @@ -1214,6 +1207,23 @@ impl TryInto for &Expr { } } +impl From for protobuf::Column { + fn from(c: Column) -> protobuf::Column { + protobuf::Column { + relation: c + .relation + .map(|relation| protobuf::ColumnRelation { relation }), + name: c.name, + } + } +} + +impl From<&Column> for protobuf::Column { + fn from(c: &Column) -> protobuf::Column { + c.clone().into() + } +} + #[allow(clippy::from_over_into)] impl Into for &Schema { fn into(self) -> protobuf::Schema { @@ -1227,6 +1237,24 @@ impl Into for &Schema { } } +impl From<&datafusion::logical_plan::DFField> for protobuf::DfField { + fn from(f: &datafusion::logical_plan::DFField) -> protobuf::DfField { + protobuf::DfField { + field: Some(f.field().into()), + qualifier: f.qualifier().map(|r| protobuf::ColumnRelation { + relation: r.to_string(), + }), + } + } +} + +impl From<&datafusion::logical_plan::DFSchemaRef> for protobuf::DfSchema { + fn from(s: &datafusion::logical_plan::DFSchemaRef) -> protobuf::DfSchema { + let columns = s.fields().iter().map(|f| f.into()).collect::>(); + protobuf::DfSchema { columns } + } +} + impl From<&AggregateFunction> for protobuf::AggregateFunction { fn from(value: &AggregateFunction) -> Self { match value { diff --git a/ballista/rust/core/src/serde/mod.rs b/ballista/rust/core/src/serde/mod.rs index b96163999f39..af83660baab5 100644 --- a/ballista/rust/core/src/serde/mod.rs +++ b/ballista/rust/core/src/serde/mod.rs @@ -20,6 +20,10 @@ use std::{convert::TryInto, io::Cursor}; +use datafusion::logical_plan::Operator; +use datafusion::physical_plan::aggregates::AggregateFunction; +use datafusion::physical_plan::window_functions::BuiltInWindowFunction; + use crate::{error::BallistaError, serde::scheduler::Action as BallistaAction}; use prost::Message; @@ -57,6 +61,17 @@ macro_rules! convert_required { }}; } +#[macro_export] +macro_rules! into_required { + ($PB:expr) => {{ + if let Some(field) = $PB.as_ref() { + Ok(field.into()) + } else { + Err(proto_error("Missing required field in protobuf")) + } + }}; +} + #[macro_export] macro_rules! convert_box_required { ($PB:expr) => {{ @@ -67,3 +82,212 @@ macro_rules! convert_box_required { } }}; } + +pub(crate) fn from_proto_binary_op(op: &str) -> Result { + match op { + "And" => Ok(Operator::And), + "Or" => Ok(Operator::Or), + "Eq" => Ok(Operator::Eq), + "NotEq" => Ok(Operator::NotEq), + "LtEq" => Ok(Operator::LtEq), + "Lt" => Ok(Operator::Lt), + "Gt" => Ok(Operator::Gt), + "GtEq" => Ok(Operator::GtEq), + "Plus" => Ok(Operator::Plus), + "Minus" => Ok(Operator::Minus), + "Multiply" => Ok(Operator::Multiply), + "Divide" => Ok(Operator::Divide), + "Like" => Ok(Operator::Like), + other => Err(proto_error(format!( + "Unsupported binary operator '{:?}'", + other + ))), + } +} + +impl From for AggregateFunction { + fn from(agg_fun: protobuf::AggregateFunction) -> AggregateFunction { + match agg_fun { + protobuf::AggregateFunction::Min => AggregateFunction::Min, + protobuf::AggregateFunction::Max => AggregateFunction::Max, + protobuf::AggregateFunction::Sum => AggregateFunction::Sum, + protobuf::AggregateFunction::Avg => AggregateFunction::Avg, + protobuf::AggregateFunction::Count => AggregateFunction::Count, + } + } +} + +impl From for BuiltInWindowFunction { + fn from(built_in_function: protobuf::BuiltInWindowFunction) -> Self { + match built_in_function { + protobuf::BuiltInWindowFunction::RowNumber => { + BuiltInWindowFunction::RowNumber + } + protobuf::BuiltInWindowFunction::Rank => BuiltInWindowFunction::Rank, + protobuf::BuiltInWindowFunction::PercentRank => { + BuiltInWindowFunction::PercentRank + } + protobuf::BuiltInWindowFunction::DenseRank => { + BuiltInWindowFunction::DenseRank + } + protobuf::BuiltInWindowFunction::Lag => BuiltInWindowFunction::Lag, + protobuf::BuiltInWindowFunction::Lead => BuiltInWindowFunction::Lead, + protobuf::BuiltInWindowFunction::FirstValue => { + BuiltInWindowFunction::FirstValue + } + protobuf::BuiltInWindowFunction::CumeDist => BuiltInWindowFunction::CumeDist, + protobuf::BuiltInWindowFunction::Ntile => BuiltInWindowFunction::Ntile, + protobuf::BuiltInWindowFunction::NthValue => BuiltInWindowFunction::NthValue, + protobuf::BuiltInWindowFunction::LastValue => { + BuiltInWindowFunction::LastValue + } + } + } +} + +impl TryInto + for &protobuf::arrow_type::ArrowTypeEnum +{ + type Error = BallistaError; + fn try_into(self) -> Result { + use datafusion::arrow::datatypes::DataType; + use protobuf::arrow_type; + Ok(match self { + arrow_type::ArrowTypeEnum::None(_) => DataType::Null, + arrow_type::ArrowTypeEnum::Bool(_) => DataType::Boolean, + arrow_type::ArrowTypeEnum::Uint8(_) => DataType::UInt8, + arrow_type::ArrowTypeEnum::Int8(_) => DataType::Int8, + arrow_type::ArrowTypeEnum::Uint16(_) => DataType::UInt16, + arrow_type::ArrowTypeEnum::Int16(_) => DataType::Int16, + arrow_type::ArrowTypeEnum::Uint32(_) => DataType::UInt32, + arrow_type::ArrowTypeEnum::Int32(_) => DataType::Int32, + arrow_type::ArrowTypeEnum::Uint64(_) => DataType::UInt64, + arrow_type::ArrowTypeEnum::Int64(_) => DataType::Int64, + arrow_type::ArrowTypeEnum::Float16(_) => DataType::Float16, + arrow_type::ArrowTypeEnum::Float32(_) => DataType::Float32, + arrow_type::ArrowTypeEnum::Float64(_) => DataType::Float64, + arrow_type::ArrowTypeEnum::Utf8(_) => DataType::Utf8, + arrow_type::ArrowTypeEnum::LargeUtf8(_) => DataType::LargeUtf8, + arrow_type::ArrowTypeEnum::Binary(_) => DataType::Binary, + arrow_type::ArrowTypeEnum::FixedSizeBinary(size) => { + DataType::FixedSizeBinary(*size) + } + arrow_type::ArrowTypeEnum::LargeBinary(_) => DataType::LargeBinary, + arrow_type::ArrowTypeEnum::Date32(_) => DataType::Date32, + arrow_type::ArrowTypeEnum::Date64(_) => DataType::Date64, + arrow_type::ArrowTypeEnum::Duration(time_unit) => { + DataType::Duration(protobuf::TimeUnit::from_i32_to_arrow(*time_unit)?) + } + arrow_type::ArrowTypeEnum::Timestamp(protobuf::Timestamp { + time_unit, + timezone, + }) => DataType::Timestamp( + protobuf::TimeUnit::from_i32_to_arrow(*time_unit)?, + match timezone.len() { + 0 => None, + _ => Some(timezone.to_owned()), + }, + ), + arrow_type::ArrowTypeEnum::Time32(time_unit) => { + DataType::Time32(protobuf::TimeUnit::from_i32_to_arrow(*time_unit)?) + } + arrow_type::ArrowTypeEnum::Time64(time_unit) => { + DataType::Time64(protobuf::TimeUnit::from_i32_to_arrow(*time_unit)?) + } + arrow_type::ArrowTypeEnum::Interval(interval_unit) => DataType::Interval( + protobuf::IntervalUnit::from_i32_to_arrow(*interval_unit)?, + ), + arrow_type::ArrowTypeEnum::Decimal(protobuf::Decimal { + whole, + fractional, + }) => DataType::Decimal(*whole as usize, *fractional as usize), + arrow_type::ArrowTypeEnum::List(list) => { + let list_type: &protobuf::Field = list + .as_ref() + .field_type + .as_ref() + .ok_or_else(|| proto_error("Protobuf deserialization error: List message missing required field 'field_type'"))? + .as_ref(); + DataType::List(Box::new(list_type.try_into()?)) + } + arrow_type::ArrowTypeEnum::LargeList(list) => { + let list_type: &protobuf::Field = list + .as_ref() + .field_type + .as_ref() + .ok_or_else(|| proto_error("Protobuf deserialization error: List message missing required field 'field_type'"))? + .as_ref(); + DataType::LargeList(Box::new(list_type.try_into()?)) + } + arrow_type::ArrowTypeEnum::FixedSizeList(list) => { + let list_type: &protobuf::Field = list + .as_ref() + .field_type + .as_ref() + .ok_or_else(|| proto_error("Protobuf deserialization error: List message missing required field 'field_type'"))? + .as_ref(); + let list_size = list.list_size; + DataType::FixedSizeList(Box::new(list_type.try_into()?), list_size) + } + arrow_type::ArrowTypeEnum::Struct(strct) => DataType::Struct( + strct + .sub_field_types + .iter() + .map(|field| field.try_into()) + .collect::, _>>()?, + ), + arrow_type::ArrowTypeEnum::Union(union) => DataType::Union( + union + .union_types + .iter() + .map(|field| field.try_into()) + .collect::, _>>()?, + ), + arrow_type::ArrowTypeEnum::Dictionary(dict) => { + let pb_key_datatype = dict + .as_ref() + .key + .as_ref() + .ok_or_else(|| proto_error("Protobuf deserialization error: Dictionary message missing required field 'key'"))?; + let pb_value_datatype = dict + .as_ref() + .value + .as_ref() + .ok_or_else(|| proto_error("Protobuf deserialization error: Dictionary message missing required field 'key'"))?; + let key_datatype: DataType = pb_key_datatype.as_ref().try_into()?; + let value_datatype: DataType = pb_value_datatype.as_ref().try_into()?; + DataType::Dictionary(Box::new(key_datatype), Box::new(value_datatype)) + } + }) + } +} + +#[allow(clippy::from_over_into)] +impl Into for protobuf::PrimitiveScalarType { + fn into(self) -> datafusion::arrow::datatypes::DataType { + use datafusion::arrow::datatypes::{DataType, TimeUnit}; + match self { + protobuf::PrimitiveScalarType::Bool => DataType::Boolean, + protobuf::PrimitiveScalarType::Uint8 => DataType::UInt8, + protobuf::PrimitiveScalarType::Int8 => DataType::Int8, + protobuf::PrimitiveScalarType::Uint16 => DataType::UInt16, + protobuf::PrimitiveScalarType::Int16 => DataType::Int16, + protobuf::PrimitiveScalarType::Uint32 => DataType::UInt32, + protobuf::PrimitiveScalarType::Int32 => DataType::Int32, + protobuf::PrimitiveScalarType::Uint64 => DataType::UInt64, + protobuf::PrimitiveScalarType::Int64 => DataType::Int64, + protobuf::PrimitiveScalarType::Float32 => DataType::Float32, + protobuf::PrimitiveScalarType::Float64 => DataType::Float64, + protobuf::PrimitiveScalarType::Utf8 => DataType::Utf8, + protobuf::PrimitiveScalarType::LargeUtf8 => DataType::LargeUtf8, + protobuf::PrimitiveScalarType::Date32 => DataType::Date32, + protobuf::PrimitiveScalarType::TimeMicrosecond => { + DataType::Time64(TimeUnit::Microsecond) + } + protobuf::PrimitiveScalarType::TimeNanosecond => { + DataType::Time64(TimeUnit::Nanosecond) + } + protobuf::PrimitiveScalarType::Null => DataType::Null, + } + } +} diff --git a/ballista/rust/core/src/serde/physical_plan/from_proto.rs b/ballista/rust/core/src/serde/physical_plan/from_proto.rs index a2c9db9ecafb..4b87be4105be 100644 --- a/ballista/rust/core/src/serde/physical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/physical_plan/from_proto.rs @@ -18,17 +18,16 @@ //! Serde code to convert from protocol buffers to Rust data structures. use std::collections::HashMap; -use std::convert::TryInto; +use std::convert::{TryFrom, TryInto}; use std::sync::Arc; use crate::error::BallistaError; use crate::execution_plans::{ShuffleReaderExec, UnresolvedShuffleExec}; use crate::serde::protobuf::repartition_exec_node::PartitionMethod; -use crate::serde::protobuf::LogicalExprNode; use crate::serde::protobuf::ShuffleReaderPartition; use crate::serde::scheduler::PartitionLocation; -use crate::serde::{proto_error, protobuf}; -use crate::{convert_box_required, convert_required}; +use crate::serde::{from_proto_binary_op, proto_error, protobuf}; +use crate::{convert_box_required, convert_required, into_required}; use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef}; use datafusion::catalog::catalog::{ CatalogList, CatalogProvider, MemoryCatalogList, MemoryCatalogProvider, @@ -36,9 +35,8 @@ use datafusion::catalog::catalog::{ use datafusion::execution::context::{ ExecutionConfig, ExecutionContextState, ExecutionProps, }; -use datafusion::logical_plan::{DFSchema, Expr}; -use datafusion::physical_plan::aggregates::AggregateFunction; -use datafusion::physical_plan::expressions::col; +use datafusion::logical_plan::{window_frames::WindowFrame, DFSchema, Expr}; +use datafusion::physical_plan::aggregates::{create_aggregate_expr, AggregateFunction}; use datafusion::physical_plan::hash_aggregate::{AggregateMode, HashAggregateExec}; use datafusion::physical_plan::hash_join::PartitionMode; use datafusion::physical_plan::merge::MergeExec; @@ -46,13 +44,18 @@ use datafusion::physical_plan::planner::DefaultPhysicalPlanner; use datafusion::physical_plan::window_functions::{ BuiltInWindowFunction, WindowFunction, }; -use datafusion::physical_plan::windows::WindowAggExec; +use datafusion::physical_plan::windows::{create_window_expr, WindowAggExec}; use datafusion::physical_plan::{ coalesce_batches::CoalesceBatchesExec, csv::CsvExec, empty::EmptyExec, - expressions::{Avg, Column, PhysicalSortExpr}, + expressions::{ + col, Avg, BinaryExpr, CaseExpr, CastExpr, Column, InListExpr, IsNotNullExpr, + IsNullExpr, Literal, NegativeExpr, NotExpr, PhysicalSortExpr, TryCastExpr, + DEFAULT_DATAFUSION_CAST_OPTIONS, + }, filter::FilterExec, + functions::{self, BuiltinScalarFunction, ScalarFunctionExpr}, hash_join::HashJoinExec, hash_utils::JoinType, limit::{GlobalLimitExec, LocalLimitExec}, @@ -65,7 +68,7 @@ use datafusion::physical_plan::{ use datafusion::physical_plan::{AggregateExpr, ExecutionPlan, PhysicalExpr, WindowExpr}; use datafusion::prelude::CsvReadOptions; use log::debug; -use protobuf::logical_expr_node::ExprType; +use protobuf::physical_expr_node::ExprType; use protobuf::physical_plan_node::PhysicalPlanType; impl TryInto> for &protobuf::PhysicalPlanNode { @@ -86,23 +89,23 @@ impl TryInto> for &protobuf::PhysicalPlanNode { .expr .iter() .zip(projection.expr_name.iter()) - .map(|(expr, name)| { - compile_expr(expr, &input.schema()).map(|e| (e, name.to_string())) - }) - .collect::, _>>()?; + .map(|(expr, name)| Ok((expr.try_into()?, name.to_string()))) + .collect::, String)>, Self::Error>>( + )?; Ok(Arc::new(ProjectionExec::try_new(exprs, input)?)) } PhysicalPlanType::Filter(filter) => { let input: Arc = convert_box_required!(filter.input)?; - let predicate = compile_expr( - filter.expr.as_ref().ok_or_else(|| { + let predicate = filter + .expr + .as_ref() + .ok_or_else(|| { BallistaError::General( "filter (FilterExecNode) in PhysicalPlanNode is missing." .to_owned(), ) - })?, - &input.schema(), - )?; + })? + .try_into()?; Ok(Arc::new(FilterExec::try_new(predicate, input)?)) } PhysicalPlanType::CsvScan(scan) => { @@ -153,7 +156,7 @@ impl TryInto> for &protobuf::PhysicalPlanNode { let expr = hash_part .hash_expr .iter() - .map(|e| compile_expr(e, &input.schema())) + .map(|e| e.try_into()) .collect::>, _>>()?; Ok(Arc::new(RepartitionExec::try_new( @@ -207,25 +210,33 @@ impl TryInto> for &protobuf::PhysicalPlanNode { .clone(); let physical_schema: SchemaRef = SchemaRef::new((&input_schema).try_into()?); - let ctx_state = ExecutionContextState::new(); - let window_agg_expr: Vec<(Expr, String)> = window_agg + + let physical_window_expr: Vec> = window_agg .window_expr .iter() .zip(window_agg.window_expr_name.iter()) - .map(|(expr, name)| expr.try_into().map(|expr| (expr, name.clone()))) - .collect::, _>>()?; - let df_planner = DefaultPhysicalPlanner::default(); - let physical_window_expr = window_agg_expr - .iter() .map(|(expr, name)| { - df_planner.create_window_expr_with_name( - expr, - name.to_string(), - &physical_schema, - &ctx_state, - ) + let expr_type = expr.expr_type.as_ref().ok_or_else(|| { + proto_error("Unexpected empty window physical expression") + })?; + + match expr_type { + ExprType::WindowExpr(window_node) => Ok(create_window_expr( + &convert_required!(window_node.window_function)?, + name.to_owned(), + &[convert_box_required!(window_node.expr)?], + &[], + &[], + Some(WindowFrame::default()), + &physical_schema, + )?), + _ => Err(BallistaError::General( + "Invalid expression for WindowAggrExec".to_string(), + )), + } }) .collect::, _>>()?; + Ok(Arc::new(WindowAggExec::try_new( physical_window_expr, input, @@ -253,16 +264,10 @@ impl TryInto> for &protobuf::PhysicalPlanNode { .iter() .zip(hash_agg.group_expr_name.iter()) .map(|(expr, name)| { - compile_expr(expr, &input.schema()).map(|e| (e, name.to_string())) + expr.try_into().map(|expr| (expr, name.to_string())) }) .collect::, _>>()?; - let logical_agg_expr: Vec<(Expr, String)> = hash_agg - .aggr_expr - .iter() - .zip(hash_agg.aggr_expr_name.iter()) - .map(|(expr, name)| expr.try_into().map(|expr| (expr, name.clone()))) - .collect::, _>>()?; - let ctx_state = ExecutionContextState::new(); + let input_schema = hash_agg .input_schema .as_ref() @@ -274,18 +279,47 @@ impl TryInto> for &protobuf::PhysicalPlanNode { .clone(); let physical_schema: SchemaRef = SchemaRef::new((&input_schema).try_into()?); - let df_planner = DefaultPhysicalPlanner::default(); - let physical_aggr_expr = logical_agg_expr + + let physical_aggr_expr: Vec> = hash_agg + .aggr_expr .iter() + .zip(hash_agg.aggr_expr_name.iter()) .map(|(expr, name)| { - df_planner.create_aggregate_expr_with_name( - expr, - name.to_string(), - &physical_schema, - &ctx_state, - ) + let expr_type = expr.expr_type.as_ref().ok_or_else(|| { + proto_error("Unexpected empty aggregate physical expression") + })?; + + match expr_type { + ExprType::AggregateExpr(agg_node) => { + let aggr_function = + protobuf::AggregateFunction::from_i32( + agg_node.aggr_function, + ) + .ok_or_else( + || { + proto_error(format!( + "Received an unknown aggregate function: {}", + agg_node.aggr_function + )) + }, + )?; + + Ok(create_aggregate_expr( + &aggr_function.into(), + false, + &[convert_box_required!(agg_node.expr)?], + &physical_schema, + name.to_string(), + )?) + } + _ => Err(BallistaError::General( + "Invalid aggregate expression for HashAggregateExec" + .to_string(), + )), + } }) .collect::, _>>()?; + Ok(Arc::new(HashAggregateExec::try_new( agg_mode, group, @@ -298,11 +332,15 @@ impl TryInto> for &protobuf::PhysicalPlanNode { let left: Arc = convert_box_required!(hashjoin.left)?; let right: Arc = convert_box_required!(hashjoin.right)?; - let on: Vec<(String, String)> = hashjoin + let on: Vec<(Column, Column)> = hashjoin .on .iter() - .map(|col| (col.left.clone(), col.right.clone())) - .collect(); + .map(|col| { + let left = into_required!(col.left)?; + let right = into_required!(col.right)?; + Ok((left, right)) + }) + .collect::>()?; let join_type = protobuf::JoinType::from_i32(hashjoin.join_type) .ok_or_else(|| { proto_error(format!( @@ -321,7 +359,7 @@ impl TryInto> for &protobuf::PhysicalPlanNode { Ok(Arc::new(HashJoinExec::try_new( left, right, - &on, + on, &join_type, PartitionMode::CollectLeft, )?)) @@ -358,7 +396,7 @@ impl TryInto> for &protobuf::PhysicalPlanNode { self )) })?; - if let protobuf::logical_expr_node::ExprType::Sort(sort_expr) = expr { + if let protobuf::physical_expr_node::ExprType::Sort(sort_expr) = expr { let expr = sort_expr .expr .as_ref() @@ -370,7 +408,7 @@ impl TryInto> for &protobuf::PhysicalPlanNode { })? .as_ref(); Ok(PhysicalSortExpr { - expr: compile_expr(expr, &input.schema())?, + expr: expr.try_into()?, options: SortOptions { descending: !sort_expr.asc, nulls_first: sort_expr.nulls_first, @@ -403,14 +441,210 @@ impl TryInto> for &protobuf::PhysicalPlanNode { } } -fn compile_expr( - expr: &protobuf::LogicalExprNode, - schema: &Schema, -) -> Result, BallistaError> { - let df_planner = DefaultPhysicalPlanner::default(); - let state = ExecutionContextState::new(); - let expr: Expr = expr.try_into()?; - df_planner - .create_physical_expr(&expr, schema, &state) - .map_err(|e| BallistaError::General(format!("{:?}", e))) +impl From<&protobuf::PhysicalColumn> for Column { + fn from(c: &protobuf::PhysicalColumn) -> Column { + Column::new(&c.name, c.index as usize) + } +} + +impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { + fn from(f: &protobuf::ScalarFunction) -> BuiltinScalarFunction { + use protobuf::ScalarFunction; + match f { + ScalarFunction::Sqrt => BuiltinScalarFunction::Sqrt, + ScalarFunction::Sin => BuiltinScalarFunction::Sin, + ScalarFunction::Cos => BuiltinScalarFunction::Cos, + ScalarFunction::Tan => BuiltinScalarFunction::Tan, + ScalarFunction::Asin => BuiltinScalarFunction::Asin, + ScalarFunction::Acos => BuiltinScalarFunction::Acos, + ScalarFunction::Atan => BuiltinScalarFunction::Atan, + ScalarFunction::Exp => BuiltinScalarFunction::Exp, + ScalarFunction::Log => BuiltinScalarFunction::Log, + ScalarFunction::Log2 => BuiltinScalarFunction::Log2, + ScalarFunction::Log10 => BuiltinScalarFunction::Log10, + ScalarFunction::Floor => BuiltinScalarFunction::Floor, + ScalarFunction::Ceil => BuiltinScalarFunction::Ceil, + ScalarFunction::Round => BuiltinScalarFunction::Round, + ScalarFunction::Trunc => BuiltinScalarFunction::Trunc, + ScalarFunction::Abs => BuiltinScalarFunction::Abs, + ScalarFunction::Signum => BuiltinScalarFunction::Signum, + ScalarFunction::Octetlength => BuiltinScalarFunction::OctetLength, + ScalarFunction::Concat => BuiltinScalarFunction::Concat, + ScalarFunction::Lower => BuiltinScalarFunction::Lower, + ScalarFunction::Upper => BuiltinScalarFunction::Upper, + ScalarFunction::Trim => BuiltinScalarFunction::Trim, + ScalarFunction::Ltrim => BuiltinScalarFunction::Ltrim, + ScalarFunction::Rtrim => BuiltinScalarFunction::Rtrim, + ScalarFunction::Totimestamp => BuiltinScalarFunction::ToTimestamp, + ScalarFunction::Array => BuiltinScalarFunction::Array, + ScalarFunction::Nullif => BuiltinScalarFunction::NullIf, + ScalarFunction::Datetrunc => BuiltinScalarFunction::DateTrunc, + ScalarFunction::Md5 => BuiltinScalarFunction::MD5, + ScalarFunction::Sha224 => BuiltinScalarFunction::SHA224, + ScalarFunction::Sha256 => BuiltinScalarFunction::SHA256, + ScalarFunction::Sha384 => BuiltinScalarFunction::SHA384, + ScalarFunction::Sha512 => BuiltinScalarFunction::SHA512, + ScalarFunction::Ln => BuiltinScalarFunction::Ln, + } + } +} + +impl TryFrom<&protobuf::PhysicalExprNode> for Arc { + type Error = BallistaError; + + fn try_from(expr: &protobuf::PhysicalExprNode) -> Result { + let expr_type = expr + .expr_type + .as_ref() + .ok_or_else(|| proto_error("Unexpected empty physical expression"))?; + + let pexpr: Arc = match expr_type { + ExprType::Column(c) => { + let pcol: Column = c.into(); + Arc::new(pcol) + } + ExprType::Literal(scalar) => { + Arc::new(Literal::new(convert_required!(scalar.value)?)) + } + ExprType::BinaryExpr(binary_expr) => Arc::new(BinaryExpr::new( + convert_box_required!(&binary_expr.l)?, + from_proto_binary_op(&binary_expr.op)?, + convert_box_required!(&binary_expr.r)?, + )), + ExprType::AggregateExpr(_) => { + return Err(BallistaError::General( + "Cannot convert aggregate expr node to physical expression" + .to_owned(), + )); + } + ExprType::WindowExpr(_) => { + return Err(BallistaError::General( + "Cannot convert window expr node to physical expression".to_owned(), + )); + } + ExprType::Sort(_) => { + return Err(BallistaError::General( + "Cannot convert sort expr node to physical expression".to_owned(), + )); + } + ExprType::IsNullExpr(e) => { + Arc::new(IsNullExpr::new(convert_box_required!(e.expr)?)) + } + ExprType::IsNotNullExpr(e) => { + Arc::new(IsNotNullExpr::new(convert_box_required!(e.expr)?)) + } + ExprType::NotExpr(e) => { + Arc::new(NotExpr::new(convert_box_required!(e.expr)?)) + } + ExprType::Negative(e) => { + Arc::new(NegativeExpr::new(convert_box_required!(e.expr)?)) + } + ExprType::InList(e) => Arc::new(InListExpr::new( + convert_box_required!(e.expr)?, + e.list + .iter() + .map(|x| x.try_into()) + .collect::, _>>()?, + e.negated, + )), + ExprType::Case(e) => Arc::new(CaseExpr::try_new( + e.expr.as_ref().map(|e| e.as_ref().try_into()).transpose()?, + e.when_then_expr + .iter() + .map(|e| { + Ok(( + convert_required!(e.when_expr)?, + convert_required!(e.then_expr)?, + )) + }) + .collect::, BallistaError>>()? + .as_slice(), + e.else_expr + .as_ref() + .map(|e| e.as_ref().try_into()) + .transpose()?, + )?), + ExprType::Cast(e) => Arc::new(CastExpr::new( + convert_box_required!(e.expr)?, + convert_required!(e.arrow_type)?, + DEFAULT_DATAFUSION_CAST_OPTIONS, + )), + ExprType::TryCast(e) => Arc::new(TryCastExpr::new( + convert_box_required!(e.expr)?, + convert_required!(e.arrow_type)?, + )), + ExprType::ScalarFunction(e) => { + let scalar_function = protobuf::ScalarFunction::from_i32(e.fun) + .ok_or_else(|| { + proto_error(format!( + "Received an unknown scalar function: {}", + e.fun, + )) + })?; + + let args = e + .args + .iter() + .map(|x| x.try_into()) + .collect::, _>>()?; + + let catalog_list = + Arc::new(MemoryCatalogList::new()) as Arc; + let ctx_state = ExecutionContextState { + catalog_list, + scalar_functions: Default::default(), + var_provider: Default::default(), + aggregate_functions: Default::default(), + config: ExecutionConfig::new(), + execution_props: ExecutionProps::new(), + }; + + let fun_expr = functions::create_physical_fun( + &(&scalar_function).into(), + &ctx_state, + )?; + + Arc::new(ScalarFunctionExpr::new( + &e.name, + fun_expr, + args, + &convert_required!(e.return_type)?, + )) + } + }; + + Ok(pexpr) + } +} + +impl TryFrom<&protobuf::physical_window_expr_node::WindowFunction> for WindowFunction { + type Error = BallistaError; + + fn try_from( + expr: &protobuf::physical_window_expr_node::WindowFunction, + ) -> Result { + match expr { + protobuf::physical_window_expr_node::WindowFunction::AggrFunction(n) => { + let f = protobuf::AggregateFunction::from_i32(*n).ok_or_else(|| { + proto_error(format!( + "Received an unknown window aggregate function: {}", + n + )) + })?; + + Ok(WindowFunction::AggregateFunction(f.into())) + } + protobuf::physical_window_expr_node::WindowFunction::BuiltInFunction(n) => { + let f = + protobuf::BuiltInWindowFunction::from_i32(*n).ok_or_else(|| { + proto_error(format!( + "Received an unknown window builtin function: {}", + n + )) + })?; + + Ok(WindowFunction::BuiltInWindowFunction(f.into())) + } + } + } } diff --git a/ballista/rust/core/src/serde/physical_plan/mod.rs b/ballista/rust/core/src/serde/physical_plan/mod.rs index fdba2152b7f8..c0fe81f0ffb9 100644 --- a/ballista/rust/core/src/serde/physical_plan/mod.rs +++ b/ballista/rust/core/src/serde/physical_plan/mod.rs @@ -30,7 +30,7 @@ mod roundtrip_tests { logical_plan::Operator, physical_plan::{ empty::EmptyExec, - expressions::{binary, lit, InListExpr, NotExpr}, + expressions::{binary, col, lit, InListExpr, NotExpr}, expressions::{Avg, Column, PhysicalSortExpr}, filter::FilterExec, hash_aggregate::{AggregateMode, HashAggregateExec}, @@ -83,35 +83,35 @@ mod roundtrip_tests { let field_a = Field::new("col", DataType::Int64, false); let schema_left = Schema::new(vec![field_a.clone()]); let schema_right = Schema::new(vec![field_a]); + let on = vec![( + Column::new("col", schema_left.index_of("col")?), + Column::new("col", schema_right.index_of("col")?), + )]; roundtrip_test(Arc::new(HashJoinExec::try_new( Arc::new(EmptyExec::new(false, Arc::new(schema_left))), Arc::new(EmptyExec::new(false, Arc::new(schema_right))), - &[("col".to_string(), "col".to_string())], + on, &JoinType::Inner, PartitionMode::CollectLeft, )?)) } - fn col(name: &str) -> Arc { - Arc::new(Column::new(name)) - } - #[test] fn rountrip_hash_aggregate() -> Result<()> { + let field_a = Field::new("a", DataType::Int64, false); + let field_b = Field::new("b", DataType::Int64, false); + let schema = Arc::new(Schema::new(vec![field_a, field_b])); + let groups: Vec<(Arc, String)> = - vec![(col("a"), "unused".to_string())]; + vec![(col("a", &schema)?, "unused".to_string())]; let aggregates: Vec> = vec![Arc::new(Avg::new( - col("b"), + col("b", &schema)?, "AVG(b)".to_string(), DataType::Float64, ))]; - let field_a = Field::new("a", DataType::Int64, false); - let field_b = Field::new("b", DataType::Int64, false); - let schema = Arc::new(Schema::new(vec![field_a, field_b])); - roundtrip_test(Arc::new(HashAggregateExec::try_new( AggregateMode::Final, groups.clone(), @@ -127,9 +127,9 @@ mod roundtrip_tests { let field_b = Field::new("b", DataType::Int64, false); let field_c = Field::new("c", DataType::Int64, false); let schema = Arc::new(Schema::new(vec![field_a, field_b, field_c])); - let not = Arc::new(NotExpr::new(col("a"))); + let not = Arc::new(NotExpr::new(col("a", &schema)?)); let in_list = Arc::new(InListExpr::new( - col("b"), + col("b", &schema)?, vec![ lit(ScalarValue::Int64(Some(1))), lit(ScalarValue::Int64(Some(2))), @@ -150,14 +150,14 @@ mod roundtrip_tests { let schema = Arc::new(Schema::new(vec![field_a, field_b])); let sort_exprs = vec![ PhysicalSortExpr { - expr: col("a"), + expr: col("a", &schema)?, options: SortOptions { descending: true, nulls_first: false, }, }, PhysicalSortExpr { - expr: col("b"), + expr: col("b", &schema)?, options: SortOptions { descending: false, nulls_first: true, diff --git a/ballista/rust/core/src/serde/physical_plan/to_proto.rs b/ballista/rust/core/src/serde/physical_plan/to_proto.rs index 15d5d4b931ff..cf5401b65019 100644 --- a/ballista/rust/core/src/serde/physical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/physical_plan/to_proto.rs @@ -125,8 +125,14 @@ impl TryInto for Arc { .on() .iter() .map(|tuple| protobuf::JoinOn { - left: tuple.0.to_owned(), - right: tuple.1.to_owned(), + left: Some(protobuf::PhysicalColumn { + name: tuple.0.name().to_string(), + index: tuple.0.index() as u32, + }), + right: Some(protobuf::PhysicalColumn { + name: tuple.1.name().to_string(), + index: tuple.1.index() as u32, + }), }) .collect(); let join_type = match exec.join_type() { @@ -300,7 +306,7 @@ impl TryInto for Arc { let pb_partition_method = match exec.partitioning() { Partitioning::Hash(exprs, partition_count) => { - PartitionMethod::Hash(protobuf::HashRepartition { + PartitionMethod::Hash(protobuf::PhysicalHashRepartition { hash_expr: exprs .iter() .map(|expr| expr.clone().try_into()) @@ -330,13 +336,13 @@ impl TryInto for Arc { .expr() .iter() .map(|expr| { - let sort_expr = Box::new(protobuf::SortExprNode { + let sort_expr = Box::new(protobuf::PhysicalSortExprNode { expr: Some(Box::new(expr.expr.to_owned().try_into()?)), asc: !expr.options.descending, nulls_first: expr.options.nulls_first, }); - Ok(protobuf::LogicalExprNode { - expr_type: Some(protobuf::logical_expr_node::ExprType::Sort( + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::Sort( sort_expr, )), }) @@ -373,10 +379,10 @@ impl TryInto for Arc { } } -impl TryInto for Arc { +impl TryInto for Arc { type Error = BallistaError; - fn try_into(self) -> Result { + fn try_into(self) -> Result { let aggr_function = if self.as_any().downcast_ref::().is_some() { Ok(protobuf::AggregateFunction::Avg.into()) } else if self.as_any().downcast_ref::().is_some() { @@ -389,14 +395,14 @@ impl TryInto for Arc { self ))) }?; - let expressions: Vec = self + let expressions: Vec = self .expressions() .iter() .map(|e| e.clone().try_into()) .collect::, BallistaError>>()?; - Ok(protobuf::LogicalExprNode { - expr_type: Some(protobuf::logical_expr_node::ExprType::AggregateExpr( - Box::new(protobuf::AggregateExprNode { + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::AggregateExpr( + Box::new(protobuf::PhysicalAggregateExprNode { aggr_function, expr: Some(Box::new(expressions[0].clone())), }), @@ -405,90 +411,100 @@ impl TryInto for Arc { } } -impl TryFrom> for protobuf::LogicalExprNode { +impl TryFrom> for protobuf::PhysicalExprNode { type Error = BallistaError; fn try_from(value: Arc) -> Result { let expr = value.as_any(); if let Some(expr) = expr.downcast_ref::() { - Ok(protobuf::LogicalExprNode { - expr_type: Some(protobuf::logical_expr_node::ExprType::ColumnName( - expr.name().to_owned(), + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::Column( + protobuf::PhysicalColumn { + name: expr.name().to_string(), + index: expr.index() as u32, + }, )), }) } else if let Some(expr) = expr.downcast_ref::() { - let binary_expr = Box::new(protobuf::BinaryExprNode { + let binary_expr = Box::new(protobuf::PhysicalBinaryExprNode { l: Some(Box::new(expr.left().to_owned().try_into()?)), r: Some(Box::new(expr.right().to_owned().try_into()?)), op: format!("{:?}", expr.op()), }); - Ok(protobuf::LogicalExprNode { - expr_type: Some(protobuf::logical_expr_node::ExprType::BinaryExpr( + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::BinaryExpr( binary_expr, )), }) } else if let Some(expr) = expr.downcast_ref::() { - Ok(protobuf::LogicalExprNode { - expr_type: Some(protobuf::logical_expr_node::ExprType::Case(Box::new( - protobuf::CaseNode { - expr: expr - .expr() - .as_ref() - .map(|exp| exp.clone().try_into().map(Box::new)) - .transpose()?, - when_then_expr: expr - .when_then_expr() - .iter() - .map(|(when_expr, then_expr)| { - try_parse_when_then_expr(when_expr, then_expr) - }) - .collect::, Self::Error>>()?, - else_expr: expr - .else_expr() - .map(|a| a.clone().try_into().map(Box::new)) - .transpose()?, - }, - ))), + Ok(protobuf::PhysicalExprNode { + expr_type: Some( + protobuf::physical_expr_node::ExprType::Case( + Box::new( + protobuf::PhysicalCaseNode { + expr: expr + .expr() + .as_ref() + .map(|exp| exp.clone().try_into().map(Box::new)) + .transpose()?, + when_then_expr: expr + .when_then_expr() + .iter() + .map(|(when_expr, then_expr)| { + try_parse_when_then_expr(when_expr, then_expr) + }) + .collect::, + Self::Error, + >>()?, + else_expr: expr + .else_expr() + .map(|a| a.clone().try_into().map(Box::new)) + .transpose()?, + }, + ), + ), + ), }) } else if let Some(expr) = expr.downcast_ref::() { - Ok(protobuf::LogicalExprNode { - expr_type: Some(protobuf::logical_expr_node::ExprType::NotExpr( - Box::new(protobuf::Not { + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::NotExpr( + Box::new(protobuf::PhysicalNot { expr: Some(Box::new(expr.arg().to_owned().try_into()?)), }), )), }) } else if let Some(expr) = expr.downcast_ref::() { - Ok(protobuf::LogicalExprNode { - expr_type: Some(protobuf::logical_expr_node::ExprType::IsNullExpr( - Box::new(protobuf::IsNull { + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::IsNullExpr( + Box::new(protobuf::PhysicalIsNull { expr: Some(Box::new(expr.arg().to_owned().try_into()?)), }), )), }) } else if let Some(expr) = expr.downcast_ref::() { - Ok(protobuf::LogicalExprNode { - expr_type: Some(protobuf::logical_expr_node::ExprType::IsNotNullExpr( - Box::new(protobuf::IsNotNull { + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::IsNotNullExpr( + Box::new(protobuf::PhysicalIsNotNull { expr: Some(Box::new(expr.arg().to_owned().try_into()?)), }), )), }) } else if let Some(expr) = expr.downcast_ref::() { - Ok(protobuf::LogicalExprNode { + Ok(protobuf::PhysicalExprNode { expr_type: Some( - protobuf::logical_expr_node::ExprType::InList( + protobuf::physical_expr_node::ExprType::InList( Box::new( - protobuf::InListNode { + protobuf::PhysicalInListNode { expr: Some(Box::new(expr.expr().to_owned().try_into()?)), list: expr .list() .iter() .map(|a| a.clone().try_into()) .collect::, + Vec, Self::Error, >>()?, negated: expr.negated(), @@ -498,32 +514,32 @@ impl TryFrom> for protobuf::LogicalExprNode { ), }) } else if let Some(expr) = expr.downcast_ref::() { - Ok(protobuf::LogicalExprNode { - expr_type: Some(protobuf::logical_expr_node::ExprType::Negative( - Box::new(protobuf::NegativeNode { + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::Negative( + Box::new(protobuf::PhysicalNegativeNode { expr: Some(Box::new(expr.arg().to_owned().try_into()?)), }), )), }) } else if let Some(lit) = expr.downcast_ref::() { - Ok(protobuf::LogicalExprNode { - expr_type: Some(protobuf::logical_expr_node::ExprType::Literal( + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::Literal( lit.value().try_into()?, )), }) } else if let Some(cast) = expr.downcast_ref::() { - Ok(protobuf::LogicalExprNode { - expr_type: Some(protobuf::logical_expr_node::ExprType::Cast(Box::new( - protobuf::CastNode { + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::Cast(Box::new( + protobuf::PhysicalCastNode { expr: Some(Box::new(cast.expr().clone().try_into()?)), arrow_type: Some(cast.cast_type().into()), }, ))), }) } else if let Some(cast) = expr.downcast_ref::() { - Ok(protobuf::LogicalExprNode { - expr_type: Some(protobuf::logical_expr_node::ExprType::TryCast( - Box::new(protobuf::TryCastNode { + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::TryCast( + Box::new(protobuf::PhysicalTryCastNode { expr: Some(Box::new(cast.expr().clone().try_into()?)), arrow_type: Some(cast.cast_type().into()), }), @@ -533,16 +549,18 @@ impl TryFrom> for protobuf::LogicalExprNode { let fun: BuiltinScalarFunction = BuiltinScalarFunction::from_str(expr.name())?; let fun: protobuf::ScalarFunction = (&fun).try_into()?; - let expr: Vec = expr + let args: Vec = expr .args() .iter() .map(|e| e.to_owned().try_into()) .collect::, _>>()?; - Ok(protobuf::LogicalExprNode { - expr_type: Some(protobuf::logical_expr_node::ExprType::ScalarFunction( - protobuf::ScalarFunctionNode { + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::ScalarFunction( + protobuf::PhysicalScalarFunctionNode { + name: expr.name().to_string(), fun: fun.into(), - expr, + args, + return_type: Some(expr.return_type().into()), }, )), }) @@ -558,8 +576,8 @@ impl TryFrom> for protobuf::LogicalExprNode { fn try_parse_when_then_expr( when_expr: &Arc, then_expr: &Arc, -) -> Result { - Ok(protobuf::WhenThen { +) -> Result { + Ok(protobuf::PhysicalWhenThen { when_expr: Some(when_expr.clone().try_into()?), then_expr: Some(then_expr.clone().try_into()?), }) diff --git a/benchmarks/run.sh b/benchmarks/run.sh index 8e36424da89f..21633d39c23a 100755 --- a/benchmarks/run.sh +++ b/benchmarks/run.sh @@ -20,7 +20,7 @@ set -e # This bash script is meant to be run inside the docker-compose environment. Check the README for instructions cd / -for query in 1 3 5 6 10 12 +for query in 1 3 5 6 7 8 9 10 12 do /tpch benchmark ballista --host ballista-scheduler --port 50050 --query $query --path /data --format tbl --iterations 1 --debug done diff --git a/benchmarks/src/bin/tpch.rs b/benchmarks/src/bin/tpch.rs index 34b8d3a27b19..379c90c69e88 100644 --- a/benchmarks/src/bin/tpch.rs +++ b/benchmarks/src/bin/tpch.rs @@ -712,6 +712,16 @@ mod tests { run_query(6).await } + #[tokio::test] + async fn run_q7() -> Result<()> { + run_query(7).await + } + + #[tokio::test] + async fn run_q8() -> Result<()> { + run_query(8).await + } + #[tokio::test] async fn run_q9() -> Result<()> { run_query(9).await diff --git a/datafusion/src/dataframe.rs b/datafusion/src/dataframe.rs index 9c7c2ef96d6b..507a79861cd5 100644 --- a/datafusion/src/dataframe.rs +++ b/datafusion/src/dataframe.rs @@ -188,6 +188,8 @@ pub trait DataFrame: Send + Sync { right_cols: &[&str], ) -> Result>; + // TODO: add join_using + /// Repartition a DataFrame based on a logical partitioning scheme. /// /// ``` diff --git a/datafusion/src/execution/context.rs b/datafusion/src/execution/context.rs index b42695b0c4c6..926e2db9450a 100644 --- a/datafusion/src/execution/context.rs +++ b/datafusion/src/execution/context.rs @@ -52,7 +52,7 @@ use crate::datasource::TableProvider; use crate::error::{DataFusionError, Result}; use crate::execution::dataframe_impl::DataFrameImpl; use crate::logical_plan::{ - FunctionRegistry, LogicalPlan, LogicalPlanBuilder, ToDFSchema, + FunctionRegistry, LogicalPlan, LogicalPlanBuilder, UNNAMED_TABLE, }; use crate::optimizer::constant_folding::ConstantFolding; use crate::optimizer::filter_push_down::FilterPushDown; @@ -297,18 +297,9 @@ impl ExecutionContext { &mut self, provider: Arc, ) -> Result> { - let schema = provider.schema(); - let table_scan = LogicalPlan::TableScan { - table_name: "".to_string(), - source: provider, - projected_schema: schema.to_dfschema_ref()?, - projection: None, - filters: vec![], - limit: None, - }; Ok(Arc::new(DataFrameImpl::new( self.state.clone(), - &LogicalPlanBuilder::from(&table_scan).build()?, + &LogicalPlanBuilder::scan(UNNAMED_TABLE, provider, None)?.build()?, ))) } @@ -410,22 +401,15 @@ impl ExecutionContext { ) -> Result> { let table_ref = table_ref.into(); let schema = self.state.lock().unwrap().schema_for_ref(table_ref)?; - match schema.table(table_ref.table()) { Some(ref provider) => { - let schema = provider.schema(); - let table_scan = LogicalPlan::TableScan { - table_name: table_ref.table().to_owned(), - source: Arc::clone(provider), - projected_schema: schema.to_dfschema_ref()?, - projection: None, - filters: vec![], - limit: None, - }; - Ok(Arc::new(DataFrameImpl::new( - self.state.clone(), - &LogicalPlanBuilder::from(&table_scan).build()?, - ))) + let plan = LogicalPlanBuilder::scan( + table_ref.table(), + Arc::clone(provider), + None, + )? + .build()?; + Ok(Arc::new(DataFrameImpl::new(self.state.clone(), &plan))) } _ => Err(DataFusionError::Plan(format!( "No table named '{}'", @@ -1038,7 +1022,6 @@ mod tests { let logical_plan = ctx.optimize(&logical_plan)?; let physical_plan = ctx.create_physical_plan(&logical_plan)?; - println!("{:?}", physical_plan); let results = collect_partitioned(physical_plan).await?; @@ -1110,7 +1093,7 @@ mod tests { _ => panic!("expect optimized_plan to be projection"), } - let expected = "Projection: #c2\ + let expected = "Projection: #test.c2\ \n TableScan: test projection=Some([1])"; assert_eq!(format!("{:?}", optimized_plan), expected); @@ -1133,7 +1116,7 @@ mod tests { let schema: Schema = ctx.table("test").unwrap().schema().clone().into(); assert!(!schema.field_with_name("c1")?.is_nullable()); - let plan = LogicalPlanBuilder::scan_empty("", &schema, None)? + let plan = LogicalPlanBuilder::scan_empty(None, &schema, None)? .project(vec![col("c1")])? .build()?; @@ -1183,8 +1166,11 @@ mod tests { _ => panic!("expect optimized_plan to be projection"), } - let expected = "Projection: #b\ - \n TableScan: projection=Some([1])"; + let expected = format!( + "Projection: #{}.b\ + \n TableScan: {} projection=Some([1])", + UNNAMED_TABLE, UNNAMED_TABLE + ); assert_eq!(format!("{:?}", optimized_plan), expected); let physical_plan = ctx.create_physical_plan(&optimized_plan)?; @@ -2138,9 +2124,9 @@ mod tests { Field::new("c2", DataType::UInt32, false), ])); - let plan = LogicalPlanBuilder::scan_empty("", schema.as_ref(), None)? + let plan = LogicalPlanBuilder::scan_empty(None, schema.as_ref(), None)? .aggregate(vec![col("c1")], vec![sum(col("c2"))])? - .project(vec![col("c1"), col("SUM(c2)").alias("total_salary")])? + .project(vec![col("c1"), sum(col("c2")).alias("total_salary")])? .build()?; let plan = ctx.optimize(&plan)?; @@ -2590,7 +2576,7 @@ mod tests { assert_eq!( format!("{:?}", plan), - "Projection: #a, #b, my_add(#a, #b)\n TableScan: t projection=None" + "Projection: #t.a, #t.b, my_add(#t.a, #t.b)\n TableScan: t projection=None" ); let plan = ctx.optimize(&plan)?; diff --git a/datafusion/src/execution/dataframe_impl.rs b/datafusion/src/execution/dataframe_impl.rs index a674e3cdb0f1..99eb7f077c96 100644 --- a/datafusion/src/execution/dataframe_impl.rs +++ b/datafusion/src/execution/dataframe_impl.rs @@ -110,7 +110,12 @@ impl DataFrame for DataFrameImpl { right_cols: &[&str], ) -> Result> { let plan = LogicalPlanBuilder::from(&self.plan) - .join(&right.to_logical_plan(), join_type, left_cols, right_cols)? + .join( + &right.to_logical_plan(), + join_type, + left_cols.to_vec(), + right_cols.to_vec(), + )? .build()?; Ok(Arc::new(DataFrameImpl::new(self.ctx_state.clone(), &plan))) } diff --git a/datafusion/src/logical_plan/builder.rs b/datafusion/src/logical_plan/builder.rs index 6bd5181050fd..4b4ed0fb9d41 100644 --- a/datafusion/src/logical_plan/builder.rs +++ b/datafusion/src/logical_plan/builder.rs @@ -24,19 +24,27 @@ use arrow::{ record_batch::RecordBatch, }; -use super::dfschema::ToDFSchema; -use super::{ - col, exprlist_to_fields, Expr, JoinType, LogicalPlan, PlanType, StringifiedPlan, -}; use crate::datasource::TableProvider; use crate::error::{DataFusionError, Result}; -use crate::logical_plan::{DFField, DFSchema, DFSchemaRef, Partitioning}; use crate::{ datasource::{empty::EmptyTable, parquet::ParquetTable, CsvFile, MemTable}, prelude::CsvReadOptions, }; + +use super::dfschema::ToDFSchema; +use super::{ + exprlist_to_fields, Expr, JoinConstraint, JoinType, LogicalPlan, PlanType, + StringifiedPlan, +}; +use crate::logical_plan::{ + columnize_expr, normalize_col, normalize_cols, Column, DFField, DFSchema, + DFSchemaRef, Partitioning, +}; use std::collections::HashSet; +/// Default table name for unnamed table +pub const UNNAMED_TABLE: &str = "?table?"; + /// Builder for logical plans /// /// ``` @@ -62,7 +70,7 @@ use std::collections::HashSet; /// // FROM employees /// // WHERE salary < 1000 /// let plan = LogicalPlanBuilder::scan_empty( -/// "employee.csv", +/// Some("employee"), /// &employee_schema(), /// None, /// )? @@ -102,7 +110,7 @@ impl LogicalPlanBuilder { projection: Option>, ) -> Result { let provider = Arc::new(MemTable::try_new(schema, partitions)?); - Self::scan("", provider, projection) + Self::scan(UNNAMED_TABLE, provider, projection) } /// Scan a CSV data source @@ -112,7 +120,7 @@ impl LogicalPlanBuilder { projection: Option>, ) -> Result { let provider = Arc::new(CsvFile::try_new(path, options)?); - Self::scan("", provider, projection) + Self::scan(path, provider, projection) } /// Scan a Parquet data source @@ -122,38 +130,53 @@ impl LogicalPlanBuilder { max_concurrency: usize, ) -> Result { let provider = Arc::new(ParquetTable::try_new(path, max_concurrency)?); - Self::scan("", provider, projection) + Self::scan(path, provider, projection) } /// Scan an empty data source, mainly used in tests pub fn scan_empty( - name: &str, + name: Option<&str>, table_schema: &Schema, projection: Option>, ) -> Result { let table_schema = Arc::new(table_schema.clone()); let provider = Arc::new(EmptyTable::new(table_schema)); - Self::scan(name, provider, projection) + Self::scan(name.unwrap_or(UNNAMED_TABLE), provider, projection) } /// Convert a table provider into a builder with a TableScan pub fn scan( - name: &str, + table_name: &str, provider: Arc, projection: Option>, ) -> Result { + if table_name.is_empty() { + return Err(DataFusionError::Plan( + "table_name cannot be empty".to_string(), + )); + } + let schema = provider.schema(); let projected_schema = projection .as_ref() - .map(|p| Schema::new(p.iter().map(|i| schema.field(*i).clone()).collect())) - .map_or(schema, SchemaRef::new) - .to_dfschema_ref()?; + .map(|p| { + DFSchema::new( + p.iter() + .map(|i| { + DFField::from_qualified(table_name, schema.field(*i).clone()) + }) + .collect(), + ) + }) + .unwrap_or_else(|| { + DFSchema::try_from_qualified_schema(table_name, &schema) + })?; let table_scan = LogicalPlan::TableScan { - table_name: name.to_string(), + table_name: table_name.to_string(), source: provider, - projected_schema, + projected_schema: Arc::new(projected_schema), projection, filters: vec![], limit: None, @@ -170,16 +193,21 @@ impl LogicalPlanBuilder { /// * An invalid expression is used (e.g. a `sort` expression) pub fn project(&self, expr: impl IntoIterator) -> Result { let input_schema = self.plan.schema(); + let all_schemas = self.plan.all_schemas(); let mut projected_expr = vec![]; for e in expr { match e { Expr::Wildcard => { (0..input_schema.fields().len()).for_each(|i| { - projected_expr.push(col(input_schema.field(i).name())) + projected_expr + .push(Expr::Column(input_schema.field(i).qualified_column())) }); } - _ => projected_expr.push(e), - }; + _ => projected_expr.push(columnize_expr( + normalize_col(e, &all_schemas)?, + input_schema, + )), + } } validate_unique_names("Projections", projected_expr.iter(), input_schema)?; @@ -195,6 +223,7 @@ impl LogicalPlanBuilder { /// Apply a filter pub fn filter(&self, expr: Expr) -> Result { + let expr = normalize_col(expr, &self.plan.all_schemas())?; Ok(Self::from(&LogicalPlan::Filter { predicate: expr, input: Arc::new(self.plan.clone()), @@ -210,69 +239,103 @@ impl LogicalPlanBuilder { } /// Apply a sort - pub fn sort(&self, expr: impl IntoIterator) -> Result { + pub fn sort(&self, exprs: impl IntoIterator) -> Result { + let schemas = self.plan.all_schemas(); Ok(Self::from(&LogicalPlan::Sort { - expr: expr.into_iter().collect(), + expr: normalize_cols(exprs, &schemas)?, input: Arc::new(self.plan.clone()), })) } /// Apply a union pub fn union(&self, plan: LogicalPlan) -> Result { - let schema = self.plan.schema(); + Ok(Self::from(&union_with_alias( + self.plan.clone(), + plan, + None, + )?)) + } - if plan.schema() != schema { + /// Apply a join with on constraint + pub fn join( + &self, + right: &LogicalPlan, + join_type: JoinType, + left_keys: Vec>, + right_keys: Vec>, + ) -> Result { + if left_keys.len() != right_keys.len() { return Err(DataFusionError::Plan( - "Schema's for union should be the same ".to_string(), + "left_keys and right_keys were not the same length".to_string(), )); } - // Add plan to existing union if possible - let mut inputs = match &self.plan { - LogicalPlan::Union { inputs, .. } => inputs.clone(), - _ => vec![self.plan.clone()], - }; - inputs.push(plan); - Ok(Self::from(&LogicalPlan::Union { - inputs, - schema: schema.clone(), - alias: None, + let left_keys: Vec = left_keys + .into_iter() + .map(|c| c.into().normalize(&self.plan.all_schemas())) + .collect::>()?; + let right_keys: Vec = right_keys + .into_iter() + .map(|c| c.into().normalize(&right.all_schemas())) + .collect::>()?; + let on: Vec<(_, _)> = left_keys.into_iter().zip(right_keys.into_iter()).collect(); + let join_schema = build_join_schema( + self.plan.schema(), + right.schema(), + &on, + &join_type, + &JoinConstraint::On, + )?; + + Ok(Self::from(&LogicalPlan::Join { + left: Arc::new(self.plan.clone()), + right: Arc::new(right.clone()), + on, + join_type, + join_constraint: JoinConstraint::On, + schema: DFSchemaRef::new(join_schema), })) } - /// Apply a join - pub fn join( + /// Apply a join with using constraint, which duplicates all join columns in output schema. + pub fn join_using( &self, right: &LogicalPlan, join_type: JoinType, - left_keys: &[&str], - right_keys: &[&str], + using_keys: Vec + Clone>, ) -> Result { - if left_keys.len() != right_keys.len() { - Err(DataFusionError::Plan( - "left_keys and right_keys were not the same length".to_string(), - )) - } else { - let on: Vec<_> = left_keys - .iter() - .zip(right_keys.iter()) - .map(|(x, y)| (x.to_string(), y.to_string())) - .collect::>(); - let join_schema = - build_join_schema(self.plan.schema(), right.schema(), &on, &join_type)?; - Ok(Self::from(&LogicalPlan::Join { - left: Arc::new(self.plan.clone()), - right: Arc::new(right.clone()), - on, - join_type, - schema: DFSchemaRef::new(join_schema), - })) - } + let left_keys: Vec = using_keys + .clone() + .into_iter() + .map(|c| c.into().normalize(&self.plan.all_schemas())) + .collect::>()?; + let right_keys: Vec = using_keys + .into_iter() + .map(|c| c.into().normalize(&right.all_schemas())) + .collect::>()?; + + let on: Vec<(_, _)> = left_keys.into_iter().zip(right_keys.into_iter()).collect(); + let join_schema = build_join_schema( + self.plan.schema(), + right.schema(), + &on, + &join_type, + &JoinConstraint::Using, + )?; + + Ok(Self::from(&LogicalPlan::Join { + left: Arc::new(self.plan.clone()), + right: Arc::new(right.clone()), + on, + join_type, + join_constraint: JoinConstraint::Using, + schema: DFSchemaRef::new(join_schema), + })) } + /// Apply a cross join pub fn cross_join(&self, right: &LogicalPlan) -> Result { let schema = self.plan.schema().join(right.schema())?; - Ok(Self::from(&LogicalPlan::CrossJoin { left: Arc::new(self.plan.clone()), right: Arc::new(right.clone()), @@ -320,9 +383,9 @@ impl LogicalPlanBuilder { group_expr: impl IntoIterator, aggr_expr: impl IntoIterator, ) -> Result { - let group_expr = group_expr.into_iter().collect::>(); - let aggr_expr = aggr_expr.into_iter().collect::>(); - + let schemas = self.plan.all_schemas(); + let group_expr = normalize_cols(group_expr, &schemas)?; + let aggr_expr = normalize_cols(aggr_expr, &schemas)?; let all_expr = group_expr.iter().chain(aggr_expr.iter()); validate_unique_names("Aggregations", all_expr.clone(), self.plan.schema())?; @@ -363,27 +426,35 @@ impl LogicalPlanBuilder { /// Creates a schema for a join operation. /// The fields from the left side are first -fn build_join_schema( +pub fn build_join_schema( left: &DFSchema, right: &DFSchema, - on: &[(String, String)], + on: &[(Column, Column)], join_type: &JoinType, + join_constraint: &JoinConstraint, ) -> Result { let fields: Vec = match join_type { JoinType::Inner | JoinType::Left | JoinType::Full => { - // remove right-side join keys if they have the same names as the left-side - let duplicate_keys = &on - .iter() - .filter(|(l, r)| l == r) - .map(|on| on.1.to_string()) - .collect::>(); + let duplicate_keys = match join_constraint { + JoinConstraint::On => on + .iter() + .filter(|(l, r)| l == r) + .map(|on| on.1.clone()) + .collect::>(), + // using join requires unique join columns in the output schema, so we mark all + // right join keys as duplicate + JoinConstraint::Using => { + on.iter().map(|on| on.1.clone()).collect::>() + } + }; let left_fields = left.fields().iter(); + // remove right-side join keys if they have the same names as the left-side let right_fields = right .fields() .iter() - .filter(|f| !duplicate_keys.contains(f.name())); + .filter(|f| !duplicate_keys.contains(&f.qualified_column())); // left then right left_fields.chain(right_fields).cloned().collect() @@ -393,17 +464,24 @@ fn build_join_schema( left.fields().clone() } JoinType::Right => { - // remove left-side join keys if they have the same names as the right-side - let duplicate_keys = &on - .iter() - .filter(|(l, r)| l == r) - .map(|on| on.1.to_string()) - .collect::>(); + let duplicate_keys = match join_constraint { + JoinConstraint::On => on + .iter() + .filter(|(l, r)| l == r) + .map(|on| on.1.clone()) + .collect::>(), + // using join requires unique join columns in the output schema, so we mark all + // left join keys as duplicate + JoinConstraint::Using => { + on.iter().map(|on| on.0.clone()).collect::>() + } + }; + // remove left-side join keys if they have the same names as the right-side let left_fields = left .fields() .iter() - .filter(|f| !duplicate_keys.contains(f.name())); + .filter(|f| !duplicate_keys.contains(&f.qualified_column())); let right_fields = right.fields().iter(); @@ -411,6 +489,7 @@ fn build_join_schema( left_fields.chain(right_fields).cloned().collect() } }; + DFSchema::new(fields) } @@ -441,17 +520,56 @@ fn validate_unique_names<'a>( }) } +/// Union two logical plans with an optional alias. +pub fn union_with_alias( + left_plan: LogicalPlan, + right_plan: LogicalPlan, + alias: Option, +) -> Result { + let inputs = vec![left_plan, right_plan] + .into_iter() + .flat_map(|p| match p { + LogicalPlan::Union { inputs, .. } => inputs, + x => vec![x], + }) + .collect::>(); + if inputs.is_empty() { + return Err(DataFusionError::Plan("Empty UNION".to_string())); + } + + let union_schema = (**inputs[0].schema()).clone(); + let union_schema = Arc::new(match alias { + Some(ref alias) => union_schema.replace_qualifier(alias.as_str()), + None => union_schema.strip_qualifiers(), + }); + if !inputs.iter().skip(1).all(|input_plan| { + // union changes all qualifers in resulting schema, so we only need to + // match against arrow schema here, which doesn't include qualifiers + union_schema.matches_arrow_schema(&((**input_plan.schema()).clone().into())) + }) { + return Err(DataFusionError::Plan( + "UNION ALL schemas are expected to be the same".to_string(), + )); + } + + Ok(LogicalPlan::Union { + schema: union_schema, + inputs, + alias, + }) +} + #[cfg(test)] mod tests { use arrow::datatypes::{DataType, Field}; - use super::super::{lit, sum}; + use super::super::{col, lit, sum}; use super::*; #[test] fn plan_builder_simple() -> Result<()> { let plan = LogicalPlanBuilder::scan_empty( - "employee.csv", + Some("employee_csv"), &employee_schema(), Some(vec![0, 3]), )? @@ -459,9 +577,9 @@ mod tests { .project(vec![col("id")])? .build()?; - let expected = "Projection: #id\ - \n Filter: #state Eq Utf8(\"CO\")\ - \n TableScan: employee.csv projection=Some([0, 3])"; + let expected = "Projection: #employee_csv.id\ + \n Filter: #employee_csv.state Eq Utf8(\"CO\")\ + \n TableScan: employee_csv projection=Some([0, 3])"; assert_eq!(expected, format!("{:?}", plan)); @@ -471,7 +589,7 @@ mod tests { #[test] fn plan_builder_aggregate() -> Result<()> { let plan = LogicalPlanBuilder::scan_empty( - "employee.csv", + Some("employee_csv"), &employee_schema(), Some(vec![3, 4]), )? @@ -482,9 +600,9 @@ mod tests { .project(vec![col("state"), col("total_salary")])? .build()?; - let expected = "Projection: #state, #total_salary\ - \n Aggregate: groupBy=[[#state]], aggr=[[SUM(#salary) AS total_salary]]\ - \n TableScan: employee.csv projection=Some([3, 4])"; + let expected = "Projection: #employee_csv.state, #total_salary\ + \n Aggregate: groupBy=[[#employee_csv.state]], aggr=[[SUM(#employee_csv.salary) AS total_salary]]\ + \n TableScan: employee_csv projection=Some([3, 4])"; assert_eq!(expected, format!("{:?}", plan)); @@ -494,7 +612,7 @@ mod tests { #[test] fn plan_builder_sort() -> Result<()> { let plan = LogicalPlanBuilder::scan_empty( - "employee.csv", + Some("employee_csv"), &employee_schema(), Some(vec![3, 4]), )? @@ -505,15 +623,15 @@ mod tests { nulls_first: true, }, Expr::Sort { - expr: Box::new(col("total_salary")), + expr: Box::new(col("salary")), asc: false, nulls_first: false, }, ])? .build()?; - let expected = "Sort: #state ASC NULLS FIRST, #total_salary DESC NULLS LAST\ - \n TableScan: employee.csv projection=Some([3, 4])"; + let expected = "Sort: #employee_csv.state ASC NULLS FIRST, #employee_csv.salary DESC NULLS LAST\ + \n TableScan: employee_csv projection=Some([3, 4])"; assert_eq!(expected, format!("{:?}", plan)); @@ -523,7 +641,7 @@ mod tests { #[test] fn plan_builder_union_combined_single_union() -> Result<()> { let plan = LogicalPlanBuilder::scan_empty( - "employee.csv", + Some("employee_csv"), &employee_schema(), Some(vec![3, 4]), )?; @@ -536,10 +654,10 @@ mod tests { // output has only one union let expected = "Union\ - \n TableScan: employee.csv projection=Some([3, 4])\ - \n TableScan: employee.csv projection=Some([3, 4])\ - \n TableScan: employee.csv projection=Some([3, 4])\ - \n TableScan: employee.csv projection=Some([3, 4])"; + \n TableScan: employee_csv projection=Some([3, 4])\ + \n TableScan: employee_csv projection=Some([3, 4])\ + \n TableScan: employee_csv projection=Some([3, 4])\ + \n TableScan: employee_csv projection=Some([3, 4])"; assert_eq!(expected, format!("{:?}", plan)); @@ -549,9 +667,10 @@ mod tests { #[test] fn projection_non_unique_names() -> Result<()> { let plan = LogicalPlanBuilder::scan_empty( - "employee.csv", + Some("employee_csv"), &employee_schema(), - Some(vec![0, 3]), + // project id and first_name by column index + Some(vec![0, 1]), )? // two columns with the same name => error .project(vec![col("id"), col("first_name").alias("id")]); @@ -560,9 +679,8 @@ mod tests { Err(DataFusionError::Plan(e)) => { assert_eq!( e, - "Projections require unique expression names \ - but the expression \"#id\" at position 0 and \"#first_name AS id\" at \ - position 1 have the same name. Consider aliasing (\"AS\") one of them." + "Schema contains qualified field name 'employee_csv.id' \ + and unqualified field name 'id' which would be ambiguous" ); Ok(()) } @@ -575,9 +693,10 @@ mod tests { #[test] fn aggregate_non_unique_names() -> Result<()> { let plan = LogicalPlanBuilder::scan_empty( - "employee.csv", + Some("employee_csv"), &employee_schema(), - Some(vec![0, 3]), + // project state and salary by column index + Some(vec![3, 4]), )? // two columns with the same name => error .aggregate(vec![col("state")], vec![sum(col("salary")).alias("state")]); @@ -586,9 +705,8 @@ mod tests { Err(DataFusionError::Plan(e)) => { assert_eq!( e, - "Aggregations require unique expression names \ - but the expression \"#state\" at position 0 and \"SUM(#salary) AS state\" at \ - position 1 have the same name. Consider aliasing (\"AS\") one of them." + "Schema contains qualified field name 'employee_csv.state' and \ + unqualified field name 'state' which would be ambiguous" ); Ok(()) } diff --git a/datafusion/src/logical_plan/dfschema.rs b/datafusion/src/logical_plan/dfschema.rs index c5437b3af953..e754addb9da7 100644 --- a/datafusion/src/logical_plan/dfschema.rs +++ b/datafusion/src/logical_plan/dfschema.rs @@ -23,6 +23,7 @@ use std::convert::TryFrom; use std::sync::Arc; use crate::error::{DataFusionError, Result}; +use crate::logical_plan::Column; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use std::fmt::{Display, Formatter}; @@ -88,7 +89,7 @@ impl DFSchema { } /// Create a `DFSchema` from an Arrow schema - pub fn try_from_qualified(qualifier: &str, schema: &Schema) -> Result { + pub fn try_from_qualified_schema(qualifier: &str, schema: &Schema) -> Result { Self::new( schema .fields() @@ -108,6 +109,21 @@ impl DFSchema { Self::new(fields) } + /// Merge a schema into self + pub fn merge(&mut self, other_schema: &DFSchema) { + for field in other_schema.fields() { + // skip duplicate columns + let duplicated_field = match field.qualifier() { + Some(q) => self.field_with_name(Some(q.as_str()), field.name()).is_ok(), + // for unqualifed columns, check as unqualified name + None => self.field_with_unqualified_name(field.name()).is_ok(), + }; + if !duplicated_field { + self.fields.push(field.clone()); + } + } + } + /// Get a list of fields pub fn fields(&self) -> &Vec { &self.fields @@ -119,7 +135,7 @@ impl DFSchema { &self.fields[i] } - /// Find the index of the column with the given name + /// Find the index of the column with the given unqualifed name pub fn index_of(&self, name: &str) -> Result { for i in 0..self.fields.len() { if self.fields[i].name() == name { @@ -129,6 +145,20 @@ impl DFSchema { Err(DataFusionError::Plan(format!("No field named '{}'", name))) } + /// Find the index of the column with the given qualifer and name + pub fn index_of_column(&self, col: &Column) -> Result { + for i in 0..self.fields.len() { + let field = &self.fields[i]; + if field.qualifier() == col.relation.as_ref() && field.name() == &col.name { + return Ok(i); + } + } + Err(DataFusionError::Plan(format!( + "No field matches column '{}'", + col, + ))) + } + /// Find the field with the given name pub fn field_with_name( &self, @@ -150,7 +180,10 @@ impl DFSchema { .filter(|field| field.name() == name) .collect(); match matches.len() { - 0 => Err(DataFusionError::Plan(format!("No field named '{}'", name))), + 0 => Err(DataFusionError::Plan(format!( + "No field with unqualified name '{}'", + name + ))), 1 => Ok(matches[0].to_owned()), _ => Err(DataFusionError::Plan(format!( "Ambiguous reference to field named '{}'", @@ -184,6 +217,62 @@ impl DFSchema { ))), } } + + /// Find the field with the given qualified column + pub fn field_from_qualified_column(&self, column: &Column) -> Result { + match &column.relation { + Some(r) => self.field_with_qualified_name(r, &column.name), + None => self.field_with_unqualified_name(&column.name), + } + } + + /// Check to see if unqualified field names matches field names in Arrow schema + pub fn matches_arrow_schema(&self, arrow_schema: &Schema) -> bool { + self.fields + .iter() + .zip(arrow_schema.fields().iter()) + .all(|(dffield, arrowfield)| dffield.name() == arrowfield.name()) + } + + /// Strip all field qualifier in schema + pub fn strip_qualifiers(self) -> Self { + DFSchema { + fields: self + .fields + .into_iter() + .map(|f| { + if f.qualifier().is_some() { + DFField::new( + None, + f.name(), + f.data_type().to_owned(), + f.is_nullable(), + ) + } else { + f + } + }) + .collect(), + } + } + + /// Replace all field qualifier with new value in schema + pub fn replace_qualifier(self, qualifer: &str) -> Self { + DFSchema { + fields: self + .fields + .into_iter() + .map(|f| { + DFField::new( + Some(qualifer), + f.name(), + f.data_type().to_owned(), + f.is_nullable(), + ) + }) + .collect(), + } + } } impl Into for DFSchema { @@ -195,7 +284,7 @@ impl Into for DFSchema { .map(|f| { if f.qualifier().is_some() { Field::new( - f.qualified_name().as_str(), + f.name().as_str(), f.data_type().to_owned(), f.is_nullable(), ) @@ -208,6 +297,13 @@ impl Into for DFSchema { } } +impl Into for &DFSchema { + /// Convert a schema into a DFSchema + fn into(self) -> Schema { + Schema::new(self.fields.iter().map(|f| f.field.clone()).collect()) + } +} + /// Create a `DFSchema` from an Arrow schema impl TryFrom for DFSchema { type Error = DataFusionError; @@ -340,7 +436,7 @@ impl DFField { self.field.is_nullable() } - /// Returns a reference to the `DFField`'s qualified name + /// Returns a string to the `DFField`'s qualified name pub fn qualified_name(&self) -> String { if let Some(relation_name) = &self.qualifier { format!("{}.{}", relation_name, self.field.name()) @@ -349,10 +445,23 @@ impl DFField { } } + /// Builds a qualified column based on self + pub fn qualified_column(&self) -> Column { + Column { + relation: self.qualifier.clone(), + name: self.field.name().to_string(), + } + } + /// Get the optional qualifier pub fn qualifier(&self) -> Option<&String> { self.qualifier.as_ref() } + + /// Get the arrow field + pub fn field(&self) -> &Field { + &self.field + } } #[cfg(test)] @@ -385,25 +494,25 @@ mod tests { #[test] fn from_qualified_schema() -> Result<()> { - let schema = DFSchema::try_from_qualified("t1", &test_schema_1())?; + let schema = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; assert_eq!("t1.c0, t1.c1", schema.to_string()); Ok(()) } #[test] fn from_qualified_schema_into_arrow_schema() -> Result<()> { - let schema = DFSchema::try_from_qualified("t1", &test_schema_1())?; + let schema = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; let arrow_schema: Schema = schema.into(); - let expected = "Field { name: \"t1.c0\", data_type: Boolean, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: None }, \ - Field { name: \"t1.c1\", data_type: Boolean, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: None }"; + let expected = "Field { name: \"c0\", data_type: Boolean, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: None }, \ + Field { name: \"c1\", data_type: Boolean, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: None }"; assert_eq!(expected, arrow_schema.to_string()); Ok(()) } #[test] fn join_qualified() -> Result<()> { - let left = DFSchema::try_from_qualified("t1", &test_schema_1())?; - let right = DFSchema::try_from_qualified("t2", &test_schema_1())?; + let left = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; + let right = DFSchema::try_from_qualified_schema("t2", &test_schema_1())?; let join = left.join(&right)?; assert_eq!("t1.c0, t1.c1, t2.c0, t2.c1", join.to_string()); // test valid access @@ -418,8 +527,8 @@ mod tests { #[test] fn join_qualified_duplicate() -> Result<()> { - let left = DFSchema::try_from_qualified("t1", &test_schema_1())?; - let right = DFSchema::try_from_qualified("t1", &test_schema_1())?; + let left = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; + let right = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; let join = left.join(&right); assert!(join.is_err()); assert_eq!( @@ -446,7 +555,7 @@ mod tests { #[test] fn join_mixed() -> Result<()> { - let left = DFSchema::try_from_qualified("t1", &test_schema_1())?; + let left = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; let right = DFSchema::try_from(test_schema_2())?; let join = left.join(&right)?; assert_eq!("t1.c0, t1.c1, c100, c101", join.to_string()); @@ -464,7 +573,7 @@ mod tests { #[test] fn join_mixed_duplicate() -> Result<()> { - let left = DFSchema::try_from_qualified("t1", &test_schema_1())?; + let left = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; let right = DFSchema::try_from(test_schema_1())?; let join = left.join(&right); assert!(join.is_err()); diff --git a/datafusion/src/logical_plan/expr.rs b/datafusion/src/logical_plan/expr.rs index 58dba16f02ef..1c5cc770c94f 100644 --- a/datafusion/src/logical_plan/expr.rs +++ b/datafusion/src/logical_plan/expr.rs @@ -20,7 +20,7 @@ pub use super::Operator; use crate::error::{DataFusionError, Result}; -use crate::logical_plan::{window_frames, DFField, DFSchema}; +use crate::logical_plan::{window_frames, DFField, DFSchema, DFSchemaRef}; use crate::physical_plan::{ aggregates, expressions::binary_operator_data_type, functions, udf::ScalarUDF, window_functions, @@ -33,6 +33,90 @@ use std::collections::HashSet; use std::fmt; use std::sync::Arc; +/// A named reference to a qualified field in a schema. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct Column { + /// relation/table name. + pub relation: Option, + /// field/column name. + pub name: String, +} + +impl Column { + /// Create Column from unqualified name. + pub fn from_name(name: String) -> Self { + Self { + relation: None, + name, + } + } + + /// Deserialize a fully qualified name string into a column + pub fn from_qualified_name(flat_name: &str) -> Self { + use sqlparser::tokenizer::Token; + + let dialect = sqlparser::dialect::GenericDialect {}; + let mut tokenizer = sqlparser::tokenizer::Tokenizer::new(&dialect, flat_name); + if let Ok(tokens) = tokenizer.tokenize() { + if let [Token::Word(relation), Token::Period, Token::Word(name)] = + tokens.as_slice() + { + return Column { + relation: Some(relation.value.clone()), + name: name.value.clone(), + }; + } + } + // any expression that's not in the form of `foo.bar` will be treated as unqualified column + // name + Column { + relation: None, + name: String::from(flat_name), + } + } + + /// Serialize column into a flat name string + pub fn flat_name(&self) -> String { + match &self.relation { + Some(r) => format!("{}.{}", r, self.name), + None => self.name.clone(), + } + } + + /// Normalize Column with qualifier based on provided dataframe schemas. + pub fn normalize(self, schemas: &[&DFSchemaRef]) -> Result { + if self.relation.is_some() { + return Ok(self); + } + + for schema in schemas { + if let Ok(field) = schema.field_with_unqualified_name(&self.name) { + return Ok(field.qualified_column()); + } + } + + Err(DataFusionError::Plan(format!( + "Column {} not found in provided schemas", + self + ))) + } +} + +impl From<&str> for Column { + fn from(c: &str) -> Self { + Self::from_qualified_name(c) + } +} + +impl fmt::Display for Column { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match &self.relation { + Some(r) => write!(f, "#{}.{}", r, self.name), + None => write!(f, "#{}", self.name), + } + } +} + /// `Expr` is a central struct of DataFusion's query API, and /// represent logical expressions such as `A + 1`, or `CAST(c1 AS /// int)`. @@ -47,7 +131,7 @@ use std::sync::Arc; /// ``` /// # use datafusion::logical_plan::*; /// let expr = col("c1"); -/// assert_eq!(expr, Expr::Column("c1".to_string())); +/// assert_eq!(expr, Expr::Column(Column::from_name("c1".to_string()))); /// ``` /// /// ## Create the expression `c1 + c2` to add columns "c1" and "c2" together @@ -81,8 +165,8 @@ use std::sync::Arc; pub enum Expr { /// An expression with a specific name. Alias(Box, String), - /// A named reference to a field in a schema. - Column(String), + /// A named reference to a qualified filed in a schema. + Column(Column), /// A named reference to a variable in a registry. ScalarVariable(Vec), /// A constant value. @@ -232,10 +316,9 @@ impl Expr { pub fn get_type(&self, schema: &DFSchema) -> Result { match self { Expr::Alias(expr, _) => expr.get_type(schema), - Expr::Column(name) => Ok(schema - .field_with_unqualified_name(name)? - .data_type() - .clone()), + Expr::Column(c) => { + Ok(schema.field_from_qualified_column(c)?.data_type().clone()) + } Expr::ScalarVariable(_) => Ok(DataType::Utf8), Expr::Literal(l) => Ok(l.get_datatype()), Expr::Case { when_then_expr, .. } => when_then_expr[0].1.get_type(schema), @@ -307,9 +390,9 @@ impl Expr { pub fn nullable(&self, input_schema: &DFSchema) -> Result { match self { Expr::Alias(expr, _) => expr.nullable(input_schema), - Expr::Column(name) => Ok(input_schema - .field_with_unqualified_name(name)? - .is_nullable()), + Expr::Column(c) => { + Ok(input_schema.field_from_qualified_column(c)?.is_nullable()) + } Expr::Literal(value) => Ok(value.is_null()), Expr::ScalarVariable(_) => Ok(true), Expr::Case { @@ -355,7 +438,7 @@ impl Expr { } } - /// Returns the name of this expression based on [arrow::datatypes::Schema]. + /// Returns the name of this expression based on [crate::logical_plan::DFSchema]. /// /// This represents how a column with this expression is named when no alias is chosen pub fn name(&self, input_schema: &DFSchema) -> Result { @@ -364,12 +447,20 @@ impl Expr { /// Returns a [arrow::datatypes::Field] compatible with this expression. pub fn to_field(&self, input_schema: &DFSchema) -> Result { - Ok(DFField::new( - None, //TODO qualifier - &self.name(input_schema)?, - self.get_type(input_schema)?, - self.nullable(input_schema)?, - )) + match self { + Expr::Column(c) => Ok(DFField::new( + c.relation.as_deref(), + &c.name, + self.get_type(input_schema)?, + self.nullable(input_schema)?, + )), + _ => Ok(DFField::new( + None, + &self.name(input_schema)?, + self.get_type(input_schema)?, + self.nullable(input_schema)?, + )), + } } /// Wraps this expression in a cast to a target [arrow::datatypes::DataType]. @@ -540,7 +631,7 @@ impl Expr { // recurse (and cover all expression types) let visitor = match self { Expr::Alias(expr, _) => expr.accept(visitor), - Expr::Column(..) => Ok(visitor), + Expr::Column(_) => Ok(visitor), Expr::ScalarVariable(..) => Ok(visitor), Expr::Literal(..) => Ok(visitor), Expr::BinaryExpr { left, right, .. } => { @@ -668,7 +759,7 @@ impl Expr { // recurse into all sub expressions(and cover all expression types) let expr = match self { Expr::Alias(expr, name) => Expr::Alias(rewrite_boxed(expr, rewriter)?, name), - Expr::Column(name) => Expr::Column(name), + Expr::Column(_) => self.clone(), Expr::ScalarVariable(names) => Expr::ScalarVariable(names), Expr::Literal(value) => Expr::Literal(value), Expr::BinaryExpr { left, op, right } => Expr::BinaryExpr { @@ -985,9 +1076,72 @@ pub fn or(left: Expr, right: Expr) -> Expr { } } -/// Create a column expression based on a column name -pub fn col(name: &str) -> Expr { - Expr::Column(name.to_owned()) +/// Create a column expression based on a qualified or unqualified column name +pub fn col(ident: &str) -> Expr { + Expr::Column(ident.into()) +} + +/// Convert an expression into Column expression if it's already provided as input plan. +/// +/// For example, it rewrites: +/// +/// ```ignore +/// .aggregate(vec![col("c1")], vec![sum(col("c2"))])? +/// .project(vec![col("c1"), sum(col("c2"))? +/// ``` +/// +/// Into: +/// +/// ```ignore +/// .aggregate(vec![col("c1")], vec![sum(col("c2"))])? +/// .project(vec![col("c1"), col("SUM(#c2)")? +/// ``` +pub fn columnize_expr(e: Expr, input_schema: &DFSchema) -> Expr { + match e { + Expr::Column(_) => e, + Expr::Alias(inner_expr, name) => { + Expr::Alias(Box::new(columnize_expr(*inner_expr, input_schema)), name) + } + _ => match e.name(input_schema) { + Ok(name) => match input_schema.field_with_unqualified_name(&name) { + Ok(field) => Expr::Column(field.qualified_column()), + // expression not provided as input, do not convert to a column reference + Err(_) => e, + }, + Err(_) => e, + }, + } +} + +/// Recursively normalize all Column expressions in a given expression tree +pub fn normalize_col(e: Expr, schemas: &[&DFSchemaRef]) -> Result { + struct ColumnNormalizer<'a, 'b> { + schemas: &'a [&'b DFSchemaRef], + } + + impl<'a, 'b> ExprRewriter for ColumnNormalizer<'a, 'b> { + fn mutate(&mut self, expr: Expr) -> Result { + if let Expr::Column(c) = expr { + Ok(Expr::Column(c.normalize(self.schemas)?)) + } else { + Ok(expr) + } + } + } + + e.rewrite(&mut ColumnNormalizer { schemas }) +} + +/// Recursively normalize all Column expressions in a list of expression trees +#[inline] +pub fn normalize_cols( + exprs: impl IntoIterator, + schemas: &[&DFSchemaRef], +) -> Result> { + exprs + .into_iter() + .map(|e| normalize_col(e, schemas)) + .collect() } /// Create an expression to represent the min() aggregate function @@ -1240,7 +1394,7 @@ impl fmt::Debug for Expr { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { Expr::Alias(expr, alias) => write!(f, "{:?} AS {}", expr, alias), - Expr::Column(name) => write!(f, "#{}", name), + Expr::Column(c) => write!(f, "{}", c), Expr::ScalarVariable(var_names) => write!(f, "{}", var_names.join(".")), Expr::Literal(v) => write!(f, "{:?}", v), Expr::Case { @@ -1373,7 +1527,7 @@ fn create_function_name( fn create_name(e: &Expr, input_schema: &DFSchema) -> Result { match e { Expr::Alias(_, name) => Ok(name.clone()), - Expr::Column(name) => Ok(name.clone()), + Expr::Column(c) => Ok(c.flat_name()), Expr::ScalarVariable(variable_names) => Ok(variable_names.join(".")), Expr::Literal(value) => Ok(format!("{:?}", value)), Expr::BinaryExpr { left, op, right } => { @@ -1524,8 +1678,8 @@ mod tests { #[test] fn filter_is_null_and_is_not_null() { - let col_null = Expr::Column("col1".to_string()); - let col_not_null = Expr::Column("col2".to_string()); + let col_null = col("col1"); + let col_not_null = col("col2"); assert_eq!(format!("{:?}", col_null.is_null()), "#col1 IS NULL"); assert_eq!( format!("{:?}", col_not_null.is_not_null()), diff --git a/datafusion/src/logical_plan/mod.rs b/datafusion/src/logical_plan/mod.rs index 4a39e114d53f..69d03d22bb21 100644 --- a/datafusion/src/logical_plan/mod.rs +++ b/datafusion/src/logical_plan/mod.rs @@ -30,22 +30,26 @@ mod operators; mod plan; mod registry; pub mod window_frames; -pub use builder::LogicalPlanBuilder; +pub use builder::{ + build_join_schema, union_with_alias, LogicalPlanBuilder, UNNAMED_TABLE, +}; pub use dfschema::{DFField, DFSchema, DFSchemaRef, ToDFSchema}; pub use display::display_schema; pub use expr::{ abs, acos, and, array, ascii, asin, atan, avg, binary_expr, bit_length, btrim, case, - ceil, character_length, chr, col, combine_filters, concat, concat_ws, cos, count, - count_distinct, create_udaf, create_udf, exp, exprlist_to_fields, floor, in_list, - initcap, left, length, lit, ln, log10, log2, lower, lpad, ltrim, max, md5, min, now, - octet_length, or, random, regexp_match, regexp_replace, repeat, replace, reverse, - right, round, rpad, rtrim, sha224, sha256, sha384, sha512, signum, sin, split_part, - sqrt, starts_with, strpos, substr, sum, tan, to_hex, translate, trim, trunc, upper, - when, Expr, ExprRewriter, ExpressionVisitor, Literal, Recursion, + ceil, character_length, chr, col, columnize_expr, combine_filters, concat, concat_ws, + cos, count, count_distinct, create_udaf, create_udf, exp, exprlist_to_fields, floor, + in_list, initcap, left, length, lit, ln, log10, log2, lower, lpad, ltrim, max, md5, + min, normalize_col, normalize_cols, now, octet_length, or, random, regexp_match, + regexp_replace, repeat, replace, reverse, right, round, rpad, rtrim, sha224, sha256, + sha384, sha512, signum, sin, split_part, sqrt, starts_with, strpos, substr, sum, tan, + to_hex, translate, trim, trunc, upper, when, Column, Expr, ExprRewriter, + ExpressionVisitor, Literal, Recursion, }; pub use extension::UserDefinedLogicalNode; pub use operators::Operator; pub use plan::{ - JoinType, LogicalPlan, Partitioning, PlanType, PlanVisitor, StringifiedPlan, + JoinConstraint, JoinType, LogicalPlan, Partitioning, PlanType, PlanVisitor, + StringifiedPlan, }; pub use registry::FunctionRegistry; diff --git a/datafusion/src/logical_plan/plan.rs b/datafusion/src/logical_plan/plan.rs index a80bc54b4a2f..256247228213 100644 --- a/datafusion/src/logical_plan/plan.rs +++ b/datafusion/src/logical_plan/plan.rs @@ -17,18 +17,14 @@ //! This module contains the `LogicalPlan` enum that describes queries //! via a logical query plan. -use super::expr::Expr; +use super::display::{GraphvizVisitor, IndentVisitor}; +use super::expr::{Column, Expr}; use super::extension::UserDefinedLogicalNode; -use super::{ - col, - display::{GraphvizVisitor, IndentVisitor}, -}; use crate::datasource::TableProvider; use crate::logical_plan::dfschema::DFSchemaRef; use crate::sql::parser::FileType; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use std::{ - cmp::min, fmt::{self, Display}, sync::Arc, }; @@ -50,6 +46,15 @@ pub enum JoinType { Anti, } +/// Join constraint +#[derive(Debug, Clone, Copy)] +pub enum JoinConstraint { + /// Join ON + On, + /// Join USING + Using, +} + /// A LogicalPlan represents the different types of relational /// operators (such as Projection, Filter, etc) and can be created by /// the SQL query planner and the DataFrame API. @@ -125,9 +130,11 @@ pub enum LogicalPlan { /// Right input right: Arc, /// Equijoin clause expressed as pairs of (left, right) join columns - on: Vec<(String, String)>, + on: Vec<(Column, Column)>, /// Join type join_type: JoinType, + /// Join constraint + join_constraint: JoinConstraint, /// The output schema, containing fields from the left and right inputs schema: DFSchemaRef, }, @@ -312,9 +319,10 @@ impl LogicalPlan { aggr_expr, .. } => group_expr.iter().chain(aggr_expr.iter()).cloned().collect(), - LogicalPlan::Join { on, .. } => { - on.iter().flat_map(|(l, r)| vec![col(l), col(r)]).collect() - } + LogicalPlan::Join { on, .. } => on + .iter() + .flat_map(|(l, r)| vec![Expr::Column(l.clone()), Expr::Column(r.clone())]) + .collect(), LogicalPlan::Sort { expr, .. } => expr.clone(), LogicalPlan::Extension { node } => node.expressions(), // plans without expressions @@ -479,9 +487,9 @@ impl LogicalPlan { /// per node. For example: /// /// ```text - /// Projection: #id - /// Filter: #state Eq Utf8(\"CO\")\ - /// CsvScan: employee.csv projection=Some([0, 3]) + /// Projection: #employee.id + /// Filter: #employee.state Eq Utf8(\"CO\")\ + /// CsvScan: employee projection=Some([0, 3]) /// ``` /// /// ``` @@ -490,15 +498,15 @@ impl LogicalPlan { /// let schema = Schema::new(vec![ /// Field::new("id", DataType::Int32, false), /// ]); - /// let plan = LogicalPlanBuilder::scan_empty("foo.csv", &schema, None).unwrap() + /// let plan = LogicalPlanBuilder::scan_empty(Some("foo_csv"), &schema, None).unwrap() /// .filter(col("id").eq(lit(5))).unwrap() /// .build().unwrap(); /// /// // Format using display_indent /// let display_string = format!("{}", plan.display_indent()); /// - /// assert_eq!("Filter: #id Eq Int32(5)\ - /// \n TableScan: foo.csv projection=None", + /// assert_eq!("Filter: #foo_csv.id Eq Int32(5)\ + /// \n TableScan: foo_csv projection=None", /// display_string); /// ``` pub fn display_indent(&self) -> impl fmt::Display + '_ { @@ -520,9 +528,9 @@ impl LogicalPlan { /// per node that includes the output schema. For example: /// /// ```text - /// Projection: #id [id:Int32]\ - /// Filter: #state Eq Utf8(\"CO\") [id:Int32, state:Utf8]\ - /// TableScan: employee.csv projection=Some([0, 3]) [id:Int32, state:Utf8]"; + /// Projection: #employee.id [id:Int32]\ + /// Filter: #employee.state Eq Utf8(\"CO\") [id:Int32, state:Utf8]\ + /// TableScan: employee projection=Some([0, 3]) [id:Int32, state:Utf8]"; /// ``` /// /// ``` @@ -531,15 +539,15 @@ impl LogicalPlan { /// let schema = Schema::new(vec![ /// Field::new("id", DataType::Int32, false), /// ]); - /// let plan = LogicalPlanBuilder::scan_empty("foo.csv", &schema, None).unwrap() + /// let plan = LogicalPlanBuilder::scan_empty(Some("foo_csv"), &schema, None).unwrap() /// .filter(col("id").eq(lit(5))).unwrap() /// .build().unwrap(); /// /// // Format using display_indent_schema /// let display_string = format!("{}", plan.display_indent_schema()); /// - /// assert_eq!("Filter: #id Eq Int32(5) [id:Int32]\ - /// \n TableScan: foo.csv projection=None [id:Int32]", + /// assert_eq!("Filter: #foo_csv.id Eq Int32(5) [id:Int32]\ + /// \n TableScan: foo_csv projection=None [id:Int32]", /// display_string); /// ``` pub fn display_indent_schema(&self) -> impl fmt::Display + '_ { @@ -571,7 +579,7 @@ impl LogicalPlan { /// let schema = Schema::new(vec![ /// Field::new("id", DataType::Int32, false), /// ]); - /// let plan = LogicalPlanBuilder::scan_empty("foo.csv", &schema, None).unwrap() + /// let plan = LogicalPlanBuilder::scan_empty(Some("foo.csv"), &schema, None).unwrap() /// .filter(col("id").eq(lit(5))).unwrap() /// .build().unwrap(); /// @@ -630,7 +638,7 @@ impl LogicalPlan { /// let schema = Schema::new(vec![ /// Field::new("id", DataType::Int32, false), /// ]); - /// let plan = LogicalPlanBuilder::scan_empty("foo.csv", &schema, None).unwrap() + /// let plan = LogicalPlanBuilder::scan_empty(Some("foo.csv"), &schema, None).unwrap() /// .build().unwrap(); /// /// // Format using display @@ -653,11 +661,10 @@ impl LogicalPlan { ref limit, .. } => { - let sep = " ".repeat(min(1, table_name.len())); write!( f, - "TableScan: {}{}projection={:?}", - table_name, sep, projection + "TableScan: {} projection={:?}", + table_name, projection )?; if !filters.is_empty() { @@ -826,7 +833,7 @@ mod tests { fn display_plan() -> LogicalPlan { LogicalPlanBuilder::scan_empty( - "employee.csv", + Some("employee_csv"), &employee_schema(), Some(vec![0, 3]), ) @@ -843,9 +850,9 @@ mod tests { fn test_display_indent() { let plan = display_plan(); - let expected = "Projection: #id\ - \n Filter: #state Eq Utf8(\"CO\")\ - \n TableScan: employee.csv projection=Some([0, 3])"; + let expected = "Projection: #employee_csv.id\ + \n Filter: #employee_csv.state Eq Utf8(\"CO\")\ + \n TableScan: employee_csv projection=Some([0, 3])"; assert_eq!(expected, format!("{}", plan.display_indent())); } @@ -854,9 +861,9 @@ mod tests { fn test_display_indent_schema() { let plan = display_plan(); - let expected = "Projection: #id [id:Int32]\ - \n Filter: #state Eq Utf8(\"CO\") [id:Int32, state:Utf8]\ - \n TableScan: employee.csv projection=Some([0, 3]) [id:Int32, state:Utf8]"; + let expected = "Projection: #employee_csv.id [id:Int32]\ + \n Filter: #employee_csv.state Eq Utf8(\"CO\") [id:Int32, state:Utf8]\ + \n TableScan: employee_csv projection=Some([0, 3]) [id:Int32, state:Utf8]"; assert_eq!(expected, format!("{}", plan.display_indent_schema())); } @@ -878,12 +885,12 @@ mod tests { ); assert!( graphviz.contains( - r#"[shape=box label="TableScan: employee.csv projection=Some([0, 3])"]"# + r#"[shape=box label="TableScan: employee_csv projection=Some([0, 3])"]"# ), "\n{}", plan.display_graphviz() ); - assert!(graphviz.contains(r#"[shape=box label="TableScan: employee.csv projection=Some([0, 3])\nSchema: [id:Int32, state:Utf8]"]"#), + assert!(graphviz.contains(r#"[shape=box label="TableScan: employee_csv projection=Some([0, 3])\nSchema: [id:Int32, state:Utf8]"]"#), "\n{}", plan.display_graphviz()); assert!( graphviz.contains(r#"// End DataFusion GraphViz Plan"#), @@ -1128,9 +1135,12 @@ mod tests { } fn test_plan() -> LogicalPlan { - let schema = Schema::new(vec![Field::new("id", DataType::Int32, false)]); + let schema = Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("state", DataType::Utf8, false), + ]); - LogicalPlanBuilder::scan_empty("", &schema, Some(vec![0])) + LogicalPlanBuilder::scan_empty(None, &schema, Some(vec![0, 1])) .unwrap() .filter(col("state").eq(lit("CO"))) .unwrap() diff --git a/datafusion/src/optimizer/constant_folding.rs b/datafusion/src/optimizer/constant_folding.rs index d2ac5ce2f383..956f74adc28f 100644 --- a/datafusion/src/optimizer/constant_folding.rs +++ b/datafusion/src/optimizer/constant_folding.rs @@ -293,7 +293,7 @@ mod tests { Field::new("c", DataType::Boolean, false), Field::new("d", DataType::UInt32, false), ]); - LogicalPlanBuilder::scan_empty("test", &schema, None)?.build() + LogicalPlanBuilder::scan_empty(Some("test"), &schema, None)?.build() } fn expr_test_schema() -> DFSchemaRef { @@ -551,9 +551,9 @@ mod tests { .build()?; let expected = "\ - Projection: #a\ - \n Filter: NOT #c\ - \n Filter: #b\ + Projection: #test.a\ + \n Filter: NOT #test.c\ + \n Filter: #test.b\ \n TableScan: test projection=None"; assert_optimized_plan_eq(&plan, expected); @@ -571,10 +571,10 @@ mod tests { .build()?; let expected = "\ - Projection: #a\ + Projection: #test.a\ \n Limit: 1\ - \n Filter: #c\ - \n Filter: NOT #b\ + \n Filter: #test.c\ + \n Filter: NOT #test.b\ \n TableScan: test projection=None"; assert_optimized_plan_eq(&plan, expected); @@ -590,8 +590,8 @@ mod tests { .build()?; let expected = "\ - Projection: #a\ - \n Filter: NOT #b And #c\ + Projection: #test.a\ + \n Filter: NOT #test.b And #test.c\ \n TableScan: test projection=None"; assert_optimized_plan_eq(&plan, expected); @@ -607,8 +607,8 @@ mod tests { .build()?; let expected = "\ - Projection: #a\ - \n Filter: NOT #b Or NOT #c\ + Projection: #test.a\ + \n Filter: NOT #test.b Or NOT #test.c\ \n TableScan: test projection=None"; assert_optimized_plan_eq(&plan, expected); @@ -624,8 +624,8 @@ mod tests { .build()?; let expected = "\ - Projection: #a\ - \n Filter: #b\ + Projection: #test.a\ + \n Filter: #test.b\ \n TableScan: test projection=None"; assert_optimized_plan_eq(&plan, expected); @@ -640,7 +640,7 @@ mod tests { .build()?; let expected = "\ - Projection: #a, #d, NOT #b\ + Projection: #test.a, #test.d, NOT #test.b\ \n TableScan: test projection=None"; assert_optimized_plan_eq(&plan, expected); @@ -659,8 +659,8 @@ mod tests { .build()?; let expected = "\ - Aggregate: groupBy=[[#a, #c]], aggr=[[MAX(#b), MIN(#b)]]\ - \n Projection: #a, #c, #b\ + Aggregate: groupBy=[[#test.a, #test.c]], aggr=[[MAX(#test.b), MIN(#test.b)]]\ + \n Projection: #test.a, #test.c, #test.b\ \n TableScan: test projection=None"; assert_optimized_plan_eq(&plan, expected); diff --git a/datafusion/src/optimizer/eliminate_limit.rs b/datafusion/src/optimizer/eliminate_limit.rs index 1b965f1d02e4..4b5a634889a7 100644 --- a/datafusion/src/optimizer/eliminate_limit.rs +++ b/datafusion/src/optimizer/eliminate_limit.rs @@ -122,7 +122,7 @@ mod tests { // Left side is removed let expected = "Union\ \n EmptyRelation\ - \n Aggregate: groupBy=[[#a]], aggr=[[SUM(#b)]]\ + \n Aggregate: groupBy=[[#test.a]], aggr=[[SUM(#test.b)]]\ \n TableScan: test projection=None"; assert_optimized_plan_eq(&plan, expected); } diff --git a/datafusion/src/optimizer/filter_push_down.rs b/datafusion/src/optimizer/filter_push_down.rs index dc4d5e993a38..e5f8dcfbfffd 100644 --- a/datafusion/src/optimizer/filter_push_down.rs +++ b/datafusion/src/optimizer/filter_push_down.rs @@ -16,7 +16,7 @@ use crate::datasource::datasource::TableProviderFilterPushDown; use crate::execution::context::ExecutionProps; -use crate::logical_plan::{and, LogicalPlan}; +use crate::logical_plan::{and, Column, LogicalPlan}; use crate::logical_plan::{DFSchema, Expr}; use crate::optimizer::optimizer::OptimizerRule; use crate::optimizer::utils; @@ -56,15 +56,15 @@ pub struct FilterPushDown {} #[derive(Debug, Clone, Default)] struct State { // (predicate, columns on the predicate) - filters: Vec<(Expr, HashSet)>, + filters: Vec<(Expr, HashSet)>, } -type Predicates<'a> = (Vec<&'a Expr>, Vec<&'a HashSet>); +type Predicates<'a> = (Vec<&'a Expr>, Vec<&'a HashSet>); /// returns all predicates in `state` that depend on any of `used_columns` fn get_predicates<'a>( state: &'a State, - used_columns: &HashSet, + used_columns: &HashSet, ) -> Predicates<'a> { state .filters @@ -89,19 +89,19 @@ fn get_join_predicates<'a>( left: &DFSchema, right: &DFSchema, ) -> ( - Vec<&'a HashSet>, - Vec<&'a HashSet>, + Vec<&'a HashSet>, + Vec<&'a HashSet>, Predicates<'a>, ) { let left_columns = &left .fields() .iter() - .map(|f| f.name().clone()) + .map(|f| f.qualified_column()) .collect::>(); let right_columns = &right .fields() .iter() - .map(|f| f.name().clone()) + .map(|f| f.qualified_column()) .collect::>(); let filters = state @@ -173,9 +173,9 @@ fn add_filter(plan: LogicalPlan, predicates: &[&Expr]) -> LogicalPlan { // remove all filters from `filters` that are in `predicate_columns` fn remove_filters( - filters: &[(Expr, HashSet)], - predicate_columns: &[&HashSet], -) -> Vec<(Expr, HashSet)> { + filters: &[(Expr, HashSet)], + predicate_columns: &[&HashSet], +) -> Vec<(Expr, HashSet)> { filters .iter() .filter(|(_, columns)| !predicate_columns.contains(&columns)) @@ -185,9 +185,9 @@ fn remove_filters( // keeps all filters from `filters` that are in `predicate_columns` fn keep_filters( - filters: &[(Expr, HashSet)], - predicate_columns: &[&HashSet], -) -> Vec<(Expr, HashSet)> { + filters: &[(Expr, HashSet)], + predicate_columns: &[&HashSet], +) -> Vec<(Expr, HashSet)> { filters .iter() .filter(|(_, columns)| predicate_columns.contains(&columns)) @@ -199,7 +199,7 @@ fn keep_filters( /// in `state` depend on the columns `used_columns`. fn issue_filters( mut state: State, - used_columns: HashSet, + used_columns: HashSet, plan: &LogicalPlan, ) -> Result { let (predicates, predicate_columns) = get_predicates(&state, &used_columns); @@ -248,8 +248,8 @@ fn optimize(plan: &LogicalPlan, mut state: State) -> Result { predicates .into_iter() .try_for_each::<_, Result<()>>(|predicate| { - let mut columns: HashSet = HashSet::new(); - utils::expr_to_column_names(predicate, &mut columns)?; + let mut columns: HashSet = HashSet::new(); + utils::expr_to_columns(predicate, &mut columns)?; if columns.is_empty() { no_col_predicates.push(predicate) } else { @@ -282,7 +282,7 @@ fn optimize(plan: &LogicalPlan, mut state: State) -> Result { expr => expr.clone(), }; - projection.insert(field.name().clone(), expr); + projection.insert(field.qualified_name(), expr); }); // re-write all filters based on this projection @@ -291,7 +291,7 @@ fn optimize(plan: &LogicalPlan, mut state: State) -> Result { *predicate = rewrite(predicate, &projection)?; columns.clear(); - utils::expr_to_column_names(predicate, columns)?; + utils::expr_to_columns(predicate, columns)?; } // optimize inner @@ -308,11 +308,11 @@ fn optimize(plan: &LogicalPlan, mut state: State) -> Result { // construct set of columns that `aggr_expr` depends on let mut used_columns = HashSet::new(); - utils::exprlist_to_column_names(aggr_expr, &mut used_columns)?; + utils::exprlist_to_columns(aggr_expr, &mut used_columns)?; let agg_columns = aggr_expr .iter() - .map(|x| x.name(input.schema())) + .map(|x| Ok(Column::from_name(x.name(input.schema())?))) .collect::>>()?; used_columns.extend(agg_columns); @@ -332,7 +332,7 @@ fn optimize(plan: &LogicalPlan, mut state: State) -> Result { .schema() .fields() .iter() - .map(|f| f.name().clone()) + .map(|f| f.qualified_column()) .collect::>(); issue_filters(state, used_columns, plan) } @@ -415,7 +415,7 @@ fn optimize(plan: &LogicalPlan, mut state: State) -> Result { .schema() .fields() .iter() - .map(|f| f.name().clone()) + .map(|f| f.qualified_column()) .collect::>(); issue_filters(state, used_columns, plan) } @@ -448,8 +448,8 @@ fn rewrite(expr: &Expr, projection: &HashMap) -> Result { .map(|e| rewrite(e, projection)) .collect::>>()?; - if let Expr::Column(name) = expr { - if let Some(expr) = projection.get(name) { + if let Expr::Column(c) = expr { + if let Some(expr) = projection.get(&c.flat_name()) { return Ok(expr.clone()); } } @@ -489,8 +489,8 @@ mod tests { .build()?; // filter is before projection let expected = "\ - Projection: #a, #b\ - \n Filter: #a Eq Int64(1)\ + Projection: #test.a, #test.b\ + \n Filter: #test.a Eq Int64(1)\ \n TableScan: test projection=None"; assert_optimized_plan_eq(&plan, expected); Ok(()) @@ -506,9 +506,9 @@ mod tests { .build()?; // filter is before single projection let expected = "\ - Filter: #a Eq Int64(1)\ + Filter: #test.a Eq Int64(1)\ \n Limit: 10\ - \n Projection: #a, #b\ + \n Projection: #test.a, #test.b\ \n TableScan: test projection=None"; assert_optimized_plan_eq(&plan, expected); Ok(()) @@ -537,9 +537,9 @@ mod tests { .build()?; // filter is before double projection let expected = "\ - Projection: #c, #b\ - \n Projection: #a, #b, #c\ - \n Filter: #a Eq Int64(1)\ + Projection: #test.c, #test.b\ + \n Projection: #test.a, #test.b, #test.c\ + \n Filter: #test.a Eq Int64(1)\ \n TableScan: test projection=None"; assert_optimized_plan_eq(&plan, expected); Ok(()) @@ -554,8 +554,8 @@ mod tests { .build()?; // filter of key aggregation is commutative let expected = "\ - Aggregate: groupBy=[[#a]], aggr=[[SUM(#b) AS total_salary]]\ - \n Filter: #a Gt Int64(10)\ + Aggregate: groupBy=[[#test.a]], aggr=[[SUM(#test.b) AS total_salary]]\ + \n Filter: #test.a Gt Int64(10)\ \n TableScan: test projection=None"; assert_optimized_plan_eq(&plan, expected); Ok(()) @@ -571,7 +571,7 @@ mod tests { // filter of aggregate is after aggregation since they are non-commutative let expected = "\ Filter: #b Gt Int64(10)\ - \n Aggregate: groupBy=[[#a]], aggr=[[SUM(#b) AS b]]\ + \n Aggregate: groupBy=[[#test.a]], aggr=[[SUM(#test.b) AS b]]\ \n TableScan: test projection=None"; assert_optimized_plan_eq(&plan, expected); Ok(()) @@ -587,8 +587,8 @@ mod tests { .build()?; // filter is before projection let expected = "\ - Projection: #a AS b, #c\ - \n Filter: #a Eq Int64(1)\ + Projection: #test.a AS b, #test.c\ + \n Filter: #test.a Eq Int64(1)\ \n TableScan: test projection=None"; assert_optimized_plan_eq(&plan, expected); Ok(()) @@ -627,14 +627,14 @@ mod tests { format!("{:?}", plan), "\ Filter: #b Eq Int64(1)\ - \n Projection: #a Multiply Int32(2) Plus #c AS b, #c\ + \n Projection: #test.a Multiply Int32(2) Plus #test.c AS b, #test.c\ \n TableScan: test projection=None" ); // filter is before projection let expected = "\ - Projection: #a Multiply Int32(2) Plus #c AS b, #c\ - \n Filter: #a Multiply Int32(2) Plus #c Eq Int64(1)\ + Projection: #test.a Multiply Int32(2) Plus #test.c AS b, #test.c\ + \n Filter: #test.a Multiply Int32(2) Plus #test.c Eq Int64(1)\ \n TableScan: test projection=None"; assert_optimized_plan_eq(&plan, expected); Ok(()) @@ -659,16 +659,16 @@ mod tests { format!("{:?}", plan), "\ Filter: #a Eq Int64(1)\ - \n Projection: #b Multiply Int32(3) AS a, #c\ - \n Projection: #a Multiply Int32(2) Plus #c AS b, #c\ + \n Projection: #b Multiply Int32(3) AS a, #test.c\ + \n Projection: #test.a Multiply Int32(2) Plus #test.c AS b, #test.c\ \n TableScan: test projection=None" ); // filter is before the projections let expected = "\ - Projection: #b Multiply Int32(3) AS a, #c\ - \n Projection: #a Multiply Int32(2) Plus #c AS b, #c\ - \n Filter: #a Multiply Int32(2) Plus #c Multiply Int32(3) Eq Int64(1)\ + Projection: #b Multiply Int32(3) AS a, #test.c\ + \n Projection: #test.a Multiply Int32(2) Plus #test.c AS b, #test.c\ + \n Filter: #test.a Multiply Int32(2) Plus #test.c Multiply Int32(3) Eq Int64(1)\ \n TableScan: test projection=None"; assert_optimized_plan_eq(&plan, expected); Ok(()) @@ -684,26 +684,26 @@ mod tests { .project(vec![col("a").alias("b"), col("c")])? .aggregate(vec![col("b")], vec![sum(col("c"))])? .filter(col("b").gt(lit(10i64)))? - .filter(col("SUM(c)").gt(lit(10i64)))? + .filter(col("SUM(test.c)").gt(lit(10i64)))? .build()?; // not part of the test, just good to know: assert_eq!( format!("{:?}", plan), "\ - Filter: #SUM(c) Gt Int64(10)\ + Filter: #SUM(test.c) Gt Int64(10)\ \n Filter: #b Gt Int64(10)\ - \n Aggregate: groupBy=[[#b]], aggr=[[SUM(#c)]]\ - \n Projection: #a AS b, #c\ + \n Aggregate: groupBy=[[#b]], aggr=[[SUM(#test.c)]]\ + \n Projection: #test.a AS b, #test.c\ \n TableScan: test projection=None" ); // filter is before the projections let expected = "\ - Filter: #SUM(c) Gt Int64(10)\ - \n Aggregate: groupBy=[[#b]], aggr=[[SUM(#c)]]\ - \n Projection: #a AS b, #c\ - \n Filter: #a Gt Int64(10)\ + Filter: #SUM(test.c) Gt Int64(10)\ + \n Aggregate: groupBy=[[#b]], aggr=[[SUM(#test.c)]]\ + \n Projection: #test.a AS b, #test.c\ + \n Filter: #test.a Gt Int64(10)\ \n TableScan: test projection=None"; assert_optimized_plan_eq(&plan, expected); @@ -720,8 +720,8 @@ mod tests { .project(vec![col("a").alias("b"), col("c")])? .aggregate(vec![col("b")], vec![sum(col("c"))])? .filter(and( - col("SUM(c)").gt(lit(10i64)), - and(col("b").gt(lit(10i64)), col("SUM(c)").lt(lit(20i64))), + col("SUM(test.c)").gt(lit(10i64)), + and(col("b").gt(lit(10i64)), col("SUM(test.c)").lt(lit(20i64))), ))? .build()?; @@ -729,18 +729,18 @@ mod tests { assert_eq!( format!("{:?}", plan), "\ - Filter: #SUM(c) Gt Int64(10) And #b Gt Int64(10) And #SUM(c) Lt Int64(20)\ - \n Aggregate: groupBy=[[#b]], aggr=[[SUM(#c)]]\ - \n Projection: #a AS b, #c\ + Filter: #SUM(test.c) Gt Int64(10) And #b Gt Int64(10) And #SUM(test.c) Lt Int64(20)\ + \n Aggregate: groupBy=[[#b]], aggr=[[SUM(#test.c)]]\ + \n Projection: #test.a AS b, #test.c\ \n TableScan: test projection=None" ); // filter is before the projections let expected = "\ - Filter: #SUM(c) Gt Int64(10) And #SUM(c) Lt Int64(20)\ - \n Aggregate: groupBy=[[#b]], aggr=[[SUM(#c)]]\ - \n Projection: #a AS b, #c\ - \n Filter: #a Gt Int64(10)\ + Filter: #SUM(test.c) Gt Int64(10) And #SUM(test.c) Lt Int64(20)\ + \n Aggregate: groupBy=[[#b]], aggr=[[SUM(#test.c)]]\ + \n Projection: #test.a AS b, #test.c\ + \n Filter: #test.a Gt Int64(10)\ \n TableScan: test projection=None"; assert_optimized_plan_eq(&plan, expected); @@ -760,11 +760,11 @@ mod tests { .build()?; // filter does not just any of the limits let expected = "\ - Projection: #a, #b\ - \n Filter: #a Eq Int64(1)\ + Projection: #test.a, #test.b\ + \n Filter: #test.a Eq Int64(1)\ \n Limit: 10\ \n Limit: 20\ - \n Projection: #a, #b\ + \n Projection: #test.a, #test.b\ \n TableScan: test projection=None"; assert_optimized_plan_eq(&plan, expected); Ok(()) @@ -804,20 +804,20 @@ mod tests { // not part of the test assert_eq!( format!("{:?}", plan), - "Filter: #a GtEq Int64(1)\ - \n Projection: #a\ + "Filter: #test.a GtEq Int64(1)\ + \n Projection: #test.a\ \n Limit: 1\ - \n Filter: #a LtEq Int64(1)\ - \n Projection: #a\ + \n Filter: #test.a LtEq Int64(1)\ + \n Projection: #test.a\ \n TableScan: test projection=None" ); let expected = "\ - Projection: #a\ - \n Filter: #a GtEq Int64(1)\ + Projection: #test.a\ + \n Filter: #test.a GtEq Int64(1)\ \n Limit: 1\ - \n Projection: #a\ - \n Filter: #a LtEq Int64(1)\ + \n Projection: #test.a\ + \n Filter: #test.a LtEq Int64(1)\ \n TableScan: test projection=None"; assert_optimized_plan_eq(&plan, expected); @@ -838,16 +838,16 @@ mod tests { // not part of the test assert_eq!( format!("{:?}", plan), - "Projection: #a\ - \n Filter: #a GtEq Int64(1)\ - \n Filter: #a LtEq Int64(1)\ + "Projection: #test.a\ + \n Filter: #test.a GtEq Int64(1)\ + \n Filter: #test.a LtEq Int64(1)\ \n Limit: 1\ \n TableScan: test projection=None" ); let expected = "\ - Projection: #a\ - \n Filter: #a GtEq Int64(1) And #a LtEq Int64(1)\ + Projection: #test.a\ + \n Filter: #test.a GtEq Int64(1) And #test.a LtEq Int64(1)\ \n Limit: 1\ \n TableScan: test projection=None"; @@ -868,7 +868,7 @@ mod tests { let expected = "\ TestUserDefined\ - \n Filter: #a LtEq Int64(1)\ + \n Filter: #test.a LtEq Int64(1)\ \n TableScan: test projection=None"; // not part of the test @@ -887,7 +887,12 @@ mod tests { .project(vec![col("a")])? .build()?; let plan = LogicalPlanBuilder::from(&left) - .join(&right, JoinType::Inner, &["a"], &["a"])? + .join( + &right, + JoinType::Inner, + vec![Column::from_name("a".to_string())], + vec![Column::from_name("a".to_string())], + )? .filter(col("a").lt_eq(lit(1i64)))? .build()?; @@ -895,20 +900,20 @@ mod tests { assert_eq!( format!("{:?}", plan), "\ - Filter: #a LtEq Int64(1)\ - \n Join: a = a\ + Filter: #test.a LtEq Int64(1)\ + \n Join: #test.a = #test.a\ \n TableScan: test projection=None\ - \n Projection: #a\ + \n Projection: #test.a\ \n TableScan: test projection=None" ); // filter sent to side before the join let expected = "\ - Join: a = a\ - \n Filter: #a LtEq Int64(1)\ + Join: #test.a = #test.a\ + \n Filter: #test.a LtEq Int64(1)\ \n TableScan: test projection=None\ - \n Projection: #a\ - \n Filter: #a LtEq Int64(1)\ + \n Projection: #test.a\ + \n Filter: #test.a LtEq Int64(1)\ \n TableScan: test projection=None"; assert_optimized_plan_eq(&plan, expected); Ok(()) @@ -925,7 +930,12 @@ mod tests { .project(vec![col("a"), col("b")])? .build()?; let plan = LogicalPlanBuilder::from(&left) - .join(&right, JoinType::Inner, &["a"], &["a"])? + .join( + &right, + JoinType::Inner, + vec![Column::from_name("a".to_string())], + vec![Column::from_name("a".to_string())], + )? // "b" and "c" are not shared by either side: they are only available together after the join .filter(col("c").lt_eq(col("b")))? .build()?; @@ -934,11 +944,11 @@ mod tests { assert_eq!( format!("{:?}", plan), "\ - Filter: #c LtEq #b\ - \n Join: a = a\ - \n Projection: #a, #c\ + Filter: #test.c LtEq #test.b\ + \n Join: #test.a = #test.a\ + \n Projection: #test.a, #test.c\ \n TableScan: test projection=None\ - \n Projection: #a, #b\ + \n Projection: #test.a, #test.b\ \n TableScan: test projection=None" ); @@ -959,7 +969,12 @@ mod tests { .project(vec![col("a"), col("c")])? .build()?; let plan = LogicalPlanBuilder::from(&left) - .join(&right, JoinType::Inner, &["a"], &["a"])? + .join( + &right, + JoinType::Inner, + vec![Column::from_name("a".to_string())], + vec![Column::from_name("a".to_string())], + )? .filter(col("b").lt_eq(lit(1i64)))? .build()?; @@ -967,20 +982,20 @@ mod tests { assert_eq!( format!("{:?}", plan), "\ - Filter: #b LtEq Int64(1)\ - \n Join: a = a\ - \n Projection: #a, #b\ + Filter: #test.b LtEq Int64(1)\ + \n Join: #test.a = #test.a\ + \n Projection: #test.a, #test.b\ \n TableScan: test projection=None\ - \n Projection: #a, #c\ + \n Projection: #test.a, #test.c\ \n TableScan: test projection=None" ); let expected = "\ - Join: a = a\ - \n Projection: #a, #b\ - \n Filter: #b LtEq Int64(1)\ + Join: #test.a = #test.a\ + \n Projection: #test.a, #test.b\ + \n Filter: #test.b LtEq Int64(1)\ \n TableScan: test projection=None\ - \n Projection: #a, #c\ + \n Projection: #test.a, #test.c\ \n TableScan: test projection=None"; assert_optimized_plan_eq(&plan, expected); Ok(()) @@ -1030,14 +1045,15 @@ mod tests { fn table_scan_with_pushdown_provider( filter_support: TableProviderFilterPushDown, ) -> Result { + use std::convert::TryFrom; + let test_provider = PushDownProvider { filter_support }; let table_scan = LogicalPlan::TableScan { - table_name: "".into(), + table_name: "test".to_string(), filters: vec![], - projected_schema: Arc::new(DFSchema::try_from_qualified( - "", - &*test_provider.schema(), + projected_schema: Arc::new(DFSchema::try_from( + (*test_provider.schema()).clone(), )?), projection: None, source: Arc::new(test_provider), @@ -1054,7 +1070,7 @@ mod tests { let plan = table_scan_with_pushdown_provider(TableProviderFilterPushDown::Exact)?; let expected = "\ - TableScan: projection=None, filters=[#a Eq Int64(1)]"; + TableScan: test projection=None, filters=[#a Eq Int64(1)]"; assert_optimized_plan_eq(&plan, expected); Ok(()) } @@ -1066,7 +1082,7 @@ mod tests { let expected = "\ Filter: #a Eq Int64(1)\ - \n TableScan: projection=None, filters=[#a Eq Int64(1)]"; + \n TableScan: test projection=None, filters=[#a Eq Int64(1)]"; assert_optimized_plan_eq(&plan, expected); Ok(()) } @@ -1080,7 +1096,7 @@ mod tests { let expected = "\ Filter: #a Eq Int64(1)\ - \n TableScan: projection=None, filters=[#a Eq Int64(1)]"; + \n TableScan: test projection=None, filters=[#a Eq Int64(1)]"; // Optimizing the same plan multiple times should produce the same plan // each time. @@ -1095,7 +1111,7 @@ mod tests { let expected = "\ Filter: #a Eq Int64(1)\ - \n TableScan: projection=None"; + \n TableScan: test projection=None"; assert_optimized_plan_eq(&plan, expected); Ok(()) } diff --git a/datafusion/src/optimizer/hash_build_probe_order.rs b/datafusion/src/optimizer/hash_build_probe_order.rs index 74d2b0090194..a2a99ae364a7 100644 --- a/datafusion/src/optimizer/hash_build_probe_order.rs +++ b/datafusion/src/optimizer/hash_build_probe_order.rs @@ -22,7 +22,7 @@ use std::sync::Arc; -use crate::logical_plan::LogicalPlan; +use crate::logical_plan::{Expr, LogicalPlan, LogicalPlanBuilder}; use crate::optimizer::optimizer::OptimizerRule; use crate::{error::Result, prelude::JoinType}; @@ -131,6 +131,7 @@ impl OptimizerRule for HashBuildProbeOrder { right, on, join_type, + join_constraint, schema, } => { let left = self.optimize(left, execution_props)?; @@ -140,11 +141,9 @@ impl OptimizerRule for HashBuildProbeOrder { Ok(LogicalPlan::Join { left: Arc::new(right), right: Arc::new(left), - on: on - .iter() - .map(|(l, r)| (r.to_string(), l.to_string())) - .collect(), + on: on.iter().map(|(l, r)| (r.clone(), l.clone())).collect(), join_type: swap_join_type(*join_type), + join_constraint: *join_constraint, schema: schema.clone(), }) } else { @@ -154,6 +153,7 @@ impl OptimizerRule for HashBuildProbeOrder { right: Arc::new(right), on: on.clone(), join_type: *join_type, + join_constraint: *join_constraint, schema: schema.clone(), }) } @@ -166,12 +166,19 @@ impl OptimizerRule for HashBuildProbeOrder { let left = self.optimize(left, execution_props)?; let right = self.optimize(right, execution_props)?; if should_swap_join_order(&left, &right) { - // Swap left and right - Ok(LogicalPlan::CrossJoin { - left: Arc::new(right), - right: Arc::new(left), - schema: schema.clone(), - }) + let swapped = LogicalPlanBuilder::from(&right).cross_join(&left)?; + // wrap plan with projection to maintain column order + let left_cols = left + .schema() + .fields() + .iter() + .map(|f| Expr::Column(f.qualified_column())); + let right_cols = right + .schema() + .fields() + .iter() + .map(|f| Expr::Column(f.qualified_column())); + swapped.project(left_cols.chain(right_cols))?.build() } else { // Keep join as is Ok(LogicalPlan::CrossJoin { diff --git a/datafusion/src/optimizer/limit_push_down.rs b/datafusion/src/optimizer/limit_push_down.rs index e616869d7c4a..afd993710a5f 100644 --- a/datafusion/src/optimizer/limit_push_down.rs +++ b/datafusion/src/optimizer/limit_push_down.rs @@ -163,7 +163,7 @@ mod test { // Should push the limit down to table provider // When it has a select let expected = "Limit: 1000\ - \n Projection: #a\ + \n Projection: #test.a\ \n TableScan: test projection=None, limit=1000"; assert_optimized_plan_eq(&plan, expected); @@ -202,7 +202,7 @@ mod test { // Limit should *not* push down aggregate node let expected = "Limit: 1000\ - \n Aggregate: groupBy=[[#a]], aggr=[[MAX(#b)]]\ + \n Aggregate: groupBy=[[#test.a]], aggr=[[MAX(#test.b)]]\ \n TableScan: test projection=None"; assert_optimized_plan_eq(&plan, expected); @@ -244,7 +244,7 @@ mod test { // Limit should use deeper LIMIT 1000, but Limit 10 shouldn't push down aggregation let expected = "Limit: 10\ - \n Aggregate: groupBy=[[#a]], aggr=[[MAX(#b)]]\ + \n Aggregate: groupBy=[[#test.a]], aggr=[[MAX(#test.b)]]\ \n Limit: 1000\ \n TableScan: test projection=None, limit=1000"; diff --git a/datafusion/src/optimizer/projection_push_down.rs b/datafusion/src/optimizer/projection_push_down.rs index ad795f5f5dd5..2544d89d0492 100644 --- a/datafusion/src/optimizer/projection_push_down.rs +++ b/datafusion/src/optimizer/projection_push_down.rs @@ -20,11 +20,14 @@ use crate::error::Result; use crate::execution::context::ExecutionProps; -use crate::logical_plan::{DFField, DFSchema, DFSchemaRef, LogicalPlan, ToDFSchema}; +use crate::logical_plan::{ + build_join_schema, Column, DFField, DFSchema, DFSchemaRef, LogicalPlan, + LogicalPlanBuilder, ToDFSchema, +}; use crate::optimizer::optimizer::OptimizerRule; use crate::optimizer::utils; use crate::sql::utils::find_sort_exprs; -use arrow::datatypes::Schema; +use arrow::datatypes::{Field, Schema}; use arrow::error::Result as ArrowResult; use std::{collections::HashSet, sync::Arc}; use utils::optimize_explain; @@ -44,8 +47,8 @@ impl OptimizerRule for ProjectionPushDown { .schema() .fields() .iter() - .map(|f| f.name().clone()) - .collect::>(); + .map(|f| f.qualified_column()) + .collect::>(); optimize_plan(self, plan, &required_columns, false, execution_props) } @@ -62,8 +65,9 @@ impl ProjectionPushDown { } fn get_projected_schema( + table_name: Option<&String>, schema: &Schema, - required_columns: &HashSet, + required_columns: &HashSet, has_projection: bool, ) -> Result<(Vec, DFSchemaRef)> { // once we reach the table scan, we can use the accumulated set of column @@ -73,7 +77,8 @@ fn get_projected_schema( // e.g. when the column derives from an aggregation let mut projection: Vec = required_columns .iter() - .map(|name| schema.index_of(name)) + .filter(|c| c.relation.as_ref() == table_name) + .map(|c| schema.index_of(&c.name)) .filter_map(ArrowResult::ok) .collect(); @@ -98,8 +103,20 @@ fn get_projected_schema( // create the projected schema let mut projected_fields: Vec = Vec::with_capacity(projection.len()); - for i in &projection { - projected_fields.push(DFField::from(schema.fields()[*i].clone())); + match table_name { + Some(qualifer) => { + for i in &projection { + projected_fields.push(DFField::from_qualified( + qualifer, + schema.fields()[*i].clone(), + )); + } + } + None => { + for i in &projection { + projected_fields.push(DFField::from(schema.fields()[*i].clone())); + } + } } Ok((projection, projected_fields.to_dfschema_ref()?)) @@ -109,7 +126,7 @@ fn get_projected_schema( fn optimize_plan( optimizer: &ProjectionPushDown, plan: &LogicalPlan, - required_columns: &HashSet, // set of columns required up to this step + required_columns: &HashSet, // set of columns required up to this step has_projection: bool, execution_props: &ExecutionProps, ) -> Result { @@ -133,12 +150,12 @@ fn optimize_plan( .iter() .enumerate() .try_for_each(|(i, field)| { - if required_columns.contains(field.name()) { + if required_columns.contains(&field.qualified_column()) { new_expr.push(expr[i].clone()); new_fields.push(field.clone()); // gather the new set of required columns - utils::expr_to_column_names(&expr[i], &mut new_required_columns) + utils::expr_to_columns(&expr[i], &mut new_required_columns) } else { Ok(()) } @@ -167,31 +184,45 @@ fn optimize_plan( right, on, join_type, - schema, + join_constraint, + .. } => { for (l, r) in on { - new_required_columns.insert(l.to_owned()); - new_required_columns.insert(r.to_owned()); + new_required_columns.insert(l.clone()); + new_required_columns.insert(r.clone()); } - Ok(LogicalPlan::Join { - left: Arc::new(optimize_plan( - optimizer, - left, - &new_required_columns, - true, - execution_props, - )?), - right: Arc::new(optimize_plan( - optimizer, - right, - &new_required_columns, - true, - execution_props, - )?), + let optimized_left = Arc::new(optimize_plan( + optimizer, + left, + &new_required_columns, + true, + execution_props, + )?); + + let optimized_right = Arc::new(optimize_plan( + optimizer, + right, + &new_required_columns, + true, + execution_props, + )?); + + let schema = build_join_schema( + optimized_left.schema(), + optimized_right.schema(), + on, + join_type, + join_constraint, + )?; + + Ok(LogicalPlan::Join { + left: optimized_left, + right: optimized_right, join_type: *join_type, + join_constraint: *join_constraint, on: on.clone(), - schema: schema.clone(), + schema: DFSchemaRef::new(schema), }) } LogicalPlan::Window { @@ -205,11 +236,12 @@ fn optimize_plan( { window_expr.iter().try_for_each(|expr| { let name = &expr.name(schema)?; - if required_columns.contains(name) { + let column = Column::from_name(name.to_string()); + if required_columns.contains(&column) { new_window_expr.push(expr.clone()); - new_required_columns.insert(name.clone()); + new_required_columns.insert(column); // add to the new set of required columns - utils::expr_to_column_names(expr, &mut new_required_columns) + utils::expr_to_columns(expr, &mut new_required_columns) } else { Ok(()) } @@ -217,31 +249,20 @@ fn optimize_plan( } // for all the retained window expr, find their sort expressions if any, and retain these - utils::exprlist_to_column_names( + utils::exprlist_to_columns( &find_sort_exprs(&new_window_expr), &mut new_required_columns, )?; - let new_schema = DFSchema::new( - schema - .fields() - .iter() - .filter(|x| new_required_columns.contains(x.name())) - .cloned() - .collect(), - )?; - - Ok(LogicalPlan::Window { - window_expr: new_window_expr, - input: Arc::new(optimize_plan( - optimizer, - input, - &new_required_columns, - true, - execution_props, - )?), - schema: DFSchemaRef::new(new_schema), - }) + LogicalPlanBuilder::from(&optimize_plan( + optimizer, + input, + &new_required_columns, + true, + execution_props, + )?) + .window(new_window_expr)? + .build() } LogicalPlan::Aggregate { schema, @@ -254,19 +275,20 @@ fn optimize_plan( // * remove any aggregate expression that is not required // * construct the new set of required columns - utils::exprlist_to_column_names(group_expr, &mut new_required_columns)?; + utils::exprlist_to_columns(group_expr, &mut new_required_columns)?; // Gather all columns needed for expressions in this Aggregate let mut new_aggr_expr = Vec::new(); aggr_expr.iter().try_for_each(|expr| { let name = &expr.name(schema)?; + let column = Column::from_name(name.to_string()); - if required_columns.contains(name) { + if required_columns.contains(&column) { new_aggr_expr.push(expr.clone()); - new_required_columns.insert(name.clone()); + new_required_columns.insert(column); // add to the new set of required columns - utils::expr_to_column_names(expr, &mut new_required_columns) + utils::expr_to_columns(expr, &mut new_required_columns) } else { Ok(()) } @@ -276,7 +298,7 @@ fn optimize_plan( schema .fields() .iter() - .filter(|x| new_required_columns.contains(x.name())) + .filter(|x| new_required_columns.contains(&x.qualified_column())) .cloned() .collect(), )?; @@ -303,12 +325,15 @@ fn optimize_plan( limit, .. } => { - let (projection, projected_schema) = - get_projected_schema(&source.schema(), required_columns, has_projection)?; - + let (projection, projected_schema) = get_projected_schema( + Some(table_name), + &source.schema(), + required_columns, + has_projection, + )?; // return the table scan with projection Ok(LogicalPlan::TableScan { - table_name: table_name.to_string(), + table_name: table_name.clone(), source: source.clone(), projection: Some(projection), projected_schema, @@ -332,6 +357,48 @@ fn optimize_plan( execution_props, ) } + LogicalPlan::Union { + inputs, + schema, + alias, + } => { + // UNION inputs will reference the same column with different identifiers, so we need + // to populate new_required_columns by unqualified column name based on required fields + // from the resulting UNION output + let union_required_fields = schema + .fields() + .iter() + .filter(|f| new_required_columns.contains(&f.qualified_column())) + .map(|f| f.field()) + .collect::>(); + + let new_inputs = inputs + .iter() + .map(|input_plan| { + input_plan + .schema() + .fields() + .iter() + .filter(|f| union_required_fields.contains(f.field())) + .for_each(|f| { + new_required_columns.insert(f.qualified_column()); + }); + optimize_plan( + optimizer, + input_plan, + &new_required_columns, + has_projection, + execution_props, + ) + }) + .collect::>>()?; + + Ok(LogicalPlan::Union { + inputs: new_inputs, + schema: schema.clone(), + alias: alias.clone(), + }) + } // all other nodes: Add any additional columns used by // expressions in this node to the list of required columns LogicalPlan::Limit { .. } @@ -340,21 +407,20 @@ fn optimize_plan( | LogicalPlan::EmptyRelation { .. } | LogicalPlan::Sort { .. } | LogicalPlan::CreateExternalTable { .. } - | LogicalPlan::Union { .. } | LogicalPlan::CrossJoin { .. } | LogicalPlan::Extension { .. } => { let expr = plan.expressions(); // collect all required columns by this plan - utils::exprlist_to_column_names(&expr, &mut new_required_columns)?; + utils::exprlist_to_columns(&expr, &mut new_required_columns)?; // apply the optimization to all inputs of the plan let inputs = plan.inputs(); let new_inputs = inputs .iter() - .map(|plan| { + .map(|input_plan| { optimize_plan( optimizer, - plan, + input_plan, &new_required_columns, has_projection, execution_props, @@ -371,8 +437,7 @@ fn optimize_plan( mod tests { use super::*; - use crate::logical_plan::{col, lit}; - use crate::logical_plan::{max, min, Expr, LogicalPlanBuilder}; + use crate::logical_plan::{col, lit, max, min, Expr, JoinType, LogicalPlanBuilder}; use crate::test::*; use arrow::datatypes::DataType; @@ -384,7 +449,7 @@ mod tests { .aggregate(vec![], vec![max(col("b"))])? .build()?; - let expected = "Aggregate: groupBy=[[]], aggr=[[MAX(#b)]]\ + let expected = "Aggregate: groupBy=[[]], aggr=[[MAX(#test.b)]]\ \n TableScan: test projection=Some([1])"; assert_optimized_plan_eq(&plan, expected); @@ -400,7 +465,7 @@ mod tests { .aggregate(vec![col("c")], vec![max(col("b"))])? .build()?; - let expected = "Aggregate: groupBy=[[#c]], aggr=[[MAX(#b)]]\ + let expected = "Aggregate: groupBy=[[#test.c]], aggr=[[MAX(#test.b)]]\ \n TableScan: test projection=Some([1, 2])"; assert_optimized_plan_eq(&plan, expected); @@ -417,8 +482,8 @@ mod tests { .aggregate(vec![], vec![max(col("b"))])? .build()?; - let expected = "Aggregate: groupBy=[[]], aggr=[[MAX(#b)]]\ - \n Filter: #c\ + let expected = "Aggregate: groupBy=[[]], aggr=[[MAX(#test.b)]]\ + \n Filter: #test.c\ \n TableScan: test projection=Some([1, 2])"; assert_optimized_plan_eq(&plan, expected); @@ -426,6 +491,43 @@ mod tests { Ok(()) } + #[test] + fn join_schema_trim() -> Result<()> { + let table_scan = test_table_scan()?; + + let schema = Schema::new(vec![Field::new("c1", DataType::UInt32, false)]); + let table2_scan = + LogicalPlanBuilder::scan_empty(Some("test2"), &schema, None)?.build()?; + + let plan = LogicalPlanBuilder::from(&table_scan) + .join(&table2_scan, JoinType::Left, vec!["a"], vec!["c1"])? + .project(vec![col("a"), col("b"), col("c1")])? + .build()?; + + // make sure projections are pushed down to table scan + let expected = "Projection: #test.a, #test.b, #test2.c1\ + \n Join: #test.a = #test2.c1\ + \n TableScan: test projection=Some([0, 1])\ + \n TableScan: test2 projection=Some([0])"; + + let optimized_plan = optimize(&plan)?; + let formatted_plan = format!("{:?}", optimized_plan); + assert_eq!(formatted_plan, expected); + + // make sure schema for join node doesn't include c1 column + let optimized_join = optimized_plan.inputs()[0]; + assert_eq!( + **optimized_join.schema(), + DFSchema::new(vec![ + DFField::new(Some("test"), "a", DataType::UInt32, false), + DFField::new(Some("test"), "b", DataType::UInt32, false), + DFField::new(Some("test2"), "c1", DataType::UInt32, false), + ])?, + ); + + Ok(()) + } + #[test] fn cast() -> Result<()> { let table_scan = test_table_scan()?; @@ -437,7 +539,7 @@ mod tests { }])? .build()?; - let expected = "Projection: CAST(#c AS Float64)\ + let expected = "Projection: CAST(#test.c AS Float64)\ \n TableScan: test projection=Some([2])"; assert_optimized_plan_eq(&projection, expected); @@ -457,7 +559,7 @@ mod tests { assert_fields_eq(&plan, vec!["a", "b"]); - let expected = "Projection: #a, #b\ + let expected = "Projection: #test.a, #test.b\ \n TableScan: test projection=Some([0, 1])"; assert_optimized_plan_eq(&plan, expected); @@ -479,7 +581,7 @@ mod tests { assert_fields_eq(&plan, vec!["c", "a"]); let expected = "Limit: 5\ - \n Projection: #c, #a\ + \n Projection: #test.c, #test.a\ \n TableScan: test projection=Some([0, 2])"; assert_optimized_plan_eq(&plan, expected); @@ -523,12 +625,12 @@ mod tests { .aggregate(vec![col("c")], vec![max(col("a"))])? .build()?; - assert_fields_eq(&plan, vec!["c", "MAX(a)"]); + assert_fields_eq(&plan, vec!["c", "MAX(test.a)"]); let expected = "\ - Aggregate: groupBy=[[#c]], aggr=[[MAX(#a)]]\ - \n Filter: #c Gt Int32(1)\ - \n Projection: #c, #a\ + Aggregate: groupBy=[[#test.c]], aggr=[[MAX(#test.a)]]\ + \n Filter: #test.c Gt Int32(1)\ + \n Projection: #test.c, #test.a\ \n TableScan: test projection=Some([0, 2])"; assert_optimized_plan_eq(&plan, expected); @@ -591,15 +693,15 @@ mod tests { let plan = LogicalPlanBuilder::from(&table_scan) .aggregate(vec![col("a"), col("c")], vec![max(col("b")), min(col("b"))])? .filter(col("c").gt(lit(1)))? - .project(vec![col("c"), col("a"), col("MAX(b)")])? + .project(vec![col("c"), col("a"), col("MAX(test.b)")])? .build()?; - assert_fields_eq(&plan, vec!["c", "a", "MAX(b)"]); + assert_fields_eq(&plan, vec!["c", "a", "MAX(test.b)"]); let expected = "\ - Projection: #c, #a, #MAX(b)\ - \n Filter: #c Gt Int32(1)\ - \n Aggregate: groupBy=[[#a, #c]], aggr=[[MAX(#b)]]\ + Projection: #test.c, #test.a, #MAX(test.b)\ + \n Filter: #test.c Gt Int32(1)\ + \n Aggregate: groupBy=[[#test.a, #test.c]], aggr=[[MAX(#test.b)]]\ \n TableScan: test projection=Some([0, 1, 2])"; assert_optimized_plan_eq(&plan, expected); diff --git a/datafusion/src/optimizer/simplify_expressions.rs b/datafusion/src/optimizer/simplify_expressions.rs index 9ad7a94d8bfe..4253d2fd4f00 100644 --- a/datafusion/src/optimizer/simplify_expressions.rs +++ b/datafusion/src/optimizer/simplify_expressions.rs @@ -510,8 +510,8 @@ mod tests { assert_optimized_plan_eq( &plan, "\ - Filter: #b Gt Int32(1)\ - \n Projection: #a\ + Filter: #test.b Gt Int32(1)\ + \n Projection: #test.a\ \n TableScan: test projection=None", ); Ok(()) @@ -532,8 +532,8 @@ mod tests { assert_optimized_plan_eq( &plan, "\ - Filter: #a Gt Int32(5) And #b Lt Int32(6)\ - \n Projection: #a\ + Filter: #test.a Gt Int32(5) And #test.b Lt Int32(6)\ + \n Projection: #test.a\ \n TableScan: test projection=None", ); Ok(()) diff --git a/datafusion/src/optimizer/utils.rs b/datafusion/src/optimizer/utils.rs index 014ec74a0bfe..76f44b84657c 100644 --- a/datafusion/src/optimizer/utils.rs +++ b/datafusion/src/optimizer/utils.rs @@ -24,8 +24,8 @@ use arrow::datatypes::Schema; use super::optimizer::OptimizerRule; use crate::execution::context::ExecutionProps; use crate::logical_plan::{ - Expr, LogicalPlan, Operator, Partitioning, PlanType, Recursion, StringifiedPlan, - ToDFSchema, + build_join_schema, Column, DFSchemaRef, Expr, LogicalPlan, LogicalPlanBuilder, + Operator, Partitioning, PlanType, Recursion, StringifiedPlan, ToDFSchema, }; use crate::prelude::lit; use crate::scalar::ScalarValue; @@ -39,14 +39,11 @@ const CASE_ELSE_MARKER: &str = "__DATAFUSION_CASE_ELSE__"; const WINDOW_PARTITION_MARKER: &str = "__DATAFUSION_WINDOW_PARTITION__"; const WINDOW_SORT_MARKER: &str = "__DATAFUSION_WINDOW_SORT__"; -/// Recursively walk a list of expression trees, collecting the unique set of column -/// names referenced in the expression -pub fn exprlist_to_column_names( - expr: &[Expr], - accum: &mut HashSet, -) -> Result<()> { +/// Recursively walk a list of expression trees, collecting the unique set of columns +/// referenced in the expression +pub fn exprlist_to_columns(expr: &[Expr], accum: &mut HashSet) -> Result<()> { for e in expr { - expr_to_column_names(e, accum)?; + expr_to_columns(e, accum)?; } Ok(()) } @@ -54,17 +51,17 @@ pub fn exprlist_to_column_names( /// Recursively walk an expression tree, collecting the unique set of column names /// referenced in the expression struct ColumnNameVisitor<'a> { - accum: &'a mut HashSet, + accum: &'a mut HashSet, } impl ExpressionVisitor for ColumnNameVisitor<'_> { fn pre_visit(self, expr: &Expr) -> Result> { match expr { - Expr::Column(name) => { - self.accum.insert(name.clone()); + Expr::Column(qc) => { + self.accum.insert(qc.clone()); } Expr::ScalarVariable(var_names) => { - self.accum.insert(var_names.join(".")); + self.accum.insert(Column::from_name(var_names.join("."))); } Expr::Alias(_, _) => {} Expr::Literal(_) => {} @@ -90,9 +87,9 @@ impl ExpressionVisitor for ColumnNameVisitor<'_> { } } -/// Recursively walk an expression tree, collecting the unique set of column names +/// Recursively walk an expression tree, collecting the unique set of columns /// referenced in the expression -pub fn expr_to_column_names(expr: &Expr, accum: &mut HashSet) -> Result<()> { +pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet) -> Result<()> { expr.accept(ColumnNameVisitor { accum })?; Ok(()) } @@ -214,21 +211,31 @@ pub fn from_plan( }), LogicalPlan::Join { join_type, + join_constraint, on, - schema, .. - } => Ok(LogicalPlan::Join { - left: Arc::new(inputs[0].clone()), - right: Arc::new(inputs[1].clone()), - join_type: *join_type, - on: on.clone(), - schema: schema.clone(), - }), - LogicalPlan::CrossJoin { schema, .. } => Ok(LogicalPlan::CrossJoin { - left: Arc::new(inputs[0].clone()), - right: Arc::new(inputs[1].clone()), - schema: schema.clone(), - }), + } => { + let schema = build_join_schema( + inputs[0].schema(), + inputs[1].schema(), + on, + join_type, + join_constraint, + )?; + Ok(LogicalPlan::Join { + left: Arc::new(inputs[0].clone()), + right: Arc::new(inputs[1].clone()), + join_type: *join_type, + join_constraint: *join_constraint, + on: on.clone(), + schema: DFSchemaRef::new(schema), + }) + } + LogicalPlan::CrossJoin { .. } => { + let left = &inputs[0]; + let right = &inputs[1]; + LogicalPlanBuilder::from(left).cross_join(right)?.build() + } LogicalPlan::Limit { n, .. } => Ok(LogicalPlan::Limit { n: *n, input: Arc::new(inputs[0].clone()), @@ -493,15 +500,15 @@ mod tests { #[test] fn test_collect_expr() -> Result<()> { - let mut accum: HashSet = HashSet::new(); - expr_to_column_names( + let mut accum: HashSet = HashSet::new(); + expr_to_columns( &Expr::Cast { expr: Box::new(col("a")), data_type: DataType::Float64, }, &mut accum, )?; - expr_to_column_names( + expr_to_columns( &Expr::Cast { expr: Box::new(col("a")), data_type: DataType::Float64, @@ -509,7 +516,7 @@ mod tests { &mut accum, )?; assert_eq!(1, accum.len()); - assert!(accum.contains("a")); + assert!(accum.contains(&Column::from_name("a".to_string()))); Ok(()) } diff --git a/datafusion/src/physical_optimizer/pruning.rs b/datafusion/src/physical_optimizer/pruning.rs index 9e8d9fa77858..5585c4d08140 100644 --- a/datafusion/src/physical_optimizer/pruning.rs +++ b/datafusion/src/physical_optimizer/pruning.rs @@ -28,6 +28,7 @@ //! https://github.com/apache/arrow-datafusion/issues/363 it will //! be genericized. +use std::convert::TryFrom; use std::{collections::HashSet, sync::Arc}; use arrow::{ @@ -39,7 +40,7 @@ use arrow::{ use crate::{ error::{DataFusionError, Result}, execution::context::ExecutionContextState, - logical_plan::{Expr, Operator}, + logical_plan::{Column, DFSchema, Expr, Operator}, optimizer::utils, physical_plan::{planner::DefaultPhysicalPlanner, ColumnarValue, PhysicalExpr}, }; @@ -65,11 +66,11 @@ use crate::{ pub trait PruningStatistics { /// return the minimum values for the named column, if known. /// Note: the returned array must contain `num_containers()` rows - fn min_values(&self, column: &str) -> Option; + fn min_values(&self, column: &Column) -> Option; /// return the maximum values for the named column, if known. /// Note: the returned array must contain `num_containers()` rows. - fn max_values(&self, column: &str) -> Option; + fn max_values(&self, column: &Column) -> Option; /// return the number of containers (e.g. row groups) being /// pruned with these statistics @@ -120,9 +121,11 @@ impl PruningPredicate { .map(|(_, _, f)| f.clone()) .collect::>(); let stat_schema = Schema::new(stat_fields); + let stat_dfschema = DFSchema::try_from(stat_schema.clone())?; let execution_context_state = ExecutionContextState::new(); let predicate_expr = DefaultPhysicalPlanner::default().create_physical_expr( &logical_predicate_expr, + &stat_dfschema, &stat_schema, &execution_context_state, )?; @@ -196,11 +199,11 @@ impl PruningPredicate { #[derive(Debug, Default, Clone)] struct RequiredStatColumns { /// The statistics required to evaluate this predicate: - /// * The column name in the input schema + /// * The unqualified column in the input schema /// * Statistics type (e.g. Min or Max) /// * The field the statistics value should be placed in for /// pruning predicate evaluation - columns: Vec<(String, StatisticsType, Field)>, + columns: Vec<(Column, StatisticsType, Field)>, } impl RequiredStatColumns { @@ -210,22 +213,22 @@ impl RequiredStatColumns { /// Retur an iterator over items in columns (see doc on /// `self.columns` for details) - fn iter(&self) -> impl Iterator { + fn iter(&self) -> impl Iterator { self.columns.iter() } fn is_stat_column_missing( &self, - column_name: &str, + column: &Column, statistics_type: StatisticsType, ) -> bool { !self .columns .iter() - .any(|(c, t, _f)| c == column_name && t == &statistics_type) + .any(|(c, t, _f)| c == column && t == &statistics_type) } - /// Rewrites column_expr so that all appearances of column_name + /// Rewrites column_expr so that all appearances of column /// are replaced with a reference to either the min or max /// statistics column, while keeping track that a reference to the statistics /// column is required @@ -235,49 +238,53 @@ impl RequiredStatColumns { /// 5` with the approprate entry noted in self.columns fn stat_column_expr( &mut self, - column_name: &str, + column: &Column, column_expr: &Expr, field: &Field, stat_type: StatisticsType, suffix: &str, ) -> Result { - let stat_column_name = format!("{}_{}", column_name, suffix); + let stat_column = Column { + relation: column.relation.clone(), + name: format!("{}_{}", column.flat_name(), suffix), + }; + let stat_field = Field::new( - stat_column_name.as_str(), + stat_column.flat_name().as_str(), field.data_type().clone(), field.is_nullable(), ); - if self.is_stat_column_missing(column_name, stat_type) { + + if self.is_stat_column_missing(column, stat_type) { // only add statistics column if not previously added - self.columns - .push((column_name.to_string(), stat_type, stat_field)); + self.columns.push((column.clone(), stat_type, stat_field)); } - rewrite_column_expr(column_expr, column_name, stat_column_name.as_str()) + rewrite_column_expr(column_expr, column, &stat_column) } /// rewrite col --> col_min fn min_column_expr( &mut self, - column_name: &str, + column: &Column, column_expr: &Expr, field: &Field, ) -> Result { - self.stat_column_expr(column_name, column_expr, field, StatisticsType::Min, "min") + self.stat_column_expr(column, column_expr, field, StatisticsType::Min, "min") } /// rewrite col --> col_max fn max_column_expr( &mut self, - column_name: &str, + column: &Column, column_expr: &Expr, field: &Field, ) -> Result { - self.stat_column_expr(column_name, column_expr, field, StatisticsType::Max, "max") + self.stat_column_expr(column, column_expr, field, StatisticsType::Max, "max") } } -impl From> for RequiredStatColumns { - fn from(columns: Vec<(String, StatisticsType, Field)>) -> Self { +impl From> for RequiredStatColumns { + fn from(columns: Vec<(Column, StatisticsType, Field)>) -> Self { Self { columns } } } @@ -314,14 +321,14 @@ fn build_statistics_record_batch( let mut fields = Vec::::new(); let mut arrays = Vec::::new(); // For each needed statistics column: - for (column_name, statistics_type, stat_field) in required_columns.iter() { + for (column, statistics_type, stat_field) in required_columns.iter() { let data_type = stat_field.data_type(); let num_containers = statistics.num_containers(); let array = match statistics_type { - StatisticsType::Min => statistics.min_values(column_name), - StatisticsType::Max => statistics.max_values(column_name), + StatisticsType::Min => statistics.min_values(column), + StatisticsType::Max => statistics.max_values(column), }; let array = array.unwrap_or_else(|| new_null_array(data_type, num_containers)); @@ -347,7 +354,7 @@ fn build_statistics_record_batch( } struct PruningExpressionBuilder<'a> { - column_name: String, + column: Column, column_expr: &'a Expr, scalar_expr: &'a Expr, field: &'a Field, @@ -363,11 +370,11 @@ impl<'a> PruningExpressionBuilder<'a> { required_columns: &'a mut RequiredStatColumns, ) -> Result { // find column name; input could be a more complicated expression - let mut left_columns = HashSet::::new(); - utils::expr_to_column_names(left, &mut left_columns)?; - let mut right_columns = HashSet::::new(); - utils::expr_to_column_names(right, &mut right_columns)?; - let (column_expr, scalar_expr, column_names, reverse_operator) = + let mut left_columns = HashSet::::new(); + utils::expr_to_columns(left, &mut left_columns)?; + let mut right_columns = HashSet::::new(); + utils::expr_to_columns(right, &mut right_columns)?; + let (column_expr, scalar_expr, columns, reverse_operator) = match (left_columns.len(), right_columns.len()) { (1, 0) => (left, right, left_columns, false), (0, 1) => (right, left, right_columns, true), @@ -379,8 +386,8 @@ impl<'a> PruningExpressionBuilder<'a> { )); } }; - let column_name = column_names.iter().next().unwrap().clone(); - let field = match schema.column_with_name(&column_name) { + let column = columns.iter().next().unwrap().clone(); + let field = match schema.column_with_name(&column.flat_name()) { Some((_, f)) => f, _ => { return Err(DataFusionError::Plan( @@ -390,7 +397,7 @@ impl<'a> PruningExpressionBuilder<'a> { }; Ok(Self { - column_name, + column, column_expr, scalar_expr, field, @@ -418,63 +425,56 @@ impl<'a> PruningExpressionBuilder<'a> { } fn min_column_expr(&mut self) -> Result { - self.required_columns.min_column_expr( - &self.column_name, - self.column_expr, - self.field, - ) + self.required_columns + .min_column_expr(&self.column, self.column_expr, self.field) } fn max_column_expr(&mut self) -> Result { - self.required_columns.max_column_expr( - &self.column_name, - self.column_expr, - self.field, - ) + self.required_columns + .max_column_expr(&self.column, self.column_expr, self.field) } } /// replaces a column with an old name with a new name in an expression fn rewrite_column_expr( expr: &Expr, - column_old_name: &str, - column_new_name: &str, + column_old: &Column, + column_new: &Column, ) -> Result { let expressions = utils::expr_sub_expressions(expr)?; let expressions = expressions .iter() - .map(|e| rewrite_column_expr(e, column_old_name, column_new_name)) + .map(|e| rewrite_column_expr(e, column_old, column_new)) .collect::>>()?; - if let Expr::Column(name) = expr { - if name == column_old_name { - return Ok(Expr::Column(column_new_name.to_string())); + if let Expr::Column(c) = expr { + if c == column_old { + return Ok(Expr::Column(column_new.clone())); } } utils::rewrite_expression(expr, &expressions) } -/// Given a column reference to `column_name`, returns a pruning +/// Given a column reference to `column`, returns a pruning /// expression in terms of the min and max that will evaluate to true /// if the column may contain values, and false if definitely does not /// contain values fn build_single_column_expr( - column_name: &str, + column: &Column, schema: &Schema, required_columns: &mut RequiredStatColumns, is_not: bool, // if true, treat as !col ) -> Option { - use crate::logical_plan; - let field = schema.field_with_name(column_name).ok()?; + let field = schema.field_with_name(&column.name).ok()?; if matches!(field.data_type(), &DataType::Boolean) { - let col_ref = logical_plan::col(column_name); + let col_ref = Expr::Column(column.clone()); let min = required_columns - .min_column_expr(column_name, &col_ref, field) + .min_column_expr(column, &col_ref, field) .ok()?; let max = required_columns - .max_column_expr(column_name, &col_ref, field) + .max_column_expr(column, &col_ref, field) .ok()?; // remember -- we want an expression that is: @@ -514,15 +514,15 @@ fn build_predicate_expression( // predicate expression can only be a binary expression let (left, op, right) = match expr { Expr::BinaryExpr { left, op, right } => (left, *op, right), - Expr::Column(name) => { - let expr = build_single_column_expr(name, schema, required_columns, false) + Expr::Column(col) => { + let expr = build_single_column_expr(col, schema, required_columns, false) .unwrap_or(unhandled); return Ok(expr); } // match !col (don't do so recursively) Expr::Not(input) => { - if let Expr::Column(name) = input.as_ref() { - let expr = build_single_column_expr(name, schema, required_columns, true) + if let Expr::Column(col) = input.as_ref() { + let expr = build_single_column_expr(col, schema, required_columns, true) .unwrap_or(unhandled); return Ok(expr); } else { @@ -674,7 +674,7 @@ mod tests { #[derive(Debug, Default)] struct TestStatistics { // key: column name - stats: HashMap, + stats: HashMap, } impl TestStatistics { @@ -687,20 +687,21 @@ mod tests { name: impl Into, container_stats: ContainerStats, ) -> Self { - self.stats.insert(name.into(), container_stats); + self.stats + .insert(Column::from_name(name.into()), container_stats); self } } impl PruningStatistics for TestStatistics { - fn min_values(&self, column: &str) -> Option { + fn min_values(&self, column: &Column) -> Option { self.stats .get(column) .map(|container_stats| container_stats.min()) .unwrap_or(None) } - fn max_values(&self, column: &str) -> Option { + fn max_values(&self, column: &Column) -> Option { self.stats .get(column) .map(|container_stats| container_stats.max()) @@ -724,11 +725,11 @@ mod tests { } impl PruningStatistics for OneContainerStats { - fn min_values(&self, _column: &str) -> Option { + fn min_values(&self, _column: &Column) -> Option { self.min_values.clone() } - fn max_values(&self, _column: &str) -> Option { + fn max_values(&self, _column: &Column) -> Option { self.max_values.clone() } @@ -743,25 +744,25 @@ mod tests { let required_columns = RequiredStatColumns::from(vec![ // min of original column s1, named s1_min ( - "s1".to_string(), + "s1".into(), StatisticsType::Min, Field::new("s1_min", DataType::Int32, true), ), // max of original column s2, named s2_max ( - "s2".to_string(), + "s2".into(), StatisticsType::Max, Field::new("s2_max", DataType::Int32, true), ), // max of original column s3, named s3_max ( - "s3".to_string(), + "s3".into(), StatisticsType::Max, Field::new("s3_max", DataType::Utf8, true), ), // min of original column s3, named s3_min ( - "s3".to_string(), + "s3".into(), StatisticsType::Min, Field::new("s3_min", DataType::Utf8, true), ), @@ -813,7 +814,7 @@ mod tests { // Request a record batch with of s1_min as a timestamp let required_columns = RequiredStatColumns::from(vec![( - "s1".to_string(), + "s3".into(), StatisticsType::Min, Field::new( "s1_min", @@ -867,7 +868,7 @@ mod tests { // Request a record batch with of s1_min as a timestamp let required_columns = RequiredStatColumns::from(vec![( - "s1".to_string(), + "s3".into(), StatisticsType::Min, Field::new("s1_min", DataType::Utf8, true), )]); @@ -896,7 +897,7 @@ mod tests { fn test_build_statistics_inconsistent_length() { // return an inconsistent length to the actual statistics arrays let required_columns = RequiredStatColumns::from(vec![( - "s1".to_string(), + "s1".into(), StatisticsType::Min, Field::new("s1_min", DataType::Int64, true), )]); @@ -1143,18 +1144,18 @@ mod tests { let c1_min_field = Field::new("c1_min", DataType::Int32, false); assert_eq!( required_columns.columns[0], - ("c1".to_owned(), StatisticsType::Min, c1_min_field) + ("c1".into(), StatisticsType::Min, c1_min_field) ); // c2 = 2 should add c2_min and c2_max let c2_min_field = Field::new("c2_min", DataType::Int32, false); assert_eq!( required_columns.columns[1], - ("c2".to_owned(), StatisticsType::Min, c2_min_field) + ("c2".into(), StatisticsType::Min, c2_min_field) ); let c2_max_field = Field::new("c2_max", DataType::Int32, false); assert_eq!( required_columns.columns[2], - ("c2".to_owned(), StatisticsType::Max, c2_max_field) + ("c2".into(), StatisticsType::Max, c2_max_field) ); // c2 = 3 shouldn't add any new statistics fields assert_eq!(required_columns.columns.len(), 3); diff --git a/datafusion/src/physical_plan/expressions/binary.rs b/datafusion/src/physical_plan/expressions/binary.rs index 5ed0c74463a6..a69b776e74bb 100644 --- a/datafusion/src/physical_plan/expressions/binary.rs +++ b/datafusion/src/physical_plan/expressions/binary.rs @@ -611,11 +611,12 @@ mod tests { ]); let a = Int32Array::from(vec![1, 2, 3, 4, 5]); let b = Int32Array::from(vec![1, 2, 4, 8, 16]); + + // expression: "a < b" + let lt = binary_simple(col("a", &schema)?, Operator::Lt, col("b", &schema)?); let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)])?; - // expression: "a < b" - let lt = binary_simple(col("a"), Operator::Lt, col("b")); let result = lt.evaluate(&batch)?.into_array(batch.num_rows()); assert_eq!(result.len(), 5); @@ -639,16 +640,17 @@ mod tests { ]); let a = Int32Array::from(vec![2, 4, 6, 8, 10]); let b = Int32Array::from(vec![2, 5, 4, 8, 8]); - let batch = - RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)])?; // expression: "a < b OR a == b" let expr = binary_simple( - binary_simple(col("a"), Operator::Lt, col("b")), + binary_simple(col("a", &schema)?, Operator::Lt, col("b", &schema)?), Operator::Or, - binary_simple(col("a"), Operator::Eq, col("b")), + binary_simple(col("a", &schema)?, Operator::Eq, col("b", &schema)?), ); - assert_eq!("a < b OR a = b", format!("{}", expr)); + let batch = + RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)])?; + + assert_eq!("a@0 < b@1 OR a@0 = b@1", format!("{}", expr)); let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); assert_eq!(result.len(), 5); @@ -680,14 +682,15 @@ mod tests { ]); let a = $A_ARRAY::from($A_VEC); let b = $B_ARRAY::from($B_VEC); + + // verify that we can construct the expression + let expression = + binary(col("a", &schema)?, $OP, col("b", &schema)?, &schema)?; let batch = RecordBatch::try_new( Arc::new(schema.clone()), vec![Arc::new(a), Arc::new(b)], )?; - // verify that we can construct the expression - let expression = binary(col("a"), $OP, col("b"), &schema)?; - // verify that the expression's type is correct assert_eq!(expression.data_type(&schema)?, $C_TYPE); @@ -863,7 +866,12 @@ mod tests { // Test 1: dict = str // verify that we can construct the expression - let expression = binary(col("dict"), Operator::Eq, col("str"), &schema)?; + let expression = binary( + col("dict", &schema)?, + Operator::Eq, + col("str", &schema)?, + &schema, + )?; assert_eq!(expression.data_type(&schema)?, DataType::Boolean); // evaluate and verify the result type matched @@ -877,7 +885,12 @@ mod tests { // str = dict // verify that we can construct the expression - let expression = binary(col("str"), Operator::Eq, col("dict"), &schema)?; + let expression = binary( + col("str", &schema)?, + Operator::Eq, + col("dict", &schema)?, + &schema, + )?; assert_eq!(expression.data_type(&schema)?, DataType::Boolean); // evaluate and verify the result type matched @@ -989,7 +1002,7 @@ mod tests { op: Operator, expected: PrimitiveArray, ) -> Result<()> { - let arithmetic_op = binary_simple(col("a"), op, col("b")); + let arithmetic_op = binary_simple(col("a", &schema)?, op, col("b", &schema)?); let batch = RecordBatch::try_new(schema, data)?; let result = arithmetic_op.evaluate(&batch)?.into_array(batch.num_rows()); @@ -1004,7 +1017,7 @@ mod tests { op: Operator, expected: BooleanArray, ) -> Result<()> { - let arithmetic_op = binary_simple(col("a"), op, col("b")); + let arithmetic_op = binary_simple(col("a", &schema)?, op, col("b", &schema)?); let data: Vec = vec![Arc::new(left), Arc::new(right)]; let batch = RecordBatch::try_new(schema, data)?; let result = arithmetic_op.evaluate(&batch)?.into_array(batch.num_rows()); diff --git a/datafusion/src/physical_plan/expressions/case.rs b/datafusion/src/physical_plan/expressions/case.rs index f89ea8d1e296..a46522d69deb 100644 --- a/datafusion/src/physical_plan/expressions/case.rs +++ b/datafusion/src/physical_plan/expressions/case.rs @@ -451,6 +451,7 @@ mod tests { #[test] fn case_with_expr() -> Result<()> { let batch = case_test_batch()?; + let schema = batch.schema(); // CASE a WHEN 'foo' THEN 123 WHEN 'bar' THEN 456 END let when1 = lit(ScalarValue::Utf8(Some("foo".to_string()))); @@ -458,7 +459,11 @@ mod tests { let when2 = lit(ScalarValue::Utf8(Some("bar".to_string()))); let then2 = lit(ScalarValue::Int32(Some(456))); - let expr = case(Some(col("a")), &[(when1, then1), (when2, then2)], None)?; + let expr = case( + Some(col("a", &schema)?), + &[(when1, then1), (when2, then2)], + None, + )?; let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); let result = result .as_any() @@ -475,6 +480,7 @@ mod tests { #[test] fn case_with_expr_else() -> Result<()> { let batch = case_test_batch()?; + let schema = batch.schema(); // CASE a WHEN 'foo' THEN 123 WHEN 'bar' THEN 456 ELSE 999 END let when1 = lit(ScalarValue::Utf8(Some("foo".to_string()))); @@ -484,7 +490,7 @@ mod tests { let else_value = lit(ScalarValue::Int32(Some(999))); let expr = case( - Some(col("a")), + Some(col("a", &schema)?), &[(when1, then1), (when2, then2)], Some(else_value), )?; @@ -505,17 +511,18 @@ mod tests { #[test] fn case_without_expr() -> Result<()> { let batch = case_test_batch()?; + let schema = batch.schema(); // CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN 456 END let when1 = binary( - col("a"), + col("a", &schema)?, Operator::Eq, lit(ScalarValue::Utf8(Some("foo".to_string()))), &batch.schema(), )?; let then1 = lit(ScalarValue::Int32(Some(123))); let when2 = binary( - col("a"), + col("a", &schema)?, Operator::Eq, lit(ScalarValue::Utf8(Some("bar".to_string()))), &batch.schema(), @@ -539,17 +546,18 @@ mod tests { #[test] fn case_without_expr_else() -> Result<()> { let batch = case_test_batch()?; + let schema = batch.schema(); // CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN 456 ELSE 999 END let when1 = binary( - col("a"), + col("a", &schema)?, Operator::Eq, lit(ScalarValue::Utf8(Some("foo".to_string()))), &batch.schema(), )?; let then1 = lit(ScalarValue::Int32(Some(123))); let when2 = binary( - col("a"), + col("a", &schema)?, Operator::Eq, lit(ScalarValue::Utf8(Some("bar".to_string()))), &batch.schema(), diff --git a/datafusion/src/physical_plan/expressions/cast.rs b/datafusion/src/physical_plan/expressions/cast.rs index 558b1e5d7e8b..bba125ebdcc9 100644 --- a/datafusion/src/physical_plan/expressions/cast.rs +++ b/datafusion/src/physical_plan/expressions/cast.rs @@ -180,10 +180,14 @@ mod tests { RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; // verify that we can construct the expression - let expression = cast_with_options(col("a"), &schema, $TYPE, $CAST_OPTIONS)?; + let expression = + cast_with_options(col("a", &schema)?, &schema, $TYPE, $CAST_OPTIONS)?; // verify that its display is correct - assert_eq!(format!("CAST(a AS {:?})", $TYPE), format!("{}", expression)); + assert_eq!( + format!("CAST(a@0 AS {:?})", $TYPE), + format!("{}", expression) + ); // verify that the expression's type is correct assert_eq!(expression.data_type(&schema)?, $TYPE); @@ -272,7 +276,7 @@ mod tests { // Ensure a useful error happens at plan time if invalid casts are used let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); - let result = cast(col("a"), &schema, DataType::LargeBinary); + let result = cast(col("a", &schema).unwrap(), &schema, DataType::LargeBinary); result.expect_err("expected Invalid CAST"); } @@ -283,7 +287,7 @@ mod tests { let a = StringArray::from(vec!["9.1"]); let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; let expression = cast_with_options( - col("a"), + col("a", &schema)?, &schema, DataType::Int32, DEFAULT_DATAFUSION_CAST_OPTIONS, diff --git a/datafusion/src/physical_plan/expressions/column.rs b/datafusion/src/physical_plan/expressions/column.rs index 7e0304e51fe7..d6eafbb05384 100644 --- a/datafusion/src/physical_plan/expressions/column.rs +++ b/datafusion/src/physical_plan/expressions/column.rs @@ -28,28 +28,40 @@ use crate::error::Result; use crate::physical_plan::{ColumnarValue, PhysicalExpr}; /// Represents the column at a given index in a RecordBatch -#[derive(Debug)] +#[derive(Debug, Hash, PartialEq, Eq, Clone)] pub struct Column { name: String, + index: usize, } impl Column { /// Create a new column expression - pub fn new(name: &str) -> Self { + pub fn new(name: &str, index: usize) -> Self { Self { name: name.to_owned(), + index, } } + /// Create a new column expression based on column name and schema + pub fn new_with_schema(name: &str, schema: &Schema) -> Result { + Ok(Column::new(name, schema.index_of(name)?)) + } + /// Get the column name pub fn name(&self) -> &str { &self.name } + + /// Get the column index + pub fn index(&self) -> usize { + self.index + } } impl std::fmt::Display for Column { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "{}", self.name) + write!(f, "{}@{}", self.name, self.index) } } @@ -61,26 +73,21 @@ impl PhysicalExpr for Column { /// Get the data type of this expression, given the schema of the input fn data_type(&self, input_schema: &Schema) -> Result { - Ok(input_schema - .field_with_name(&self.name)? - .data_type() - .clone()) + Ok(input_schema.field(self.index).data_type().clone()) } /// Decide whehter this expression is nullable, given the schema of the input fn nullable(&self, input_schema: &Schema) -> Result { - Ok(input_schema.field_with_name(&self.name)?.is_nullable()) + Ok(input_schema.field(self.index).is_nullable()) } /// Evaluate the expression fn evaluate(&self, batch: &RecordBatch) -> Result { - Ok(ColumnarValue::Array( - batch.column(batch.schema().index_of(&self.name)?).clone(), - )) + Ok(ColumnarValue::Array(batch.column(self.index).clone())) } } /// Create a column expression -pub fn col(name: &str) -> Arc { - Arc::new(Column::new(name)) +pub fn col(name: &str, schema: &Schema) -> Result> { + Ok(Arc::new(Column::new_with_schema(name, schema)?)) } diff --git a/datafusion/src/physical_plan/expressions/in_list.rs b/datafusion/src/physical_plan/expressions/in_list.rs index 41f111006ea2..38b2b9d45b9b 100644 --- a/datafusion/src/physical_plan/expressions/in_list.rs +++ b/datafusion/src/physical_plan/expressions/in_list.rs @@ -296,8 +296,8 @@ mod tests { // applies the in_list expr to an input batch and list macro_rules! in_list { - ($BATCH:expr, $LIST:expr, $NEGATED:expr, $EXPECTED:expr) => {{ - let expr = in_list(col("a"), $LIST, $NEGATED).unwrap(); + ($BATCH:expr, $LIST:expr, $NEGATED:expr, $EXPECTED:expr, $COL:expr) => {{ + let expr = in_list($COL, $LIST, $NEGATED).unwrap(); let result = expr.evaluate(&$BATCH)?.into_array($BATCH.num_rows()); let result = result .as_any() @@ -312,6 +312,7 @@ mod tests { fn in_list_utf8() -> Result<()> { let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]); let a = StringArray::from(vec![Some("a"), Some("d"), None]); + let col_a = col("a", &schema)?; let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?; // expression: "a in ("a", "b")" @@ -319,14 +320,26 @@ mod tests { lit(ScalarValue::Utf8(Some("a".to_string()))), lit(ScalarValue::Utf8(Some("b".to_string()))), ]; - in_list!(batch, list, &false, vec![Some(true), Some(false), None]); + in_list!( + batch, + list, + &false, + vec![Some(true), Some(false), None], + col_a.clone() + ); // expression: "a not in ("a", "b")" let list = vec![ lit(ScalarValue::Utf8(Some("a".to_string()))), lit(ScalarValue::Utf8(Some("b".to_string()))), ]; - in_list!(batch, list, &true, vec![Some(false), Some(true), None]); + in_list!( + batch, + list, + &true, + vec![Some(false), Some(true), None], + col_a.clone() + ); // expression: "a not in ("a", "b")" let list = vec![ @@ -334,7 +347,13 @@ mod tests { lit(ScalarValue::Utf8(Some("b".to_string()))), lit(ScalarValue::Utf8(None)), ]; - in_list!(batch, list, &false, vec![Some(true), None, None]); + in_list!( + batch, + list, + &false, + vec![Some(true), None, None], + col_a.clone() + ); // expression: "a not in ("a", "b")" let list = vec![ @@ -342,7 +361,13 @@ mod tests { lit(ScalarValue::Utf8(Some("b".to_string()))), lit(ScalarValue::Utf8(None)), ]; - in_list!(batch, list, &true, vec![Some(false), None, None]); + in_list!( + batch, + list, + &true, + vec![Some(false), None, None], + col_a.clone() + ); Ok(()) } @@ -351,6 +376,7 @@ mod tests { fn in_list_int64() -> Result<()> { let schema = Schema::new(vec![Field::new("a", DataType::Int64, true)]); let a = Int64Array::from(vec![Some(0), Some(2), None]); + let col_a = col("a", &schema)?; let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?; // expression: "a in (0, 1)" @@ -358,14 +384,26 @@ mod tests { lit(ScalarValue::Int64(Some(0))), lit(ScalarValue::Int64(Some(1))), ]; - in_list!(batch, list, &false, vec![Some(true), Some(false), None]); + in_list!( + batch, + list, + &false, + vec![Some(true), Some(false), None], + col_a.clone() + ); // expression: "a not in (0, 1)" let list = vec![ lit(ScalarValue::Int64(Some(0))), lit(ScalarValue::Int64(Some(1))), ]; - in_list!(batch, list, &true, vec![Some(false), Some(true), None]); + in_list!( + batch, + list, + &true, + vec![Some(false), Some(true), None], + col_a.clone() + ); // expression: "a in (0, 1, NULL)" let list = vec![ @@ -373,7 +411,13 @@ mod tests { lit(ScalarValue::Int64(Some(1))), lit(ScalarValue::Utf8(None)), ]; - in_list!(batch, list, &false, vec![Some(true), None, None]); + in_list!( + batch, + list, + &false, + vec![Some(true), None, None], + col_a.clone() + ); // expression: "a not in (0, 1, NULL)" let list = vec![ @@ -381,7 +425,13 @@ mod tests { lit(ScalarValue::Int64(Some(1))), lit(ScalarValue::Utf8(None)), ]; - in_list!(batch, list, &true, vec![Some(false), None, None]); + in_list!( + batch, + list, + &true, + vec![Some(false), None, None], + col_a.clone() + ); Ok(()) } @@ -390,6 +440,7 @@ mod tests { fn in_list_float64() -> Result<()> { let schema = Schema::new(vec![Field::new("a", DataType::Float64, true)]); let a = Float64Array::from(vec![Some(0.0), Some(0.2), None]); + let col_a = col("a", &schema)?; let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?; // expression: "a in (0.0, 0.2)" @@ -397,14 +448,26 @@ mod tests { lit(ScalarValue::Float64(Some(0.0))), lit(ScalarValue::Float64(Some(0.1))), ]; - in_list!(batch, list, &false, vec![Some(true), Some(false), None]); + in_list!( + batch, + list, + &false, + vec![Some(true), Some(false), None], + col_a.clone() + ); // expression: "a not in (0.0, 0.2)" let list = vec![ lit(ScalarValue::Float64(Some(0.0))), lit(ScalarValue::Float64(Some(0.1))), ]; - in_list!(batch, list, &true, vec![Some(false), Some(true), None]); + in_list!( + batch, + list, + &true, + vec![Some(false), Some(true), None], + col_a.clone() + ); // expression: "a in (0.0, 0.2, NULL)" let list = vec![ @@ -412,7 +475,13 @@ mod tests { lit(ScalarValue::Float64(Some(0.1))), lit(ScalarValue::Utf8(None)), ]; - in_list!(batch, list, &false, vec![Some(true), None, None]); + in_list!( + batch, + list, + &false, + vec![Some(true), None, None], + col_a.clone() + ); // expression: "a not in (0.0, 0.2, NULL)" let list = vec![ @@ -420,7 +489,13 @@ mod tests { lit(ScalarValue::Float64(Some(0.1))), lit(ScalarValue::Utf8(None)), ]; - in_list!(batch, list, &true, vec![Some(false), None, None]); + in_list!( + batch, + list, + &true, + vec![Some(false), None, None], + col_a.clone() + ); Ok(()) } @@ -429,29 +504,30 @@ mod tests { fn in_list_bool() -> Result<()> { let schema = Schema::new(vec![Field::new("a", DataType::Boolean, true)]); let a = BooleanArray::from(vec![Some(true), None]); + let col_a = col("a", &schema)?; let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?; // expression: "a in (true)" let list = vec![lit(ScalarValue::Boolean(Some(true)))]; - in_list!(batch, list, &false, vec![Some(true), None]); + in_list!(batch, list, &false, vec![Some(true), None], col_a.clone()); // expression: "a not in (true)" let list = vec![lit(ScalarValue::Boolean(Some(true)))]; - in_list!(batch, list, &true, vec![Some(false), None]); + in_list!(batch, list, &true, vec![Some(false), None], col_a.clone()); // expression: "a in (true, NULL)" let list = vec![ lit(ScalarValue::Boolean(Some(true))), lit(ScalarValue::Utf8(None)), ]; - in_list!(batch, list, &false, vec![Some(true), None]); + in_list!(batch, list, &false, vec![Some(true), None], col_a.clone()); // expression: "a not in (true, NULL)" let list = vec![ lit(ScalarValue::Boolean(Some(true))), lit(ScalarValue::Utf8(None)), ]; - in_list!(batch, list, &true, vec![Some(false), None]); + in_list!(batch, list, &true, vec![Some(false), None], col_a.clone()); Ok(()) } diff --git a/datafusion/src/physical_plan/expressions/is_not_null.rs b/datafusion/src/physical_plan/expressions/is_not_null.rs index 7ac2110b5022..cce27e36a68c 100644 --- a/datafusion/src/physical_plan/expressions/is_not_null.rs +++ b/datafusion/src/physical_plan/expressions/is_not_null.rs @@ -100,10 +100,10 @@ mod tests { fn is_not_null_op() -> Result<()> { let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]); let a = StringArray::from(vec![Some("foo"), None]); + let expr = is_not_null(col("a", &schema)?).unwrap(); let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?; // expression: "a is not null" - let expr = is_not_null(col("a")).unwrap(); let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); let result = result .as_any() diff --git a/datafusion/src/physical_plan/expressions/is_null.rs b/datafusion/src/physical_plan/expressions/is_null.rs index dfa53f3f7d26..dbb57dfa5f8b 100644 --- a/datafusion/src/physical_plan/expressions/is_null.rs +++ b/datafusion/src/physical_plan/expressions/is_null.rs @@ -100,10 +100,11 @@ mod tests { fn is_null_op() -> Result<()> { let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]); let a = StringArray::from(vec![Some("foo"), None]); - let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?; // expression: "a is null" - let expr = is_null(col("a")).unwrap(); + let expr = is_null(col("a", &schema)?).unwrap(); + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?; + let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); let result = result .as_any() diff --git a/datafusion/src/physical_plan/expressions/min_max.rs b/datafusion/src/physical_plan/expressions/min_max.rs index ea917d30d940..680e739cbf29 100644 --- a/datafusion/src/physical_plan/expressions/min_max.rs +++ b/datafusion/src/physical_plan/expressions/min_max.rs @@ -278,7 +278,7 @@ macro_rules! min_max { } e => { return Err(DataFusionError::Internal(format!( - "MIN/MAX is not expected to receive a scalar {:?}", + "MIN/MAX is not expected to receive scalars of incompatible types {:?}", e ))) } diff --git a/datafusion/src/physical_plan/expressions/mod.rs b/datafusion/src/physical_plan/expressions/mod.rs index f8cb40cbacbd..0b32dca0467d 100644 --- a/datafusion/src/physical_plan/expressions/mod.rs +++ b/datafusion/src/physical_plan/expressions/mod.rs @@ -66,6 +66,7 @@ pub use nullif::{nullif_func, SUPPORTED_NULLIF_TYPES}; pub use row_number::RowNumber; pub use sum::{sum_return_type, Sum}; pub use try_cast::{try_cast, TryCastExpr}; + /// returns the name of the state pub fn format_state_name(name: &str, state_name: &str) -> String { format!("{}[{}]", name, state_name) @@ -126,8 +127,11 @@ mod tests { let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![$ARRAY])?; - let agg = - Arc::new(<$OP>::new(col("a"), "bla".to_string(), $EXPECTED_DATATYPE)); + let agg = Arc::new(<$OP>::new( + col("a", &schema)?, + "bla".to_string(), + $EXPECTED_DATATYPE, + )); let actual = aggregate(&batch, agg)?; let expected = ScalarValue::from($EXPECTED); diff --git a/datafusion/src/physical_plan/expressions/not.rs b/datafusion/src/physical_plan/expressions/not.rs index 7a997b61b488..341d38a10aa1 100644 --- a/datafusion/src/physical_plan/expressions/not.rs +++ b/datafusion/src/physical_plan/expressions/not.rs @@ -127,7 +127,7 @@ mod tests { fn neg_op() -> Result<()> { let schema = Schema::new(vec![Field::new("a", DataType::Boolean, true)]); - let expr = not(col("a"), &schema)?; + let expr = not(col("a", &schema)?, &schema)?; assert_eq!(expr.data_type(&schema)?, DataType::Boolean); assert!(expr.nullable(&schema)?); @@ -152,7 +152,7 @@ mod tests { fn neg_op_not_null() { let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]); - let expr = not(col("a"), &schema); + let expr = not(col("a", &schema).unwrap(), &schema); assert!(expr.is_err()); } } diff --git a/datafusion/src/physical_plan/expressions/nth_value.rs b/datafusion/src/physical_plan/expressions/nth_value.rs index 16897d45119f..577c19b54ade 100644 --- a/datafusion/src/physical_plan/expressions/nth_value.rs +++ b/datafusion/src/physical_plan/expressions/nth_value.rs @@ -148,7 +148,7 @@ impl BuiltInWindowFunctionExpr for NthValue { mod tests { use super::*; use crate::error::Result; - use crate::physical_plan::expressions::col; + use crate::physical_plan::expressions::Column; use arrow::record_batch::RecordBatch; use arrow::{array::*, datatypes::*}; @@ -166,32 +166,46 @@ mod tests { #[test] fn first_value() -> Result<()> { - let first_value = - NthValue::first_value("first_value".to_owned(), col("arr"), DataType::Int32); + let first_value = NthValue::first_value( + "first_value".to_owned(), + Arc::new(Column::new("arr", 0)), + DataType::Int32, + ); test_i32_result(first_value, vec![1; 8])?; Ok(()) } #[test] fn last_value() -> Result<()> { - let last_value = - NthValue::last_value("last_value".to_owned(), col("arr"), DataType::Int32); + let last_value = NthValue::last_value( + "last_value".to_owned(), + Arc::new(Column::new("arr", 0)), + DataType::Int32, + ); test_i32_result(last_value, vec![8; 8])?; Ok(()) } #[test] fn nth_value_1() -> Result<()> { - let nth_value = - NthValue::nth_value("nth_value".to_owned(), col("arr"), DataType::Int32, 1)?; + let nth_value = NthValue::nth_value( + "nth_value".to_owned(), + Arc::new(Column::new("arr", 0)), + DataType::Int32, + 1, + )?; test_i32_result(nth_value, vec![1; 8])?; Ok(()) } #[test] fn nth_value_2() -> Result<()> { - let nth_value = - NthValue::nth_value("nth_value".to_owned(), col("arr"), DataType::Int32, 2)?; + let nth_value = NthValue::nth_value( + "nth_value".to_owned(), + Arc::new(Column::new("arr", 0)), + DataType::Int32, + 2, + )?; test_i32_result(nth_value, vec![-2; 8])?; Ok(()) } diff --git a/datafusion/src/physical_plan/expressions/try_cast.rs b/datafusion/src/physical_plan/expressions/try_cast.rs index 5e402fdea28a..1ba4a50260d4 100644 --- a/datafusion/src/physical_plan/expressions/try_cast.rs +++ b/datafusion/src/physical_plan/expressions/try_cast.rs @@ -139,10 +139,13 @@ mod tests { RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; // verify that we can construct the expression - let expression = try_cast(col("a"), &schema, $TYPE)?; + let expression = try_cast(col("a", &schema)?, &schema, $TYPE)?; // verify that its display is correct - assert_eq!(format!("CAST(a AS {:?})", $TYPE), format!("{}", expression)); + assert_eq!( + format!("CAST(a@0 AS {:?})", $TYPE), + format!("{}", expression) + ); // verify that the expression's type is correct assert_eq!(expression.data_type(&schema)?, $TYPE); @@ -241,7 +244,7 @@ mod tests { // Ensure a useful error happens at plan time if invalid casts are used let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); - let result = try_cast(col("a"), &schema, DataType::LargeBinary); + let result = try_cast(col("a", &schema).unwrap(), &schema, DataType::LargeBinary); result.expect_err("expected Invalid CAST"); } } diff --git a/datafusion/src/physical_plan/filter.rs b/datafusion/src/physical_plan/filter.rs index 0a8c825aba1a..9e7fa9df9711 100644 --- a/datafusion/src/physical_plan/filter.rs +++ b/datafusion/src/physical_plan/filter.rs @@ -223,14 +223,14 @@ mod tests { let predicate: Arc = binary( binary( - col("c2"), + col("c2", &schema)?, Operator::Gt, lit(ScalarValue::from(1u32)), &schema, )?, Operator::And, binary( - col("c2"), + col("c2", &schema)?, Operator::Lt, lit(ScalarValue::from(4u32)), &schema, diff --git a/datafusion/src/physical_plan/functions.rs b/datafusion/src/physical_plan/functions.rs index 0e2be51d3ebc..01f7e95a0ee9 100644 --- a/datafusion/src/physical_plan/functions.rs +++ b/datafusion/src/physical_plan/functions.rs @@ -3651,7 +3651,7 @@ mod tests { let expr = create_physical_expr( &BuiltinScalarFunction::Array, - &[col("a"), col("b")], + &[col("a", &schema)?, col("b", &schema)?], &schema, &ctx_state, )?; @@ -3718,7 +3718,7 @@ mod tests { let columns: Vec = vec![col_value]; let expr = create_physical_expr( &BuiltinScalarFunction::RegexpMatch, - &[col("a"), pattern], + &[col("a", &schema)?, pattern], &schema, &ctx_state, )?; diff --git a/datafusion/src/physical_plan/hash_aggregate.rs b/datafusion/src/physical_plan/hash_aggregate.rs index f1611ebd7a77..250ba2b08306 100644 --- a/datafusion/src/physical_plan/hash_aggregate.rs +++ b/datafusion/src/physical_plan/hash_aggregate.rs @@ -663,9 +663,12 @@ async fn compute_grouped_hash_aggregate( aggr_expr: Vec>, mut input: SendableRecordBatchStream, ) -> ArrowResult { - // the expressions to evaluate the batch, one vec of expressions per aggregation - let aggregate_expressions = aggregate_expressions(&aggr_expr, &mode) - .map_err(DataFusionError::into_arrow_external_error)?; + // The expressions to evaluate the batch, one vec of expressions per aggregation. + // Assume create_schema() always put group columns in front of aggr columns, we set + // col_idx_base to group expression count. + let aggregate_expressions = + aggregate_expressions(&aggr_expr, &mode, group_expr.len()) + .map_err(DataFusionError::into_arrow_external_error)?; // mapping key -> (set of accumulators, indices of the key in the batch) // * the indexes are updated at each row @@ -794,14 +797,21 @@ fn evaluate_many( .collect::>>() } -/// uses `state_fields` to build a vec of expressions required to merge the AggregateExpr' accumulator's state. +/// uses `state_fields` to build a vec of physical column expressions required to merge the +/// AggregateExpr' accumulator's state. +/// +/// `index_base` is the starting physical column index for the next expanded state field. fn merge_expressions( + index_base: usize, expr: &Arc, ) -> Result>> { Ok(expr .state_fields()? .iter() - .map(|f| Arc::new(Column::new(f.name())) as Arc) + .enumerate() + .map(|(idx, f)| { + Arc::new(Column::new(f.name(), index_base + idx)) as Arc + }) .collect::>()) } @@ -809,22 +819,27 @@ fn merge_expressions( /// The expressions are different depending on `mode`: /// * Partial: AggregateExpr::expressions /// * Final: columns of `AggregateExpr::state_fields()` -/// The return value is to be understood as: -/// * index 0 is the aggregation -/// * index 1 is the expression i of the aggregation fn aggregate_expressions( aggr_expr: &[Arc], mode: &AggregateMode, + col_idx_base: usize, ) -> Result>>> { match mode { AggregateMode::Partial => { Ok(aggr_expr.iter().map(|agg| agg.expressions()).collect()) } // in this mode, we build the merge expressions of the aggregation - AggregateMode::Final | AggregateMode::FinalPartitioned => Ok(aggr_expr - .iter() - .map(|agg| merge_expressions(agg)) - .collect::>>()?), + AggregateMode::Final | AggregateMode::FinalPartitioned => { + let mut col_idx_base = col_idx_base; + Ok(aggr_expr + .iter() + .map(|agg| { + let exprs = merge_expressions(col_idx_base, agg)?; + col_idx_base += exprs.len(); + Ok(exprs) + }) + .collect::>>()?) + } } } @@ -846,10 +861,8 @@ async fn compute_hash_aggregate( ) -> ArrowResult { let mut accumulators = create_accumulators(&aggr_expr) .map_err(DataFusionError::into_arrow_external_error)?; - - let expressions = aggregate_expressions(&aggr_expr, &mode) + let expressions = aggregate_expressions(&aggr_expr, &mode, 0) .map_err(DataFusionError::into_arrow_external_error)?; - let expressions = Arc::new(expressions); // 1 for each batch, update / merge accumulators with the expressions' values @@ -1253,16 +1266,17 @@ mod tests { /// build the aggregates on the data from some_data() and check the results async fn check_aggregates(input: Arc) -> Result<()> { + let input_schema = input.schema(); + let groups: Vec<(Arc, String)> = - vec![(col("a"), "a".to_string())]; + vec![(col("a", &input_schema)?, "a".to_string())]; let aggregates: Vec> = vec![Arc::new(Avg::new( - col("b"), + col("b", &input_schema)?, "AVG(b)".to_string(), DataType::Float64, ))]; - let input_schema = input.schema(); let partial_aggregate = Arc::new(HashAggregateExec::try_new( AggregateMode::Partial, groups.clone(), @@ -1286,8 +1300,9 @@ mod tests { let merge = Arc::new(MergeExec::new(partial_aggregate)); - let final_group: Vec> = - (0..groups.len()).map(|i| col(&groups[i].1)).collect(); + let final_group: Vec> = (0..groups.len()) + .map(|i| col(&groups[i].1, &input_schema)) + .collect::>()?; let merged_aggregate = Arc::new(HashAggregateExec::try_new( AggregateMode::Final, diff --git a/datafusion/src/physical_plan/hash_join.rs b/datafusion/src/physical_plan/hash_join.rs index 928392a84433..ad356079387a 100644 --- a/datafusion/src/physical_plan/hash_join.rs +++ b/datafusion/src/physical_plan/hash_join.rs @@ -52,7 +52,7 @@ use arrow::array::{ UInt64Array, UInt8Array, }; -use super::expressions::col; +use super::expressions::Column; use super::{ hash_utils::{build_join_schema, check_join_is_valid, JoinOn, JoinType}, merge::MergeExec, @@ -64,6 +64,7 @@ use super::{ SendableRecordBatchStream, }; use crate::physical_plan::coalesce_batches::concat_batches; +use crate::physical_plan::PhysicalExpr; use log::debug; // Maps a `u64` hash value based on the left ["on" values] to a list of indices with this key's value. @@ -90,7 +91,7 @@ pub struct HashJoinExec { /// right (probe) side which are filtered by the hash table right: Arc, /// Set of common columns used to join on - on: Vec<(String, String)>, + on: Vec<(Column, Column)>, /// How the join is performed join_type: JoinType, /// The schema once the join is applied @@ -127,26 +128,21 @@ impl HashJoinExec { pub fn try_new( left: Arc, right: Arc, - on: &JoinOn, + on: JoinOn, join_type: &JoinType, partition_mode: PartitionMode, ) -> Result { let left_schema = left.schema(); let right_schema = right.schema(); - check_join_is_valid(&left_schema, &right_schema, on)?; + check_join_is_valid(&left_schema, &right_schema, &on)?; let schema = Arc::new(build_join_schema( &left_schema, &right_schema, - on, + &on, join_type, )); - let on = on - .iter() - .map(|(l, r)| (l.to_string(), r.to_string())) - .collect(); - let random_state = RandomState::with_seeds(0, 0, 0, 0); Ok(HashJoinExec { @@ -172,7 +168,7 @@ impl HashJoinExec { } /// Set of common columns used to join on - pub fn on(&self) -> &[(String, String)] { + pub fn on(&self) -> &[(Column, Column)] { &self.on } @@ -236,7 +232,7 @@ impl ExecutionPlan for HashJoinExec { 2 => Ok(Arc::new(HashJoinExec::try_new( children[0].clone(), children[1].clone(), - &self.on, + self.on.clone(), &self.join_type, self.mode, )?)), @@ -307,10 +303,10 @@ impl ExecutionPlan for HashJoinExec { *build_side = Some(left_side.clone()); debug!( - "Built build-side of hash join containing {} rows in {} ms", - num_rows, - start.elapsed().as_millis() - ); + "Built build-side of hash join containing {} rows in {} ms", + num_rows, + start.elapsed().as_millis() + ); left_side } @@ -372,7 +368,7 @@ impl ExecutionPlan for HashJoinExec { // we have the batches and the hash map with their keys. We can how create a stream // over the right that uses this information to issue new batches. - let stream = self.right.execute(partition).await?; + let right_stream = self.right.execute(partition).await?; let on_right = self.on.iter().map(|on| on.1.clone()).collect::>(); let column_indices = self.column_indices_from_schema()?; @@ -383,23 +379,17 @@ impl ExecutionPlan for HashJoinExec { } JoinType::Inner | JoinType::Right => vec![], }; - Ok(Box::pin(HashJoinStream { - schema: self.schema.clone(), + Ok(Box::pin(HashJoinStream::new( + self.schema.clone(), on_left, on_right, - join_type: self.join_type, + self.join_type, left_data, - right: stream, + right_stream, column_indices, - num_input_batches: 0, - num_input_rows: 0, - num_output_batches: 0, - num_output_rows: 0, - join_time: 0, - random_state: self.random_state.clone(), + self.random_state.clone(), visited_left_side, - is_exhausted: false, - })) + ))) } fn fmt_as( @@ -422,7 +412,7 @@ impl ExecutionPlan for HashJoinExec { /// Updates `hash` with new entries from [RecordBatch] evaluated against the expressions `on`, /// assuming that the [RecordBatch] corresponds to the `index`th fn update_hash( - on: &[String], + on: &[Column], batch: &RecordBatch, hash: &mut JoinHashMap, offset: usize, @@ -432,7 +422,7 @@ fn update_hash( // evaluate the keys let keys_values = on .iter() - .map(|name| Ok(col(name).evaluate(batch)?.into_array(batch.num_rows()))) + .map(|c| Ok(c.evaluate(batch)?.into_array(batch.num_rows()))) .collect::>>()?; // calculate the hash values @@ -461,9 +451,9 @@ struct HashJoinStream { /// Input schema schema: Arc, /// columns from the left - on_left: Vec, + on_left: Vec, /// columns from the right used to compute the hash - on_right: Vec, + on_right: Vec, /// type of the join join_type: JoinType, /// information from the left @@ -490,6 +480,39 @@ struct HashJoinStream { is_exhausted: bool, } +#[allow(clippy::too_many_arguments)] +impl HashJoinStream { + fn new( + schema: Arc, + on_left: Vec, + on_right: Vec, + join_type: JoinType, + left_data: JoinLeftData, + right: SendableRecordBatchStream, + column_indices: Vec, + random_state: RandomState, + visited_left_side: Vec, + ) -> Self { + HashJoinStream { + schema, + on_left, + on_right, + join_type, + left_data, + right, + column_indices, + num_input_batches: 0, + num_input_rows: 0, + num_output_batches: 0, + num_output_rows: 0, + join_time: 0, + random_state, + visited_left_side, + is_exhausted: false, + } + } +} + impl RecordBatchStream for HashJoinStream { fn schema(&self) -> SchemaRef { self.schema.clone() @@ -531,8 +554,8 @@ fn build_batch_from_indices( fn build_batch( batch: &RecordBatch, left_data: &JoinLeftData, - on_left: &[String], - on_right: &[String], + on_left: &[Column], + on_right: &[Column], join_type: JoinType, schema: &Schema, column_indices: &[ColumnIndex], @@ -590,21 +613,17 @@ fn build_join_indexes( left_data: &JoinLeftData, right: &RecordBatch, join_type: JoinType, - left_on: &[String], - right_on: &[String], + left_on: &[Column], + right_on: &[Column], random_state: &RandomState, ) -> Result<(UInt64Array, UInt32Array)> { let keys_values = right_on .iter() - .map(|name| Ok(col(name).evaluate(right)?.into_array(right.num_rows()))) + .map(|c| Ok(c.evaluate(right)?.into_array(right.num_rows()))) .collect::>>()?; let left_join_values = left_on .iter() - .map(|name| { - Ok(col(name) - .evaluate(&left_data.1)? - .into_array(left_data.1.num_rows())) - }) + .map(|c| Ok(c.evaluate(&left_data.1)?.into_array(left_data.1.num_rows()))) .collect::>>()?; let hashes_buffer = &mut vec![0; keys_values[0].len()]; let hash_values = create_hashes(&keys_values, random_state, hashes_buffer)?; @@ -1250,6 +1269,7 @@ impl Stream for HashJoinStream { | JoinType::Right => {} } + // End of right batch, print stats in debug mode debug!( "Processed {} probe-side input batches containing {} rows and \ produced {} output batches containing {} rows in {} ms", @@ -1269,7 +1289,9 @@ impl Stream for HashJoinStream { mod tests { use crate::{ assert_batches_sorted_eq, - physical_plan::{common, memory::MemoryExec}, + physical_plan::{ + common, expressions::Column, memory::MemoryExec, repartition::RepartitionExec, + }, test::{build_table_i32, columns}, }; @@ -1289,14 +1311,74 @@ mod tests { fn join( left: Arc, right: Arc, - on: &[(&str, &str)], + on: JoinOn, join_type: &JoinType, ) -> Result { - let on: Vec<_> = on + HashJoinExec::try_new(left, right, on, join_type, PartitionMode::CollectLeft) + } + + async fn join_collect( + left: Arc, + right: Arc, + on: JoinOn, + join_type: &JoinType, + ) -> Result<(Vec, Vec)> { + let join = join(left, right, on, join_type)?; + let columns = columns(&join.schema()); + + let stream = join.execute(0).await?; + let batches = common::collect(stream).await?; + + Ok((columns, batches)) + } + + async fn partitioned_join_collect( + left: Arc, + right: Arc, + on: JoinOn, + join_type: &JoinType, + ) -> Result<(Vec, Vec)> { + let partition_count = 4; + + let (left_expr, right_expr) = on .iter() - .map(|(l, r)| (l.to_string(), r.to_string())) - .collect(); - HashJoinExec::try_new(left, right, &on, join_type, PartitionMode::CollectLeft) + .map(|(l, r)| { + ( + Arc::new(l.clone()) as Arc, + Arc::new(r.clone()) as Arc, + ) + }) + .unzip(); + + let join = HashJoinExec::try_new( + Arc::new(RepartitionExec::try_new( + left, + Partitioning::Hash(left_expr, partition_count), + )?), + Arc::new(RepartitionExec::try_new( + right, + Partitioning::Hash(right_expr, partition_count), + )?), + on, + join_type, + PartitionMode::Partitioned, + )?; + + let columns = columns(&join.schema()); + + let mut batches = vec![]; + for i in 0..partition_count { + let stream = join.execute(i).await?; + let more_batches = common::collect(stream).await?; + batches.extend( + more_batches + .into_iter() + .filter(|b| b.num_rows() > 0) + .collect::>(), + ); + } + + Ok((columns, batches)) } #[tokio::test] @@ -1311,15 +1393,58 @@ mod tests { ("b1", &vec![4, 5, 6]), ("c2", &vec![70, 80, 90]), ); - let on = &[("b1", "b1")]; - let join = join(left, right, on, &JoinType::Inner)?; + let on = vec![( + Column::new_with_schema("b1", &left.schema())?, + Column::new_with_schema("b1", &right.schema())?, + )]; + + let (columns, batches) = + join_collect(left.clone(), right.clone(), on.clone(), &JoinType::Inner) + .await?; - let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "c2"]); - let stream = join.execute(0).await?; - let batches = common::collect(stream).await?; + let expected = vec![ + "+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | c2 |", + "+----+----+----+----+----+", + "| 1 | 4 | 7 | 10 | 70 |", + "| 2 | 5 | 8 | 20 | 80 |", + "| 3 | 5 | 9 | 20 | 80 |", + "+----+----+----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &batches); + + Ok(()) + } + + #[tokio::test] + async fn partitioned_join_inner_one() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 5]), // this has a repetition + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), + ("c2", &vec![70, 80, 90]), + ); + let on = vec![( + Column::new_with_schema("b1", &left.schema())?, + Column::new_with_schema("b1", &right.schema())?, + )]; + + let (columns, batches) = partitioned_join_collect( + left.clone(), + right.clone(), + on.clone(), + &JoinType::Inner, + ) + .await?; + + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "c2"]); let expected = vec![ "+----+----+----+----+----+", @@ -1347,16 +1472,15 @@ mod tests { ("b2", &vec![4, 5, 6]), ("c2", &vec![70, 80, 90]), ); - let on = &[("b1", "b2")]; + let on = vec![( + Column::new_with_schema("b1", &left.schema())?, + Column::new_with_schema("b2", &right.schema())?, + )]; - let join = join(left, right, on, &JoinType::Inner)?; + let (columns, batches) = join_collect(left, right, on, &JoinType::Inner).await?; - let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); - let stream = join.execute(0).await?; - let batches = common::collect(stream).await?; - let expected = vec![ "+----+----+----+----+----+----+", "| a1 | b1 | c1 | a2 | b2 | c2 |", @@ -1384,15 +1508,21 @@ mod tests { ("b2", &vec![1, 2, 2]), ("c2", &vec![70, 80, 90]), ); - let on = &[("a1", "a1"), ("b2", "b2")]; + let on = vec![ + ( + Column::new_with_schema("a1", &left.schema())?, + Column::new_with_schema("a1", &right.schema())?, + ), + ( + Column::new_with_schema("b2", &left.schema())?, + Column::new_with_schema("b2", &right.schema())?, + ), + ]; - let join = join(left, right, on, &JoinType::Inner)?; + let (columns, batches) = join_collect(left, right, on, &JoinType::Inner).await?; - let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "b2", "c1", "c2"]); - let stream = join.execute(0).await?; - let batches = common::collect(stream).await?; assert_eq!(batches.len(), 1); let expected = vec![ @@ -1430,15 +1560,21 @@ mod tests { ("b2", &vec![1, 2, 2]), ("c2", &vec![70, 80, 90]), ); - let on = &[("a1", "a1"), ("b2", "b2")]; + let on = vec![ + ( + Column::new_with_schema("a1", &left.schema())?, + Column::new_with_schema("a1", &right.schema())?, + ), + ( + Column::new_with_schema("b2", &left.schema())?, + Column::new_with_schema("b2", &right.schema())?, + ), + ]; - let join = join(left, right, on, &JoinType::Inner)?; + let (columns, batches) = join_collect(left, right, on, &JoinType::Inner).await?; - let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "b2", "c1", "c2"]); - let stream = join.execute(0).await?; - let batches = common::collect(stream).await?; assert_eq!(batches.len(), 1); let expected = vec![ @@ -1477,7 +1613,10 @@ mod tests { MemoryExec::try_new(&[vec![batch1], vec![batch2]], schema, None).unwrap(), ); - let on = &[("b1", "b1")]; + let on = vec![( + Column::new_with_schema("b1", &left.schema())?, + Column::new_with_schema("b1", &right.schema())?, + )]; let join = join(left, right, on, &JoinType::Inner)?; @@ -1540,7 +1679,10 @@ mod tests { ("b1", &vec![4, 5, 6]), ("c2", &vec![70, 80, 90]), ); - let on = &[("b1", "b1")]; + let on = vec![( + Column::new_with_schema("b1", &left.schema()).unwrap(), + Column::new_with_schema("b1", &right.schema()).unwrap(), + )]; let join = join(left, right, on, &JoinType::Left).unwrap(); @@ -1578,7 +1720,10 @@ mod tests { ("b2", &vec![4, 5, 6]), ("c2", &vec![70, 80, 90]), ); - let on = &[("b1", "b2")]; + let on = vec![( + Column::new_with_schema("b1", &left.schema()).unwrap(), + Column::new_with_schema("b2", &right.schema()).unwrap(), + )]; let join = join(left, right, on, &JoinType::Full).unwrap(); @@ -1613,7 +1758,10 @@ mod tests { ("c1", &vec![7, 8, 9]), ); let right = build_table_i32(("a2", &vec![]), ("b1", &vec![]), ("c2", &vec![])); - let on = &[("b1", "b1")]; + let on = vec![( + Column::new_with_schema("b1", &left.schema()).unwrap(), + Column::new_with_schema("b1", &right.schema()).unwrap(), + )]; let schema = right.schema(); let right = Arc::new(MemoryExec::try_new(&[vec![right]], schema, None).unwrap()); let join = join(left, right, on, &JoinType::Left).unwrap(); @@ -1645,7 +1793,10 @@ mod tests { ("c1", &vec![7, 8, 9]), ); let right = build_table_i32(("a2", &vec![]), ("b2", &vec![]), ("c2", &vec![])); - let on = &[("b1", "b2")]; + let on = vec![( + Column::new_with_schema("b1", &left.schema()).unwrap(), + Column::new_with_schema("b2", &right.schema()).unwrap(), + )]; let schema = right.schema(); let right = Arc::new(MemoryExec::try_new(&[vec![right]], schema, None).unwrap()); let join = join(left, right, on, &JoinType::Full).unwrap(); @@ -1681,15 +1832,55 @@ mod tests { ("b1", &vec![4, 5, 6]), ("c2", &vec![70, 80, 90]), ); - let on = &[("b1", "b1")]; + let on = vec![( + Column::new_with_schema("b1", &left.schema())?, + Column::new_with_schema("b1", &right.schema())?, + )]; + + let (columns, batches) = + join_collect(left.clone(), right.clone(), on.clone(), &JoinType::Left) + .await?; + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "c2"]); - let join = join(left, right, on, &JoinType::Left)?; + let expected = vec![ + "+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | c2 |", + "+----+----+----+----+----+", + "| 1 | 4 | 7 | 10 | 70 |", + "| 2 | 5 | 8 | 20 | 80 |", + "| 3 | 7 | 9 | | |", + "+----+----+----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &batches); - let columns = columns(&join.schema()); - assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "c2"]); + Ok(()) + } - let stream = join.execute(0).await?; - let batches = common::collect(stream).await?; + #[tokio::test] + async fn partitioned_join_left_one() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 7]), // 7 does not exist on the right + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), + ("c2", &vec![70, 80, 90]), + ); + let on = vec![( + Column::new_with_schema("b1", &left.schema())?, + Column::new_with_schema("b1", &right.schema())?, + )]; + + let (columns, batches) = partitioned_join_collect( + left.clone(), + right.clone(), + on.clone(), + &JoinType::Left, + ) + .await?; + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "c2"]); let expected = vec![ "+----+----+----+----+----+", @@ -1717,7 +1908,10 @@ mod tests { ("b1", &vec![4, 5, 6, 5]), // 5 is double on the right ("c2", &vec![70, 80, 90, 100]), ); - let on = &[("b1", "b1")]; + let on = vec![( + Column::new_with_schema("b1", &left.schema())?, + Column::new_with_schema("b1", &right.schema())?, + )]; let join = join(left, right, on, &JoinType::Semi)?; @@ -1753,7 +1947,10 @@ mod tests { ("b1", &vec![4, 5, 6, 5]), // 5 is double on the right ("c2", &vec![70, 80, 90, 100]), ); - let on = &[("b1", "b1")]; + let on = vec![( + Column::new_with_schema("b1", &left.schema())?, + Column::new_with_schema("b1", &right.schema())?, + )]; let join = join(left, right, on, &JoinType::Anti)?; @@ -1787,15 +1984,51 @@ mod tests { ("b1", &vec![4, 5, 6]), // 6 does not exist on the left ("c2", &vec![70, 80, 90]), ); - let on = &[("b1", "b1")]; + let on = vec![( + Column::new_with_schema("b1", &left.schema())?, + Column::new_with_schema("b1", &right.schema())?, + )]; - let join = join(left, right, on, &JoinType::Right)?; + let (columns, batches) = join_collect(left, right, on, &JoinType::Right).await?; - let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "c1", "a2", "b1", "c2"]); - let stream = join.execute(0).await?; - let batches = common::collect(stream).await?; + let expected = vec![ + "+----+----+----+----+----+", + "| a1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+", + "| | | 30 | 6 | 90 |", + "| 1 | 7 | 10 | 4 | 70 |", + "| 2 | 8 | 20 | 5 | 80 |", + "+----+----+----+----+----+", + ]; + + assert_batches_sorted_eq!(expected, &batches); + + Ok(()) + } + + #[tokio::test] + async fn partitioned_join_right_one() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 7]), + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), // 6 does not exist on the left + ("c2", &vec![70, 80, 90]), + ); + let on = vec![( + Column::new_with_schema("b1", &left.schema())?, + Column::new_with_schema("b1", &right.schema())?, + )]; + + let (columns, batches) = + partitioned_join_collect(left, right, on, &JoinType::Right).await?; + + assert_eq!(columns, vec!["a1", "c1", "a2", "b1", "c2"]); let expected = vec![ "+----+----+----+----+----+", @@ -1824,7 +2057,10 @@ mod tests { ("b2", &vec![4, 5, 6]), ("c2", &vec![70, 80, 90]), ); - let on = &[("b1", "b2")]; + let on = vec![( + Column::new_with_schema("b1", &left.schema()).unwrap(), + Column::new_with_schema("b2", &right.schema()).unwrap(), + )]; let join = join(left, right, on, &JoinType::Full)?; @@ -1904,8 +2140,8 @@ mod tests { &left_data, &right, JoinType::Inner, - &["a".to_string()], - &["a".to_string()], + &[Column::new("a", 0)], + &[Column::new("a", 0)], &random_state, )?; @@ -1914,7 +2150,6 @@ mod tests { left_ids.append_value(1)?; let mut right_ids = UInt32Builder::new(0); - right_ids.append_value(0)?; right_ids.append_value(1)?; diff --git a/datafusion/src/physical_plan/hash_utils.rs b/datafusion/src/physical_plan/hash_utils.rs index a48710bfbfc3..0cf0b9212cd2 100644 --- a/datafusion/src/physical_plan/hash_utils.rs +++ b/datafusion/src/physical_plan/hash_utils.rs @@ -21,6 +21,8 @@ use crate::error::{DataFusionError, Result}; use arrow::datatypes::{Field, Schema}; use std::collections::HashSet; +use crate::physical_plan::expressions::Column; + /// All valid types of joins. #[derive(Clone, Copy, Debug, Eq, PartialEq)] pub enum JoinType { @@ -39,14 +41,25 @@ pub enum JoinType { } /// The on clause of the join, as vector of (left, right) columns. -pub type JoinOn = [(String, String)]; +pub type JoinOn = Vec<(Column, Column)>; +/// Reference for JoinOn. +pub type JoinOnRef<'a> = &'a [(Column, Column)]; /// Checks whether the schemas "left" and "right" and columns "on" represent a valid join. /// They are valid whenever their columns' intersection equals the set `on` -pub fn check_join_is_valid(left: &Schema, right: &Schema, on: &JoinOn) -> Result<()> { - let left: HashSet = left.fields().iter().map(|f| f.name().clone()).collect(); - let right: HashSet = - right.fields().iter().map(|f| f.name().clone()).collect(); +pub fn check_join_is_valid(left: &Schema, right: &Schema, on: JoinOnRef) -> Result<()> { + let left: HashSet = left + .fields() + .iter() + .enumerate() + .map(|(idx, f)| Column::new(f.name(), idx)) + .collect(); + let right: HashSet = right + .fields() + .iter() + .enumerate() + .map(|(idx, f)| Column::new(f.name(), idx)) + .collect(); check_join_set_is_valid(&left, &right, on) } @@ -54,14 +67,14 @@ pub fn check_join_is_valid(left: &Schema, right: &Schema, on: &JoinOn) -> Result /// Checks whether the sets left, right and on compose a valid join. /// They are valid whenever their intersection equals the set `on` fn check_join_set_is_valid( - left: &HashSet, - right: &HashSet, - on: &JoinOn, + left: &HashSet, + right: &HashSet, + on: &[(Column, Column)], ) -> Result<()> { - let on_left = &on.iter().map(|on| on.0.to_string()).collect::>(); + let on_left = &on.iter().map(|on| on.0.clone()).collect::>(); let left_missing = on_left.difference(left).collect::>(); - let on_right = &on.iter().map(|on| on.1.to_string()).collect::>(); + let on_right = &on.iter().map(|on| on.1.clone()).collect::>(); let right_missing = on_right.difference(right).collect::>(); if !left_missing.is_empty() | !right_missing.is_empty() { @@ -75,7 +88,7 @@ fn check_join_set_is_valid( let remaining = right .difference(on_right) .cloned() - .collect::>(); + .collect::>(); let collisions = left.intersection(&remaining).collect::>(); @@ -94,7 +107,7 @@ fn check_join_set_is_valid( pub fn build_join_schema( left: &Schema, right: &Schema, - on: &JoinOn, + on: JoinOnRef, join_type: &JoinType, ) -> Schema { let fields: Vec = match join_type { @@ -102,8 +115,8 @@ pub fn build_join_schema( // remove right-side join keys if they have the same names as the left-side let duplicate_keys = &on .iter() - .filter(|(l, r)| l == r) - .map(|on| on.1.to_string()) + .filter(|(l, r)| l.name() == r.name()) + .map(|on| on.1.name()) .collect::>(); let left_fields = left.fields().iter(); @@ -111,7 +124,7 @@ pub fn build_join_schema( let right_fields = right .fields() .iter() - .filter(|f| !duplicate_keys.contains(f.name())); + .filter(|f| !duplicate_keys.contains(f.name().as_str())); // left then right left_fields.chain(right_fields).cloned().collect() @@ -120,14 +133,14 @@ pub fn build_join_schema( // remove left-side join keys if they have the same names as the right-side let duplicate_keys = &on .iter() - .filter(|(l, r)| l == r) - .map(|on| on.1.to_string()) + .filter(|(l, r)| l.name() == r.name()) + .map(|on| on.1.name()) .collect::>(); let left_fields = left .fields() .iter() - .filter(|f| !duplicate_keys.contains(f.name())); + .filter(|f| !duplicate_keys.contains(f.name().as_str())); let right_fields = right.fields().iter(); @@ -141,24 +154,25 @@ pub fn build_join_schema( #[cfg(test)] mod tests { - use super::*; - fn check(left: &[&str], right: &[&str], on: &[(&str, &str)]) -> Result<()> { - let left = left.iter().map(|x| x.to_string()).collect::>(); - let right = right.iter().map(|x| x.to_string()).collect::>(); - let on: Vec<_> = on + fn check(left: &[Column], right: &[Column], on: &[(Column, Column)]) -> Result<()> { + let left = left + .iter() + .map(|x| x.to_owned()) + .collect::>(); + let right = right .iter() - .map(|(l, r)| (l.to_string(), r.to_string())) - .collect(); - check_join_set_is_valid(&left, &right, &on) + .map(|x| x.to_owned()) + .collect::>(); + check_join_set_is_valid(&left, &right, on) } #[test] fn check_valid() -> Result<()> { - let left = vec!["a", "b1"]; - let right = vec!["a", "b2"]; - let on = &[("a", "a")]; + let left = vec![Column::new("a", 0), Column::new("b1", 1)]; + let right = vec![Column::new("a", 0), Column::new("b2", 1)]; + let on = &[(Column::new("a", 0), Column::new("a", 0))]; check(&left, &right, on)?; Ok(()) @@ -166,18 +180,18 @@ mod tests { #[test] fn check_not_in_right() { - let left = vec!["a", "b"]; - let right = vec!["b"]; - let on = &[("a", "a")]; + let left = vec![Column::new("a", 0), Column::new("b", 1)]; + let right = vec![Column::new("b", 0)]; + let on = &[(Column::new("a", 0), Column::new("a", 0))]; assert!(check(&left, &right, on).is_err()); } #[test] fn check_not_in_left() { - let left = vec!["b"]; - let right = vec!["a"]; - let on = &[("a", "a")]; + let left = vec![Column::new("b", 0)]; + let right = vec![Column::new("a", 0)]; + let on = &[(Column::new("a", 0), Column::new("a", 0))]; assert!(check(&left, &right, on).is_err()); } @@ -185,18 +199,18 @@ mod tests { #[test] fn check_collision() { // column "a" would appear both in left and right - let left = vec!["a", "c"]; - let right = vec!["a", "b"]; - let on = &[("a", "b")]; + let left = vec![Column::new("a", 0), Column::new("c", 1)]; + let right = vec![Column::new("a", 0), Column::new("b", 1)]; + let on = &[(Column::new("a", 0), Column::new("b", 1))]; assert!(check(&left, &right, on).is_err()); } #[test] fn check_in_right() { - let left = vec!["a", "c"]; - let right = vec!["b"]; - let on = &[("a", "b")]; + let left = vec![Column::new("a", 0), Column::new("c", 1)]; + let right = vec![Column::new("b", 0)]; + let on = &[(Column::new("a", 0), Column::new("b", 0))]; assert!(check(&left, &right, on).is_ok()); } diff --git a/datafusion/src/physical_plan/mod.rs b/datafusion/src/physical_plan/mod.rs index 50c30a57b5fe..7b26d7b3ab6e 100644 --- a/datafusion/src/physical_plan/mod.rs +++ b/datafusion/src/physical_plan/mod.rs @@ -211,9 +211,9 @@ pub trait ExecutionPlan: Debug + Send + Sync { /// let displayable_plan = displayable(physical_plan.as_ref()); /// let plan_string = format!("{}", displayable_plan.indent()); /// -/// assert_eq!("ProjectionExec: expr=[a]\ +/// assert_eq!("ProjectionExec: expr=[a@0 as a]\ /// \n CoalesceBatchesExec: target_batch_size=4096\ -/// \n FilterExec: a < 5\ +/// \n FilterExec: a@0 < 5\ /// \n RepartitionExec: partitioning=RoundRobinBatch(3)\ /// \n CsvExec: source=Path(tests/example.csv: [tests/example.csv]), has_header=true", /// plan_string.trim()); diff --git a/datafusion/src/physical_plan/parquet.rs b/datafusion/src/physical_plan/parquet.rs index 2bea94aee1e5..3d20a9bf98c1 100644 --- a/datafusion/src/physical_plan/parquet.rs +++ b/datafusion/src/physical_plan/parquet.rs @@ -25,7 +25,7 @@ use std::{any::Any, convert::TryInto}; use crate::{ error::{DataFusionError, Result}, - logical_plan::Expr, + logical_plan::{Column, Expr}, physical_optimizer::pruning::{PruningPredicate, PruningStatistics}, physical_plan::{ common, DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream, @@ -497,7 +497,7 @@ macro_rules! get_statistic { // Extract the min or max value calling `func` or `bytes_func` on the ParquetStatistics as appropriate macro_rules! get_min_max_values { ($self:expr, $column:expr, $func:ident, $bytes_func:ident) => {{ - let (column_index, field) = if let Some((v, f)) = $self.parquet_schema.column_with_name($column) { + let (column_index, field) = if let Some((v, f)) = $self.parquet_schema.column_with_name(&$column.name) { (v, f) } else { // Named column was not present @@ -532,11 +532,11 @@ macro_rules! get_min_max_values { } impl<'a> PruningStatistics for RowGroupPruningStatistics<'a> { - fn min_values(&self, column: &str) -> Option { + fn min_values(&self, column: &Column) -> Option { get_min_max_values!(self, column, min, min_bytes) } - fn max_values(&self, column: &str) -> Option { + fn max_values(&self, column: &Column) -> Option { get_min_max_values!(self, column, max, max_bytes) } @@ -593,7 +593,6 @@ fn read_files( loop { match batch_reader.next() { Some(Ok(batch)) => { - //println!("ParquetExec got new batch from {}", filename); total_rows += batch.num_rows(); send_result(&response_tx, Ok(batch))?; if limit.map(|l| total_rows >= l).unwrap_or(false) { diff --git a/datafusion/src/physical_plan/planner.rs b/datafusion/src/physical_plan/planner.rs index af0e60f2194c..a4c20a7f60eb 100644 --- a/datafusion/src/physical_plan/planner.rs +++ b/datafusion/src/physical_plan/planner.rs @@ -56,6 +56,121 @@ use expressions::col; use log::debug; use std::sync::Arc; +fn create_function_physical_name( + fun: &str, + distinct: bool, + args: &[Expr], + input_schema: &DFSchema, +) -> Result { + let names: Vec = args + .iter() + .map(|e| physical_name(e, input_schema)) + .collect::>()?; + + let distinct_str = match distinct { + true => "DISTINCT ", + false => "", + }; + Ok(format!("{}({}{})", fun, distinct_str, names.join(","))) +} + +fn physical_name(e: &Expr, input_schema: &DFSchema) -> Result { + match e { + Expr::Column(c) => Ok(c.name.clone()), + Expr::Alias(_, name) => Ok(name.clone()), + Expr::ScalarVariable(variable_names) => Ok(variable_names.join(".")), + Expr::Literal(value) => Ok(format!("{:?}", value)), + Expr::BinaryExpr { left, op, right } => { + let left = physical_name(left, input_schema)?; + let right = physical_name(right, input_schema)?; + Ok(format!("{} {:?} {}", left, op, right)) + } + Expr::Case { + expr, + when_then_expr, + else_expr, + } => { + let mut name = "CASE ".to_string(); + if let Some(e) = expr { + name += &format!("{:?} ", e); + } + for (w, t) in when_then_expr { + name += &format!("WHEN {:?} THEN {:?} ", w, t); + } + if let Some(e) = else_expr { + name += &format!("ELSE {:?} ", e); + } + name += "END"; + Ok(name) + } + Expr::Cast { expr, data_type } => { + let expr = physical_name(expr, input_schema)?; + Ok(format!("CAST({} AS {:?})", expr, data_type)) + } + Expr::TryCast { expr, data_type } => { + let expr = physical_name(expr, input_schema)?; + Ok(format!("TRY_CAST({} AS {:?})", expr, data_type)) + } + Expr::Not(expr) => { + let expr = physical_name(expr, input_schema)?; + Ok(format!("NOT {}", expr)) + } + Expr::Negative(expr) => { + let expr = physical_name(expr, input_schema)?; + Ok(format!("(- {})", expr)) + } + Expr::IsNull(expr) => { + let expr = physical_name(expr, input_schema)?; + Ok(format!("{} IS NULL", expr)) + } + Expr::IsNotNull(expr) => { + let expr = physical_name(expr, input_schema)?; + Ok(format!("{} IS NOT NULL", expr)) + } + Expr::ScalarFunction { fun, args, .. } => { + create_function_physical_name(&fun.to_string(), false, args, input_schema) + } + Expr::ScalarUDF { fun, args, .. } => { + create_function_physical_name(&fun.name, false, args, input_schema) + } + Expr::WindowFunction { fun, args, .. } => { + create_function_physical_name(&fun.to_string(), false, args, input_schema) + } + Expr::AggregateFunction { + fun, + distinct, + args, + .. + } => { + create_function_physical_name(&fun.to_string(), *distinct, args, input_schema) + } + Expr::AggregateUDF { fun, args } => { + let mut names = Vec::with_capacity(args.len()); + for e in args { + names.push(physical_name(e, input_schema)?); + } + Ok(format!("{}({})", fun.name, names.join(","))) + } + Expr::InList { + expr, + list, + negated, + } => { + let expr = physical_name(expr, input_schema)?; + let list = list.iter().map(|expr| physical_name(expr, input_schema)); + if *negated { + Ok(format!("{} NOT IN ({:?})", expr, list)) + } else { + Ok(format!("{} IN ({:?})", expr, list)) + } + } + other => Err(DataFusionError::NotImplemented(format!( + "Cannot derive physical field name for logical expression {:?}", + other + ))), + } +} + /// This trait exposes the ability to plan an [`ExecutionPlan`] out of a [`LogicalPlan`]. pub trait ExtensionPlanner { /// Create a physical plan for a [`UserDefinedLogicalNode`]. @@ -150,10 +265,8 @@ impl DefaultPhysicalPlanner { } let input_exec = self.create_initial_plan(input, ctx_state)?; - let input_schema = input_exec.schema(); - + let physical_input_schema = input_exec.schema(); let logical_input_schema = input.as_ref().schema(); - let physical_input_schema = input_exec.as_ref().schema(); let window_expr = window_expr .iter() @@ -170,7 +283,7 @@ impl DefaultPhysicalPlanner { Ok(Arc::new(WindowAggExec::try_new( window_expr, input_exec.clone(), - input_schema, + physical_input_schema, )?)) } LogicalPlan::Aggregate { @@ -181,8 +294,7 @@ impl DefaultPhysicalPlanner { } => { // Initially need to perform the aggregate and then merge the partitions let input_exec = self.create_initial_plan(input, ctx_state)?; - let input_schema = input_exec.schema(); - let physical_input_schema = input_exec.as_ref().schema(); + let physical_input_schema = input_exec.schema(); let logical_input_schema = input.as_ref().schema(); let groups = group_expr @@ -191,10 +303,11 @@ impl DefaultPhysicalPlanner { tuple_err(( self.create_physical_expr( e, + logical_input_schema, &physical_input_schema, ctx_state, ), - e.name(logical_input_schema), + physical_name(e, logical_input_schema), )) }) .collect::>>()?; @@ -215,11 +328,13 @@ impl DefaultPhysicalPlanner { groups.clone(), aggregates.clone(), input_exec, - input_schema.clone(), + physical_input_schema.clone(), )?); - let final_group: Vec> = - (0..groups.len()).map(|i| col(&groups[i].1)).collect(); + // update group column indices based on partial aggregate plan evaluation + let final_group: Vec> = (0..groups.len()) + .map(|i| col(&groups[i].1, &initial_aggr.schema())) + .collect::>()?; // TODO: dictionary type not yet supported in Hash Repartition let contains_dict = groups @@ -261,31 +376,74 @@ impl DefaultPhysicalPlanner { .collect(), aggregates, initial_aggr, - input_schema, + physical_input_schema.clone(), )?)) } LogicalPlan::Projection { input, expr, .. } => { let input_exec = self.create_initial_plan(input, ctx_state)?; let input_schema = input.as_ref().schema(); - let runtime_expr = expr + + let physical_exprs = expr .iter() .map(|e| { + // For projections, SQL planner and logical plan builder may convert user + // provided expressions into logical Column expressions if their results + // are already provided from the input plans. Because we work with + // qualified columns in logical plane, derived columns involve operators or + // functions will contain qualifers as well. This will result in logical + // columns with names like `SUM(t1.c1)`, `t1.c1 + t1.c2`, etc. + // + // If we run these logical columns through physical_name function, we will + // get physical names with column qualifiers, which violates Datafusion's + // field name semantics. To account for this, we need to derive the + // physical name from physical input instead. + // + // This depends on the invariant that logical schema field index MUST match + // with physical schema field index. + let physical_name = if let Expr::Column(col) = e { + match input_schema.index_of_column(col) { + Ok(idx) => { + // index physical field using logical field index + Ok(input_exec.schema().field(idx).name().to_string()) + } + // logical column is not a derived column, safe to pass along to + // physical_name + Err(_) => physical_name(e, input_schema), + } + } else { + physical_name(e, input_schema) + }; + tuple_err(( - self.create_physical_expr(e, &input_exec.schema(), ctx_state), - e.name(input_schema), + self.create_physical_expr( + e, + input_schema, + &input_exec.schema(), + ctx_state, + ), + physical_name, )) }) .collect::>>()?; - Ok(Arc::new(ProjectionExec::try_new(runtime_expr, input_exec)?)) + + Ok(Arc::new(ProjectionExec::try_new( + physical_exprs, + input_exec, + )?)) } LogicalPlan::Filter { input, predicate, .. } => { - let input = self.create_initial_plan(input, ctx_state)?; - let input_schema = input.as_ref().schema(); - let runtime_expr = - self.create_physical_expr(predicate, &input_schema, ctx_state)?; - Ok(Arc::new(FilterExec::try_new(runtime_expr, input)?)) + let physical_input = self.create_initial_plan(input, ctx_state)?; + let input_schema = physical_input.as_ref().schema(); + let input_dfschema = input.as_ref().schema(); + let runtime_expr = self.create_physical_expr( + predicate, + input_dfschema, + &input_schema, + ctx_state, + )?; + Ok(Arc::new(FilterExec::try_new(runtime_expr, physical_input)?)) } LogicalPlan::Union { inputs, .. } => { let physical_plans = inputs @@ -298,8 +456,9 @@ impl DefaultPhysicalPlanner { input, partitioning_scheme, } => { - let input = self.create_initial_plan(input, ctx_state)?; - let input_schema = input.schema(); + let physical_input = self.create_initial_plan(input, ctx_state)?; + let input_schema = physical_input.schema(); + let input_dfschema = input.as_ref().schema(); let physical_partitioning = match partitioning_scheme { LogicalPartitioning::RoundRobinBatch(n) => { Partitioning::RoundRobinBatch(*n) @@ -308,20 +467,26 @@ impl DefaultPhysicalPlanner { let runtime_expr = expr .iter() .map(|e| { - self.create_physical_expr(e, &input_schema, ctx_state) + self.create_physical_expr( + e, + input_dfschema, + &input_schema, + ctx_state, + ) }) .collect::>>()?; Partitioning::Hash(runtime_expr, *n) } }; Ok(Arc::new(RepartitionExec::try_new( - input, + physical_input, physical_partitioning, )?)) } LogicalPlan::Sort { expr, input, .. } => { - let input = self.create_initial_plan(input, ctx_state)?; - let input_schema = input.as_ref().schema(); + let physical_input = self.create_initial_plan(input, ctx_state)?; + let input_schema = physical_input.as_ref().schema(); + let input_dfschema = input.as_ref().schema(); let sort_expr = expr .iter() @@ -332,6 +497,7 @@ impl DefaultPhysicalPlanner { nulls_first, } => self.create_physical_sort_expr( expr, + input_dfschema, &input_schema, SortOptions { descending: !*asc, @@ -345,7 +511,7 @@ impl DefaultPhysicalPlanner { }) .collect::>>()?; - Ok(Arc::new(SortExec::try_new(sort_expr, input)?)) + Ok(Arc::new(SortExec::try_new(sort_expr, physical_input)?)) } LogicalPlan::Join { left, @@ -354,8 +520,10 @@ impl DefaultPhysicalPlanner { join_type, .. } => { - let left = self.create_initial_plan(left, ctx_state)?; - let right = self.create_initial_plan(right, ctx_state)?; + let left_df_schema = left.schema(); + let physical_left = self.create_initial_plan(left, ctx_state)?; + let right_df_schema = right.schema(); + let physical_right = self.create_initial_plan(right, ctx_state)?; let physical_join_type = match join_type { JoinType::Inner => hash_utils::JoinType::Inner, JoinType::Left => hash_utils::JoinType::Left, @@ -364,30 +532,47 @@ impl DefaultPhysicalPlanner { JoinType::Semi => hash_utils::JoinType::Semi, JoinType::Anti => hash_utils::JoinType::Anti, }; + let join_on = keys + .iter() + .map(|(l, r)| { + Ok(( + Column::new(&l.name, left_df_schema.index_of_column(l)?), + Column::new(&r.name, right_df_schema.index_of_column(r)?), + )) + }) + .collect::>()?; + if ctx_state.config.concurrency > 1 && ctx_state.config.repartition_joins { - let left_expr = keys.iter().map(|x| col(&x.0)).collect(); - let right_expr = keys.iter().map(|x| col(&x.1)).collect(); + let (left_expr, right_expr) = join_on + .iter() + .map(|(l, r)| { + ( + Arc::new(l.clone()) as Arc, + Arc::new(r.clone()) as Arc, + ) + }) + .unzip(); // Use hash partition by default to parallelize hash joins Ok(Arc::new(HashJoinExec::try_new( Arc::new(RepartitionExec::try_new( - left, + physical_left, Partitioning::Hash(left_expr, ctx_state.config.concurrency), )?), Arc::new(RepartitionExec::try_new( - right, + physical_right, Partitioning::Hash(right_expr, ctx_state.config.concurrency), )?), - keys, + join_on, &physical_join_type, PartitionMode::Partitioned, )?)) } else { Ok(Arc::new(HashJoinExec::try_new( - left, - right, - keys, + physical_left, + physical_right, + join_on, &physical_join_type, PartitionMode::CollectLeft, )?)) @@ -476,10 +661,10 @@ impl DefaultPhysicalPlanner { "No installed planner was able to convert the custom node to an execution plan: {:?}", node )))?; - // Ensure the ExecutionPlan's schema matches the + // Ensure the ExecutionPlan's schema matches the // declared logical schema to catch and warn about // logic errors when creating user defined plans. - if plan.schema() != node.schema().as_ref().to_owned().into() { + if !node.schema().matches_arrow_schema(&plan.schema()) { Err(DataFusionError::Plan(format!( "Extension planner for {:?} created an ExecutionPlan with mismatched schema. \ LogicalPlan schema: {:?}, ExecutionPlan schema: {:?}", @@ -496,17 +681,20 @@ impl DefaultPhysicalPlanner { pub fn create_physical_expr( &self, e: &Expr, + input_dfschema: &DFSchema, input_schema: &Schema, ctx_state: &ExecutionContextState, ) -> Result> { match e { - Expr::Alias(expr, ..) => { - Ok(self.create_physical_expr(expr, input_schema, ctx_state)?) - } - Expr::Column(name) => { - // check that name exists - input_schema.field_with_name(name)?; - Ok(Arc::new(Column::new(name))) + Expr::Alias(expr, ..) => Ok(self.create_physical_expr( + expr, + input_dfschema, + input_schema, + ctx_state, + )?), + Expr::Column(c) => { + let idx = input_dfschema.index_of_column(c)?; + Ok(Arc::new(Column::new(&c.name, idx))) } Expr::Literal(value) => Ok(Arc::new(Literal::new(value.clone()))), Expr::ScalarVariable(variable_names) => { @@ -535,8 +723,18 @@ impl DefaultPhysicalPlanner { } } Expr::BinaryExpr { left, op, right } => { - let lhs = self.create_physical_expr(left, input_schema, ctx_state)?; - let rhs = self.create_physical_expr(right, input_schema, ctx_state)?; + let lhs = self.create_physical_expr( + left, + input_dfschema, + input_schema, + ctx_state, + )?; + let rhs = self.create_physical_expr( + right, + input_dfschema, + input_schema, + ctx_state, + )?; binary(lhs, *op, rhs, input_schema) } Expr::Case { @@ -548,6 +746,7 @@ impl DefaultPhysicalPlanner { let expr: Option> = if let Some(e) = expr { Some(self.create_physical_expr( e.as_ref(), + input_dfschema, input_schema, ctx_state, )?) @@ -557,13 +756,23 @@ impl DefaultPhysicalPlanner { let when_expr = when_then_expr .iter() .map(|(w, _)| { - self.create_physical_expr(w.as_ref(), input_schema, ctx_state) + self.create_physical_expr( + w.as_ref(), + input_dfschema, + input_schema, + ctx_state, + ) }) .collect::>>()?; let then_expr = when_then_expr .iter() .map(|(_, t)| { - self.create_physical_expr(t.as_ref(), input_schema, ctx_state) + self.create_physical_expr( + t.as_ref(), + input_dfschema, + input_schema, + ctx_state, + ) }) .collect::>>()?; let when_then_expr: Vec<(Arc, Arc)> = @@ -576,6 +785,7 @@ impl DefaultPhysicalPlanner { { Some(self.create_physical_expr( e.as_ref(), + input_dfschema, input_schema, ctx_state, )?) @@ -589,35 +799,43 @@ impl DefaultPhysicalPlanner { )?)) } Expr::Cast { expr, data_type } => expressions::cast( - self.create_physical_expr(expr, input_schema, ctx_state)?, + self.create_physical_expr(expr, input_dfschema, input_schema, ctx_state)?, input_schema, data_type.clone(), ), Expr::TryCast { expr, data_type } => expressions::try_cast( - self.create_physical_expr(expr, input_schema, ctx_state)?, + self.create_physical_expr(expr, input_dfschema, input_schema, ctx_state)?, input_schema, data_type.clone(), ), Expr::Not(expr) => expressions::not( - self.create_physical_expr(expr, input_schema, ctx_state)?, + self.create_physical_expr(expr, input_dfschema, input_schema, ctx_state)?, input_schema, ), Expr::Negative(expr) => expressions::negative( - self.create_physical_expr(expr, input_schema, ctx_state)?, + self.create_physical_expr(expr, input_dfschema, input_schema, ctx_state)?, input_schema, ), Expr::IsNull(expr) => expressions::is_null(self.create_physical_expr( expr, + input_dfschema, input_schema, ctx_state, )?), Expr::IsNotNull(expr) => expressions::is_not_null( - self.create_physical_expr(expr, input_schema, ctx_state)?, + self.create_physical_expr(expr, input_dfschema, input_schema, ctx_state)?, ), Expr::ScalarFunction { fun, args } => { let physical_args = args .iter() - .map(|e| self.create_physical_expr(e, input_schema, ctx_state)) + .map(|e| { + self.create_physical_expr( + e, + input_dfschema, + input_schema, + ctx_state, + ) + }) .collect::>>()?; functions::create_physical_expr( fun, @@ -631,6 +849,7 @@ impl DefaultPhysicalPlanner { for e in args { physical_args.push(self.create_physical_expr( e, + input_dfschema, input_schema, ctx_state, )?); @@ -648,11 +867,24 @@ impl DefaultPhysicalPlanner { low, high, } => { - let value_expr = - self.create_physical_expr(expr, input_schema, ctx_state)?; - let low_expr = self.create_physical_expr(low, input_schema, ctx_state)?; - let high_expr = - self.create_physical_expr(high, input_schema, ctx_state)?; + let value_expr = self.create_physical_expr( + expr, + input_dfschema, + input_schema, + ctx_state, + )?; + let low_expr = self.create_physical_expr( + low, + input_dfschema, + input_schema, + ctx_state, + )?; + let high_expr = self.create_physical_expr( + high, + input_dfschema, + input_schema, + ctx_state, + )?; // rewrite the between into the two binary operators let binary_expr = binary( @@ -677,44 +909,54 @@ impl DefaultPhysicalPlanner { Ok(expressions::lit(ScalarValue::Boolean(None))) } _ => { - let value_expr = - self.create_physical_expr(expr, input_schema, ctx_state)?; + let value_expr = self.create_physical_expr( + expr, + input_dfschema, + input_schema, + ctx_state, + )?; let value_expr_data_type = value_expr.data_type(input_schema)?; - let list_exprs = - list.iter() - .map(|expr| match expr { - Expr::Literal(ScalarValue::Utf8(None)) => self - .create_physical_expr(expr, input_schema, ctx_state), - _ => { - let list_expr = self.create_physical_expr( - expr, + let list_exprs = list + .iter() + .map(|expr| match expr { + Expr::Literal(ScalarValue::Utf8(None)) => self + .create_physical_expr( + expr, + input_dfschema, + input_schema, + ctx_state, + ), + _ => { + let list_expr = self.create_physical_expr( + expr, + input_dfschema, + input_schema, + ctx_state, + )?; + let list_expr_data_type = + list_expr.data_type(input_schema)?; + + if list_expr_data_type == value_expr_data_type { + Ok(list_expr) + } else if can_cast_types( + &list_expr_data_type, + &value_expr_data_type, + ) { + expressions::cast( + list_expr, input_schema, - ctx_state, - )?; - let list_expr_data_type = - list_expr.data_type(input_schema)?; - - if list_expr_data_type == value_expr_data_type { - Ok(list_expr) - } else if can_cast_types( - &list_expr_data_type, - &value_expr_data_type, - ) { - expressions::cast( - list_expr, - input_schema, - value_expr.data_type(input_schema)?, - ) - } else { - Err(DataFusionError::Plan(format!( - "Unsupported CAST from {:?} to {:?}", - list_expr_data_type, value_expr_data_type - ))) - } + value_expr.data_type(input_schema)?, + ) + } else { + Err(DataFusionError::Plan(format!( + "Unsupported CAST from {:?} to {:?}", + list_expr_data_type, value_expr_data_type + ))) } - }) - .collect::>>()?; + } + }) + .collect::>>()?; expressions::in_list(value_expr, list_exprs, negated) } @@ -731,6 +973,7 @@ impl DefaultPhysicalPlanner { &self, e: &Expr, name: String, + logical_input_schema: &DFSchema, physical_input_schema: &Schema, ctx_state: &ExecutionContextState, ) -> Result> { @@ -745,13 +988,23 @@ impl DefaultPhysicalPlanner { let args = args .iter() .map(|e| { - self.create_physical_expr(e, physical_input_schema, ctx_state) + self.create_physical_expr( + e, + logical_input_schema, + physical_input_schema, + ctx_state, + ) }) .collect::>>()?; let partition_by = partition_by .iter() .map(|e| { - self.create_physical_expr(e, physical_input_schema, ctx_state) + self.create_physical_expr( + e, + logical_input_schema, + physical_input_schema, + ctx_state, + ) }) .collect::>>()?; let order_by = order_by @@ -763,6 +1016,7 @@ impl DefaultPhysicalPlanner { nulls_first, } => self.create_physical_sort_expr( expr, + logical_input_schema, physical_input_schema, SortOptions { descending: !*asc, @@ -809,9 +1063,15 @@ impl DefaultPhysicalPlanner { // unpack aliased logical expressions, e.g. "sum(col) over () as total" let (name, e) = match e { Expr::Alias(sub_expr, alias) => (alias.clone(), sub_expr.as_ref()), - _ => (e.name(logical_input_schema)?, e), + _ => (physical_name(e, logical_input_schema)?, e), }; - self.create_window_expr_with_name(e, name, physical_input_schema, ctx_state) + self.create_window_expr_with_name( + e, + name, + logical_input_schema, + physical_input_schema, + ctx_state, + ) } /// Create an aggregate expression with a name from a logical expression @@ -819,6 +1079,7 @@ impl DefaultPhysicalPlanner { &self, e: &Expr, name: String, + logical_input_schema: &DFSchema, physical_input_schema: &Schema, ctx_state: &ExecutionContextState, ) -> Result> { @@ -832,7 +1093,12 @@ impl DefaultPhysicalPlanner { let args = args .iter() .map(|e| { - self.create_physical_expr(e, physical_input_schema, ctx_state) + self.create_physical_expr( + e, + logical_input_schema, + physical_input_schema, + ctx_state, + ) }) .collect::>>()?; aggregates::create_aggregate_expr( @@ -847,7 +1113,12 @@ impl DefaultPhysicalPlanner { let args = args .iter() .map(|e| { - self.create_physical_expr(e, physical_input_schema, ctx_state) + self.create_physical_expr( + e, + logical_input_schema, + physical_input_schema, + ctx_state, + ) }) .collect::>>()?; @@ -871,21 +1142,34 @@ impl DefaultPhysicalPlanner { // unpack aliased logical expressions, e.g. "sum(col) as total" let (name, e) = match e { Expr::Alias(sub_expr, alias) => (alias.clone(), sub_expr.as_ref()), - _ => (e.name(logical_input_schema)?, e), + _ => (physical_name(e, logical_input_schema)?, e), }; - self.create_aggregate_expr_with_name(e, name, physical_input_schema, ctx_state) + + self.create_aggregate_expr_with_name( + e, + name, + logical_input_schema, + physical_input_schema, + ctx_state, + ) } /// Create a physical sort expression from a logical expression pub fn create_physical_sort_expr( &self, e: &Expr, + input_dfschema: &DFSchema, input_schema: &Schema, options: SortOptions, ctx_state: &ExecutionContextState, ) -> Result { Ok(PhysicalSortExpr { - expr: self.create_physical_expr(e, input_schema, ctx_state)?, + expr: self.create_physical_expr( + e, + input_dfschema, + input_schema, + ctx_state, + )?, options, }) } @@ -913,6 +1197,7 @@ mod tests { use arrow::datatypes::{DataType, Field, SchemaRef}; use async_trait::async_trait; use fmt::Debug; + use std::convert::TryFrom; use std::{any::Any, fmt}; fn make_ctx_state() -> ExecutionContextState { @@ -945,7 +1230,7 @@ mod tests { // verify that the plan correctly casts u8 to i64 // the cast here is implicit so has CastOptions with safe=true - let expected = "BinaryExpr { left: Column { name: \"c7\" }, op: Lt, right: TryCastExpr { expr: Literal { value: UInt8(5) }, cast_type: Int64 } }"; + let expected = "BinaryExpr { left: Column { name: \"c7\", index: 6 }, op: Lt, right: TryCastExpr { expr: Literal { value: UInt8(5) }, cast_type: Int64 } }"; assert!(format!("{:?}", plan).contains(expected)); Ok(()) @@ -954,12 +1239,17 @@ mod tests { #[test] fn test_create_not() -> Result<()> { let schema = Schema::new(vec![Field::new("a", DataType::Boolean, true)]); + let dfschema = DFSchema::try_from(schema.clone())?; let planner = DefaultPhysicalPlanner::default(); - let expr = - planner.create_physical_expr(&col("a").not(), &schema, &make_ctx_state())?; - let expected = expressions::not(expressions::col("a"), &schema)?; + let expr = planner.create_physical_expr( + &col("a").not(), + &dfschema, + &schema, + &make_ctx_state(), + )?; + let expected = expressions::not(expressions::col("a", &schema)?, &schema)?; assert_eq!(format!("{:?}", expr), format!("{:?}", expected)); @@ -980,7 +1270,7 @@ mod tests { // c12 is f64, c7 is u8 -> cast c7 to f64 // the cast here is implicit so has CastOptions with safe=true - let expected = "predicate: BinaryExpr { left: TryCastExpr { expr: Column { name: \"c7\" }, cast_type: Float64 }, op: Lt, right: Column { name: \"c12\" } }"; + let expected = "predicate: BinaryExpr { left: TryCastExpr { expr: Column { name: \"c7\", index: 6 }, cast_type: Float64 }, op: Lt, right: Column { name: \"c12\", index: 11 } }"; assert!(format!("{:?}", plan).contains(expected)); Ok(()) } @@ -1105,8 +1395,7 @@ mod tests { .build()?; let execution_plan = plan(&logical_plan)?; // verify that the plan correctly adds cast from Int64(1) to Utf8 - let expected = "InListExpr { expr: Column { name: \"c1\" }, list: [Literal { value: Utf8(\"a\") }, CastExpr { expr: Literal { value: Int64(1) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }], negated: false }"; - println!("{:?}", execution_plan); + let expected = "InListExpr { expr: Column { name: \"c1\", index: 0 }, list: [Literal { value: Utf8(\"a\") }, CastExpr { expr: Literal { value: Int64(1) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }], negated: false }"; assert!(format!("{:?}", execution_plan).contains(expected)); // expression: "a in (true, 'a')" diff --git a/datafusion/src/physical_plan/projection.rs b/datafusion/src/physical_plan/projection.rs index d4c0459c211b..5110e5b5a879 100644 --- a/datafusion/src/physical_plan/projection.rs +++ b/datafusion/src/physical_plan/projection.rs @@ -233,8 +233,10 @@ mod tests { )?; // pick column c1 and name it column c1 in the output schema - let projection = - ProjectionExec::try_new(vec![(col("c1"), "c1".to_string())], Arc::new(csv))?; + let projection = ProjectionExec::try_new( + vec![(col("c1", &schema)?, "c1".to_string())], + Arc::new(csv), + )?; let mut partition_count = 0; let mut row_count = 0; diff --git a/datafusion/src/physical_plan/repartition.rs b/datafusion/src/physical_plan/repartition.rs index a7b17c4161b0..e67e4c2d4477 100644 --- a/datafusion/src/physical_plan/repartition.rs +++ b/datafusion/src/physical_plan/repartition.rs @@ -435,7 +435,7 @@ mod tests { use super::*; use crate::{ assert_batches_sorted_eq, - physical_plan::memory::MemoryExec, + physical_plan::{expressions::col, memory::MemoryExec}, test::exec::{BarrierExec, ErrorExec, MockExec}, }; use arrow::datatypes::{DataType, Field, Schema}; @@ -513,12 +513,7 @@ mod tests { let output_partitions = repartition( &schema, partitions, - Partitioning::Hash( - vec![Arc::new(crate::physical_plan::expressions::Column::new( - "c0", - ))], - 8, - ), + Partitioning::Hash(vec![col("c0", &schema)?], 8), ) .await?; @@ -761,6 +756,7 @@ mod tests { partitioning: Partitioning::Hash( vec![Arc::new(crate::physical_plan::expressions::Column::new( "my_awesome_field", + 0, ))], 2, ), diff --git a/datafusion/src/physical_plan/sort.rs b/datafusion/src/physical_plan/sort.rs index 437519a7d2a2..365097822cc7 100644 --- a/datafusion/src/physical_plan/sort.rs +++ b/datafusion/src/physical_plan/sort.rs @@ -343,17 +343,17 @@ mod tests { vec![ // c1 string column PhysicalSortExpr { - expr: col("c1"), + expr: col("c1", &schema)?, options: SortOptions::default(), }, // c2 uin32 column PhysicalSortExpr { - expr: col("c2"), + expr: col("c2", &schema)?, options: SortOptions::default(), }, // c7 uin8 column PhysicalSortExpr { - expr: col("c7"), + expr: col("c7", &schema)?, options: SortOptions::default(), }, ], @@ -417,14 +417,14 @@ mod tests { let sort_exec = Arc::new(SortExec::try_new( vec![ PhysicalSortExpr { - expr: col("a"), + expr: col("a", &schema)?, options: SortOptions { descending: true, nulls_first: true, }, }, PhysicalSortExpr { - expr: col("b"), + expr: col("b", &schema)?, options: SortOptions { descending: false, nulls_first: false, diff --git a/datafusion/src/physical_plan/sort_preserving_merge.rs b/datafusion/src/physical_plan/sort_preserving_merge.rs index c39acc474d31..b8ca97cc5974 100644 --- a/datafusion/src/physical_plan/sort_preserving_merge.rs +++ b/datafusion/src/physical_plan/sort_preserving_merge.rs @@ -579,21 +579,18 @@ mod tests { let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap(); let schema = b1.schema(); + let sort = vec![ + PhysicalSortExpr { + expr: col("b", &schema).unwrap(), + options: Default::default(), + }, + PhysicalSortExpr { + expr: col("c", &schema).unwrap(), + options: Default::default(), + }, + ]; let exec = MemoryExec::try_new(&[vec![b1], vec![b2]], schema, None).unwrap(); - let merge = Arc::new(SortPreservingMergeExec::new( - vec![ - PhysicalSortExpr { - expr: col("b"), - options: Default::default(), - }, - PhysicalSortExpr { - expr: col("c"), - options: Default::default(), - }, - ], - Arc::new(exec), - 1024, - )); + let merge = Arc::new(SortPreservingMergeExec::new(sort, Arc::new(exec), 1024)); let collected = collect(merge).await.unwrap(); assert_eq!(collected.len(), 1); @@ -668,18 +665,18 @@ mod tests { let sort = vec![ PhysicalSortExpr { - expr: col("c1"), + expr: col("c1", &schema).unwrap(), options: SortOptions { descending: true, nulls_first: true, }, }, PhysicalSortExpr { - expr: col("c2"), + expr: col("c2", &schema).unwrap(), options: Default::default(), }, PhysicalSortExpr { - expr: col("c7"), + expr: col("c7", &schema).unwrap(), options: SortOptions::default(), }, ]; @@ -744,25 +741,26 @@ mod tests { #[tokio::test] async fn test_partition_sort_streaming_input() { + let schema = test::aggr_test_schema(); let sort = vec![ // uint8 PhysicalSortExpr { - expr: col("c7"), + expr: col("c7", &schema).unwrap(), options: Default::default(), }, // int16 PhysicalSortExpr { - expr: col("c4"), + expr: col("c4", &schema).unwrap(), options: Default::default(), }, // utf-8 PhysicalSortExpr { - expr: col("c1"), + expr: col("c1", &schema).unwrap(), options: SortOptions::default(), }, // utf-8 PhysicalSortExpr { - expr: col("c13"), + expr: col("c13", &schema).unwrap(), options: SortOptions::default(), }, ]; @@ -782,15 +780,17 @@ mod tests { #[tokio::test] async fn test_partition_sort_streaming_input_output() { + let schema = test::aggr_test_schema(); + let sort = vec![ // float64 PhysicalSortExpr { - expr: col("c12"), + expr: col("c12", &schema).unwrap(), options: Default::default(), }, // utf-8 PhysicalSortExpr { - expr: col("c13"), + expr: col("c13", &schema).unwrap(), options: Default::default(), }, ]; @@ -850,27 +850,24 @@ mod tests { let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap(); let schema = b1.schema(); - let exec = MemoryExec::try_new(&[vec![b1], vec![b2]], schema, None).unwrap(); - let merge = Arc::new(SortPreservingMergeExec::new( - vec![ - PhysicalSortExpr { - expr: col("b"), - options: SortOptions { - descending: false, - nulls_first: true, - }, + let sort = vec![ + PhysicalSortExpr { + expr: col("b", &schema).unwrap(), + options: SortOptions { + descending: false, + nulls_first: true, }, - PhysicalSortExpr { - expr: col("c"), - options: SortOptions { - descending: false, - nulls_first: false, - }, + }, + PhysicalSortExpr { + expr: col("c", &schema).unwrap(), + options: SortOptions { + descending: false, + nulls_first: false, }, - ], - Arc::new(exec), - 1024, - )); + }, + ]; + let exec = MemoryExec::try_new(&[vec![b1], vec![b2]], schema, None).unwrap(); + let merge = Arc::new(SortPreservingMergeExec::new(sort, Arc::new(exec), 1024)); let collected = collect(merge).await.unwrap(); assert_eq!(collected.len(), 1); @@ -898,8 +895,9 @@ mod tests { #[tokio::test] async fn test_async() { + let schema = test::aggr_test_schema(); let sort = vec![PhysicalSortExpr { - expr: col("c7"), + expr: col("c7", &schema).unwrap(), options: SortOptions::default(), }]; diff --git a/datafusion/src/physical_plan/type_coercion.rs b/datafusion/src/physical_plan/type_coercion.rs index fe87ecda872c..ffd8f20064f7 100644 --- a/datafusion/src/physical_plan/type_coercion.rs +++ b/datafusion/src/physical_plan/type_coercion.rs @@ -267,7 +267,9 @@ mod tests { let expressions = |t: Vec, schema| -> Result> { t.iter() .enumerate() - .map(|(i, t)| try_cast(col(&format!("c{}", i)), &schema, t.clone())) + .map(|(i, t)| { + try_cast(col(&format!("c{}", i), &schema)?, &schema, t.clone()) + }) .collect::>>() }; diff --git a/datafusion/src/physical_plan/windows.rs b/datafusion/src/physical_plan/windows.rs index 466cc51b447d..a214ef17a9f8 100644 --- a/datafusion/src/physical_plan/windows.rs +++ b/datafusion/src/physical_plan/windows.rs @@ -369,7 +369,7 @@ impl WindowAggExec { input: Arc, input_schema: SchemaRef, ) -> Result { - let schema = create_schema(&input.schema(), &window_expr)?; + let schema = create_schema(&input_schema, &window_expr)?; let schema = Arc::new(schema); Ok(WindowAggExec { input, @@ -599,7 +599,7 @@ mod tests { vec![create_window_expr( &WindowFunction::AggregateFunction(AggregateFunction::Count), "count".to_owned(), - &[col("c3")], + &[col("c3", &schema)?], &[], &[], Some(WindowFrame::default()), @@ -632,7 +632,7 @@ mod tests { create_window_expr( &WindowFunction::AggregateFunction(AggregateFunction::Count), "count".to_owned(), - &[col("c3")], + &[col("c3", &schema)?], &[], &[], Some(WindowFrame::default()), @@ -641,7 +641,7 @@ mod tests { create_window_expr( &WindowFunction::AggregateFunction(AggregateFunction::Max), "max".to_owned(), - &[col("c3")], + &[col("c3", &schema)?], &[], &[], Some(WindowFrame::default()), @@ -650,7 +650,7 @@ mod tests { create_window_expr( &WindowFunction::AggregateFunction(AggregateFunction::Min), "min".to_owned(), - &[col("c3")], + &[col("c3", &schema)?], &[], &[], Some(WindowFrame::default()), diff --git a/datafusion/src/prelude.rs b/datafusion/src/prelude.rs index e1f1d7b76047..e7ad04e74d1a 100644 --- a/datafusion/src/prelude.rs +++ b/datafusion/src/prelude.rs @@ -32,6 +32,6 @@ pub use crate::logical_plan::{ count, create_udf, in_list, initcap, left, length, lit, lower, lpad, ltrim, max, md5, min, now, octet_length, random, regexp_replace, repeat, replace, reverse, right, rpad, rtrim, sha224, sha256, sha384, sha512, split_part, starts_with, strpos, substr, - sum, to_hex, translate, trim, upper, JoinType, Partitioning, + sum, to_hex, translate, trim, upper, Column, JoinType, Partitioning, }; pub use crate::physical_plan::csv::CsvReadOptions; diff --git a/datafusion/src/sql/planner.rs b/datafusion/src/sql/planner.rs index 547e9afd38d9..7912241329a3 100644 --- a/datafusion/src/sql/planner.rs +++ b/datafusion/src/sql/planner.rs @@ -17,13 +17,18 @@ //! SQL Query Planner (produces logical plan from SQL AST) +use std::collections::HashSet; +use std::str::FromStr; +use std::sync::Arc; +use std::{convert::TryInto, vec}; + use crate::catalog::TableReference; use crate::datasource::TableProvider; use crate::logical_plan::window_frames::{WindowFrame, WindowFrameUnits}; use crate::logical_plan::Expr::Alias; use crate::logical_plan::{ - and, lit, DFSchema, Expr, LogicalPlan, LogicalPlanBuilder, Operator, PlanType, - StringifiedPlan, ToDFSchema, + and, lit, union_with_alias, Column, DFSchema, Expr, LogicalPlan, LogicalPlanBuilder, + Operator, PlanType, StringifiedPlan, ToDFSchema, }; use crate::prelude::JoinType; use crate::scalar::ScalarValue; @@ -47,9 +52,6 @@ use sqlparser::ast::{ use sqlparser::ast::{ColumnDef as SQLColumnDef, ColumnOption}; use sqlparser::ast::{OrderByExpr, Statement}; use sqlparser::parser::ParserError::ParserError; -use std::str::FromStr; -use std::sync::Arc; -use std::{convert::TryInto, vec}; use super::{ parser::DFParser, @@ -163,29 +165,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { (SetOperator::Union, true) => { let left_plan = self.set_expr_to_plan(left.as_ref(), None, ctes)?; let right_plan = self.set_expr_to_plan(right.as_ref(), None, ctes)?; - let inputs = vec![left_plan, right_plan] - .into_iter() - .flat_map(|p| match p { - LogicalPlan::Union { inputs, .. } => inputs, - x => vec![x], - }) - .collect::>(); - if inputs.is_empty() { - return Err(DataFusionError::Plan(format!( - "Empty UNION: {}", - set_expr - ))); - } - if !inputs.iter().all(|s| s.schema() == inputs[0].schema()) { - return Err(DataFusionError::Plan( - "UNION ALL schemas are expected to be the same".to_string(), - )); - } - Ok(LogicalPlan::Union { - schema: inputs[0].schema().clone(), - inputs, - alias, - }) + union_with_alias(left_plan, right_plan, alias) } _ => Err(DataFusionError::NotImplemented(format!( "Only UNION ALL is supported, found {}", @@ -382,7 +362,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ) -> Result { match constraint { JoinConstraint::On(sql_expr) => { - let mut keys: Vec<(String, String)> = vec![]; + let mut keys: Vec<(Column, Column)> = vec![]; let join_schema = left.schema().join(right.schema())?; // parse ON expression @@ -390,20 +370,21 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // extract join keys extract_join_keys(&expr, &mut keys)?; - let left_keys: Vec<&str> = - keys.iter().map(|pair| pair.0.as_str()).collect(); - let right_keys: Vec<&str> = - keys.iter().map(|pair| pair.1.as_str()).collect(); + let (left_keys, right_keys): (Vec, Vec) = + keys.into_iter().unzip(); // return the logical plan representing the join LogicalPlanBuilder::from(left) - .join(right, join_type, &left_keys, &right_keys)? + .join(right, join_type, left_keys, right_keys)? .build() } JoinConstraint::Using(idents) => { - let keys: Vec<&str> = idents.iter().map(|x| x.value.as_str()).collect(); + let keys: Vec = idents + .iter() + .map(|x| Column::from_name(x.value.clone())) + .collect(); LogicalPlanBuilder::from(left) - .join(right, join_type, &keys, &keys)? + .join_using(right, join_type, keys)? .build() } JoinConstraint::Natural => { @@ -489,37 +470,38 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let mut possible_join_keys = vec![]; extract_possible_join_keys(&filter_expr, &mut possible_join_keys)?; - let mut all_join_keys = vec![]; + let mut all_join_keys = HashSet::new(); let mut left = plans[0].clone(); for right in plans.iter().skip(1) { let left_schema = left.schema(); let right_schema = right.schema(); let mut join_keys = vec![]; for (l, r) in &possible_join_keys { - if left_schema.field_with_unqualified_name(l).is_ok() - && right_schema.field_with_unqualified_name(r).is_ok() + if left_schema.field_from_qualified_column(l).is_ok() + && right_schema.field_from_qualified_column(r).is_ok() { - join_keys.push((l.as_str(), r.as_str())); - } else if left_schema.field_with_unqualified_name(r).is_ok() - && right_schema.field_with_unqualified_name(l).is_ok() + join_keys.push((l.clone(), r.clone())); + } else if left_schema.field_from_qualified_column(r).is_ok() + && right_schema.field_from_qualified_column(l).is_ok() { - join_keys.push((r.as_str(), l.as_str())); + join_keys.push((r.clone(), l.clone())); } } if join_keys.is_empty() { left = LogicalPlanBuilder::from(&left).cross_join(right)?.build()?; } else { - let left_keys: Vec<_> = - join_keys.iter().map(|(l, _)| *l).collect(); - let right_keys: Vec<_> = - join_keys.iter().map(|(_, r)| *r).collect(); + let left_keys: Vec = + join_keys.iter().map(|(l, _)| l.clone()).collect(); + let right_keys: Vec = + join_keys.iter().map(|(_, r)| r.clone()).collect(); let builder = LogicalPlanBuilder::from(&left); left = builder - .join(right, JoinType::Inner, &left_keys, &right_keys)? + .join(right, JoinType::Inner, left_keys, right_keys)? .build()?; } - all_join_keys.extend_from_slice(&join_keys); + + all_join_keys.extend(join_keys); } // remove join expressions from filter @@ -548,12 +530,18 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // The SELECT expressions, with wildcards expanded. let select_exprs = self.prepare_select_exprs(&plan, &select.projection)?; + // having and group by clause may reference aliases defined in select projection + let projected_plan = self.project(&plan, select_exprs.clone())?; + let mut combined_schema = (**projected_plan.schema()).clone(); + combined_schema.merge(plan.schema()); + // Optionally the HAVING expression. let having_expr_opt = select .having .as_ref() .map::, _>(|having_expr| { - let having_expr = self.sql_expr_to_logical_expr(having_expr)?; + let having_expr = + self.sql_expr_to_logical_expr(having_expr, &combined_schema)?; // This step "dereferences" any aliases in the HAVING clause. // @@ -582,7 +570,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // The outer expressions we will search through for // aggregates. Aggregates may be sourced from the SELECT... let mut aggr_expr_haystack = select_exprs.clone(); - // ... or from the HAVING. if let Some(having_expr) = &having_expr_opt { aggr_expr_haystack.push(having_expr.clone()); @@ -596,7 +583,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .group_by .iter() .map(|e| { - let group_by_expr = self.sql_expr_to_logical_expr(e)?; + let group_by_expr = self.sql_expr_to_logical_expr(e, &combined_schema)?; let group_by_expr = resolve_aliases_to_exprs(&group_by_expr, &alias_map)?; let group_by_expr = resolve_positions_to_exprs(&group_by_expr, &select_exprs)?; @@ -816,16 +803,16 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let order_by_rex = order_by .iter() - .map(|e| self.order_by_to_sort_expr(e)) + .map(|e| self.order_by_to_sort_expr(e, plan.schema())) .collect::>>()?; LogicalPlanBuilder::from(plan).sort(order_by_rex)?.build() } /// convert sql OrderByExpr to Expr::Sort - fn order_by_to_sort_expr(&self, e: &OrderByExpr) -> Result { + fn order_by_to_sort_expr(&self, e: &OrderByExpr, schema: &DFSchema) -> Result { Ok(Expr::Sort { - expr: Box::new(self.sql_expr_to_logical_expr(&e.expr)?), + expr: Box::new(self.sql_expr_to_logical_expr(&e.expr, schema)?), // by default asc asc: e.asc.unwrap_or(true), // by default nulls first to be consistent with spark @@ -842,11 +829,15 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { find_column_exprs(exprs) .iter() .try_for_each(|col| match col { - Expr::Column(name) => { - schema.field_with_unqualified_name(name).map_err(|_| { + Expr::Column(col) => { + match &col.relation { + Some(r) => schema.field_with_qualified_name(r, &col.name), + None => schema.field_with_unqualified_name(&col.name), + } + .map_err(|_| { DataFusionError::Plan(format!( "Invalid identifier '{}' for schema {}", - name, + col, schema.to_string() )) })?; @@ -873,19 +864,25 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { /// Generate a relational expression from a SQL expression pub fn sql_to_rex(&self, sql: &SQLExpr, schema: &DFSchema) -> Result { - let expr = self.sql_expr_to_logical_expr(sql)?; + let expr = self.sql_expr_to_logical_expr(sql, schema)?; self.validate_schema_satisfies_exprs(schema, &[expr.clone()])?; Ok(expr) } - fn sql_fn_arg_to_logical_expr(&self, sql: &FunctionArg) -> Result { + fn sql_fn_arg_to_logical_expr( + &self, + sql: &FunctionArg, + schema: &DFSchema, + ) -> Result { match sql { - FunctionArg::Named { name: _, arg } => self.sql_expr_to_logical_expr(arg), - FunctionArg::Unnamed(value) => self.sql_expr_to_logical_expr(value), + FunctionArg::Named { name: _, arg } => { + self.sql_expr_to_logical_expr(arg, schema) + } + FunctionArg::Unnamed(value) => self.sql_expr_to_logical_expr(value, schema), } } - fn sql_expr_to_logical_expr(&self, sql: &SQLExpr) -> Result { + fn sql_expr_to_logical_expr(&self, sql: &SQLExpr, schema: &DFSchema) -> Result { match sql { SQLExpr::Value(Value::Number(n, _)) => match n.parse::() { Ok(n) => Ok(lit(n)), @@ -900,7 +897,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { fun: functions::BuiltinScalarFunction::DatePart, args: vec![ Expr::Literal(ScalarValue::Utf8(Some(format!("{}", field)))), - self.sql_expr_to_logical_expr(expr)?, + self.sql_expr_to_logical_expr(expr, schema)?, ], }), @@ -923,7 +920,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let var_names = vec![id.value.clone()]; Ok(Expr::ScalarVariable(var_names)) } else { - Ok(Expr::Column(id.value.to_string())) + Ok(Expr::Column( + schema + .field_with_unqualified_name(&id.value)? + .qualified_column(), + )) } } @@ -934,6 +935,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } if &var_names[0][0..1] == "@" { Ok(Expr::ScalarVariable(var_names)) + } else if var_names.len() == 2 { + // table.column identifier + let name = var_names.pop().unwrap(); + let relation = Some(var_names.pop().unwrap()); + Ok(Expr::Column(Column { relation, name })) } else { Err(DataFusionError::NotImplemented(format!( "Unsupported compound identifier '{:?}'", @@ -951,20 +957,20 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { else_result, } => { let expr = if let Some(e) = operand { - Some(Box::new(self.sql_expr_to_logical_expr(e)?)) + Some(Box::new(self.sql_expr_to_logical_expr(e, schema)?)) } else { None }; let when_expr = conditions .iter() - .map(|e| self.sql_expr_to_logical_expr(e)) + .map(|e| self.sql_expr_to_logical_expr(e, schema)) .collect::>>()?; let then_expr = results .iter() - .map(|e| self.sql_expr_to_logical_expr(e)) + .map(|e| self.sql_expr_to_logical_expr(e, schema)) .collect::>>()?; let else_expr = if let Some(e) = else_result { - Some(Box::new(self.sql_expr_to_logical_expr(e)?)) + Some(Box::new(self.sql_expr_to_logical_expr(e, schema)?)) } else { None }; @@ -984,7 +990,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ref expr, ref data_type, } => Ok(Expr::Cast { - expr: Box::new(self.sql_expr_to_logical_expr(expr)?), + expr: Box::new(self.sql_expr_to_logical_expr(expr, schema)?), data_type: convert_data_type(data_type)?, }), @@ -992,7 +998,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ref expr, ref data_type, } => Ok(Expr::TryCast { - expr: Box::new(self.sql_expr_to_logical_expr(expr)?), + expr: Box::new(self.sql_expr_to_logical_expr(expr, schema)?), data_type: convert_data_type(data_type)?, }), @@ -1004,19 +1010,19 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { data_type: convert_data_type(data_type)?, }), - SQLExpr::IsNull(ref expr) => { - Ok(Expr::IsNull(Box::new(self.sql_expr_to_logical_expr(expr)?))) - } + SQLExpr::IsNull(ref expr) => Ok(Expr::IsNull(Box::new( + self.sql_expr_to_logical_expr(expr, schema)?, + ))), SQLExpr::IsNotNull(ref expr) => Ok(Expr::IsNotNull(Box::new( - self.sql_expr_to_logical_expr(expr)?, + self.sql_expr_to_logical_expr(expr, schema)?, ))), SQLExpr::UnaryOp { ref op, ref expr } => match op { - UnaryOperator::Not => { - Ok(Expr::Not(Box::new(self.sql_expr_to_logical_expr(expr)?))) - } - UnaryOperator::Plus => Ok(self.sql_expr_to_logical_expr(expr)?), + UnaryOperator::Not => Ok(Expr::Not(Box::new( + self.sql_expr_to_logical_expr(expr, schema)?, + ))), + UnaryOperator::Plus => Ok(self.sql_expr_to_logical_expr(expr, schema)?), UnaryOperator::Minus => { match expr.as_ref() { // optimization: if it's a number literal, we apply the negative operator @@ -1032,7 +1038,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { })?)), }, // not a literal, apply negative operator on expression - _ => Ok(Expr::Negative(Box::new(self.sql_expr_to_logical_expr(expr)?))), + _ => Ok(Expr::Negative(Box::new(self.sql_expr_to_logical_expr(expr, schema)?))), } } _ => Err(DataFusionError::NotImplemented(format!( @@ -1047,10 +1053,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ref low, ref high, } => Ok(Expr::Between { - expr: Box::new(self.sql_expr_to_logical_expr(expr)?), + expr: Box::new(self.sql_expr_to_logical_expr(expr, schema)?), negated: *negated, - low: Box::new(self.sql_expr_to_logical_expr(low)?), - high: Box::new(self.sql_expr_to_logical_expr(high)?), + low: Box::new(self.sql_expr_to_logical_expr(low, schema)?), + high: Box::new(self.sql_expr_to_logical_expr(high, schema)?), }), SQLExpr::InList { @@ -1060,11 +1066,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } => { let list_expr = list .iter() - .map(|e| self.sql_expr_to_logical_expr(e)) + .map(|e| self.sql_expr_to_logical_expr(e, schema)) .collect::>>()?; Ok(Expr::InList { - expr: Box::new(self.sql_expr_to_logical_expr(expr)?), + expr: Box::new(self.sql_expr_to_logical_expr(expr, schema)?), list: list_expr, negated: *negated, }) @@ -1098,9 +1104,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { }?; Ok(Expr::BinaryExpr { - left: Box::new(self.sql_expr_to_logical_expr(left)?), + left: Box::new(self.sql_expr_to_logical_expr(left, schema)?), op: operator, - right: Box::new(self.sql_expr_to_logical_expr(right)?), + right: Box::new(self.sql_expr_to_logical_expr(right, schema)?), }) } @@ -1121,7 +1127,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // first, scalar built-in if let Ok(fun) = functions::BuiltinScalarFunction::from_str(&name) { - let args = self.function_args_to_expr(function)?; + let args = self.function_args_to_expr(function, schema)?; return Ok(Expr::ScalarFunction { fun, args }); }; @@ -1131,12 +1137,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let partition_by = window .partition_by .iter() - .map(|e| self.sql_expr_to_logical_expr(e)) + .map(|e| self.sql_expr_to_logical_expr(e, schema)) .collect::>>()?; let order_by = window .order_by .iter() - .map(|e| self.order_by_to_sort_expr(e)) + .map(|e| self.order_by_to_sort_expr(e, schema)) .collect::>>()?; let window_frame = window .window_frame @@ -1163,8 +1169,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { fun: window_functions::WindowFunction::AggregateFunction( aggregate_fun.clone(), ), - args: self - .aggregate_fn_to_expr(&aggregate_fun, function)?, + args: self.aggregate_fn_to_expr( + &aggregate_fun, + function, + schema, + )?, partition_by, order_by, window_frame, @@ -1177,7 +1186,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { fun: window_functions::WindowFunction::BuiltInWindowFunction( window_fun, ), - args: self.function_args_to_expr(function)?, + args:self.function_args_to_expr(function, schema)?, partition_by, order_by, window_frame, @@ -1188,7 +1197,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // next, aggregate built-ins if let Ok(fun) = aggregates::AggregateFunction::from_str(&name) { - let args = self.aggregate_fn_to_expr(&fun, function)?; + let args = self.aggregate_fn_to_expr(&fun, function, schema)?; return Ok(Expr::AggregateFunction { fun, distinct: function.distinct, @@ -1199,13 +1208,13 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // finally, user-defined functions (UDF) and UDAF match self.schema_provider.get_function_meta(&name) { Some(fm) => { - let args = self.function_args_to_expr(function)?; + let args = self.function_args_to_expr(function, schema)?; Ok(Expr::ScalarUDF { fun: fm, args }) } None => match self.schema_provider.get_aggregate_meta(&name) { Some(fm) => { - let args = self.function_args_to_expr(function)?; + let args = self.function_args_to_expr(function, schema)?; Ok(Expr::AggregateUDF { fun: fm, args }) } _ => Err(DataFusionError::Plan(format!( @@ -1216,7 +1225,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } } - SQLExpr::Nested(e) => self.sql_expr_to_logical_expr(e), + SQLExpr::Nested(e) => self.sql_expr_to_logical_expr(e, schema), _ => Err(DataFusionError::NotImplemented(format!( "Unsupported ast node {:?} in sqltorel", @@ -1228,11 +1237,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { fn function_args_to_expr( &self, function: &sqlparser::ast::Function, + schema: &DFSchema, ) -> Result> { function .args .iter() - .map(|a| self.sql_fn_arg_to_logical_expr(a)) + .map(|a| self.sql_fn_arg_to_logical_expr(a, schema)) .collect::>>() } @@ -1240,6 +1250,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { &self, fun: &aggregates::AggregateFunction, function: &sqlparser::ast::Function, + schema: &DFSchema, ) -> Result> { if *fun == aggregates::AggregateFunction::Count { function @@ -1250,11 +1261,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Ok(lit(1_u8)) } FunctionArg::Unnamed(SQLExpr::Wildcard) => Ok(lit(1_u8)), - _ => self.sql_fn_arg_to_logical_expr(a), + _ => self.sql_fn_arg_to_logical_expr(a, schema), }) .collect::>>() } else { - self.function_args_to_expr(function) + self.function_args_to_expr(function, schema) } } @@ -1519,13 +1530,15 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { /// Remove join expressions from a filter expression fn remove_join_expressions( expr: &Expr, - join_columns: &[(&str, &str)], + join_columns: &HashSet<(Column, Column)>, ) -> Result> { match expr { Expr::BinaryExpr { left, op, right } => match op { Operator::Eq => match (left.as_ref(), right.as_ref()) { (Expr::Column(l), Expr::Column(r)) => { - if join_columns.contains(&(l, r)) || join_columns.contains(&(r, l)) { + if join_columns.contains(&(l.clone(), r.clone())) + || join_columns.contains(&(r.clone(), l.clone())) + { Ok(None) } else { Ok(Some(expr.clone())) @@ -1556,12 +1569,12 @@ fn remove_join_expressions( /// foo = bar /// foo = bar AND bar = baz AND ... /// -fn extract_join_keys(expr: &Expr, accum: &mut Vec<(String, String)>) -> Result<()> { +fn extract_join_keys(expr: &Expr, accum: &mut Vec<(Column, Column)>) -> Result<()> { match expr { Expr::BinaryExpr { left, op, right } => match op { Operator::Eq => match (left.as_ref(), right.as_ref()) { (Expr::Column(l), Expr::Column(r)) => { - accum.push((l.to_owned(), r.to_owned())); + accum.push((l.clone(), r.clone())); Ok(()) } other => Err(DataFusionError::SQL(ParserError(format!( @@ -1588,13 +1601,13 @@ fn extract_join_keys(expr: &Expr, accum: &mut Vec<(String, String)>) -> Result<( /// Extract join keys from a WHERE clause fn extract_possible_join_keys( expr: &Expr, - accum: &mut Vec<(String, String)>, + accum: &mut Vec<(Column, Column)>, ) -> Result<()> { match expr { Expr::BinaryExpr { left, op, right } => match op { Operator::Eq => match (left.as_ref(), right.as_ref()) { (Expr::Column(l), Expr::Column(r)) => { - accum.push((l.to_owned(), r.to_owned())); + accum.push((l.clone(), r.clone())); Ok(()) } _ => Ok(()), @@ -1635,9 +1648,6 @@ mod tests { use crate::{logical_plan::create_udf, sql::parser::DFParser}; use functions::ScalarFunctionImplementation; - const PERSON_COLUMN_NAMES: &str = - "id, first_name, last_name, age, state, salary, birth_date, 😀"; - #[test] fn select_no_relation() { quick_test( @@ -1651,13 +1661,10 @@ mod tests { fn select_column_does_not_exist() { let sql = "SELECT doesnotexist FROM person"; let err = logical_plan(sql).expect_err("query should have failed"); - assert_eq!( - format!( - r#"Plan("Invalid identifier 'doesnotexist' for schema {}")"#, - PERSON_COLUMN_NAMES - ), - format!("{:?}", err) - ); + assert!(matches!( + err, + DataFusionError::Plan(msg) if msg == "No field with unqualified name 'doesnotexist'", + )); } #[test] @@ -1665,7 +1672,7 @@ mod tests { let sql = "SELECT age, age FROM person"; let err = logical_plan(sql).expect_err("query should have failed"); assert_eq!( - r##"Plan("Projections require unique expression names but the expression \"#age\" at position 0 and \"#age\" at position 1 have the same name. Consider aliasing (\"AS\") one of them.")"##, + r##"Plan("Projections require unique expression names but the expression \"#person.age\" at position 0 and \"#person.age\" at position 1 have the same name. Consider aliasing (\"AS\") one of them.")"##, format!("{:?}", err) ); } @@ -1675,7 +1682,7 @@ mod tests { let sql = "SELECT *, age FROM person"; let err = logical_plan(sql).expect_err("query should have failed"); assert_eq!( - r##"Plan("Projections require unique expression names but the expression \"#age\" at position 3 and \"#age\" at position 8 have the same name. Consider aliasing (\"AS\") one of them.")"##, + r##"Plan("Projections require unique expression names but the expression \"#person.age\" at position 3 and \"#person.age\" at position 8 have the same name. Consider aliasing (\"AS\") one of them.")"##, format!("{:?}", err) ); } @@ -1684,7 +1691,7 @@ mod tests { fn select_wildcard_with_repeated_column_but_is_aliased() { quick_test( "SELECT *, first_name AS fn from person", - "Projection: #id, #first_name, #last_name, #age, #state, #salary, #birth_date, #😀, #first_name AS fn\ + "Projection: #person.id, #person.first_name, #person.last_name, #person.age, #person.state, #person.salary, #person.birth_date, #person.😀, #person.first_name AS fn\ \n TableScan: person projection=None", ); } @@ -1702,8 +1709,8 @@ mod tests { fn select_simple_filter() { let sql = "SELECT id, first_name, last_name \ FROM person WHERE state = 'CO'"; - let expected = "Projection: #id, #first_name, #last_name\ - \n Filter: #state Eq Utf8(\"CO\")\ + let expected = "Projection: #person.id, #person.first_name, #person.last_name\ + \n Filter: #person.state Eq Utf8(\"CO\")\ \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -1712,34 +1719,28 @@ mod tests { fn select_filter_column_does_not_exist() { let sql = "SELECT first_name FROM person WHERE doesnotexist = 'A'"; let err = logical_plan(sql).expect_err("query should have failed"); - assert_eq!( - format!( - r#"Plan("Invalid identifier 'doesnotexist' for schema {}")"#, - PERSON_COLUMN_NAMES - ), - format!("{:?}", err) - ); + assert!(matches!( + err, + DataFusionError::Plan(msg) if msg == "No field with unqualified name 'doesnotexist'", + )); } #[test] fn select_filter_cannot_use_alias() { let sql = "SELECT first_name AS x FROM person WHERE x = 'A'"; let err = logical_plan(sql).expect_err("query should have failed"); - assert_eq!( - format!( - r#"Plan("Invalid identifier 'x' for schema {}")"#, - PERSON_COLUMN_NAMES - ), - format!("{:?}", err) - ); + assert!(matches!( + err, + DataFusionError::Plan(msg) if msg == "No field with unqualified name 'x'", + )); } #[test] fn select_neg_filter() { let sql = "SELECT id, first_name, last_name \ FROM person WHERE NOT state"; - let expected = "Projection: #id, #first_name, #last_name\ - \n Filter: NOT #state\ + let expected = "Projection: #person.id, #person.first_name, #person.last_name\ + \n Filter: NOT #person.state\ \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -1748,8 +1749,8 @@ mod tests { fn select_compound_filter() { let sql = "SELECT id, first_name, last_name \ FROM person WHERE state = 'CO' AND age >= 21 AND age <= 65"; - let expected = "Projection: #id, #first_name, #last_name\ - \n Filter: #state Eq Utf8(\"CO\") And #age GtEq Int64(21) And #age LtEq Int64(65)\ + let expected = "Projection: #person.id, #person.first_name, #person.last_name\ + \n Filter: #person.state Eq Utf8(\"CO\") And #person.age GtEq Int64(21) And #person.age LtEq Int64(65)\ \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -1759,8 +1760,8 @@ mod tests { let sql = "SELECT state FROM person WHERE birth_date < CAST (158412331400600000 as timestamp)"; - let expected = "Projection: #state\ - \n Filter: #birth_date Lt CAST(Int64(158412331400600000) AS Timestamp(Nanosecond, None))\ + let expected = "Projection: #person.state\ + \n Filter: #person.birth_date Lt CAST(Int64(158412331400600000) AS Timestamp(Nanosecond, None))\ \n TableScan: person projection=None"; quick_test(sql, expected); @@ -1771,8 +1772,8 @@ mod tests { let sql = "SELECT state FROM person WHERE birth_date < CAST ('2020-01-01' as date)"; - let expected = "Projection: #state\ - \n Filter: #birth_date Lt CAST(Utf8(\"2020-01-01\") AS Date32)\ + let expected = "Projection: #person.state\ + \n Filter: #person.birth_date Lt CAST(Utf8(\"2020-01-01\") AS Date32)\ \n TableScan: person projection=None"; quick_test(sql, expected); @@ -1788,13 +1789,13 @@ mod tests { AND age >= 21 \ AND age < 65 \ AND age <= 65"; - let expected = "Projection: #age, #first_name, #last_name\ - \n Filter: #age Eq Int64(21) \ - And #age NotEq Int64(21) \ - And #age Gt Int64(21) \ - And #age GtEq Int64(21) \ - And #age Lt Int64(65) \ - And #age LtEq Int64(65)\ + let expected = "Projection: #person.age, #person.first_name, #person.last_name\ + \n Filter: #person.age Eq Int64(21) \ + And #person.age NotEq Int64(21) \ + And #person.age Gt Int64(21) \ + And #person.age GtEq Int64(21) \ + And #person.age Lt Int64(65) \ + And #person.age LtEq Int64(65)\ \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -1802,8 +1803,8 @@ mod tests { #[test] fn select_between() { let sql = "SELECT state FROM person WHERE age BETWEEN 21 AND 65"; - let expected = "Projection: #state\ - \n Filter: #age BETWEEN Int64(21) AND Int64(65)\ + let expected = "Projection: #person.state\ + \n Filter: #person.age BETWEEN Int64(21) AND Int64(65)\ \n TableScan: person projection=None"; quick_test(sql, expected); @@ -1812,8 +1813,8 @@ mod tests { #[test] fn select_between_negated() { let sql = "SELECT state FROM person WHERE age NOT BETWEEN 21 AND 65"; - let expected = "Projection: #state\ - \n Filter: #age NOT BETWEEN Int64(21) AND Int64(65)\ + let expected = "Projection: #person.state\ + \n Filter: #person.age NOT BETWEEN Int64(21) AND Int64(65)\ \n TableScan: person projection=None"; quick_test(sql, expected); @@ -1829,9 +1830,9 @@ mod tests { FROM person ) )"; - let expected = "Projection: #fn2, #last_name\ - \n Projection: #fn1 AS fn2, #last_name, #birth_date\ - \n Projection: #first_name AS fn1, #last_name, #birth_date, #age\ + let expected = "Projection: #fn2, #person.last_name\ + \n Projection: #fn1 AS fn2, #person.last_name, #person.birth_date\ + \n Projection: #person.first_name AS fn1, #person.last_name, #person.birth_date, #person.age\ \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -1846,10 +1847,10 @@ mod tests { ) WHERE fn1 = 'X' AND age < 30"; - let expected = "Projection: #fn1, #age\ - \n Filter: #fn1 Eq Utf8(\"X\") And #age Lt Int64(30)\ - \n Projection: #first_name AS fn1, #age\ - \n Filter: #age Gt Int64(20)\ + let expected = "Projection: #fn1, #person.age\ + \n Filter: #fn1 Eq Utf8(\"X\") And #person.age Lt Int64(30)\ + \n Projection: #person.first_name AS fn1, #person.age\ + \n Filter: #person.age Gt Int64(20)\ \n TableScan: person projection=None"; quick_test(sql, expected); @@ -1860,8 +1861,8 @@ mod tests { let sql = "SELECT id, age FROM person HAVING age > 100 AND age < 200"; - let expected = "Projection: #id, #age\ - \n Filter: #age Gt Int64(100) And #age Lt Int64(200)\ + let expected = "Projection: #person.id, #person.age\ + \n Filter: #person.age Gt Int64(100) And #person.age Lt Int64(200)\ \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -1907,9 +1908,9 @@ mod tests { let sql = "SELECT MAX(age) FROM person HAVING MAX(age) < 30"; - let expected = "Projection: #MAX(age)\ - \n Filter: #MAX(age) Lt Int64(30)\ - \n Aggregate: groupBy=[[]], aggr=[[MAX(#age)]]\ + let expected = "Projection: #MAX(person.age)\ + \n Filter: #MAX(person.age) Lt Int64(30)\ + \n Aggregate: groupBy=[[]], aggr=[[MAX(#person.age)]]\ \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -1919,9 +1920,9 @@ mod tests { let sql = "SELECT MAX(age) FROM person HAVING MAX(first_name) > 'M'"; - let expected = "Projection: #MAX(age)\ - \n Filter: #MAX(first_name) Gt Utf8(\"M\")\ - \n Aggregate: groupBy=[[]], aggr=[[MAX(#age), MAX(#first_name)]]\ + let expected = "Projection: #MAX(person.age)\ + \n Filter: #MAX(person.first_name) Gt Utf8(\"M\")\ + \n Aggregate: groupBy=[[]], aggr=[[MAX(#person.age), MAX(#person.first_name)]]\ \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -1943,9 +1944,10 @@ mod tests { let sql = "SELECT MAX(age) as max_age FROM person HAVING max_age < 30"; - let expected = "Projection: #MAX(age) AS max_age\ - \n Filter: #MAX(age) Lt Int64(30)\ - \n Aggregate: groupBy=[[]], aggr=[[MAX(#age)]]\ + // FIXME: add test for having in execution + let expected = "Projection: #MAX(person.age) AS max_age\ + \n Filter: #MAX(person.age) Lt Int64(30)\ + \n Aggregate: groupBy=[[]], aggr=[[MAX(#person.age)]]\ \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -1955,9 +1957,9 @@ mod tests { let sql = "SELECT MAX(age) as max_age FROM person HAVING MAX(age) < 30"; - let expected = "Projection: #MAX(age) AS max_age\ - \n Filter: #MAX(age) Lt Int64(30)\ - \n Aggregate: groupBy=[[]], aggr=[[MAX(#age)]]\ + let expected = "Projection: #MAX(person.age) AS max_age\ + \n Filter: #MAX(person.age) Lt Int64(30)\ + \n Aggregate: groupBy=[[]], aggr=[[MAX(#person.age)]]\ \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -1968,9 +1970,9 @@ mod tests { FROM person GROUP BY first_name HAVING first_name = 'M'"; - let expected = "Projection: #first_name, #MAX(age)\ - \n Filter: #first_name Eq Utf8(\"M\")\ - \n Aggregate: groupBy=[[#first_name]], aggr=[[MAX(#age)]]\ + let expected = "Projection: #person.first_name, #MAX(person.age)\ + \n Filter: #person.first_name Eq Utf8(\"M\")\ + \n Aggregate: groupBy=[[#person.first_name]], aggr=[[MAX(#person.age)]]\ \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -1982,10 +1984,10 @@ mod tests { WHERE id > 5 GROUP BY first_name HAVING MAX(age) < 100"; - let expected = "Projection: #first_name, #MAX(age)\ - \n Filter: #MAX(age) Lt Int64(100)\ - \n Aggregate: groupBy=[[#first_name]], aggr=[[MAX(#age)]]\ - \n Filter: #id Gt Int64(5)\ + let expected = "Projection: #person.first_name, #MAX(person.age)\ + \n Filter: #MAX(person.age) Lt Int64(100)\ + \n Aggregate: groupBy=[[#person.first_name]], aggr=[[MAX(#person.age)]]\ + \n Filter: #person.id Gt Int64(5)\ \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -1998,10 +2000,10 @@ mod tests { WHERE id > 5 AND age > 18 GROUP BY first_name HAVING MAX(age) < 100"; - let expected = "Projection: #first_name, #MAX(age)\ - \n Filter: #MAX(age) Lt Int64(100)\ - \n Aggregate: groupBy=[[#first_name]], aggr=[[MAX(#age)]]\ - \n Filter: #id Gt Int64(5) And #age Gt Int64(18)\ + let expected = "Projection: #person.first_name, #MAX(person.age)\ + \n Filter: #MAX(person.age) Lt Int64(100)\ + \n Aggregate: groupBy=[[#person.first_name]], aggr=[[MAX(#person.age)]]\ + \n Filter: #person.id Gt Int64(5) And #person.age Gt Int64(18)\ \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -2012,9 +2014,9 @@ mod tests { FROM person GROUP BY first_name HAVING MAX(age) > 2 AND fn = 'M'"; - let expected = "Projection: #first_name AS fn, #MAX(age)\ - \n Filter: #MAX(age) Gt Int64(2) And #first_name Eq Utf8(\"M\")\ - \n Aggregate: groupBy=[[#first_name]], aggr=[[MAX(#age)]]\ + let expected = "Projection: #person.first_name AS fn, #MAX(person.age)\ + \n Filter: #MAX(person.age) Gt Int64(2) And #person.first_name Eq Utf8(\"M\")\ + \n Aggregate: groupBy=[[#person.first_name]], aggr=[[MAX(#person.age)]]\ \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -2026,9 +2028,9 @@ mod tests { FROM person GROUP BY first_name HAVING MAX(age) > 2 AND max_age < 5 AND first_name = 'M' AND fn = 'N'"; - let expected = "Projection: #first_name AS fn, #MAX(age) AS max_age\ - \n Filter: #MAX(age) Gt Int64(2) And #MAX(age) Lt Int64(5) And #first_name Eq Utf8(\"M\") And #first_name Eq Utf8(\"N\")\ - \n Aggregate: groupBy=[[#first_name]], aggr=[[MAX(#age)]]\ + let expected = "Projection: #person.first_name AS fn, #MAX(person.age) AS max_age\ + \n Filter: #MAX(person.age) Gt Int64(2) And #MAX(person.age) Lt Int64(5) And #person.first_name Eq Utf8(\"M\") And #person.first_name Eq Utf8(\"N\")\ + \n Aggregate: groupBy=[[#person.first_name]], aggr=[[MAX(#person.age)]]\ \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -2039,9 +2041,9 @@ mod tests { FROM person GROUP BY first_name HAVING MAX(age) > 100"; - let expected = "Projection: #first_name, #MAX(age)\ - \n Filter: #MAX(age) Gt Int64(100)\ - \n Aggregate: groupBy=[[#first_name]], aggr=[[MAX(#age)]]\ + let expected = "Projection: #person.first_name, #MAX(person.age)\ + \n Filter: #MAX(person.age) Gt Int64(100)\ + \n Aggregate: groupBy=[[#person.first_name]], aggr=[[MAX(#person.age)]]\ \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -2065,9 +2067,9 @@ mod tests { FROM person GROUP BY first_name HAVING MAX(age) > 100 AND MAX(age) < 200"; - let expected = "Projection: #first_name, #MAX(age)\ - \n Filter: #MAX(age) Gt Int64(100) And #MAX(age) Lt Int64(200)\ - \n Aggregate: groupBy=[[#first_name]], aggr=[[MAX(#age)]]\ + let expected = "Projection: #person.first_name, #MAX(person.age)\ + \n Filter: #MAX(person.age) Gt Int64(100) And #MAX(person.age) Lt Int64(200)\ + \n Aggregate: groupBy=[[#person.first_name]], aggr=[[MAX(#person.age)]]\ \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -2078,9 +2080,9 @@ mod tests { FROM person GROUP BY first_name HAVING MAX(age) > 100 AND MIN(id) < 50"; - let expected = "Projection: #first_name, #MAX(age)\ - \n Filter: #MAX(age) Gt Int64(100) And #MIN(id) Lt Int64(50)\ - \n Aggregate: groupBy=[[#first_name]], aggr=[[MAX(#age), MIN(#id)]]\ + let expected = "Projection: #person.first_name, #MAX(person.age)\ + \n Filter: #MAX(person.age) Gt Int64(100) And #MIN(person.id) Lt Int64(50)\ + \n Aggregate: groupBy=[[#person.first_name]], aggr=[[MAX(#person.age), MIN(#person.id)]]\ \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -2092,9 +2094,9 @@ mod tests { FROM person GROUP BY first_name HAVING max_age > 100"; - let expected = "Projection: #first_name, #MAX(age) AS max_age\ - \n Filter: #MAX(age) Gt Int64(100)\ - \n Aggregate: groupBy=[[#first_name]], aggr=[[MAX(#age)]]\ + let expected = "Projection: #person.first_name, #MAX(person.age) AS max_age\ + \n Filter: #MAX(person.age) Gt Int64(100)\ + \n Aggregate: groupBy=[[#person.first_name]], aggr=[[MAX(#person.age)]]\ \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -2107,9 +2109,9 @@ mod tests { GROUP BY first_name HAVING max_age_plus_one > 100"; let expected = - "Projection: #first_name, #MAX(age) Plus Int64(1) AS max_age_plus_one\ - \n Filter: #MAX(age) Plus Int64(1) Gt Int64(100)\ - \n Aggregate: groupBy=[[#first_name]], aggr=[[MAX(#age)]]\ + "Projection: #person.first_name, #MAX(person.age) Plus Int64(1) AS max_age_plus_one\ + \n Filter: #MAX(person.age) Plus Int64(1) Gt Int64(100)\ + \n Aggregate: groupBy=[[#person.first_name]], aggr=[[MAX(#person.age)]]\ \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -2121,9 +2123,9 @@ mod tests { FROM person GROUP BY first_name HAVING MAX(age) > 100 AND MIN(id - 2) < 50"; - let expected = "Projection: #first_name, #MAX(age)\ - \n Filter: #MAX(age) Gt Int64(100) And #MIN(id Minus Int64(2)) Lt Int64(50)\ - \n Aggregate: groupBy=[[#first_name]], aggr=[[MAX(#age), MIN(#id Minus Int64(2))]]\ + let expected = "Projection: #person.first_name, #MAX(person.age)\ + \n Filter: #MAX(person.age) Gt Int64(100) And #MIN(person.id Minus Int64(2)) Lt Int64(50)\ + \n Aggregate: groupBy=[[#person.first_name]], aggr=[[MAX(#person.age), MIN(#person.id Minus Int64(2))]]\ \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -2134,9 +2136,9 @@ mod tests { FROM person GROUP BY first_name HAVING MAX(age) > 100 AND COUNT(*) < 50"; - let expected = "Projection: #first_name, #MAX(age)\ - \n Filter: #MAX(age) Gt Int64(100) And #COUNT(UInt8(1)) Lt Int64(50)\ - \n Aggregate: groupBy=[[#first_name]], aggr=[[MAX(#age), COUNT(UInt8(1))]]\ + let expected = "Projection: #person.first_name, #MAX(person.age)\ + \n Filter: #MAX(person.age) Gt Int64(100) And #COUNT(UInt8(1)) Lt Int64(50)\ + \n Aggregate: groupBy=[[#person.first_name]], aggr=[[MAX(#person.age), COUNT(UInt8(1))]]\ \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -2144,7 +2146,7 @@ mod tests { #[test] fn select_binary_expr() { let sql = "SELECT age + salary from person"; - let expected = "Projection: #age Plus #salary\ + let expected = "Projection: #person.age Plus #person.salary\ \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -2152,7 +2154,7 @@ mod tests { #[test] fn select_binary_expr_nested() { let sql = "SELECT (age + salary)/2 from person"; - let expected = "Projection: #age Plus #salary Divide Int64(2)\ + let expected = "Projection: #person.age Plus #person.salary Divide Int64(2)\ \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -2161,15 +2163,15 @@ mod tests { fn select_wildcard_with_groupby() { quick_test( r#"SELECT * FROM person GROUP BY id, first_name, last_name, age, state, salary, birth_date, "😀""#, - "Projection: #id, #first_name, #last_name, #age, #state, #salary, #birth_date, #😀\ - \n Aggregate: groupBy=[[#id, #first_name, #last_name, #age, #state, #salary, #birth_date, #😀]], aggr=[[]]\ + "Projection: #person.id, #person.first_name, #person.last_name, #person.age, #person.state, #person.salary, #person.birth_date, #person.😀\ + \n Aggregate: groupBy=[[#person.id, #person.first_name, #person.last_name, #person.age, #person.state, #person.salary, #person.birth_date, #person.😀]], aggr=[[]]\ \n TableScan: person projection=None", ); quick_test( "SELECT * FROM (SELECT first_name, last_name FROM person) GROUP BY first_name, last_name", - "Projection: #first_name, #last_name\ - \n Aggregate: groupBy=[[#first_name, #last_name]], aggr=[[]]\ - \n Projection: #first_name, #last_name\ + "Projection: #person.first_name, #person.last_name\ + \n Aggregate: groupBy=[[#person.first_name, #person.last_name]], aggr=[[]]\ + \n Projection: #person.first_name, #person.last_name\ \n TableScan: person projection=None", ); } @@ -2178,8 +2180,8 @@ mod tests { fn select_simple_aggregate() { quick_test( "SELECT MIN(age) FROM person", - "Projection: #MIN(age)\ - \n Aggregate: groupBy=[[]], aggr=[[MIN(#age)]]\ + "Projection: #MIN(person.age)\ + \n Aggregate: groupBy=[[]], aggr=[[MIN(#person.age)]]\ \n TableScan: person projection=None", ); } @@ -2188,8 +2190,8 @@ mod tests { fn test_sum_aggregate() { quick_test( "SELECT SUM(age) from person", - "Projection: #SUM(age)\ - \n Aggregate: groupBy=[[]], aggr=[[SUM(#age)]]\ + "Projection: #SUM(person.age)\ + \n Aggregate: groupBy=[[]], aggr=[[SUM(#person.age)]]\ \n TableScan: person projection=None", ); } @@ -2198,13 +2200,10 @@ mod tests { fn select_simple_aggregate_column_does_not_exist() { let sql = "SELECT MIN(doesnotexist) FROM person"; let err = logical_plan(sql).expect_err("query should have failed"); - assert_eq!( - format!( - r#"Plan("Invalid identifier 'doesnotexist' for schema {}")"#, - PERSON_COLUMN_NAMES - ), - format!("{:?}", err) - ); + assert!(matches!( + err, + DataFusionError::Plan(msg) if msg == "No field with unqualified name 'doesnotexist'", + )); } #[test] @@ -2212,7 +2211,7 @@ mod tests { let sql = "SELECT MIN(age), MIN(age) FROM person"; let err = logical_plan(sql).expect_err("query should have failed"); assert_eq!( - r##"Plan("Projections require unique expression names but the expression \"#MIN(age)\" at position 0 and \"#MIN(age)\" at position 1 have the same name. Consider aliasing (\"AS\") one of them.")"##, + r##"Plan("Projections require unique expression names but the expression \"MIN(#person.age)\" at position 0 and \"MIN(#person.age)\" at position 1 have the same name. Consider aliasing (\"AS\") one of them.")"##, format!("{:?}", err) ); } @@ -2221,8 +2220,8 @@ mod tests { fn select_simple_aggregate_repeated_aggregate_with_single_alias() { quick_test( "SELECT MIN(age), MIN(age) AS a FROM person", - "Projection: #MIN(age), #MIN(age) AS a\ - \n Aggregate: groupBy=[[]], aggr=[[MIN(#age)]]\ + "Projection: #MIN(person.age), #MIN(person.age) AS a\ + \n Aggregate: groupBy=[[]], aggr=[[MIN(#person.age)]]\ \n TableScan: person projection=None", ); } @@ -2231,8 +2230,8 @@ mod tests { fn select_simple_aggregate_repeated_aggregate_with_unique_aliases() { quick_test( "SELECT MIN(age) AS a, MIN(age) AS b FROM person", - "Projection: #MIN(age) AS a, #MIN(age) AS b\ - \n Aggregate: groupBy=[[]], aggr=[[MIN(#age)]]\ + "Projection: #MIN(person.age) AS a, #MIN(person.age) AS b\ + \n Aggregate: groupBy=[[]], aggr=[[MIN(#person.age)]]\ \n TableScan: person projection=None", ); } @@ -2242,7 +2241,7 @@ mod tests { let sql = "SELECT MIN(age) AS a, MIN(age) AS a FROM person"; let err = logical_plan(sql).expect_err("query should have failed"); assert_eq!( - r##"Plan("Projections require unique expression names but the expression \"#MIN(age) AS a\" at position 0 and \"#MIN(age) AS a\" at position 1 have the same name. Consider aliasing (\"AS\") one of them.")"##, + r##"Plan("Projections require unique expression names but the expression \"MIN(#person.age) AS a\" at position 0 and \"MIN(#person.age) AS a\" at position 1 have the same name. Consider aliasing (\"AS\") one of them.")"##, format!("{:?}", err) ); } @@ -2251,8 +2250,8 @@ mod tests { fn select_simple_aggregate_with_groupby() { quick_test( "SELECT state, MIN(age), MAX(age) FROM person GROUP BY state", - "Projection: #state, #MIN(age), #MAX(age)\ - \n Aggregate: groupBy=[[#state]], aggr=[[MIN(#age), MAX(#age)]]\ + "Projection: #person.state, #MIN(person.age), #MAX(person.age)\ + \n Aggregate: groupBy=[[#person.state]], aggr=[[MIN(#person.age), MAX(#person.age)]]\ \n TableScan: person projection=None", ); } @@ -2261,8 +2260,8 @@ mod tests { fn select_simple_aggregate_with_groupby_with_aliases() { quick_test( "SELECT state AS a, MIN(age) AS b FROM person GROUP BY state", - "Projection: #state AS a, #MIN(age) AS b\ - \n Aggregate: groupBy=[[#state]], aggr=[[MIN(#age)]]\ + "Projection: #person.state AS a, #MIN(person.age) AS b\ + \n Aggregate: groupBy=[[#person.state]], aggr=[[MIN(#person.age)]]\ \n TableScan: person projection=None", ); } @@ -2272,7 +2271,7 @@ mod tests { let sql = "SELECT state AS a, MIN(age) AS a FROM person GROUP BY state"; let err = logical_plan(sql).expect_err("query should have failed"); assert_eq!( - r##"Plan("Projections require unique expression names but the expression \"#state AS a\" at position 0 and \"#MIN(age) AS a\" at position 1 have the same name. Consider aliasing (\"AS\") one of them.")"##, + r##"Plan("Projections require unique expression names but the expression \"#person.state AS a\" at position 0 and \"MIN(#person.age) AS a\" at position 1 have the same name. Consider aliasing (\"AS\") one of them.")"##, format!("{:?}", err) ); } @@ -2281,8 +2280,8 @@ mod tests { fn select_simple_aggregate_with_groupby_column_unselected() { quick_test( "SELECT MIN(age), MAX(age) FROM person GROUP BY state", - "Projection: #MIN(age), #MAX(age)\ - \n Aggregate: groupBy=[[#state]], aggr=[[MIN(#age), MAX(#age)]]\ + "Projection: #MIN(person.age), #MAX(person.age)\ + \n Aggregate: groupBy=[[#person.state]], aggr=[[MIN(#person.age), MAX(#person.age)]]\ \n TableScan: person projection=None", ); } @@ -2291,26 +2290,20 @@ mod tests { fn select_simple_aggregate_with_groupby_and_column_in_group_by_does_not_exist() { let sql = "SELECT SUM(age) FROM person GROUP BY doesnotexist"; let err = logical_plan(sql).expect_err("query should have failed"); - assert_eq!( - format!( - r#"Plan("Invalid identifier 'doesnotexist' for schema {}")"#, - PERSON_COLUMN_NAMES - ), - format!("{:?}", err) - ); + assert!(matches!( + err, + DataFusionError::Plan(msg) if msg == "No field with unqualified name 'doesnotexist'", + )); } #[test] fn select_simple_aggregate_with_groupby_and_column_in_aggregate_does_not_exist() { let sql = "SELECT SUM(doesnotexist) FROM person GROUP BY first_name"; let err = logical_plan(sql).expect_err("query should have failed"); - assert_eq!( - format!( - r#"Plan("Invalid identifier 'doesnotexist' for schema {}")"#, - PERSON_COLUMN_NAMES - ), - format!("{:?}", err) - ); + assert!(matches!( + err, + DataFusionError::Plan(msg) if msg == "No field with unqualified name 'doesnotexist'", + )); } #[test] @@ -2327,18 +2320,18 @@ mod tests { fn select_unsupported_complex_interval() { let sql = "SELECT INTERVAL '1 year 1 day'"; let err = logical_plan(sql).expect_err("query should have failed"); - assert_eq!( - r#"NotImplemented("DF does not support intervals that have both a Year/Month part as well as Days/Hours/Mins/Seconds: \"1 year 1 day\". Hint: try breaking the interval into two parts, one with Year/Month and the other with Days/Hours/Mins/Seconds - e.g. (NOW() + INTERVAL '1 year') + INTERVAL '1 day'")"#, - format!("{:?}", err) - ); + assert!(matches!( + err, + DataFusionError::NotImplemented(msg) if msg == "DF does not support intervals that have both a Year/Month part as well as Days/Hours/Mins/Seconds: \"1 year 1 day\". Hint: try breaking the interval into two parts, one with Year/Month and the other with Days/Hours/Mins/Seconds - e.g. (NOW() + INTERVAL '1 year') + INTERVAL '1 day'", + )); } #[test] fn select_simple_aggregate_with_groupby_and_column_is_in_aggregate_and_groupby() { quick_test( "SELECT MAX(first_name) FROM person GROUP BY first_name", - "Projection: #MAX(first_name)\ - \n Aggregate: groupBy=[[#first_name]], aggr=[[MAX(#first_name)]]\ + "Projection: #MAX(person.first_name)\ + \n Aggregate: groupBy=[[#person.first_name]], aggr=[[MAX(#person.first_name)]]\ \n TableScan: person projection=None", ); } @@ -2347,14 +2340,14 @@ mod tests { fn select_simple_aggregate_with_groupby_can_use_positions() { quick_test( "SELECT state, age AS b, COUNT(1) FROM person GROUP BY 1, 2", - "Projection: #state, #age AS b, #COUNT(UInt8(1))\ - \n Aggregate: groupBy=[[#state, #age]], aggr=[[COUNT(UInt8(1))]]\ + "Projection: #person.state, #person.age AS b, #COUNT(UInt8(1))\ + \n Aggregate: groupBy=[[#person.state, #person.age]], aggr=[[COUNT(UInt8(1))]]\ \n TableScan: person projection=None", ); quick_test( "SELECT state, age AS b, COUNT(1) FROM person GROUP BY 2, 1", - "Projection: #state, #age AS b, #COUNT(UInt8(1))\ - \n Aggregate: groupBy=[[#age, #state]], aggr=[[COUNT(UInt8(1))]]\ + "Projection: #person.state, #person.age AS b, #COUNT(UInt8(1))\ + \n Aggregate: groupBy=[[#person.age, #person.state]], aggr=[[COUNT(UInt8(1))]]\ \n TableScan: person projection=None", ); } @@ -2380,8 +2373,8 @@ mod tests { fn select_simple_aggregate_with_groupby_can_use_alias() { quick_test( "SELECT state AS a, MIN(age) AS b FROM person GROUP BY a", - "Projection: #state AS a, #MIN(age) AS b\ - \n Aggregate: groupBy=[[#state]], aggr=[[MIN(#age)]]\ + "Projection: #person.state AS a, #MIN(person.age) AS b\ + \n Aggregate: groupBy=[[#person.state]], aggr=[[MIN(#person.age)]]\ \n TableScan: person projection=None", ); } @@ -2391,7 +2384,7 @@ mod tests { let sql = "SELECT state, MIN(age), MIN(age) FROM person GROUP BY state"; let err = logical_plan(sql).expect_err("query should have failed"); assert_eq!( - r##"Plan("Projections require unique expression names but the expression \"#MIN(age)\" at position 1 and \"#MIN(age)\" at position 2 have the same name. Consider aliasing (\"AS\") one of them.")"##, + r##"Plan("Projections require unique expression names but the expression \"MIN(#person.age)\" at position 1 and \"MIN(#person.age)\" at position 2 have the same name. Consider aliasing (\"AS\") one of them.")"##, format!("{:?}", err) ); } @@ -2400,8 +2393,8 @@ mod tests { fn select_simple_aggregate_with_groupby_aggregate_repeated_and_one_has_alias() { quick_test( "SELECT state, MIN(age), MIN(age) AS ma FROM person GROUP BY state", - "Projection: #state, #MIN(age), #MIN(age) AS ma\ - \n Aggregate: groupBy=[[#state]], aggr=[[MIN(#age)]]\ + "Projection: #person.state, #MIN(person.age), #MIN(person.age) AS ma\ + \n Aggregate: groupBy=[[#person.state]], aggr=[[MIN(#person.age)]]\ \n TableScan: person projection=None", ) } @@ -2409,8 +2402,8 @@ mod tests { fn select_simple_aggregate_with_groupby_non_column_expression_unselected() { quick_test( "SELECT MIN(first_name) FROM person GROUP BY age + 1", - "Projection: #MIN(first_name)\ - \n Aggregate: groupBy=[[#age Plus Int64(1)]], aggr=[[MIN(#first_name)]]\ + "Projection: #MIN(person.first_name)\ + \n Aggregate: groupBy=[[#person.age Plus Int64(1)]], aggr=[[MIN(#person.first_name)]]\ \n TableScan: person projection=None", ); } @@ -2420,14 +2413,14 @@ mod tests { ) { quick_test( "SELECT age + 1, MIN(first_name) FROM person GROUP BY age + 1", - "Projection: #age Plus Int64(1), #MIN(first_name)\ - \n Aggregate: groupBy=[[#age Plus Int64(1)]], aggr=[[MIN(#first_name)]]\ + "Projection: #person.age Plus Int64(1), #MIN(person.first_name)\ + \n Aggregate: groupBy=[[#person.age Plus Int64(1)]], aggr=[[MIN(#person.first_name)]]\ \n TableScan: person projection=None", ); quick_test( "SELECT MIN(first_name), age + 1 FROM person GROUP BY age + 1", - "Projection: #MIN(first_name), #age Plus Int64(1)\ - \n Aggregate: groupBy=[[#age Plus Int64(1)]], aggr=[[MIN(#first_name)]]\ + "Projection: #MIN(person.first_name), #person.age Plus Int64(1)\ + \n Aggregate: groupBy=[[#person.age Plus Int64(1)]], aggr=[[MIN(#person.first_name)]]\ \n TableScan: person projection=None", ); } @@ -2437,8 +2430,8 @@ mod tests { { quick_test( "SELECT ((age + 1) / 2) * (age + 1), MIN(first_name) FROM person GROUP BY age + 1", - "Projection: #age Plus Int64(1) Divide Int64(2) Multiply #age Plus Int64(1), #MIN(first_name)\ - \n Aggregate: groupBy=[[#age Plus Int64(1)]], aggr=[[MIN(#first_name)]]\ + "Projection: #person.age Plus Int64(1) Divide Int64(2) Multiply #person.age Plus Int64(1), #MIN(person.first_name)\ + \n Aggregate: groupBy=[[#person.age Plus Int64(1)]], aggr=[[MIN(#person.first_name)]]\ \n TableScan: person projection=None", ); } @@ -2471,8 +2464,8 @@ mod tests { fn select_simple_aggregate_nested_in_binary_expr_with_groupby() { quick_test( "SELECT state, MIN(age) < 10 FROM person GROUP BY state", - "Projection: #state, #MIN(age) Lt Int64(10)\ - \n Aggregate: groupBy=[[#state]], aggr=[[MIN(#age)]]\ + "Projection: #person.state, #MIN(person.age) Lt Int64(10)\ + \n Aggregate: groupBy=[[#person.state]], aggr=[[MIN(#person.age)]]\ \n TableScan: person projection=None", ); } @@ -2481,8 +2474,8 @@ mod tests { fn select_simple_aggregate_and_nested_groupby_column() { quick_test( "SELECT age + 1, MAX(first_name) FROM person GROUP BY age", - "Projection: #age Plus Int64(1), #MAX(first_name)\ - \n Aggregate: groupBy=[[#age]], aggr=[[MAX(#first_name)]]\ + "Projection: #person.age Plus Int64(1), #MAX(person.first_name)\ + \n Aggregate: groupBy=[[#person.age]], aggr=[[MAX(#person.first_name)]]\ \n TableScan: person projection=None", ); } @@ -2491,8 +2484,8 @@ mod tests { fn select_aggregate_compounded_with_groupby_column() { quick_test( "SELECT age + MIN(salary) FROM person GROUP BY age", - "Projection: #age Plus #MIN(salary)\ - \n Aggregate: groupBy=[[#age]], aggr=[[MIN(#salary)]]\ + "Projection: #person.age Plus #MIN(person.salary)\ + \n Aggregate: groupBy=[[#person.age]], aggr=[[MIN(#person.salary)]]\ \n TableScan: person projection=None", ); } @@ -2501,8 +2494,8 @@ mod tests { fn select_aggregate_with_non_column_inner_expression_with_groupby() { quick_test( "SELECT state, MIN(age + 1) FROM person GROUP BY state", - "Projection: #state, #MIN(age Plus Int64(1))\ - \n Aggregate: groupBy=[[#state]], aggr=[[MIN(#age Plus Int64(1))]]\ + "Projection: #person.state, #MIN(person.age Plus Int64(1))\ + \n Aggregate: groupBy=[[#person.state]], aggr=[[MIN(#person.age Plus Int64(1))]]\ \n TableScan: person projection=None", ); } @@ -2511,7 +2504,7 @@ mod tests { fn test_wildcard() { quick_test( "SELECT * from person", - "Projection: #id, #first_name, #last_name, #age, #state, #salary, #birth_date, #😀\ + "Projection: #person.id, #person.first_name, #person.last_name, #person.age, #person.state, #person.salary, #person.birth_date, #person.😀\ \n TableScan: person projection=None", ); } @@ -2528,8 +2521,8 @@ mod tests { #[test] fn select_count_column() { let sql = "SELECT COUNT(id) FROM person"; - let expected = "Projection: #COUNT(id)\ - \n Aggregate: groupBy=[[]], aggr=[[COUNT(#id)]]\ + let expected = "Projection: #COUNT(person.id)\ + \n Aggregate: groupBy=[[]], aggr=[[COUNT(#person.id)]]\ \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -2537,15 +2530,15 @@ mod tests { #[test] fn select_scalar_func() { let sql = "SELECT sqrt(age) FROM person"; - let expected = "Projection: sqrt(#age)\ + let expected = "Projection: sqrt(#person.age)\ \n TableScan: person projection=None"; quick_test(sql, expected); } #[test] fn select_aliased_scalar_func() { - let sql = "SELECT sqrt(age) AS square_people FROM person"; - let expected = "Projection: sqrt(#age) AS square_people\ + let sql = "SELECT sqrt(person.age) AS square_people FROM person"; + let expected = "Projection: sqrt(#person.age) AS square_people\ \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -2554,8 +2547,8 @@ mod tests { fn select_where_nullif_division() { let sql = "SELECT c3/(c4+c5) \ FROM aggregate_test_100 WHERE c3/nullif(c4+c5, 0) > 0.1"; - let expected = "Projection: #c3 Divide #c4 Plus #c5\ - \n Filter: #c3 Divide nullif(#c4 Plus #c5, Int64(0)) Gt Float64(0.1)\ + let expected = "Projection: #aggregate_test_100.c3 Divide #aggregate_test_100.c4 Plus #aggregate_test_100.c5\ + \n Filter: #aggregate_test_100.c3 Divide nullif(#aggregate_test_100.c4 Plus #aggregate_test_100.c5, Int64(0)) Gt Float64(0.1)\ \n TableScan: aggregate_test_100 projection=None"; quick_test(sql, expected); } @@ -2563,8 +2556,8 @@ mod tests { #[test] fn select_where_with_negative_operator() { let sql = "SELECT c3 FROM aggregate_test_100 WHERE c3 > -0.1 AND -c4 > 0"; - let expected = "Projection: #c3\ - \n Filter: #c3 Gt Float64(-0.1) And (- #c4) Gt Int64(0)\ + let expected = "Projection: #aggregate_test_100.c3\ + \n Filter: #aggregate_test_100.c3 Gt Float64(-0.1) And (- #aggregate_test_100.c4) Gt Int64(0)\ \n TableScan: aggregate_test_100 projection=None"; quick_test(sql, expected); } @@ -2572,8 +2565,8 @@ mod tests { #[test] fn select_where_with_positive_operator() { let sql = "SELECT c3 FROM aggregate_test_100 WHERE c3 > +0.1 AND +c4 > 0"; - let expected = "Projection: #c3\ - \n Filter: #c3 Gt Float64(0.1) And #c4 Gt Int64(0)\ + let expected = "Projection: #aggregate_test_100.c3\ + \n Filter: #aggregate_test_100.c3 Gt Float64(0.1) And #aggregate_test_100.c4 Gt Int64(0)\ \n TableScan: aggregate_test_100 projection=None"; quick_test(sql, expected); } @@ -2581,8 +2574,8 @@ mod tests { #[test] fn select_order_by() { let sql = "SELECT id FROM person ORDER BY id"; - let expected = "Sort: #id ASC NULLS FIRST\ - \n Projection: #id\ + let expected = "Sort: #person.id ASC NULLS FIRST\ + \n Projection: #person.id\ \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -2590,8 +2583,8 @@ mod tests { #[test] fn select_order_by_desc() { let sql = "SELECT id FROM person ORDER BY id DESC"; - let expected = "Sort: #id DESC NULLS FIRST\ - \n Projection: #id\ + let expected = "Sort: #person.id DESC NULLS FIRST\ + \n Projection: #person.id\ \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -2600,15 +2593,15 @@ mod tests { fn select_order_by_nulls_last() { quick_test( "SELECT id FROM person ORDER BY id DESC NULLS LAST", - "Sort: #id DESC NULLS LAST\ - \n Projection: #id\ + "Sort: #person.id DESC NULLS LAST\ + \n Projection: #person.id\ \n TableScan: person projection=None", ); quick_test( "SELECT id FROM person ORDER BY id NULLS LAST", - "Sort: #id ASC NULLS LAST\ - \n Projection: #id\ + "Sort: #person.id ASC NULLS LAST\ + \n Projection: #person.id\ \n TableScan: person projection=None", ); } @@ -2616,8 +2609,8 @@ mod tests { #[test] fn select_group_by() { let sql = "SELECT state FROM person GROUP BY state"; - let expected = "Projection: #state\ - \n Aggregate: groupBy=[[#state]], aggr=[[]]\ + let expected = "Projection: #person.state\ + \n Aggregate: groupBy=[[#person.state]], aggr=[[]]\ \n TableScan: person projection=None"; quick_test(sql, expected); @@ -2626,8 +2619,8 @@ mod tests { #[test] fn select_group_by_columns_not_in_select() { let sql = "SELECT MAX(age) FROM person GROUP BY state"; - let expected = "Projection: #MAX(age)\ - \n Aggregate: groupBy=[[#state]], aggr=[[MAX(#age)]]\ + let expected = "Projection: #MAX(person.age)\ + \n Aggregate: groupBy=[[#person.state]], aggr=[[MAX(#person.age)]]\ \n TableScan: person projection=None"; quick_test(sql, expected); @@ -2636,8 +2629,8 @@ mod tests { #[test] fn select_group_by_count_star() { let sql = "SELECT state, COUNT(*) FROM person GROUP BY state"; - let expected = "Projection: #state, #COUNT(UInt8(1))\ - \n Aggregate: groupBy=[[#state]], aggr=[[COUNT(UInt8(1))]]\ + let expected = "Projection: #person.state, #COUNT(UInt8(1))\ + \n Aggregate: groupBy=[[#person.state]], aggr=[[COUNT(UInt8(1))]]\ \n TableScan: person projection=None"; quick_test(sql, expected); @@ -2647,8 +2640,8 @@ mod tests { fn select_group_by_needs_projection() { let sql = "SELECT COUNT(state), state FROM person GROUP BY state"; let expected = "\ - Projection: #COUNT(state), #state\ - \n Aggregate: groupBy=[[#state]], aggr=[[COUNT(#state)]]\ + Projection: #COUNT(person.state), #person.state\ + \n Aggregate: groupBy=[[#person.state]], aggr=[[COUNT(#person.state)]]\ \n TableScan: person projection=None"; quick_test(sql, expected); @@ -2657,8 +2650,8 @@ mod tests { #[test] fn select_7480_1() { let sql = "SELECT c1, MIN(c12) FROM aggregate_test_100 GROUP BY c1, c13"; - let expected = "Projection: #c1, #MIN(c12)\ - \n Aggregate: groupBy=[[#c1, #c13]], aggr=[[MIN(#c12)]]\ + let expected = "Projection: #aggregate_test_100.c1, #MIN(aggregate_test_100.c12)\ + \n Aggregate: groupBy=[[#aggregate_test_100.c1, #aggregate_test_100.c13]], aggr=[[MIN(#aggregate_test_100.c12)]]\ \n TableScan: aggregate_test_100 projection=None"; quick_test(sql, expected); } @@ -2714,22 +2707,49 @@ mod tests { FROM person \ JOIN orders \ ON id = customer_id"; - let expected = "Projection: #id, #order_id\ - \n Join: id = customer_id\ + let expected = "Projection: #person.id, #orders.order_id\ + \n Join: #person.id = #orders.customer_id\ \n TableScan: person projection=None\ \n TableScan: orders projection=None"; quick_test(sql, expected); } + #[test] + fn join_with_table_name() { + let sql = "SELECT id, order_id \ + FROM person \ + JOIN orders \ + ON person.id = orders.customer_id"; + let expected = "Projection: #person.id, #orders.order_id\ + \n Join: #person.id = #orders.customer_id\ + \n TableScan: person projection=None\ + \n TableScan: orders projection=None"; + quick_test(sql, expected); + } + + #[test] + fn join_with_using() { + let sql = "SELECT person.first_name, id \ + FROM person \ + JOIN person as person2 \ + USING (id)"; + let expected = "Projection: #person.first_name, #person.id\ + \n Join: #person.id = #person2.id\ + \n TableScan: person projection=None\ + \n TableScan: person2 projection=None"; + quick_test(sql, expected); + } + #[test] fn equijoin_explicit_syntax_3_tables() { let sql = "SELECT id, order_id, l_description \ FROM person \ JOIN orders ON id = customer_id \ JOIN lineitem ON o_item_id = l_item_id"; - let expected = "Projection: #id, #order_id, #l_description\ - \n Join: o_item_id = l_item_id\ - \n Join: id = customer_id\ + let expected = + "Projection: #person.id, #orders.order_id, #lineitem.l_description\ + \n Join: #orders.o_item_id = #lineitem.l_item_id\ + \n Join: #person.id = #orders.customer_id\ \n TableScan: person projection=None\ \n TableScan: orders projection=None\ \n TableScan: lineitem projection=None"; @@ -2741,8 +2761,8 @@ mod tests { let sql = "SELECT order_id \ FROM orders \ WHERE delivered = false OR delivered = true"; - let expected = "Projection: #order_id\ - \n Filter: #delivered Eq Boolean(false) Or #delivered Eq Boolean(true)\ + let expected = "Projection: #orders.order_id\ + \n Filter: #orders.delivered Eq Boolean(false) Or #orders.delivered Eq Boolean(true)\ \n TableScan: orders projection=None"; quick_test(sql, expected); } @@ -2751,9 +2771,9 @@ mod tests { fn union() { let sql = "SELECT order_id from orders UNION ALL SELECT order_id FROM orders"; let expected = "Union\ - \n Projection: #order_id\ + \n Projection: #orders.order_id\ \n TableScan: orders projection=None\ - \n Projection: #order_id\ + \n Projection: #orders.order_id\ \n TableScan: orders projection=None"; quick_test(sql, expected); } @@ -2765,13 +2785,13 @@ mod tests { UNION ALL SELECT order_id FROM orders UNION ALL SELECT order_id FROM orders"; let expected = "Union\ - \n Projection: #order_id\ + \n Projection: #orders.order_id\ \n TableScan: orders projection=None\ - \n Projection: #order_id\ + \n Projection: #orders.order_id\ \n TableScan: orders projection=None\ - \n Projection: #order_id\ + \n Projection: #orders.order_id\ \n TableScan: orders projection=None\ - \n Projection: #order_id\ + \n Projection: #orders.order_id\ \n TableScan: orders projection=None"; quick_test(sql, expected); } @@ -2790,8 +2810,8 @@ mod tests { fn empty_over() { let sql = "SELECT order_id, MAX(order_id) OVER () from orders"; let expected = "\ - Projection: #order_id, #MAX(order_id)\ - \n WindowAggr: windowExpr=[[MAX(#order_id)]]\ + Projection: #orders.order_id, #MAX(orders.order_id)\ + \n WindowAggr: windowExpr=[[MAX(#orders.order_id)]]\ \n TableScan: orders projection=None"; quick_test(sql, expected); } @@ -2800,8 +2820,8 @@ mod tests { fn empty_over_with_alias() { let sql = "SELECT order_id oid, MAX(order_id) OVER () max_oid from orders"; let expected = "\ - Projection: #order_id AS oid, #MAX(order_id) AS max_oid\ - \n WindowAggr: windowExpr=[[MAX(#order_id)]]\ + Projection: #orders.order_id AS oid, #MAX(orders.order_id) AS max_oid\ + \n WindowAggr: windowExpr=[[MAX(#orders.order_id)]]\ \n TableScan: orders projection=None"; quick_test(sql, expected); } @@ -2810,8 +2830,8 @@ mod tests { fn empty_over_plus() { let sql = "SELECT order_id, MAX(qty * 1.1) OVER () from orders"; let expected = "\ - Projection: #order_id, #MAX(qty Multiply Float64(1.1))\ - \n WindowAggr: windowExpr=[[MAX(#qty Multiply Float64(1.1))]]\ + Projection: #orders.order_id, #MAX(orders.qty Multiply Float64(1.1))\ + \n WindowAggr: windowExpr=[[MAX(#orders.qty Multiply Float64(1.1))]]\ \n TableScan: orders projection=None"; quick_test(sql, expected); } @@ -2821,8 +2841,8 @@ mod tests { let sql = "SELECT order_id, MAX(qty) OVER (), min(qty) over (), aVg(qty) OVER () from orders"; let expected = "\ - Projection: #order_id, #MAX(qty), #MIN(qty), #AVG(qty)\ - \n WindowAggr: windowExpr=[[MAX(#qty), MIN(#qty), AVG(#qty)]]\ + Projection: #orders.order_id, #MAX(orders.qty), #MIN(orders.qty), #AVG(orders.qty)\ + \n WindowAggr: windowExpr=[[MAX(#orders.qty), MIN(#orders.qty), AVG(#orders.qty)]]\ \n TableScan: orders projection=None"; quick_test(sql, expected); } @@ -2840,9 +2860,9 @@ mod tests { fn over_partition_by() { let sql = "SELECT order_id, MAX(qty) OVER (PARTITION BY order_id) from orders"; let expected = "\ - Projection: #order_id, #MAX(qty)\ - \n WindowAggr: windowExpr=[[MAX(#qty)]]\ - \n Sort: #order_id ASC NULLS FIRST\ + Projection: #orders.order_id, #MAX(orders.qty)\ + \n WindowAggr: windowExpr=[[MAX(#orders.qty)]]\ + \n Sort: #orders.order_id ASC NULLS FIRST\ \n TableScan: orders projection=None"; quick_test(sql, expected); } @@ -2863,11 +2883,11 @@ mod tests { fn over_order_by() { let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY order_id), MIN(qty) OVER (ORDER BY order_id DESC) from orders"; let expected = "\ - Projection: #order_id, #MAX(qty), #MIN(qty)\ - \n WindowAggr: windowExpr=[[MAX(#qty)]]\ - \n Sort: #order_id ASC NULLS FIRST\ - \n WindowAggr: windowExpr=[[MIN(#qty)]]\ - \n Sort: #order_id DESC NULLS FIRST\ + Projection: #orders.order_id, #MAX(orders.qty), #MIN(orders.qty)\ + \n WindowAggr: windowExpr=[[MAX(#orders.qty)]]\ + \n Sort: #orders.order_id ASC NULLS FIRST\ + \n WindowAggr: windowExpr=[[MIN(#orders.qty)]]\ + \n Sort: #orders.order_id DESC NULLS FIRST\ \n TableScan: orders projection=None"; quick_test(sql, expected); } @@ -2876,11 +2896,11 @@ mod tests { fn over_order_by_with_window_frame_double_end() { let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY order_id ROWS BETWEEN 3 PRECEDING and 3 FOLLOWING), MIN(qty) OVER (ORDER BY order_id DESC) from orders"; let expected = "\ - Projection: #order_id, #MAX(qty) ROWS BETWEEN 3 PRECEDING AND 3 FOLLOWING, #MIN(qty)\ - \n WindowAggr: windowExpr=[[MAX(#qty) ROWS BETWEEN 3 PRECEDING AND 3 FOLLOWING]]\ - \n Sort: #order_id ASC NULLS FIRST\ - \n WindowAggr: windowExpr=[[MIN(#qty)]]\ - \n Sort: #order_id DESC NULLS FIRST\ + Projection: #orders.order_id, #MAX(orders.qty) ROWS BETWEEN 3 PRECEDING AND 3 FOLLOWING, #MIN(orders.qty)\ + \n WindowAggr: windowExpr=[[MAX(#orders.qty) ROWS BETWEEN 3 PRECEDING AND 3 FOLLOWING]]\ + \n Sort: #orders.order_id ASC NULLS FIRST\ + \n WindowAggr: windowExpr=[[MIN(#orders.qty)]]\ + \n Sort: #orders.order_id DESC NULLS FIRST\ \n TableScan: orders projection=None"; quick_test(sql, expected); } @@ -2889,11 +2909,11 @@ mod tests { fn over_order_by_with_window_frame_single_end() { let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY order_id ROWS 3 PRECEDING), MIN(qty) OVER (ORDER BY order_id DESC) from orders"; let expected = "\ - Projection: #order_id, #MAX(qty) ROWS BETWEEN 3 PRECEDING AND CURRENT ROW, #MIN(qty)\ - \n WindowAggr: windowExpr=[[MAX(#qty) ROWS BETWEEN 3 PRECEDING AND CURRENT ROW]]\ - \n Sort: #order_id ASC NULLS FIRST\ - \n WindowAggr: windowExpr=[[MIN(#qty)]]\ - \n Sort: #order_id DESC NULLS FIRST\ + Projection: #orders.order_id, #MAX(orders.qty) ROWS BETWEEN 3 PRECEDING AND CURRENT ROW, #MIN(orders.qty)\ + \n WindowAggr: windowExpr=[[MAX(#orders.qty) ROWS BETWEEN 3 PRECEDING AND CURRENT ROW]]\ + \n Sort: #orders.order_id ASC NULLS FIRST\ + \n WindowAggr: windowExpr=[[MIN(#orders.qty)]]\ + \n Sort: #orders.order_id DESC NULLS FIRST\ \n TableScan: orders projection=None"; quick_test(sql, expected); } @@ -2934,11 +2954,11 @@ mod tests { fn over_order_by_with_window_frame_single_end_groups() { let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY order_id GROUPS 3 PRECEDING), MIN(qty) OVER (ORDER BY order_id DESC) from orders"; let expected = "\ - Projection: #order_id, #MAX(qty) GROUPS BETWEEN 3 PRECEDING AND CURRENT ROW, #MIN(qty)\ - \n WindowAggr: windowExpr=[[MAX(#qty) GROUPS BETWEEN 3 PRECEDING AND CURRENT ROW]]\ - \n Sort: #order_id ASC NULLS FIRST\ - \n WindowAggr: windowExpr=[[MIN(#qty)]]\ - \n Sort: #order_id DESC NULLS FIRST\ + Projection: #orders.order_id, #MAX(orders.qty) GROUPS BETWEEN 3 PRECEDING AND CURRENT ROW, #MIN(orders.qty)\ + \n WindowAggr: windowExpr=[[MAX(#orders.qty) GROUPS BETWEEN 3 PRECEDING AND CURRENT ROW]]\ + \n Sort: #orders.order_id ASC NULLS FIRST\ + \n WindowAggr: windowExpr=[[MIN(#orders.qty)]]\ + \n Sort: #orders.order_id DESC NULLS FIRST\ \n TableScan: orders projection=None"; quick_test(sql, expected); } @@ -2959,11 +2979,11 @@ mod tests { fn over_order_by_two_sort_keys() { let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY order_id), MIN(qty) OVER (ORDER BY (order_id + 1)) from orders"; let expected = "\ - Projection: #order_id, #MAX(qty), #MIN(qty)\ - \n WindowAggr: windowExpr=[[MAX(#qty)]]\ - \n Sort: #order_id ASC NULLS FIRST\ - \n WindowAggr: windowExpr=[[MIN(#qty)]]\ - \n Sort: #order_id Plus Int64(1) ASC NULLS FIRST\ + Projection: #orders.order_id, #MAX(orders.qty), #MIN(orders.qty)\ + \n WindowAggr: windowExpr=[[MAX(#orders.qty)]]\ + \n Sort: #orders.order_id ASC NULLS FIRST\ + \n WindowAggr: windowExpr=[[MIN(#orders.qty)]]\ + \n Sort: #orders.order_id Plus Int64(1) ASC NULLS FIRST\ \n TableScan: orders projection=None"; quick_test(sql, expected); } @@ -2985,12 +3005,12 @@ mod tests { fn over_order_by_sort_keys_sorting() { let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY qty, order_id), SUM(qty) OVER (), MIN(qty) OVER (ORDER BY order_id, qty) from orders"; let expected = "\ - Projection: #order_id, #MAX(qty), #SUM(qty), #MIN(qty)\ - \n WindowAggr: windowExpr=[[SUM(#qty)]]\ - \n WindowAggr: windowExpr=[[MAX(#qty)]]\ - \n Sort: #qty ASC NULLS FIRST, #order_id ASC NULLS FIRST\ - \n WindowAggr: windowExpr=[[MIN(#qty)]]\ - \n Sort: #order_id ASC NULLS FIRST, #qty ASC NULLS FIRST\ + Projection: #orders.order_id, #MAX(orders.qty), #SUM(orders.qty), #MIN(orders.qty)\ + \n WindowAggr: windowExpr=[[SUM(#orders.qty)]]\ + \n WindowAggr: windowExpr=[[MAX(#orders.qty)]]\ + \n Sort: #orders.qty ASC NULLS FIRST, #orders.order_id ASC NULLS FIRST\ + \n WindowAggr: windowExpr=[[MIN(#orders.qty)]]\ + \n Sort: #orders.order_id ASC NULLS FIRST, #orders.qty ASC NULLS FIRST\ \n TableScan: orders projection=None"; quick_test(sql, expected); } @@ -3012,12 +3032,12 @@ mod tests { fn over_order_by_sort_keys_sorting_prefix_compacting() { let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY order_id), SUM(qty) OVER (), MIN(qty) OVER (ORDER BY order_id, qty) from orders"; let expected = "\ - Projection: #order_id, #MAX(qty), #SUM(qty), #MIN(qty)\ - \n WindowAggr: windowExpr=[[SUM(#qty)]]\ - \n WindowAggr: windowExpr=[[MAX(#qty)]]\ - \n Sort: #order_id ASC NULLS FIRST\ - \n WindowAggr: windowExpr=[[MIN(#qty)]]\ - \n Sort: #order_id ASC NULLS FIRST, #qty ASC NULLS FIRST\ + Projection: #orders.order_id, #MAX(orders.qty), #SUM(orders.qty), #MIN(orders.qty)\ + \n WindowAggr: windowExpr=[[SUM(#orders.qty)]]\ + \n WindowAggr: windowExpr=[[MAX(#orders.qty)]]\ + \n Sort: #orders.order_id ASC NULLS FIRST\ + \n WindowAggr: windowExpr=[[MIN(#orders.qty)]]\ + \n Sort: #orders.order_id ASC NULLS FIRST, #orders.qty ASC NULLS FIRST\ \n TableScan: orders projection=None"; quick_test(sql, expected); } @@ -3042,13 +3062,13 @@ mod tests { fn over_order_by_sort_keys_sorting_global_order_compacting() { let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY qty, order_id), SUM(qty) OVER (), MIN(qty) OVER (ORDER BY order_id, qty) from orders ORDER BY order_id"; let expected = "\ - Sort: #order_id ASC NULLS FIRST\ - \n Projection: #order_id, #MAX(qty), #SUM(qty), #MIN(qty)\ - \n WindowAggr: windowExpr=[[SUM(#qty)]]\ - \n WindowAggr: windowExpr=[[MAX(#qty)]]\ - \n Sort: #qty ASC NULLS FIRST, #order_id ASC NULLS FIRST\ - \n WindowAggr: windowExpr=[[MIN(#qty)]]\ - \n Sort: #order_id ASC NULLS FIRST, #qty ASC NULLS FIRST\ + Sort: #orders.order_id ASC NULLS FIRST\ + \n Projection: #orders.order_id, #MAX(orders.qty), #SUM(orders.qty), #MIN(orders.qty)\ + \n WindowAggr: windowExpr=[[SUM(#orders.qty)]]\ + \n WindowAggr: windowExpr=[[MAX(#orders.qty)]]\ + \n Sort: #orders.qty ASC NULLS FIRST, #orders.order_id ASC NULLS FIRST\ + \n WindowAggr: windowExpr=[[MIN(#orders.qty)]]\ + \n Sort: #orders.order_id ASC NULLS FIRST, #orders.qty ASC NULLS FIRST\ \n TableScan: orders projection=None"; quick_test(sql, expected); } @@ -3067,9 +3087,9 @@ mod tests { let sql = "SELECT order_id, MAX(qty) OVER (PARTITION BY order_id ORDER BY qty) from orders"; let expected = "\ - Projection: #order_id, #MAX(qty)\ - \n WindowAggr: windowExpr=[[MAX(#qty)]]\ - \n Sort: #order_id ASC NULLS FIRST, #qty ASC NULLS FIRST\ + Projection: #orders.order_id, #MAX(orders.qty)\ + \n WindowAggr: windowExpr=[[MAX(#orders.qty)]]\ + \n Sort: #orders.order_id ASC NULLS FIRST, #orders.qty ASC NULLS FIRST\ \n TableScan: orders projection=None"; quick_test(sql, expected); } @@ -3088,9 +3108,9 @@ mod tests { let sql = "SELECT order_id, MAX(qty) OVER (PARTITION BY order_id, qty ORDER BY qty) from orders"; let expected = "\ - Projection: #order_id, #MAX(qty)\ - \n WindowAggr: windowExpr=[[MAX(#qty)]]\ - \n Sort: #order_id ASC NULLS FIRST, #qty ASC NULLS FIRST\ + Projection: #orders.order_id, #MAX(orders.qty)\ + \n WindowAggr: windowExpr=[[MAX(#orders.qty)]]\ + \n Sort: #orders.order_id ASC NULLS FIRST, #orders.qty ASC NULLS FIRST\ \n TableScan: orders projection=None"; quick_test(sql, expected); } @@ -3112,11 +3132,11 @@ mod tests { let sql = "SELECT order_id, MAX(qty) OVER (PARTITION BY order_id, qty ORDER BY qty), MIN(qty) OVER (PARTITION BY qty ORDER BY order_id) from orders"; let expected = "\ - Projection: #order_id, #MAX(qty), #MIN(qty)\ - \n WindowAggr: windowExpr=[[MAX(#qty)]]\ - \n Sort: #order_id ASC NULLS FIRST, #qty ASC NULLS FIRST\ - \n WindowAggr: windowExpr=[[MIN(#qty)]]\ - \n Sort: #qty ASC NULLS FIRST, #order_id ASC NULLS FIRST\ + Projection: #orders.order_id, #MAX(orders.qty), #MIN(orders.qty)\ + \n WindowAggr: windowExpr=[[MAX(#orders.qty)]]\ + \n Sort: #orders.order_id ASC NULLS FIRST, #orders.qty ASC NULLS FIRST\ + \n WindowAggr: windowExpr=[[MIN(#orders.qty)]]\ + \n Sort: #orders.qty ASC NULLS FIRST, #orders.order_id ASC NULLS FIRST\ \n TableScan: orders projection=None"; quick_test(sql, expected); } @@ -3137,11 +3157,11 @@ mod tests { let sql = "SELECT order_id, MAX(qty) OVER (PARTITION BY order_id ORDER BY qty), MIN(qty) OVER (PARTITION BY order_id, qty ORDER BY price) from orders"; let expected = "\ - Projection: #order_id, #MAX(qty), #MIN(qty)\ - \n WindowAggr: windowExpr=[[MAX(#qty)]]\ - \n Sort: #order_id ASC NULLS FIRST, #qty ASC NULLS FIRST\ - \n WindowAggr: windowExpr=[[MIN(#qty)]]\ - \n Sort: #order_id ASC NULLS FIRST, #qty ASC NULLS FIRST, #price ASC NULLS FIRST\ + Projection: #orders.order_id, #MAX(orders.qty), #MIN(orders.qty)\ + \n WindowAggr: windowExpr=[[MAX(#orders.qty)]]\ + \n Sort: #orders.order_id ASC NULLS FIRST, #orders.qty ASC NULLS FIRST\ + \n WindowAggr: windowExpr=[[MIN(#orders.qty)]]\ + \n Sort: #orders.order_id ASC NULLS FIRST, #orders.qty ASC NULLS FIRST, #orders.price ASC NULLS FIRST\ \n TableScan: orders projection=None"; quick_test(sql, expected); } @@ -3167,7 +3187,7 @@ mod tests { #[test] fn select_multibyte_column() { let sql = r#"SELECT "😀" FROM person"#; - let expected = "Projection: #😀\ + let expected = "Projection: #person.😀\ \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -3182,7 +3202,7 @@ mod tests { /// Create logical plan, write with formatter, compare to expected output fn quick_test(sql: &str, expected: &str) { let plan = logical_plan(sql).unwrap(); - assert_eq!(expected, format!("{:?}", plan)); + assert_eq!(format!("{:?}", plan), expected); } struct MockContextProvider {} @@ -3218,6 +3238,7 @@ mod tests { "lineitem" => Some(Schema::new(vec![ Field::new("l_item_id", DataType::UInt32, false), Field::new("l_description", DataType::Utf8, false), + Field::new("price", DataType::Float64, false), ])), "aggregate_test_100" => Some(Schema::new(vec![ Field::new("c1", DataType::Utf8, false), diff --git a/datafusion/src/sql/utils.rs b/datafusion/src/sql/utils.rs index 82431c2314ab..7702748df44f 100644 --- a/datafusion/src/sql/utils.rs +++ b/datafusion/src/sql/utils.rs @@ -21,7 +21,7 @@ use crate::logical_plan::{DFSchema, Expr, LogicalPlan}; use crate::scalar::ScalarValue; use crate::{ error::{DataFusionError, Result}, - logical_plan::{ExpressionVisitor, Recursion}, + logical_plan::{Column, ExpressionVisitor, Recursion}, }; use std::collections::HashMap; @@ -31,7 +31,7 @@ pub(crate) fn expand_wildcard(expr: &Expr, schema: &DFSchema) -> Vec { Expr::Wildcard => schema .fields() .iter() - .map(|f| Expr::Column(f.name().to_string())) + .map(|f| Expr::Column(f.qualified_column())) .collect::>(), _ => vec![expr.clone()], } @@ -146,7 +146,7 @@ where pub(crate) fn expr_as_column_expr(expr: &Expr, plan: &LogicalPlan) -> Result { match expr { Expr::Column(_) => Ok(expr.clone()), - _ => Ok(Expr::Column(expr.name(plan.schema())?)), + _ => Ok(Expr::Column(Column::from_name(expr.name(plan.schema())?))), } } @@ -376,7 +376,7 @@ where asc: *asc, nulls_first: *nulls_first, }), - Expr::Column(_) | Expr::Literal(_) | Expr::ScalarVariable(_) => { + Expr::Column { .. } | Expr::Literal(_) | Expr::ScalarVariable(_) => { Ok(expr.clone()) } Expr::Wildcard => Ok(Expr::Wildcard), @@ -426,8 +426,8 @@ pub(crate) fn resolve_aliases_to_exprs( aliases: &HashMap, ) -> Result { clone_with_replacement(expr, &|nested_expr| match nested_expr { - Expr::Column(name) => { - if let Some(aliased_expr) = aliases.get(name) { + Expr::Column(c) if c.relation.is_none() => { + if let Some(aliased_expr) = aliases.get(&c.name) { Ok(Some(aliased_expr.clone())) } else { Ok(None) diff --git a/datafusion/src/test/mod.rs b/datafusion/src/test/mod.rs index 51dfe7f3a099..7ca7cc12d9ef 100644 --- a/datafusion/src/test/mod.rs +++ b/datafusion/src/test/mod.rs @@ -117,7 +117,7 @@ pub fn test_table_scan() -> Result { Field::new("b", DataType::UInt32, false), Field::new("c", DataType::UInt32, false), ]); - LogicalPlanBuilder::scan_empty("test", &schema, None)?.build() + LogicalPlanBuilder::scan_empty(Some("test"), &schema, None)?.build() } pub fn assert_fields_eq(plan: &LogicalPlan, expected: Vec<&str>) { diff --git a/datafusion/tests/custom_sources.rs b/datafusion/tests/custom_sources.rs index b39f47bba07b..75fbe8e8eede 100644 --- a/datafusion/tests/custom_sources.rs +++ b/datafusion/tests/custom_sources.rs @@ -30,7 +30,9 @@ use datafusion::{ }; use datafusion::execution::context::ExecutionContext; -use datafusion::logical_plan::{col, Expr, LogicalPlan, LogicalPlanBuilder}; +use datafusion::logical_plan::{ + col, Expr, LogicalPlan, LogicalPlanBuilder, UNNAMED_TABLE, +}; use datafusion::physical_plan::{ ExecutionPlan, Partitioning, RecordBatchStream, SendableRecordBatchStream, }; @@ -196,8 +198,11 @@ async fn custom_source_dataframe() -> Result<()> { _ => panic!("expect optimized_plan to be projection"), } - let expected = "Projection: #c2\ - \n TableScan: projection=Some([1])"; + let expected = format!( + "Projection: #{}.c2\ + \n TableScan: {} projection=Some([1])", + UNNAMED_TABLE, UNNAMED_TABLE + ); assert_eq!(format!("{:?}", optimized_plan), expected); let physical_plan = ctx.create_physical_plan(&optimized_plan)?; diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs index cfdb6f4bc9e4..c06a4bb1462e 100644 --- a/datafusion/tests/sql.rs +++ b/datafusion/tests/sql.rs @@ -1957,12 +1957,12 @@ async fn csv_explain() { register_aggregate_csv_by_sql(&mut ctx).await; let sql = "EXPLAIN SELECT c1 FROM aggregate_test_100 where c2 > 10"; let actual = execute(&mut ctx, sql).await; - let expected = vec![ - vec![ - "logical_plan", - "Projection: #c1\n Filter: #c2 Gt Int64(10)\n TableScan: aggregate_test_100 projection=None" - ] - ]; + let expected = vec![vec![ + "logical_plan", + "Projection: #aggregate_test_100.c1\ + \n Filter: #aggregate_test_100.c2 Gt Int64(10)\ + \n TableScan: aggregate_test_100 projection=None", + ]]; assert_eq!(expected, actual); // Also, expect same result with lowercase explain @@ -1990,8 +1990,8 @@ async fn csv_explain_plans() { // Verify schema let expected = vec![ "Explain [plan_type:Utf8, plan:Utf8]", - " Projection: #c1 [c1:Utf8]", - " Filter: #c2 Gt Int64(10) [c1:Utf8, c2:Int32, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:Int64, c10:Utf8, c11:Float32, c12:Float64, c13:Utf8]", + " Projection: #aggregate_test_100.c1 [c1:Utf8]", + " Filter: #aggregate_test_100.c2 Gt Int64(10) [c1:Utf8, c2:Int32, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:Int64, c10:Utf8, c11:Float32, c12:Float64, c13:Utf8]", " TableScan: aggregate_test_100 projection=None [c1:Utf8, c2:Int32, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:Int64, c10:Utf8, c11:Float32, c12:Float64, c13:Utf8]", ]; let formatted = plan.display_indent_schema().to_string(); @@ -2005,8 +2005,8 @@ async fn csv_explain_plans() { // Verify the text format of the plan let expected = vec![ "Explain", - " Projection: #c1", - " Filter: #c2 Gt Int64(10)", + " Projection: #aggregate_test_100.c1", + " Filter: #aggregate_test_100.c2 Gt Int64(10)", " TableScan: aggregate_test_100 projection=None", ]; let formatted = plan.display_indent().to_string(); @@ -2025,9 +2025,9 @@ async fn csv_explain_plans() { " {", " graph[label=\"LogicalPlan\"]", " 2[shape=box label=\"Explain\"]", - " 3[shape=box label=\"Projection: #c1\"]", + " 3[shape=box label=\"Projection: #aggregate_test_100.c1\"]", " 2 -> 3 [arrowhead=none, arrowtail=normal, dir=back]", - " 4[shape=box label=\"Filter: #c2 Gt Int64(10)\"]", + " 4[shape=box label=\"Filter: #aggregate_test_100.c2 Gt Int64(10)\"]", " 3 -> 4 [arrowhead=none, arrowtail=normal, dir=back]", " 5[shape=box label=\"TableScan: aggregate_test_100 projection=None\"]", " 4 -> 5 [arrowhead=none, arrowtail=normal, dir=back]", @@ -2036,9 +2036,9 @@ async fn csv_explain_plans() { " {", " graph[label=\"Detailed LogicalPlan\"]", " 7[shape=box label=\"Explain\\nSchema: [plan_type:Utf8, plan:Utf8]\"]", - " 8[shape=box label=\"Projection: #c1\\nSchema: [c1:Utf8]\"]", + " 8[shape=box label=\"Projection: #aggregate_test_100.c1\\nSchema: [c1:Utf8]\"]", " 7 -> 8 [arrowhead=none, arrowtail=normal, dir=back]", - " 9[shape=box label=\"Filter: #c2 Gt Int64(10)\\nSchema: [c1:Utf8, c2:Int32, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:Int64, c10:Utf8, c11:Float32, c12:Float64, c13:Utf8]\"]", + " 9[shape=box label=\"Filter: #aggregate_test_100.c2 Gt Int64(10)\\nSchema: [c1:Utf8, c2:Int32, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:Int64, c10:Utf8, c11:Float32, c12:Float64, c13:Utf8]\"]", " 8 -> 9 [arrowhead=none, arrowtail=normal, dir=back]", " 10[shape=box label=\"TableScan: aggregate_test_100 projection=None\\nSchema: [c1:Utf8, c2:Int32, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:Int64, c10:Utf8, c11:Float32, c12:Float64, c13:Utf8]\"]", " 9 -> 10 [arrowhead=none, arrowtail=normal, dir=back]", @@ -2065,8 +2065,8 @@ async fn csv_explain_plans() { // Verify schema let expected = vec![ "Explain [plan_type:Utf8, plan:Utf8]", - " Projection: #c1 [c1:Utf8]", - " Filter: #c2 Gt Int64(10) [c1:Utf8, c2:Int32]", + " Projection: #aggregate_test_100.c1 [c1:Utf8]", + " Filter: #aggregate_test_100.c2 Gt Int64(10) [c1:Utf8, c2:Int32]", " TableScan: aggregate_test_100 projection=Some([0, 1]) [c1:Utf8, c2:Int32]", ]; let formatted = plan.display_indent_schema().to_string(); @@ -2080,8 +2080,8 @@ async fn csv_explain_plans() { // Verify the text format of the plan let expected = vec![ "Explain", - " Projection: #c1", - " Filter: #c2 Gt Int64(10)", + " Projection: #aggregate_test_100.c1", + " Filter: #aggregate_test_100.c2 Gt Int64(10)", " TableScan: aggregate_test_100 projection=Some([0, 1])", ]; let formatted = plan.display_indent().to_string(); @@ -2100,9 +2100,9 @@ async fn csv_explain_plans() { " {", " graph[label=\"LogicalPlan\"]", " 2[shape=box label=\"Explain\"]", - " 3[shape=box label=\"Projection: #c1\"]", + " 3[shape=box label=\"Projection: #aggregate_test_100.c1\"]", " 2 -> 3 [arrowhead=none, arrowtail=normal, dir=back]", - " 4[shape=box label=\"Filter: #c2 Gt Int64(10)\"]", + " 4[shape=box label=\"Filter: #aggregate_test_100.c2 Gt Int64(10)\"]", " 3 -> 4 [arrowhead=none, arrowtail=normal, dir=back]", " 5[shape=box label=\"TableScan: aggregate_test_100 projection=Some([0, 1])\"]", " 4 -> 5 [arrowhead=none, arrowtail=normal, dir=back]", @@ -2111,9 +2111,9 @@ async fn csv_explain_plans() { " {", " graph[label=\"Detailed LogicalPlan\"]", " 7[shape=box label=\"Explain\\nSchema: [plan_type:Utf8, plan:Utf8]\"]", - " 8[shape=box label=\"Projection: #c1\\nSchema: [c1:Utf8]\"]", + " 8[shape=box label=\"Projection: #aggregate_test_100.c1\\nSchema: [c1:Utf8]\"]", " 7 -> 8 [arrowhead=none, arrowtail=normal, dir=back]", - " 9[shape=box label=\"Filter: #c2 Gt Int64(10)\\nSchema: [c1:Utf8, c2:Int32]\"]", + " 9[shape=box label=\"Filter: #aggregate_test_100.c2 Gt Int64(10)\\nSchema: [c1:Utf8, c2:Int32]\"]", " 8 -> 9 [arrowhead=none, arrowtail=normal, dir=back]", " 10[shape=box label=\"TableScan: aggregate_test_100 projection=Some([0, 1])\\nSchema: [c1:Utf8, c2:Int32]\"]", " 9 -> 10 [arrowhead=none, arrowtail=normal, dir=back]", @@ -2142,9 +2142,13 @@ async fn csv_explain_plans() { let actual = actual.into_iter().map(|r| r.join("\t")).collect::(); // Since the plan contains path that are environmentally dependant (e.g. full path of the test file), only verify important content assert!(actual.contains("logical_plan"), "Actual: '{}'", actual); - assert!(actual.contains("Projection: #c1"), "Actual: '{}'", actual); assert!( - actual.contains("Filter: #c2 Gt Int64(10)"), + actual.contains("Projection: #aggregate_test_100.c1"), + "Actual: '{}'", + actual + ); + assert!( + actual.contains("Filter: #aggregate_test_100.c2 Gt Int64(10)"), "Actual: '{}'", actual ); @@ -2165,7 +2169,11 @@ async fn csv_explain_verbose() { // pain). Instead just check for a few key pieces. assert!(actual.contains("logical_plan"), "Actual: '{}'", actual); assert!(actual.contains("physical_plan"), "Actual: '{}'", actual); - assert!(actual.contains("#c2 Gt Int64(10)"), "Actual: '{}'", actual); + assert!( + actual.contains("#aggregate_test_100.c2 Gt Int64(10)"), + "Actual: '{}'", + actual + ); } #[tokio::test] @@ -2188,8 +2196,8 @@ async fn csv_explain_verbose_plans() { // Verify schema let expected = vec![ "Explain [plan_type:Utf8, plan:Utf8]", - " Projection: #c1 [c1:Utf8]", - " Filter: #c2 Gt Int64(10) [c1:Utf8, c2:Int32, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:Int64, c10:Utf8, c11:Float32, c12:Float64, c13:Utf8]", + " Projection: #aggregate_test_100.c1 [c1:Utf8]", + " Filter: #aggregate_test_100.c2 Gt Int64(10) [c1:Utf8, c2:Int32, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:Int64, c10:Utf8, c11:Float32, c12:Float64, c13:Utf8]", " TableScan: aggregate_test_100 projection=None [c1:Utf8, c2:Int32, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:Int64, c10:Utf8, c11:Float32, c12:Float64, c13:Utf8]", ]; let formatted = plan.display_indent_schema().to_string(); @@ -2203,8 +2211,8 @@ async fn csv_explain_verbose_plans() { // Verify the text format of the plan let expected = vec![ "Explain", - " Projection: #c1", - " Filter: #c2 Gt Int64(10)", + " Projection: #aggregate_test_100.c1", + " Filter: #aggregate_test_100.c2 Gt Int64(10)", " TableScan: aggregate_test_100 projection=None", ]; let formatted = plan.display_indent().to_string(); @@ -2223,9 +2231,9 @@ async fn csv_explain_verbose_plans() { " {", " graph[label=\"LogicalPlan\"]", " 2[shape=box label=\"Explain\"]", - " 3[shape=box label=\"Projection: #c1\"]", + " 3[shape=box label=\"Projection: #aggregate_test_100.c1\"]", " 2 -> 3 [arrowhead=none, arrowtail=normal, dir=back]", - " 4[shape=box label=\"Filter: #c2 Gt Int64(10)\"]", + " 4[shape=box label=\"Filter: #aggregate_test_100.c2 Gt Int64(10)\"]", " 3 -> 4 [arrowhead=none, arrowtail=normal, dir=back]", " 5[shape=box label=\"TableScan: aggregate_test_100 projection=None\"]", " 4 -> 5 [arrowhead=none, arrowtail=normal, dir=back]", @@ -2234,9 +2242,9 @@ async fn csv_explain_verbose_plans() { " {", " graph[label=\"Detailed LogicalPlan\"]", " 7[shape=box label=\"Explain\\nSchema: [plan_type:Utf8, plan:Utf8]\"]", - " 8[shape=box label=\"Projection: #c1\\nSchema: [c1:Utf8]\"]", + " 8[shape=box label=\"Projection: #aggregate_test_100.c1\\nSchema: [c1:Utf8]\"]", " 7 -> 8 [arrowhead=none, arrowtail=normal, dir=back]", - " 9[shape=box label=\"Filter: #c2 Gt Int64(10)\\nSchema: [c1:Utf8, c2:Int32, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:Int64, c10:Utf8, c11:Float32, c12:Float64, c13:Utf8]\"]", + " 9[shape=box label=\"Filter: #aggregate_test_100.c2 Gt Int64(10)\\nSchema: [c1:Utf8, c2:Int32, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:Int64, c10:Utf8, c11:Float32, c12:Float64, c13:Utf8]\"]", " 8 -> 9 [arrowhead=none, arrowtail=normal, dir=back]", " 10[shape=box label=\"TableScan: aggregate_test_100 projection=None\\nSchema: [c1:Utf8, c2:Int32, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:Int64, c10:Utf8, c11:Float32, c12:Float64, c13:Utf8]\"]", " 9 -> 10 [arrowhead=none, arrowtail=normal, dir=back]", @@ -2263,8 +2271,8 @@ async fn csv_explain_verbose_plans() { // Verify schema let expected = vec![ "Explain [plan_type:Utf8, plan:Utf8]", - " Projection: #c1 [c1:Utf8]", - " Filter: #c2 Gt Int64(10) [c1:Utf8, c2:Int32]", + " Projection: #aggregate_test_100.c1 [c1:Utf8]", + " Filter: #aggregate_test_100.c2 Gt Int64(10) [c1:Utf8, c2:Int32]", " TableScan: aggregate_test_100 projection=Some([0, 1]) [c1:Utf8, c2:Int32]", ]; let formatted = plan.display_indent_schema().to_string(); @@ -2278,8 +2286,8 @@ async fn csv_explain_verbose_plans() { // Verify the text format of the plan let expected = vec![ "Explain", - " Projection: #c1", - " Filter: #c2 Gt Int64(10)", + " Projection: #aggregate_test_100.c1", + " Filter: #aggregate_test_100.c2 Gt Int64(10)", " TableScan: aggregate_test_100 projection=Some([0, 1])", ]; let formatted = plan.display_indent().to_string(); @@ -2298,9 +2306,9 @@ async fn csv_explain_verbose_plans() { " {", " graph[label=\"LogicalPlan\"]", " 2[shape=box label=\"Explain\"]", - " 3[shape=box label=\"Projection: #c1\"]", + " 3[shape=box label=\"Projection: #aggregate_test_100.c1\"]", " 2 -> 3 [arrowhead=none, arrowtail=normal, dir=back]", - " 4[shape=box label=\"Filter: #c2 Gt Int64(10)\"]", + " 4[shape=box label=\"Filter: #aggregate_test_100.c2 Gt Int64(10)\"]", " 3 -> 4 [arrowhead=none, arrowtail=normal, dir=back]", " 5[shape=box label=\"TableScan: aggregate_test_100 projection=Some([0, 1])\"]", " 4 -> 5 [arrowhead=none, arrowtail=normal, dir=back]", @@ -2309,9 +2317,9 @@ async fn csv_explain_verbose_plans() { " {", " graph[label=\"Detailed LogicalPlan\"]", " 7[shape=box label=\"Explain\\nSchema: [plan_type:Utf8, plan:Utf8]\"]", - " 8[shape=box label=\"Projection: #c1\\nSchema: [c1:Utf8]\"]", + " 8[shape=box label=\"Projection: #aggregate_test_100.c1\\nSchema: [c1:Utf8]\"]", " 7 -> 8 [arrowhead=none, arrowtail=normal, dir=back]", - " 9[shape=box label=\"Filter: #c2 Gt Int64(10)\\nSchema: [c1:Utf8, c2:Int32]\"]", + " 9[shape=box label=\"Filter: #aggregate_test_100.c2 Gt Int64(10)\\nSchema: [c1:Utf8, c2:Int32]\"]", " 8 -> 9 [arrowhead=none, arrowtail=normal, dir=back]", " 10[shape=box label=\"TableScan: aggregate_test_100 projection=Some([0, 1])\\nSchema: [c1:Utf8, c2:Int32]\"]", " 9 -> 10 [arrowhead=none, arrowtail=normal, dir=back]", @@ -2346,12 +2354,12 @@ async fn csv_explain_verbose_plans() { ); assert!(actual.contains("physical_plan"), "Actual: '{}'", actual); assert!( - actual.contains("FilterExec: CAST(c2 AS Int64) > 10"), + actual.contains("FilterExec: CAST(c2@1 AS Int64) > 10"), "Actual: '{}'", actual ); assert!( - actual.contains("ProjectionExec: expr=[c1]"), + actual.contains("ProjectionExec: expr=[c1@0 as c1]"), "Actual: '{}'", actual ); @@ -3793,15 +3801,15 @@ async fn test_physical_plan_display_indent() { let physical_plan = ctx.create_physical_plan(&plan).unwrap(); let expected = vec![ "GlobalLimitExec: limit=10", - " SortExec: [the_min DESC]", + " SortExec: [the_min@2 DESC]", " MergeExec", - " ProjectionExec: expr=[c1, MAX(c12), MIN(c12) as the_min]", - " HashAggregateExec: mode=FinalPartitioned, gby=[c1], aggr=[MAX(c12), MIN(c12)]", + " ProjectionExec: expr=[c1@0 as c1, MAX(aggregate_test_100.c12)@1 as MAX(c12), MIN(aggregate_test_100.c12)@2 as the_min]", + " HashAggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[MAX(c12), MIN(c12)]", " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([Column { name: \"c1\" }], 3)", - " HashAggregateExec: mode=Partial, gby=[c1], aggr=[MAX(c12), MIN(c12)]", + " RepartitionExec: partitioning=Hash([Column { name: \"c1\", index: 0 }], 3)", + " HashAggregateExec: mode=Partial, gby=[c1@0 as c1], aggr=[MAX(c12), MIN(c12)]", " CoalesceBatchesExec: target_batch_size=4096", - " FilterExec: c12 < CAST(10 AS Float64)", + " FilterExec: c12@1 < CAST(10 AS Float64)", " RepartitionExec: partitioning=RoundRobinBatch(3)", " CsvExec: source=Path(ARROW_TEST_DATA/csv/aggregate_test_100.csv: [ARROW_TEST_DATA/csv/aggregate_test_100.csv]), has_header=true", ]; @@ -3840,17 +3848,17 @@ async fn test_physical_plan_display_indent_multi_children() { let physical_plan = ctx.create_physical_plan(&plan).unwrap(); let expected = vec![ - "ProjectionExec: expr=[c1]", + "ProjectionExec: expr=[c1@0 as c1]", " CoalesceBatchesExec: target_batch_size=4096", - " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(\"c1\", \"c2\")]", + " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: \"c1\", index: 0 }, Column { name: \"c2\", index: 0 })]", " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([Column { name: \"c1\" }], 3)", - " ProjectionExec: expr=[c1]", + " RepartitionExec: partitioning=Hash([Column { name: \"c1\", index: 0 }], 3)", + " ProjectionExec: expr=[c1@0 as c1]", " RepartitionExec: partitioning=RoundRobinBatch(3)", " CsvExec: source=Path(ARROW_TEST_DATA/csv/aggregate_test_100.csv: [ARROW_TEST_DATA/csv/aggregate_test_100.csv]), has_header=true", " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([Column { name: \"c2\" }], 3)", - " ProjectionExec: expr=[c1 as c2]", + " RepartitionExec: partitioning=Hash([Column { name: \"c2\", index: 0 }], 3)", + " ProjectionExec: expr=[c1@0 as c2]", " RepartitionExec: partitioning=RoundRobinBatch(3)", " CsvExec: source=Path(ARROW_TEST_DATA/csv/aggregate_test_100.csv: [ARROW_TEST_DATA/csv/aggregate_test_100.csv]), has_header=true", ]; diff --git a/datafusion/tests/user_defined_plan.rs b/datafusion/tests/user_defined_plan.rs index 8914c05e8f88..22ebec8b9a99 100644 --- a/datafusion/tests/user_defined_plan.rs +++ b/datafusion/tests/user_defined_plan.rs @@ -164,7 +164,7 @@ async fn topk_plan() -> Result<()> { let expected = vec![ "| logical_plan after topk | TopK: k=3 |", - "| | Projection: #customer_id, #revenue |", + "| | Projection: #sales.customer_id, #sales.revenue |", "| | TableScan: sales projection=Some([0, 1]) |", ].join("\n"); @@ -174,7 +174,18 @@ async fn topk_plan() -> Result<()> { // normalize newlines (output on windows uses \r\n) let actual_output = actual_output.replace("\r\n", "\n"); - assert!(actual_output.contains(&expected) , "Expected output not present in actual output\nExpected:\n---------\n{}\nActual:\n--------\n{}", expected, actual_output); + assert!( + actual_output.contains(&expected), + "Expected output not present in actual output\ + \nExpected:\ + \n---------\ + \n{}\ + \nActual:\ + \n--------\ + \n{}", + expected, + actual_output + ); Ok(()) } diff --git a/integration-tests/test_psql_parity.py b/integration-tests/test_psql_parity.py index c4b5a7596ae9..92670bed0c4d 100644 --- a/integration-tests/test_psql_parity.py +++ b/integration-tests/test_psql_parity.py @@ -83,7 +83,7 @@ def test_parity(self): psql_output = pd.read_csv(io.BytesIO(generate_csv_from_psql(fname))) self.assertTrue( np.allclose(datafusion_output, psql_output), - msg=f"data fusion output={datafusion_output}, psql_output={psql_output}", + msg=f"datafusion output=\n{datafusion_output}, psql_output=\n{psql_output}", )