diff --git a/ballista/rust/core/src/serde/logical_plan/to_proto.rs b/ballista/rust/core/src/serde/logical_plan/to_proto.rs index 4049622b83dc..b8a8792d4321 100644 --- a/ballista/rust/core/src/serde/logical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/to_proto.rs @@ -1201,7 +1201,7 @@ impl TryInto for &Expr { Expr::Wildcard => Ok(protobuf::LogicalExprNode { expr_type: Some(protobuf::logical_expr_node::ExprType::Wildcard(true)), }), - Expr::TryCast { .. } => unimplemented!(), + _ => unimplemented!(), } } } diff --git a/datafusion/src/lib.rs b/datafusion/src/lib.rs index 64cc0a1349a2..ea4c1a8cddd1 100644 --- a/datafusion/src/lib.rs +++ b/datafusion/src/lib.rs @@ -224,6 +224,7 @@ pub mod physical_plan; pub mod prelude; pub mod scalar; pub mod sql; +mod utils; pub mod variable; // re-export dependencies from arrow-rs to minimise version maintenance for crate users diff --git a/datafusion/src/logical_plan/expr.rs b/datafusion/src/logical_plan/expr.rs index 622b7a4ec4ae..4c905538bcc0 100644 --- a/datafusion/src/logical_plan/expr.rs +++ b/datafusion/src/logical_plan/expr.rs @@ -25,6 +25,7 @@ use crate::physical_plan::{ aggregates, expressions::binary_operator_data_type, functions, udf::ScalarUDF, window_functions, }; +use crate::utils::get_field; use crate::{physical_plan::udaf::AggregateUDF, scalar::ScalarValue}; use aggregates::{AccumulatorFunctionImplementation, StateTypeFunction}; use arrow::{compute::can_cast_types, datatypes::DataType}; @@ -188,6 +189,13 @@ pub enum Expr { IsNull(Box), /// arithmetic negation of an expression, the operand must be of a signed numeric data type Negative(Box), + /// Returns the field of a [`StructArray`] by name + GetField { + /// the expression to take the field from + expr: Box, + /// The name of the field to take + name: String, + }, /// Whether an expression is between a given range. Between { /// The value to compare @@ -378,6 +386,10 @@ impl Expr { Expr::Wildcard => Err(DataFusionError::Internal( "Wildcard expressions are not valid in a logical query plan".to_owned(), )), + Expr::GetField { ref expr, name } => { + let data_type = expr.get_type(schema)?; + get_field(&data_type, name).map(|x| x.data_type().clone()) + } } } @@ -435,6 +447,10 @@ impl Expr { Expr::Wildcard => Err(DataFusionError::Internal( "Wildcard expressions are not valid in a logical query plan".to_owned(), )), + Expr::GetField { ref expr, name } => { + let data_type = expr.get_type(input_schema)?; + get_field(&data_type, name).map(|x| x.is_nullable()) + } } } @@ -575,6 +591,14 @@ impl Expr { Expr::IsNotNull(Box::new(self)) } + /// Returns the values of the field `name` from an expression returning a `Struct` + pub fn get_field>(self, name: I) -> Expr { + Expr::GetField { + expr: Box::new(self), + name: name.into(), + } + } + /// Create a sort expression from an existing expression. /// /// ``` @@ -710,6 +734,7 @@ impl Expr { .try_fold(visitor, |visitor, arg| arg.accept(visitor)) } Expr::Wildcard => Ok(visitor), + Expr::GetField { ref expr, .. } => expr.accept(visitor), }?; visitor.post_visit(self) @@ -867,6 +892,10 @@ impl Expr { negated, }, Expr::Wildcard => Expr::Wildcard, + Expr::GetField { expr, name } => Expr::GetField { + expr: rewrite_boxed(expr, rewriter)?, + name, + }, }; // now rewrite this expression itself @@ -1508,6 +1537,7 @@ impl fmt::Debug for Expr { } } Expr::Wildcard => write!(f, "*"), + Expr::GetField { ref expr, name } => write!(f, "({:?}).{}", expr, name), } } } @@ -1584,6 +1614,10 @@ fn create_name(e: &Expr, input_schema: &DFSchema) -> Result { let expr = create_name(expr, input_schema)?; Ok(format!("{} IS NOT NULL", expr)) } + Expr::GetField { expr, name } => { + let expr = create_name(expr, input_schema)?; + Ok(format!("{}.{}", expr, name)) + } Expr::ScalarFunction { fun, args, .. } => { create_function_name(&fun.to_string(), false, args, input_schema) } @@ -1694,6 +1728,12 @@ mod tests { ); } + #[test] + fn display_get_field() { + let col_null = col("col1").get_field("name"); + assert_eq!(format!("{:?}", col_null), "(#col1).name"); + } + #[derive(Default)] struct RecordingRewriter { v: Vec, diff --git a/datafusion/src/optimizer/utils.rs b/datafusion/src/optimizer/utils.rs index 76f44b84657c..4fec41548bec 100644 --- a/datafusion/src/optimizer/utils.rs +++ b/datafusion/src/optimizer/utils.rs @@ -82,6 +82,7 @@ impl ExpressionVisitor for ColumnNameVisitor<'_> { Expr::AggregateUDF { .. } => {} Expr::InList { .. } => {} Expr::Wildcard => {} + Expr::GetField { .. } => {} } Ok(Recursion::Continue(self)) } @@ -329,6 +330,7 @@ pub fn expr_sub_expressions(expr: &Expr) -> Result> { Expr::Wildcard { .. } => Err(DataFusionError::Internal( "Wildcard expressions are not valid in a logical query plan".to_owned(), )), + Expr::GetField { expr, .. } => Ok(vec![expr.as_ref().to_owned()]), } } @@ -344,6 +346,10 @@ pub fn rewrite_expression(expr: &Expr, expressions: &[Expr]) -> Result { }), Expr::IsNull(_) => Ok(Expr::IsNull(Box::new(expressions[0].clone()))), Expr::IsNotNull(_) => Ok(Expr::IsNotNull(Box::new(expressions[0].clone()))), + Expr::GetField { expr: _, name } => Ok(Expr::GetField { + expr: Box::new(expressions[0].clone()), + name: name.clone(), + }), Expr::ScalarFunction { fun, .. } => Ok(Expr::ScalarFunction { fun: fun.clone(), args: expressions.to_vec(), diff --git a/datafusion/src/physical_plan/expressions/get_field.rs b/datafusion/src/physical_plan/expressions/get_field.rs new file mode 100644 index 000000000000..da56a29f0777 --- /dev/null +++ b/datafusion/src/physical_plan/expressions/get_field.rs @@ -0,0 +1,100 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! get field of a struct array + +use std::{any::Any, sync::Arc}; + +use arrow::{ + array::StructArray, + datatypes::{DataType, Schema}, + record_batch::RecordBatch, +}; + +use crate::{ + error::DataFusionError, + error::Result, + physical_plan::{ColumnarValue, PhysicalExpr}, + utils::get_field as get_data_type_field, +}; + +/// expression to get a field of a struct array. +#[derive(Debug)] +pub struct GetFieldExpr { + arg: Arc, + name: String, +} + +impl GetFieldExpr { + /// Create new get field expression + pub fn new(arg: Arc, name: String) -> Self { + Self { arg, name } + } + + /// Get the input expression + pub fn arg(&self) -> &Arc { + &self.arg + } +} + +impl std::fmt::Display for GetFieldExpr { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "({}).{}", self.arg, self.name) + } +} + +impl PhysicalExpr for GetFieldExpr { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, input_schema: &Schema) -> Result { + let data_type = self.arg.data_type(input_schema)?; + get_data_type_field(&data_type, &self.name).map(|f| f.data_type().clone()) + } + + fn nullable(&self, input_schema: &Schema) -> Result { + let data_type = self.arg.data_type(input_schema)?; + get_data_type_field(&data_type, &self.name).map(|f| f.is_nullable()) + } + + fn evaluate(&self, batch: &RecordBatch) -> Result { + let arg = self.arg.evaluate(batch)?; + match arg { + ColumnarValue::Array(array) => Ok(ColumnarValue::Array( + array + .as_any() + .downcast_ref::() + .unwrap() + .column_by_name(&self.name) + .unwrap() + .clone(), + )), + ColumnarValue::Scalar(_) => Err(DataFusionError::NotImplemented( + "field is not yet implemented for scalar values".to_string(), + )), + } + } +} + +/// Create an `.field` expression +pub fn get_field( + arg: Arc, + name: String, +) -> Result> { + Ok(Arc::new(GetFieldExpr::new(arg, name))) +} diff --git a/datafusion/src/physical_plan/expressions/mod.rs b/datafusion/src/physical_plan/expressions/mod.rs index 0b32dca0467d..89abb103532c 100644 --- a/datafusion/src/physical_plan/expressions/mod.rs +++ b/datafusion/src/physical_plan/expressions/mod.rs @@ -33,6 +33,7 @@ mod cast; mod coercion; mod column; mod count; +mod get_field; mod in_list; mod is_not_null; mod is_null; @@ -54,6 +55,7 @@ pub use cast::{ }; pub use column::{col, Column}; pub use count::Count; +pub use get_field::{get_field, GetFieldExpr}; pub use in_list::{in_list, InListExpr}; pub use is_not_null::{is_not_null, IsNotNullExpr}; pub use is_null::{is_null, IsNullExpr}; diff --git a/datafusion/src/physical_plan/planner.rs b/datafusion/src/physical_plan/planner.rs index d59004243533..7fbf8e40b17d 100644 --- a/datafusion/src/physical_plan/planner.rs +++ b/datafusion/src/physical_plan/planner.rs @@ -128,6 +128,10 @@ fn physical_name(e: &Expr, input_schema: &DFSchema) -> Result { let expr = physical_name(expr, input_schema)?; Ok(format!("{} IS NOT NULL", expr)) } + Expr::GetField { expr, name } => { + let expr = physical_name(expr, input_schema)?; + Ok(format!("{}.{}", expr, name)) + } Expr::ScalarFunction { fun, args, .. } => { create_function_physical_name(&fun.to_string(), false, args, input_schema) } @@ -871,6 +875,10 @@ impl DefaultPhysicalPlanner { Expr::IsNotNull(expr) => expressions::is_not_null( self.create_physical_expr(expr, input_dfschema, input_schema, ctx_state)?, ), + Expr::GetField { expr, name } => expressions::get_field( + self.create_physical_expr(expr, input_dfschema, input_schema, ctx_state)?, + name.clone(), + ), Expr::ScalarFunction { fun, args } => { let physical_args = args .iter() diff --git a/datafusion/src/sql/planner.rs b/datafusion/src/sql/planner.rs index 17181230c26c..e99fb8f4444b 100644 --- a/datafusion/src/sql/planner.rs +++ b/datafusion/src/sql/planner.rs @@ -79,6 +79,22 @@ pub struct SqlToRel<'a, S: ContextProvider> { schema_provider: &'a S, } +fn plan_compound(mut identifiers: Vec) -> Expr { + if &identifiers[0][0..1] == "@" { + Expr::ScalarVariable(identifiers) + } else if identifiers.len() == 2 { + // "table.column" + let name = identifiers.pop().unwrap(); + let relation = Some(identifiers.pop().unwrap()); + Expr::Column(Column { relation, name }) + } else { + // "table.column.field..." + let name = identifiers.pop().unwrap(); + let expr = Box::new(plan_compound(identifiers)); + Expr::GetField { expr, name } + } +} + impl<'a, S: ContextProvider> SqlToRel<'a, S> { /// Create a new query planner pub fn new(schema_provider: &'a S) -> Self { @@ -916,23 +932,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } SQLExpr::CompoundIdentifier(ids) => { - let mut var_names = vec![]; - for id in ids { - var_names.push(id.value.clone()); - } - if &var_names[0][0..1] == "@" { - Ok(Expr::ScalarVariable(var_names)) - } else if var_names.len() == 2 { - // table.column identifier - let name = var_names.pop().unwrap(); - let relation = Some(var_names.pop().unwrap()); - Ok(Expr::Column(Column { relation, name })) - } else { - Err(DataFusionError::NotImplemented(format!( - "Unsupported compound identifier '{:?}'", - var_names, - ))) - } + let var_names = ids.iter().map(|x| x.value.clone()).collect::>(); + Ok(plan_compound(var_names)) } SQLExpr::Wildcard => Ok(Expr::Wildcard), diff --git a/datafusion/src/sql/utils.rs b/datafusion/src/sql/utils.rs index 080f84ef10ed..dfb2a8f6015a 100644 --- a/datafusion/src/sql/utils.rs +++ b/datafusion/src/sql/utils.rs @@ -380,6 +380,10 @@ where Ok(expr.clone()) } Expr::Wildcard => Ok(Expr::Wildcard), + Expr::GetField { expr, name } => Ok(Expr::GetField { + expr: Box::new(clone_with_replacement(expr.as_ref(), replacement_fn)?), + name: name.clone(), + }), }, } } diff --git a/datafusion/src/utils.rs b/datafusion/src/utils.rs new file mode 100644 index 000000000000..587cb18e5856 --- /dev/null +++ b/datafusion/src/utils.rs @@ -0,0 +1,43 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::datatypes::{DataType, Field}; + +use crate::error::{DataFusionError, Result}; + +/// Returns the first field named `name` from the fields of a [`DataType::Struct`]. +/// # Error +/// Errors iff +/// * the `data_type` is not a Struct or, +/// * there is no field named `name` +pub fn get_field<'a>(data_type: &'a DataType, name: &str) -> Result<&'a Field> { + if let DataType::Struct(fields) = data_type { + let maybe_field = fields.iter().find(|x| x.name() == name); + if let Some(field) = maybe_field { + Ok(field) + } else { + Err(DataFusionError::Plan(format!( + "The `Struct` has no field named \"{}\"", + name + ))) + } + } else { + Err(DataFusionError::Plan( + "The expression to get a field is only valid for `Struct`".to_string(), + )) + } +} diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs index c06a4bb1462e..84f7c9fefa10 100644 --- a/datafusion/tests/sql.rs +++ b/datafusion/tests/sql.rs @@ -24,14 +24,8 @@ use chrono::Duration; extern crate arrow; extern crate datafusion; -use arrow::{array::*, datatypes::TimeUnit}; -use arrow::{datatypes::Int32Type, datatypes::Int64Type, record_batch::RecordBatch}; use arrow::{ - datatypes::{ - ArrowNativeType, ArrowPrimitiveType, ArrowTimestampType, DataType, Field, Schema, - SchemaRef, TimestampMicrosecondType, TimestampMillisecondType, - TimestampNanosecondType, TimestampSecondType, - }, + array::*, datatypes::*, record_batch::RecordBatch, util::display::array_value_to_string, }; @@ -2860,6 +2854,31 @@ async fn query_is_not_null() -> Result<()> { Ok(()) } +#[tokio::test] +async fn query_get_field() -> Result<()> { + let inner_field = Field::new("inner", DataType::Float64, true); + let field = Field::new("c1", DataType::Struct(vec![inner_field.clone()]), true); + let schema = Arc::new(Schema::new(vec![field])); + + let array = Arc::new(Float64Array::from(vec![Some(1.1), None])) as ArrayRef; + + let data = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(StructArray::from(vec![(inner_field, array)]))], + )?; + + let table = MemTable::try_new(schema, vec![vec![data]])?; + + let mut ctx = ExecutionContext::new(); + ctx.register_table("test", Arc::new(table))?; + let sql = "SELECT test.c1.inner FROM test"; + let actual = execute(&mut ctx, sql).await; + let expected = vec![vec!["1.1"], vec!["NULL"]]; + + assert_eq!(expected, actual); + Ok(()) +} + #[tokio::test] async fn query_count_distinct() -> Result<()> { let schema = Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32, true)]));