From 6f76390b35830c02ffa1787e10e72dd28836cec6 Mon Sep 17 00:00:00 2001 From: Emil Ejbyfeldt Date: Sat, 22 Jun 2024 09:33:34 +0200 Subject: [PATCH 01/12] Add input_nullable to UDAF args StateField/AccumulatorArgs This follows how it done for input_type and only provide a single value. But might need to be changed into a Vec in the future. This is need when we are moving `arrag_agg` to udaf where one of the states nullability will depend on the nullability of the input. --- datafusion/expr/src/function.rs | 6 ++++++ datafusion/functions-aggregate/src/first_last.rs | 1 + datafusion/physical-expr-common/src/aggregate/mod.rs | 7 +++++++ 3 files changed, 14 insertions(+) diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index 73ab51494de6..6b26bc4eeb37 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -91,6 +91,9 @@ pub struct AccumulatorArgs<'a> { /// The input type of the aggregate function. pub input_type: &'a DataType, + /// If the input type is nullable. + pub input_nullable: bool, + /// The logical expression of arguments the aggregate function takes. pub input_exprs: &'a [Expr], } @@ -106,6 +109,9 @@ pub struct StateFieldsArgs<'a> { /// The input type of the aggregate function. pub input_type: &'a DataType, + /// If the input type is nullable. + pub input_nullable: bool, + /// The return type of the aggregate function. pub return_type: &'a DataType, diff --git a/datafusion/functions-aggregate/src/first_last.rs b/datafusion/functions-aggregate/src/first_last.rs index 0e619bacef82..85891eaf1e33 100644 --- a/datafusion/functions-aggregate/src/first_last.rs +++ b/datafusion/functions-aggregate/src/first_last.rs @@ -440,6 +440,7 @@ impl AggregateUDFImpl for LastValue { let StateFieldsArgs { name, input_type, + input_nullable: _, return_type: _, ordering_fields, is_distinct: _, diff --git a/datafusion/physical-expr-common/src/aggregate/mod.rs b/datafusion/physical-expr-common/src/aggregate/mod.rs index 0e245fd0a66a..9ebd6ff44665 100644 --- a/datafusion/physical-expr-common/src/aggregate/mod.rs +++ b/datafusion/physical-expr-common/src/aggregate/mod.rs @@ -96,6 +96,7 @@ pub fn create_aggregate_expr( ordering_fields, is_distinct, input_type: input_exprs_types[0].clone(), + input_nullable: input_phy_exprs[0].nullable(&schema)?, })) } @@ -271,6 +272,7 @@ pub struct AggregateFunctionExpr { ordering_fields: Vec, is_distinct: bool, input_type: DataType, + input_nullable: bool, } impl AggregateFunctionExpr { @@ -304,6 +306,7 @@ impl AggregateExpr for AggregateFunctionExpr { let args = StateFieldsArgs { name: &self.name, input_type: &self.input_type, + input_nullable: self.input_nullable, return_type: &self.data_type, ordering_fields: &self.ordering_fields, is_distinct: self.is_distinct, @@ -324,6 +327,7 @@ impl AggregateExpr for AggregateFunctionExpr { sort_exprs: &self.sort_exprs, is_distinct: self.is_distinct, input_type: &self.input_type, + input_nullable: self.input_nullable, input_exprs: &self.logical_args, name: &self.name, }; @@ -339,6 +343,7 @@ impl AggregateExpr for AggregateFunctionExpr { sort_exprs: &self.sort_exprs, is_distinct: self.is_distinct, input_type: &self.input_type, + input_nullable: self.input_nullable, input_exprs: &self.logical_args, name: &self.name, }; @@ -409,6 +414,7 @@ impl AggregateExpr for AggregateFunctionExpr { sort_exprs: &self.sort_exprs, is_distinct: self.is_distinct, input_type: &self.input_type, + input_nullable: self.input_nullable, input_exprs: &self.logical_args, name: &self.name, }; @@ -423,6 +429,7 @@ impl AggregateExpr for AggregateFunctionExpr { sort_exprs: &self.sort_exprs, is_distinct: self.is_distinct, input_type: &self.input_type, + input_nullable: self.input_nullable, input_exprs: &self.logical_args, name: &self.name, }; From d1634bc0bf55cb6d428aced5a86f991eba6ae123 Mon Sep 17 00:00:00 2001 From: Emil Ejbyfeldt Date: Fri, 21 Jun 2024 09:50:36 +0200 Subject: [PATCH 02/12] Make ArragAgg (not ordered or distinct) into a UDAF --- datafusion/core/src/dataframe/mod.rs | 6 +- datafusion/core/src/physical_planner.rs | 28 ++ datafusion/core/tests/dataframe/mod.rs | 4 +- datafusion/expr/src/expr_fn.rs | 1 + datafusion/functions-aggregate/Cargo.toml | 1 + .../functions-aggregate/src/array_agg.rs | 271 ++++++++++++++++++ datafusion/functions-aggregate/src/lib.rs | 8 +- .../physical-expr/src/aggregate/array_agg.rs | 185 ------------ .../physical-expr/src/aggregate/build_in.rs | 27 +- datafusion/physical-expr/src/aggregate/mod.rs | 1 - .../physical-expr/src/expressions/mod.rs | 1 - .../proto/src/physical_plan/to_proto.rs | 3 +- 12 files changed, 319 insertions(+), 217 deletions(-) create mode 100644 datafusion/functions-aggregate/src/array_agg.rs delete mode 100644 datafusion/physical-expr/src/aggregate/array_agg.rs diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index c55b7c752765..fb28b5c1ab47 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -1696,10 +1696,10 @@ mod tests { use datafusion_common::{Constraint, Constraints, ScalarValue}; use datafusion_common_runtime::SpawnedTask; use datafusion_expr::{ - array_agg, cast, create_udf, expr, lit, BuiltInWindowFunction, - ScalarFunctionImplementation, Volatility, WindowFrame, WindowFunctionDefinition, + cast, create_udf, expr, lit, BuiltInWindowFunction, ScalarFunctionImplementation, + Volatility, WindowFrame, WindowFunctionDefinition, }; - use datafusion_functions_aggregate::expr_fn::count_distinct; + use datafusion_functions_aggregate::expr_fn::{array_agg, count_distinct}; use datafusion_physical_expr::expressions::Column; use datafusion_physical_plan::{get_plan_string, ExecutionPlanProperties}; diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index efc83d8f6b5c..69ddbe4256f9 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -85,6 +85,7 @@ use datafusion_expr::{ DescribeTable, DmlStatement, Extension, Filter, RecursiveQuery, StringifiedPlan, WindowFrame, WindowFrameBound, WriteOp, }; +use datafusion_functions_aggregate::array_agg::array_agg_udaf; use datafusion_physical_expr::expressions::Literal; use datafusion_physical_expr::LexOrdering; use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; @@ -1840,6 +1841,33 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( == NullTreatment::IgnoreNulls; let (agg_expr, filter, order_by) = match func_def { + AggregateFunctionDefinition::BuiltIn( + datafusion_expr::AggregateFunction::ArrayAgg, + ) if !distinct && order_by.is_none() => { + let sort_exprs = order_by.clone().unwrap_or(vec![]); + let physical_sort_exprs = match order_by { + Some(exprs) => Some(create_physical_sort_exprs( + exprs, + logical_input_schema, + execution_props, + )?), + None => None, + }; + let ordering_reqs: Vec = + physical_sort_exprs.clone().unwrap_or(vec![]); + let agg_expr = udaf::create_aggregate_expr( + &array_agg_udaf(), + &physical_args, + args, + &sort_exprs, + &ordering_reqs, + physical_input_schema, + name, + ignore_nulls, + *distinct, + )?; + (agg_expr, filter, physical_sort_exprs) + } AggregateFunctionDefinition::BuiltIn(fun) => { let physical_sort_exprs = match order_by { Some(exprs) => Some(create_physical_sort_exprs( diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index 1b2a6770cf01..4e347a8f4e5a 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -54,11 +54,11 @@ use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_expr::expr::{GroupingSet, Sort}; use datafusion_expr::var_provider::{VarProvider, VarType}; use datafusion_expr::{ - array_agg, cast, col, exists, expr, in_subquery, lit, max, out_ref_col, placeholder, + cast, col, exists, expr, in_subquery, lit, max, out_ref_col, placeholder, scalar_subquery, when, wildcard, Expr, ExprSchemable, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; -use datafusion_functions_aggregate::expr_fn::{avg, count, sum}; +use datafusion_functions_aggregate::expr_fn::{array_agg, avg, count, sum}; #[tokio::test] async fn test_count_wildcard_on_sort() -> Result<()> { diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 8b0213fd52fd..0f7b33e12461 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -171,6 +171,7 @@ pub fn max(expr: Expr) -> Expr { )) } +// TODO: remove /// Create an expression to represent the array_agg() aggregate function pub fn array_agg(expr: Expr) -> Expr { Expr::AggregateFunction(AggregateFunction::new( diff --git a/datafusion/functions-aggregate/Cargo.toml b/datafusion/functions-aggregate/Cargo.toml index 26630a0352d5..3331701844b4 100644 --- a/datafusion/functions-aggregate/Cargo.toml +++ b/datafusion/functions-aggregate/Cargo.toml @@ -40,6 +40,7 @@ path = "src/lib.rs" [dependencies] ahash = { workspace = true } arrow = { workspace = true } +arrow-array = { workspace = true } arrow-schema = { workspace = true } datafusion-common = { workspace = true } datafusion-execution = { workspace = true } diff --git a/datafusion/functions-aggregate/src/array_agg.rs b/datafusion/functions-aggregate/src/array_agg.rs new file mode 100644 index 000000000000..27e3c11049f2 --- /dev/null +++ b/datafusion/functions-aggregate/src/array_agg.rs @@ -0,0 +1,271 @@ +// Licensed to the Apache Software Foundation (ASF) under on +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines physical expressions that can evaluated at runtime during query execution + +use arrow::array::ArrayRef; +use arrow::datatypes::DataType; +use arrow_array::Array; +use arrow_schema::Field; + +use datafusion_common::cast::as_list_array; +use datafusion_common::utils::array_into_list_array; +use datafusion_common::Result; +use datafusion_common::ScalarValue; +use datafusion_expr::expr::AggregateFunction; +use datafusion_expr::expr::AggregateFunctionDefinition; +use datafusion_expr::function::AccumulatorArgs; +use datafusion_expr::simplify::SimplifyInfo; +use datafusion_expr::utils::format_state_name; +use datafusion_expr::AggregateUDFImpl; +use datafusion_expr::Expr; +use datafusion_expr::{Accumulator, Signature, Volatility}; +use std::sync::Arc; + +make_udaf_expr_and_func!( + ArrayAgg, + array_agg, + expression, + "Computes the nth value", + array_agg_udaf +); + +#[derive(Debug)] +/// ARRAY_AGG aggregate expression +pub struct ArrayAgg { + signature: Signature, + alias: Vec, +} + +impl Default for ArrayAgg { + fn default() -> Self { + Self { + signature: Signature::any(1, Volatility::Immutable), + alias: vec!["array_agg".to_string()], + } + } +} + +impl AggregateUDFImpl for ArrayAgg { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "ARRAY_AGG" + } + + fn aliases(&self) -> &[String] { + &self.alias + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(DataType::List(Arc::new(Field::new( + "item", + arg_types[0].clone(), + true, + )))) + } + + fn state_fields( + &self, + args: datafusion_expr::function::StateFieldsArgs, + ) -> Result> { + Ok(vec![Field::new_list( + format_state_name(args.name, "array_agg"), + Field::new("item", args.input_type.clone(), true), + true, + )]) + } + + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + Ok(Box::new(ArrayAggAccumulator::try_new(acc_args.input_type)?)) + } + + fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF { + datafusion_expr::ReversedUDAF::Identical + } + + fn simplify( + &self, + ) -> Option { + let simplify = |aggregate_function: AggregateFunction, _: &dyn SimplifyInfo| { + if aggregate_function.order_by.is_some() || aggregate_function.distinct { + Ok(Expr::AggregateFunction(AggregateFunction { + func_def: AggregateFunctionDefinition::BuiltIn( + datafusion_expr::aggregate_function::AggregateFunction::ArrayAgg, + ), + args: aggregate_function.args, + distinct: aggregate_function.distinct, + filter: aggregate_function.filter, + order_by: aggregate_function.order_by, + null_treatment: aggregate_function.null_treatment, + })) + } else { + Ok(Expr::AggregateFunction(aggregate_function)) + } + }; + + Some(Box::new(simplify)) + } +} + +#[derive(Debug)] +pub struct ArrayAggAccumulator { + values: Vec, + datatype: DataType, +} + +impl ArrayAggAccumulator { + /// new array_agg accumulator based on given item data type + pub fn try_new(datatype: &DataType) -> Result { + Ok(Self { + values: vec![], + datatype: datatype.clone(), + }) + } +} + +impl Accumulator for ArrayAggAccumulator { + // Append value like Int64Array(1,2,3) + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if values.is_empty() { + return Ok(()); + } + assert!(values.len() == 1, "array_agg can only take 1 param!"); + let val = values[0].clone(); + self.values.push(val); + Ok(()) + } + + // Append value like ListArray(Int64Array(1,2,3), Int64Array(4,5,6)) + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + if states.is_empty() { + return Ok(()); + } + assert!(states.len() == 1, "array_agg states must be singleton!"); + + let list_arr = as_list_array(&states[0])?; + for arr in list_arr.iter().flatten() { + self.values.push(arr); + } + Ok(()) + } + + fn state(&mut self) -> Result> { + Ok(vec![self.evaluate()?]) + } + + fn evaluate(&mut self) -> Result { + // Transform Vec to ListArr + + let element_arrays: Vec<&dyn Array> = + self.values.iter().map(|a| a.as_ref()).collect(); + + if element_arrays.is_empty() { + let arr = ScalarValue::new_list(&[], &self.datatype); + return Ok(ScalarValue::List(arr)); + } + + let concated_array = arrow::compute::concat(&element_arrays)?; + let list_array = array_into_list_array(concated_array); + + Ok(ScalarValue::List(Arc::new(list_array))) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + + (std::mem::size_of::() * self.values.capacity()) + + self + .values + .iter() + .map(|arr| arr.get_array_memory_size()) + .sum::() + + self.datatype.size() + - std::mem::size_of_val(&self.datatype) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use std::sync::Arc; + + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_common::Result; + use datafusion_physical_expr_common::aggregate::create_aggregate_expr; + use datafusion_physical_expr_common::expressions::column::Column; + use datafusion_physical_expr_common::physical_expr::PhysicalExpr; + + #[test] + fn test_array_agg_expr() -> Result<()> { + let data_types = vec![ + DataType::UInt32, + DataType::Int32, + DataType::Float32, + DataType::Float64, + DataType::Decimal128(10, 2), + DataType::Utf8, + ]; + for data_type in &data_types { + let input_schema = + Schema::new(vec![Field::new("c1", data_type.clone(), true)]); + let input_phy_exprs: Vec> = vec![Arc::new( + Column::new_with_schema("c1", &input_schema).unwrap(), + )]; + let result_agg_phy_exprs = create_aggregate_expr( + &array_agg_udaf(), + &input_phy_exprs[0..1], + &[], + &[], + &[], + &input_schema, + "c1", + false, + false, + )?; + assert_eq!("c1", result_agg_phy_exprs.name()); + assert_eq!( + Field::new_list("c1", Field::new("item", data_type.clone(), true), true,), + result_agg_phy_exprs.field().unwrap() + ); + + let result_distinct = create_aggregate_expr( + &array_agg_udaf(), + &input_phy_exprs[0..1], + &[], + &[], + &[], + &input_schema, + "c1", + false, + true, + )?; + assert_eq!("c1", result_distinct.name()); + assert_eq!( + Field::new_list("c1", Field::new("item", data_type.clone(), true), true,), + result_agg_phy_exprs.field().unwrap() + ); + } + Ok(()) + } +} diff --git a/datafusion/functions-aggregate/src/lib.rs b/datafusion/functions-aggregate/src/lib.rs index a3808a08b007..1785abc168f0 100644 --- a/datafusion/functions-aggregate/src/lib.rs +++ b/datafusion/functions-aggregate/src/lib.rs @@ -59,6 +59,7 @@ pub mod macros; pub mod approx_distinct; pub mod correlation; +pub mod array_agg; pub mod count; pub mod covariance; pub mod first_last; @@ -94,6 +95,7 @@ pub mod expr_fn { pub use super::approx_percentile_cont::approx_percentile_cont; pub use super::approx_percentile_cont_with_weight::approx_percentile_cont_with_weight; pub use super::average::avg; + pub use super::array_agg::array_agg; pub use super::bit_and_or_xor::bit_and; pub use super::bit_and_or_xor::bit_or; pub use super::bit_and_or_xor::bit_xor; @@ -128,6 +130,7 @@ pub mod expr_fn { /// Returns all default aggregate functions pub fn all_default_aggregate_functions() -> Vec> { vec![ + array_agg::array_agg_udaf(), first_last::first_value_udaf(), first_last::last_value_udaf(), covariance::covar_samp_udaf(), @@ -191,8 +194,9 @@ mod tests { let mut names = HashSet::new(); for func in all_default_aggregate_functions() { // TODO: remove this - // These functions are in intermediate migration state, skip them - if func.name().to_lowercase() == "count" { + // These functions are in intermidiate migration state, skip them + let name_lower_case = func.name().to_lowercase(); + if name_lower_case == "count" || name_lower_case == "array_agg" { continue; } assert!( diff --git a/datafusion/physical-expr/src/aggregate/array_agg.rs b/datafusion/physical-expr/src/aggregate/array_agg.rs deleted file mode 100644 index 0d5ed730e283..000000000000 --- a/datafusion/physical-expr/src/aggregate/array_agg.rs +++ /dev/null @@ -1,185 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Defines physical expressions that can evaluated at runtime during query execution - -use crate::aggregate::utils::down_cast_any_ref; -use crate::expressions::format_state_name; -use crate::{AggregateExpr, PhysicalExpr}; -use arrow::array::ArrayRef; -use arrow::datatypes::{DataType, Field}; -use arrow_array::Array; -use datafusion_common::cast::as_list_array; -use datafusion_common::utils::array_into_list_array_nullable; -use datafusion_common::Result; -use datafusion_common::ScalarValue; -use datafusion_expr::Accumulator; -use std::any::Any; -use std::sync::Arc; - -/// ARRAY_AGG aggregate expression -#[derive(Debug)] -pub struct ArrayAgg { - /// Column name - name: String, - /// The DataType for the input expression - input_data_type: DataType, - /// The input expression - expr: Arc, -} - -impl ArrayAgg { - /// Create a new ArrayAgg aggregate function - pub fn new( - expr: Arc, - name: impl Into, - data_type: DataType, - ) -> Self { - Self { - name: name.into(), - input_data_type: data_type, - expr, - } - } -} - -impl AggregateExpr for ArrayAgg { - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - Ok(Field::new_list( - &self.name, - // This should be the same as return type of AggregateFunction::ArrayAgg - Field::new("item", self.input_data_type.clone(), true), - true, - )) - } - - fn create_accumulator(&self) -> Result> { - Ok(Box::new(ArrayAggAccumulator::try_new( - &self.input_data_type, - )?)) - } - - fn state_fields(&self) -> Result> { - Ok(vec![Field::new_list( - format_state_name(&self.name, "array_agg"), - Field::new("item", self.input_data_type.clone(), true), - true, - )]) - } - - fn expressions(&self) -> Vec> { - vec![Arc::clone(&self.expr)] - } - - fn name(&self) -> &str { - &self.name - } -} - -impl PartialEq for ArrayAgg { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.input_data_type == x.input_data_type - && self.expr.eq(&x.expr) - }) - .unwrap_or(false) - } -} - -#[derive(Debug)] -pub(crate) struct ArrayAggAccumulator { - values: Vec, - datatype: DataType, -} - -impl ArrayAggAccumulator { - /// new array_agg accumulator based on given item data type - pub fn try_new(datatype: &DataType) -> Result { - Ok(Self { - values: vec![], - datatype: datatype.clone(), - }) - } -} - -impl Accumulator for ArrayAggAccumulator { - // Append value like Int64Array(1,2,3) - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - if values.is_empty() { - return Ok(()); - } - assert!(values.len() == 1, "array_agg can only take 1 param!"); - - let val = Arc::clone(&values[0]); - if val.len() > 0 { - self.values.push(val); - } - Ok(()) - } - - // Append value like ListArray(Int64Array(1,2,3), Int64Array(4,5,6)) - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - if states.is_empty() { - return Ok(()); - } - assert!(states.len() == 1, "array_agg states must be singleton!"); - - let list_arr = as_list_array(&states[0])?; - for arr in list_arr.iter().flatten() { - self.values.push(arr); - } - Ok(()) - } - - fn state(&mut self) -> Result> { - Ok(vec![self.evaluate()?]) - } - - fn evaluate(&mut self) -> Result { - // Transform Vec to ListArr - let element_arrays: Vec<&dyn Array> = - self.values.iter().map(|a| a.as_ref()).collect(); - - if element_arrays.is_empty() { - return Ok(ScalarValue::new_null_list(self.datatype.clone(), true, 1)); - } - - let concated_array = arrow::compute::concat(&element_arrays)?; - let list_array = array_into_list_array_nullable(concated_array); - - Ok(ScalarValue::List(Arc::new(list_array))) - } - - fn size(&self) -> usize { - std::mem::size_of_val(self) - + (std::mem::size_of::() * self.values.capacity()) - + self - .values - .iter() - .map(|arr| arr.get_array_memory_size()) - .sum::() - + self.datatype.size() - - std::mem::size_of_val(&self.datatype) - } -} diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index ef21b3d0f788..ef47eba0a499 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -64,7 +64,9 @@ pub fn create_aggregate_expr( let expr = Arc::clone(&input_phy_exprs[0]); if ordering_req.is_empty() { - Arc::new(expressions::ArrayAgg::new(expr, name, data_type)) + return internal_err!( + "ArrayAgg without ordering should be handled as UDAF" + ); } else { Arc::new(expressions::OrderSensitiveArrayAgg::new( expr, @@ -104,7 +106,7 @@ mod tests { use datafusion_common::plan_err; use datafusion_expr::{type_coercion, Signature}; - use crate::expressions::{try_cast, ArrayAgg, DistinctArrayAgg, Max, Min}; + use crate::expressions::{try_cast, DistinctArrayAgg, Max, Min}; use super::*; #[test] @@ -125,25 +127,6 @@ mod tests { let input_phy_exprs: Vec> = vec![Arc::new( expressions::Column::new_with_schema("c1", &input_schema).unwrap(), )]; - let result_agg_phy_exprs = create_physical_agg_expr_for_test( - &fun, - false, - &input_phy_exprs[0..1], - &input_schema, - "c1", - )?; - if fun == AggregateFunction::ArrayAgg { - assert!(result_agg_phy_exprs.as_any().is::()); - assert_eq!("c1", result_agg_phy_exprs.name()); - assert_eq!( - Field::new_list( - "c1", - Field::new("item", data_type.clone(), true), - true, - ), - result_agg_phy_exprs.field().unwrap() - ); - } let result_distinct = create_physical_agg_expr_for_test( &fun, @@ -161,7 +144,7 @@ mod tests { Field::new("item", data_type.clone(), true), true, ), - result_agg_phy_exprs.field().unwrap() + result_distinct.field().unwrap() ); } } diff --git a/datafusion/physical-expr/src/aggregate/mod.rs b/datafusion/physical-expr/src/aggregate/mod.rs index b9d803900f53..c7baac120347 100644 --- a/datafusion/physical-expr/src/aggregate/mod.rs +++ b/datafusion/physical-expr/src/aggregate/mod.rs @@ -17,7 +17,6 @@ pub use datafusion_physical_expr_common::aggregate::AggregateExpr; -pub(crate) mod array_agg; pub(crate) mod array_agg_distinct; pub(crate) mod array_agg_ordered; #[macro_use] diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index 7d8f12091f46..5131639e523a 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -34,7 +34,6 @@ mod try_cast; pub mod helpers { pub use crate::aggregate::min_max::{max, min}; } -pub use crate::aggregate::array_agg::ArrayAgg; pub use crate::aggregate::array_agg_distinct::DistinctArrayAgg; pub use crate::aggregate::array_agg_ordered::OrderSensitiveArrayAgg; pub use crate::aggregate::build_in::create_aggregate_expr; diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 7ea2902cf3c0..44a7efc0230f 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -23,7 +23,7 @@ use datafusion::datasource::file_format::parquet::ParquetSink; use datafusion::physical_expr::window::{NthValueKind, SlidingAggregateWindowExpr}; use datafusion::physical_expr::{PhysicalSortExpr, ScalarFunctionExpr}; use datafusion::physical_plan::expressions::{ - ArrayAgg, BinaryExpr, CaseExpr, CastExpr, Column, CumeDist, DistinctArrayAgg, + BinaryExpr, CaseExpr, CastExpr, Column, CumeDist, DistinctArrayAgg, InListExpr, IsNotNullExpr, IsNullExpr, Literal, Max, Min, NegativeExpr, NotExpr, NthValue, Ntile, OrderSensitiveArrayAgg, Rank, RankType, RowNumber, TryCastExpr, WindowShift, @@ -262,6 +262,7 @@ fn aggr_expr_to_aggr_fn(expr: &dyn AggregateExpr) -> Result { let aggr_expr = expr.as_any(); let mut distinct = false; + // TODO: remove let inner = if aggr_expr.downcast_ref::().is_some() { protobuf::AggregateFunction::ArrayAgg } else if aggr_expr.downcast_ref::().is_some() { From dc712f4cf28272569295c504756f81650bfbf17e Mon Sep 17 00:00:00 2001 From: Emil Ejbyfeldt Date: Fri, 21 Jun 2024 10:00:52 +0200 Subject: [PATCH 03/12] Add roundtrip_expr_api test case --- datafusion/proto/tests/cases/roundtrip_logical_plan.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 0117502f400d..9f3144975565 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -66,7 +66,7 @@ use datafusion_expr::{ }; use datafusion_functions_aggregate::average::avg_udaf; use datafusion_functions_aggregate::expr_fn::{ - avg, bit_and, bit_or, bit_xor, bool_and, bool_or, corr, + array_agg, avg, bit_and, bit_or, bit_xor, bool_and, bool_or, corr, }; use datafusion_functions_aggregate::string_agg::string_agg; use datafusion_proto::bytes::{ @@ -702,6 +702,7 @@ async fn roundtrip_expr_api() -> Result<()> { string_agg(col("a").cast_to(&DataType::Utf8, &schema)?, lit("|")), bool_and(lit(true)), bool_or(lit(true)), + array_agg(lit(1)) ]; // ensure expressions created with the expr api can be round tripped From 41b9c8ab738080ed7afe5b5985dc9d989bd8840b Mon Sep 17 00:00:00 2001 From: Emil Ejbyfeldt Date: Sat, 22 Jun 2024 10:32:47 +0200 Subject: [PATCH 04/12] Address PR comments --- datafusion/core/tests/sql/aggregates.rs | 2 +- datafusion/functions-aggregate/Cargo.toml | 1 - .../functions-aggregate/src/array_agg.rs | 75 +------------------ .../tests/cases/roundtrip_logical_plan.rs | 2 +- 4 files changed, 6 insertions(+), 74 deletions(-) diff --git a/datafusion/core/tests/sql/aggregates.rs b/datafusion/core/tests/sql/aggregates.rs index 1f4f9e77d5dc..c3139f6fcdfb 100644 --- a/datafusion/core/tests/sql/aggregates.rs +++ b/datafusion/core/tests/sql/aggregates.rs @@ -37,7 +37,7 @@ async fn csv_query_array_agg_distinct() -> Result<()> { Schema::new(vec![Field::new_list( "ARRAY_AGG(DISTINCT aggregate_test_100.c2)", Field::new("item", DataType::UInt32, true), - true + false ),]) ); diff --git a/datafusion/functions-aggregate/Cargo.toml b/datafusion/functions-aggregate/Cargo.toml index 3331701844b4..26630a0352d5 100644 --- a/datafusion/functions-aggregate/Cargo.toml +++ b/datafusion/functions-aggregate/Cargo.toml @@ -40,7 +40,6 @@ path = "src/lib.rs" [dependencies] ahash = { workspace = true } arrow = { workspace = true } -arrow-array = { workspace = true } arrow-schema = { workspace = true } datafusion-common = { workspace = true } datafusion-execution = { workspace = true } diff --git a/datafusion/functions-aggregate/src/array_agg.rs b/datafusion/functions-aggregate/src/array_agg.rs index 27e3c11049f2..a0cedf5817ff 100644 --- a/datafusion/functions-aggregate/src/array_agg.rs +++ b/datafusion/functions-aggregate/src/array_agg.rs @@ -1,4 +1,4 @@ -// Licensed to the Apache Software Foundation (ASF) under on +// Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file @@ -17,9 +17,8 @@ //! Defines physical expressions that can evaluated at runtime during query execution -use arrow::array::ArrayRef; +use arrow::array::{Array, ArrayRef}; use arrow::datatypes::DataType; -use arrow_array::Array; use arrow_schema::Field; use datafusion_common::cast::as_list_array; @@ -40,7 +39,7 @@ make_udaf_expr_and_func!( ArrayAgg, array_agg, expression, - "Computes the nth value", + "input values, including nulls, concatenated into an array", array_agg_udaf ); @@ -92,7 +91,7 @@ impl AggregateUDFImpl for ArrayAgg { Ok(vec![Field::new_list( format_state_name(args.name, "array_agg"), Field::new("item", args.input_type.clone(), true), - true, + args.input_nullable, )]) } @@ -203,69 +202,3 @@ impl Accumulator for ArrayAggAccumulator { - std::mem::size_of_val(&self.datatype) } } - -#[cfg(test)] -mod tests { - use super::*; - - use std::sync::Arc; - - use arrow::datatypes::{DataType, Field, Schema}; - use datafusion_common::Result; - use datafusion_physical_expr_common::aggregate::create_aggregate_expr; - use datafusion_physical_expr_common::expressions::column::Column; - use datafusion_physical_expr_common::physical_expr::PhysicalExpr; - - #[test] - fn test_array_agg_expr() -> Result<()> { - let data_types = vec![ - DataType::UInt32, - DataType::Int32, - DataType::Float32, - DataType::Float64, - DataType::Decimal128(10, 2), - DataType::Utf8, - ]; - for data_type in &data_types { - let input_schema = - Schema::new(vec![Field::new("c1", data_type.clone(), true)]); - let input_phy_exprs: Vec> = vec![Arc::new( - Column::new_with_schema("c1", &input_schema).unwrap(), - )]; - let result_agg_phy_exprs = create_aggregate_expr( - &array_agg_udaf(), - &input_phy_exprs[0..1], - &[], - &[], - &[], - &input_schema, - "c1", - false, - false, - )?; - assert_eq!("c1", result_agg_phy_exprs.name()); - assert_eq!( - Field::new_list("c1", Field::new("item", data_type.clone(), true), true,), - result_agg_phy_exprs.field().unwrap() - ); - - let result_distinct = create_aggregate_expr( - &array_agg_udaf(), - &input_phy_exprs[0..1], - &[], - &[], - &[], - &input_schema, - "c1", - false, - true, - )?; - assert_eq!("c1", result_distinct.name()); - assert_eq!( - Field::new_list("c1", Field::new("item", data_type.clone(), true), true,), - result_agg_phy_exprs.field().unwrap() - ); - } - Ok(()) - } -} diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 9f3144975565..9d3a60eb6115 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -702,7 +702,7 @@ async fn roundtrip_expr_api() -> Result<()> { string_agg(col("a").cast_to(&DataType::Utf8, &schema)?, lit("|")), bool_and(lit(true)), bool_or(lit(true)), - array_agg(lit(1)) + array_agg(lit(1)), ]; // ensure expressions created with the expr api can be round tripped From a7712e7a93003c1b368e22d1b764719fcf7ddf76 Mon Sep 17 00:00:00 2001 From: Emil Ejbyfeldt Date: Sat, 22 Jun 2024 17:39:59 +0200 Subject: [PATCH 05/12] Propegate input nullability for aggregates --- datafusion/physical-expr-common/src/aggregate/mod.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/datafusion/physical-expr-common/src/aggregate/mod.rs b/datafusion/physical-expr-common/src/aggregate/mod.rs index 9ebd6ff44665..77ff7b5b1412 100644 --- a/datafusion/physical-expr-common/src/aggregate/mod.rs +++ b/datafusion/physical-expr-common/src/aggregate/mod.rs @@ -316,7 +316,11 @@ impl AggregateExpr for AggregateFunctionExpr { } fn field(&self) -> Result { - Ok(Field::new(&self.name, self.data_type.clone(), true)) + Ok(Field::new( + &self.name, + self.data_type.clone(), + self.input_nullable, + )) } fn create_accumulator(&self) -> Result> { From d5d21b0198b0c81857b44a1a7629a2ecfee8664a Mon Sep 17 00:00:00 2001 From: Emil Ejbyfeldt Date: Sat, 22 Jun 2024 17:42:20 +0200 Subject: [PATCH 06/12] Remove from accumulator args --- datafusion/expr/src/function.rs | 3 --- datafusion/physical-expr-common/src/aggregate/mod.rs | 6 +----- 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index 6b26bc4eeb37..576209ef8e22 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -91,9 +91,6 @@ pub struct AccumulatorArgs<'a> { /// The input type of the aggregate function. pub input_type: &'a DataType, - /// If the input type is nullable. - pub input_nullable: bool, - /// The logical expression of arguments the aggregate function takes. pub input_exprs: &'a [Expr], } diff --git a/datafusion/physical-expr-common/src/aggregate/mod.rs b/datafusion/physical-expr-common/src/aggregate/mod.rs index 77ff7b5b1412..94c71033d39d 100644 --- a/datafusion/physical-expr-common/src/aggregate/mod.rs +++ b/datafusion/physical-expr-common/src/aggregate/mod.rs @@ -96,7 +96,7 @@ pub fn create_aggregate_expr( ordering_fields, is_distinct, input_type: input_exprs_types[0].clone(), - input_nullable: input_phy_exprs[0].nullable(&schema)?, + input_nullable: input_phy_exprs[0].nullable(schema)?, })) } @@ -331,7 +331,6 @@ impl AggregateExpr for AggregateFunctionExpr { sort_exprs: &self.sort_exprs, is_distinct: self.is_distinct, input_type: &self.input_type, - input_nullable: self.input_nullable, input_exprs: &self.logical_args, name: &self.name, }; @@ -347,7 +346,6 @@ impl AggregateExpr for AggregateFunctionExpr { sort_exprs: &self.sort_exprs, is_distinct: self.is_distinct, input_type: &self.input_type, - input_nullable: self.input_nullable, input_exprs: &self.logical_args, name: &self.name, }; @@ -418,7 +416,6 @@ impl AggregateExpr for AggregateFunctionExpr { sort_exprs: &self.sort_exprs, is_distinct: self.is_distinct, input_type: &self.input_type, - input_nullable: self.input_nullable, input_exprs: &self.logical_args, name: &self.name, }; @@ -433,7 +430,6 @@ impl AggregateExpr for AggregateFunctionExpr { sort_exprs: &self.sort_exprs, is_distinct: self.is_distinct, input_type: &self.input_type, - input_nullable: self.input_nullable, input_exprs: &self.logical_args, name: &self.name, }; From 7a1056ce6bdc96b61d4b53088220fbeaa418cfaa Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sat, 13 Jul 2024 15:39:11 +0800 Subject: [PATCH 07/12] first draft Signed-off-by: jayzhan211 --- datafusion/core/src/physical_planner.rs | 19 +++---- datafusion/expr/src/function.rs | 3 - .../functions-aggregate/src/array_agg.rs | 55 ++++--------------- .../functions-aggregate/src/first_last.rs | 1 - datafusion/functions-aggregate/src/lib.rs | 4 +- datafusion/functions-array/src/planner.rs | 12 ++-- .../physical-expr-common/src/aggregate/mod.rs | 9 +-- .../physical-expr/src/aggregate/build_in.rs | 2 +- .../src/aggregates/no_grouping.rs | 1 + .../proto/src/physical_plan/to_proto.rs | 11 ++-- 10 files changed, 36 insertions(+), 81 deletions(-) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 69ddbe4256f9..16ddcaf9e105 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -85,7 +85,6 @@ use datafusion_expr::{ DescribeTable, DmlStatement, Extension, Filter, RecursiveQuery, StringifiedPlan, WindowFrame, WindowFrameBound, WriteOp, }; -use datafusion_functions_aggregate::array_agg::array_agg_udaf; use datafusion_physical_expr::expressions::Literal; use datafusion_physical_expr::LexOrdering; use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; @@ -1840,11 +1839,11 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( .unwrap_or(sqlparser::ast::NullTreatment::RespectNulls) == NullTreatment::IgnoreNulls; + // TODO: Remove this after array_agg are all udafs let (agg_expr, filter, order_by) = match func_def { - AggregateFunctionDefinition::BuiltIn( - datafusion_expr::AggregateFunction::ArrayAgg, - ) if !distinct && order_by.is_none() => { - let sort_exprs = order_by.clone().unwrap_or(vec![]); + AggregateFunctionDefinition::UDF(udf) + if udf.name() == "ARRAY_AGG" && (*distinct || order_by.is_some()) => + { let physical_sort_exprs = match order_by { Some(exprs) => Some(create_physical_sort_exprs( exprs, @@ -1855,16 +1854,15 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( }; let ordering_reqs: Vec = physical_sort_exprs.clone().unwrap_or(vec![]); - let agg_expr = udaf::create_aggregate_expr( - &array_agg_udaf(), + let fun = aggregates::AggregateFunction::ArrayAgg; + let agg_expr = aggregates::create_aggregate_expr( + &fun, + *distinct, &physical_args, - args, - &sort_exprs, &ordering_reqs, physical_input_schema, name, ignore_nulls, - *distinct, )?; (agg_expr, filter, physical_sort_exprs) } @@ -1916,6 +1914,7 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( (agg_expr, filter, physical_sort_exprs) } }; + Ok((agg_expr, filter, order_by)) } other => internal_err!("Invalid aggregate expression '{other:?}'"), diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index 576209ef8e22..73ab51494de6 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -106,9 +106,6 @@ pub struct StateFieldsArgs<'a> { /// The input type of the aggregate function. pub input_type: &'a DataType, - /// If the input type is nullable. - pub input_nullable: bool, - /// The return type of the aggregate function. pub return_type: &'a DataType, diff --git a/datafusion/functions-aggregate/src/array_agg.rs b/datafusion/functions-aggregate/src/array_agg.rs index a0cedf5817ff..8a104fa0b861 100644 --- a/datafusion/functions-aggregate/src/array_agg.rs +++ b/datafusion/functions-aggregate/src/array_agg.rs @@ -22,16 +22,12 @@ use arrow::datatypes::DataType; use arrow_schema::Field; use datafusion_common::cast::as_list_array; -use datafusion_common::utils::array_into_list_array; +use datafusion_common::utils::array_into_list_array_nullable; use datafusion_common::Result; use datafusion_common::ScalarValue; -use datafusion_expr::expr::AggregateFunction; -use datafusion_expr::expr::AggregateFunctionDefinition; -use datafusion_expr::function::AccumulatorArgs; -use datafusion_expr::simplify::SimplifyInfo; +use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::format_state_name; use datafusion_expr::AggregateUDFImpl; -use datafusion_expr::Expr; use datafusion_expr::{Accumulator, Signature, Volatility}; use std::sync::Arc; @@ -84,47 +80,17 @@ impl AggregateUDFImpl for ArrayAgg { )))) } - fn state_fields( - &self, - args: datafusion_expr::function::StateFieldsArgs, - ) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { Ok(vec![Field::new_list( format_state_name(args.name, "array_agg"), Field::new("item", args.input_type.clone(), true), - args.input_nullable, + true, )]) } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { Ok(Box::new(ArrayAggAccumulator::try_new(acc_args.input_type)?)) } - - fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF { - datafusion_expr::ReversedUDAF::Identical - } - - fn simplify( - &self, - ) -> Option { - let simplify = |aggregate_function: AggregateFunction, _: &dyn SimplifyInfo| { - if aggregate_function.order_by.is_some() || aggregate_function.distinct { - Ok(Expr::AggregateFunction(AggregateFunction { - func_def: AggregateFunctionDefinition::BuiltIn( - datafusion_expr::aggregate_function::AggregateFunction::ArrayAgg, - ), - args: aggregate_function.args, - distinct: aggregate_function.distinct, - filter: aggregate_function.filter, - order_by: aggregate_function.order_by, - null_treatment: aggregate_function.null_treatment, - })) - } else { - Ok(Expr::AggregateFunction(aggregate_function)) - } - }; - - Some(Box::new(simplify)) - } } #[derive(Debug)] @@ -150,8 +116,11 @@ impl Accumulator for ArrayAggAccumulator { return Ok(()); } assert!(values.len() == 1, "array_agg can only take 1 param!"); - let val = values[0].clone(); - self.values.push(val); + + let val = Arc::clone(&values[0]); + if val.len() > 0 { + self.values.push(val); + } Ok(()) } @@ -175,17 +144,15 @@ impl Accumulator for ArrayAggAccumulator { fn evaluate(&mut self) -> Result { // Transform Vec to ListArr - let element_arrays: Vec<&dyn Array> = self.values.iter().map(|a| a.as_ref()).collect(); if element_arrays.is_empty() { - let arr = ScalarValue::new_list(&[], &self.datatype); - return Ok(ScalarValue::List(arr)); + return Ok(ScalarValue::new_null_list(self.datatype.clone(), true, 1)); } let concated_array = arrow::compute::concat(&element_arrays)?; - let list_array = array_into_list_array(concated_array); + let list_array = array_into_list_array_nullable(concated_array); Ok(ScalarValue::List(Arc::new(list_array))) } diff --git a/datafusion/functions-aggregate/src/first_last.rs b/datafusion/functions-aggregate/src/first_last.rs index 85891eaf1e33..0e619bacef82 100644 --- a/datafusion/functions-aggregate/src/first_last.rs +++ b/datafusion/functions-aggregate/src/first_last.rs @@ -440,7 +440,6 @@ impl AggregateUDFImpl for LastValue { let StateFieldsArgs { name, input_type, - input_nullable: _, return_type: _, ordering_fields, is_distinct: _, diff --git a/datafusion/functions-aggregate/src/lib.rs b/datafusion/functions-aggregate/src/lib.rs index 1785abc168f0..b39b1955bb07 100644 --- a/datafusion/functions-aggregate/src/lib.rs +++ b/datafusion/functions-aggregate/src/lib.rs @@ -58,8 +58,8 @@ pub mod macros; pub mod approx_distinct; -pub mod correlation; pub mod array_agg; +pub mod correlation; pub mod count; pub mod covariance; pub mod first_last; @@ -94,8 +94,8 @@ pub mod expr_fn { pub use super::approx_median::approx_median; pub use super::approx_percentile_cont::approx_percentile_cont; pub use super::approx_percentile_cont_with_weight::approx_percentile_cont_with_weight; - pub use super::average::avg; pub use super::array_agg::array_agg; + pub use super::average::avg; pub use super::bit_and_or_xor::bit_and; pub use super::bit_and_or_xor::bit_or; pub use super::bit_and_or_xor::bit_xor; diff --git a/datafusion/functions-array/src/planner.rs b/datafusion/functions-array/src/planner.rs index cfbe99b4b7fd..dfb620f84f3a 100644 --- a/datafusion/functions-array/src/planner.rs +++ b/datafusion/functions-array/src/planner.rs @@ -19,8 +19,9 @@ use datafusion_common::{utils::list_ndims, DFSchema, Result}; use datafusion_expr::{ + expr::AggregateFunctionDefinition, planner::{ExprPlanner, PlannerResult, RawBinaryExpr, RawFieldAccessExpr}, - sqlparser, AggregateFunction, Expr, ExprSchemable, GetFieldAccess, + sqlparser, Expr, ExprSchemable, GetFieldAccess, }; use datafusion_functions::expr_fn::get_field; use datafusion_functions_aggregate::nth_value::nth_value_udaf; @@ -153,8 +154,9 @@ impl ExprPlanner for FieldAccessPlanner { } fn is_array_agg(agg_func: &datafusion_expr::expr::AggregateFunction) -> bool { - agg_func.func_def - == datafusion_expr::expr::AggregateFunctionDefinition::BuiltIn( - AggregateFunction::ArrayAgg, - ) + if let AggregateFunctionDefinition::UDF(udf) = &agg_func.func_def { + return udf.name() == "ARRAY_AGG"; + } + + false } diff --git a/datafusion/physical-expr-common/src/aggregate/mod.rs b/datafusion/physical-expr-common/src/aggregate/mod.rs index 94c71033d39d..0e245fd0a66a 100644 --- a/datafusion/physical-expr-common/src/aggregate/mod.rs +++ b/datafusion/physical-expr-common/src/aggregate/mod.rs @@ -96,7 +96,6 @@ pub fn create_aggregate_expr( ordering_fields, is_distinct, input_type: input_exprs_types[0].clone(), - input_nullable: input_phy_exprs[0].nullable(schema)?, })) } @@ -272,7 +271,6 @@ pub struct AggregateFunctionExpr { ordering_fields: Vec, is_distinct: bool, input_type: DataType, - input_nullable: bool, } impl AggregateFunctionExpr { @@ -306,7 +304,6 @@ impl AggregateExpr for AggregateFunctionExpr { let args = StateFieldsArgs { name: &self.name, input_type: &self.input_type, - input_nullable: self.input_nullable, return_type: &self.data_type, ordering_fields: &self.ordering_fields, is_distinct: self.is_distinct, @@ -316,11 +313,7 @@ impl AggregateExpr for AggregateFunctionExpr { } fn field(&self) -> Result { - Ok(Field::new( - &self.name, - self.data_type.clone(), - self.input_nullable, - )) + Ok(Field::new(&self.name, self.data_type.clone(), true)) } fn create_accumulator(&self) -> Result> { diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index ef47eba0a499..58f948b9edc6 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -30,7 +30,7 @@ use std::sync::Arc; use arrow::datatypes::Schema; -use datafusion_common::{not_impl_err, Result}; +use datafusion_common::{internal_err, not_impl_err, Result}; use datafusion_expr::AggregateFunction; use crate::expressions::{self}; diff --git a/datafusion/physical-plan/src/aggregates/no_grouping.rs b/datafusion/physical-plan/src/aggregates/no_grouping.rs index f85164f7f1e2..99417e4ee3e9 100644 --- a/datafusion/physical-plan/src/aggregates/no_grouping.rs +++ b/datafusion/physical-plan/src/aggregates/no_grouping.rs @@ -218,6 +218,7 @@ fn aggregate_batch( Some(filter) => Cow::Owned(batch_filter(&batch, filter)?), None => Cow::Borrowed(&batch), }; + // 1.3 let values = &expr .iter() diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 44a7efc0230f..768419601fa5 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -23,10 +23,9 @@ use datafusion::datasource::file_format::parquet::ParquetSink; use datafusion::physical_expr::window::{NthValueKind, SlidingAggregateWindowExpr}; use datafusion::physical_expr::{PhysicalSortExpr, ScalarFunctionExpr}; use datafusion::physical_plan::expressions::{ - BinaryExpr, CaseExpr, CastExpr, Column, CumeDist, DistinctArrayAgg, - InListExpr, IsNotNullExpr, IsNullExpr, Literal, Max, Min, NegativeExpr, NotExpr, - NthValue, Ntile, OrderSensitiveArrayAgg, Rank, RankType, RowNumber, TryCastExpr, - WindowShift, + BinaryExpr, CaseExpr, CastExpr, Column, CumeDist, DistinctArrayAgg, InListExpr, + IsNotNullExpr, IsNullExpr, Literal, Max, Min, NegativeExpr, NotExpr, NthValue, Ntile, + OrderSensitiveArrayAgg, Rank, RankType, RowNumber, TryCastExpr, WindowShift, }; use datafusion::physical_plan::udaf::AggregateFunctionExpr; use datafusion::physical_plan::windows::{BuiltInWindowExpr, PlainAggregateWindowExpr}; @@ -263,9 +262,7 @@ fn aggr_expr_to_aggr_fn(expr: &dyn AggregateExpr) -> Result { let mut distinct = false; // TODO: remove - let inner = if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::ArrayAgg - } else if aggr_expr.downcast_ref::().is_some() { + let inner = if aggr_expr.downcast_ref::().is_some() { distinct = true; protobuf::AggregateFunction::ArrayAgg } else if aggr_expr.downcast_ref::().is_some() { From 5ddaddb10aae4cabb491073b4b7988c6a41e1571 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sat, 13 Jul 2024 15:44:13 +0800 Subject: [PATCH 08/12] cleanup Signed-off-by: jayzhan211 --- datafusion/core/tests/sql/aggregates.rs | 2 +- datafusion/functions-aggregate/src/array_agg.rs | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/datafusion/core/tests/sql/aggregates.rs b/datafusion/core/tests/sql/aggregates.rs index c3139f6fcdfb..1f4f9e77d5dc 100644 --- a/datafusion/core/tests/sql/aggregates.rs +++ b/datafusion/core/tests/sql/aggregates.rs @@ -37,7 +37,7 @@ async fn csv_query_array_agg_distinct() -> Result<()> { Schema::new(vec![Field::new_list( "ARRAY_AGG(DISTINCT aggregate_test_100.c2)", Field::new("item", DataType::UInt32, true), - false + true ),]) ); diff --git a/datafusion/functions-aggregate/src/array_agg.rs b/datafusion/functions-aggregate/src/array_agg.rs index 8a104fa0b861..e89c86fadb60 100644 --- a/datafusion/functions-aggregate/src/array_agg.rs +++ b/datafusion/functions-aggregate/src/array_agg.rs @@ -60,6 +60,7 @@ impl AggregateUDFImpl for ArrayAgg { self } + // TODO: change name to lowercase fn name(&self) -> &str { "ARRAY_AGG" } From 9d17c1c91fbc1ca7d6f196a833d0db8a4d9453ef Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sat, 13 Jul 2024 16:16:47 +0800 Subject: [PATCH 09/12] fix test Signed-off-by: jayzhan211 --- datafusion/core/tests/dataframe/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index 4e347a8f4e5a..9ede4300838a 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -1389,7 +1389,7 @@ async fn unnest_with_redundant_columns() -> Result<()> { let expected = vec![ "Projection: shapes.shape_id [shape_id:UInt32]", " Unnest: lists[shape_id2] structs[] [shape_id:UInt32, shape_id2:UInt32;N]", - " Aggregate: groupBy=[[shapes.shape_id]], aggr=[[ARRAY_AGG(shapes.shape_id) AS shape_id2]] [shape_id:UInt32, shape_id2:List(Field { name: \"item\", data_type: UInt32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} });N]", + " Aggregate: groupBy=[[shapes.shape_id]], aggr=[[ARRAY_AGG(shapes.shape_id) AS shape_id2]] [shape_id:UInt32, shape_id2:List(Field { name: \"item\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} });N]", " TableScan: shapes projection=[shape_id] [shape_id:UInt32]", ]; From ac640fb739744e63b52cb409b9464ec0143d9bc2 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Wed, 17 Jul 2024 14:15:42 +0800 Subject: [PATCH 10/12] distinct Signed-off-by: jayzhan211 --- datafusion/core/src/physical_planner.rs | 3 +- .../functions-aggregate/src/array_agg.rs | 79 +++- .../src/aggregate/array_agg_distinct.rs | 433 ------------------ .../physical-expr/src/aggregate/build_in.rs | 57 +-- datafusion/physical-expr/src/aggregate/mod.rs | 1 - .../physical-expr/src/expressions/mod.rs | 1 - .../proto/src/physical_plan/to_proto.rs | 9 +- .../tests/cases/roundtrip_logical_plan.rs | 1 + 8 files changed, 87 insertions(+), 497 deletions(-) delete mode 100644 datafusion/physical-expr/src/aggregate/array_agg_distinct.rs diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 16ddcaf9e105..fd989cd991ed 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -1842,8 +1842,9 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( // TODO: Remove this after array_agg are all udafs let (agg_expr, filter, order_by) = match func_def { AggregateFunctionDefinition::UDF(udf) - if udf.name() == "ARRAY_AGG" && (*distinct || order_by.is_some()) => + if udf.name() == "ARRAY_AGG" && order_by.is_some() => { + // not yet support UDAF, fallback to builtin let physical_sort_exprs = match order_by { Some(exprs) => Some(create_physical_sort_exprs( exprs, diff --git a/datafusion/functions-aggregate/src/array_agg.rs b/datafusion/functions-aggregate/src/array_agg.rs index e89c86fadb60..870beda3fbb7 100644 --- a/datafusion/functions-aggregate/src/array_agg.rs +++ b/datafusion/functions-aggregate/src/array_agg.rs @@ -17,7 +17,7 @@ //! Defines physical expressions that can evaluated at runtime during query execution -use arrow::array::{Array, ArrayRef}; +use arrow::array::{Array, ArrayRef, AsArray}; use arrow::datatypes::DataType; use arrow_schema::Field; @@ -29,6 +29,7 @@ use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::format_state_name; use datafusion_expr::AggregateUDFImpl; use datafusion_expr::{Accumulator, Signature, Volatility}; +use std::collections::HashSet; use std::sync::Arc; make_udaf_expr_and_func!( @@ -82,6 +83,14 @@ impl AggregateUDFImpl for ArrayAgg { } fn state_fields(&self, args: StateFieldsArgs) -> Result> { + if args.is_distinct { + return Ok(vec![Field::new_list( + format_state_name(args.name, "distinct_array_agg"), + Field::new("item", args.input_type.clone(), true), + true, + )]); + } + Ok(vec![Field::new_list( format_state_name(args.name, "array_agg"), Field::new("item", args.input_type.clone(), true), @@ -90,6 +99,12 @@ impl AggregateUDFImpl for ArrayAgg { } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + if acc_args.is_distinct { + return Ok(Box::new(DistinctArrayAggAccumulator::try_new( + acc_args.input_type, + )?)); + } + Ok(Box::new(ArrayAggAccumulator::try_new(acc_args.input_type)?)) } } @@ -170,3 +185,65 @@ impl Accumulator for ArrayAggAccumulator { - std::mem::size_of_val(&self.datatype) } } + +#[derive(Debug)] +struct DistinctArrayAggAccumulator { + values: HashSet, + datatype: DataType, +} + +impl DistinctArrayAggAccumulator { + pub fn try_new(datatype: &DataType) -> Result { + Ok(Self { + values: HashSet::new(), + datatype: datatype.clone(), + }) + } +} + +impl Accumulator for DistinctArrayAggAccumulator { + fn state(&mut self) -> Result> { + Ok(vec![self.evaluate()?]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + assert_eq!(values.len(), 1, "batch input should only include 1 column!"); + + let array = &values[0]; + + for i in 0..array.len() { + let scalar = ScalarValue::try_from_array(&array, i)?; + self.values.insert(scalar); + } + + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + if states.is_empty() { + return Ok(()); + } + + states[0] + .as_list::() + .iter() + .flatten() + .try_for_each(|val| self.update_batch(&[val])) + } + + fn evaluate(&mut self) -> Result { + let values: Vec = self.values.iter().cloned().collect(); + if values.is_empty() { + return Ok(ScalarValue::new_null_list(self.datatype.clone(), true, 1)); + } + let arr = ScalarValue::new_list(&values, &self.datatype, true); + Ok(ScalarValue::List(arr)) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + ScalarValue::size_of_hashset(&self.values) + - std::mem::size_of_val(&self.values) + + self.datatype.size() + - std::mem::size_of_val(&self.datatype) + } +} diff --git a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs deleted file mode 100644 index eca6e4ce4f65..000000000000 --- a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs +++ /dev/null @@ -1,433 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Implementations for DISTINCT expressions, e.g. `COUNT(DISTINCT c)` - -use std::any::Any; -use std::collections::HashSet; -use std::fmt::Debug; -use std::sync::Arc; - -use arrow::array::ArrayRef; -use arrow::datatypes::{DataType, Field}; -use arrow_array::cast::AsArray; - -use crate::aggregate::utils::down_cast_any_ref; -use crate::expressions::format_state_name; -use crate::{AggregateExpr, PhysicalExpr}; - -use datafusion_common::{Result, ScalarValue}; -use datafusion_expr::Accumulator; - -/// Expression for a ARRAY_AGG(DISTINCT) aggregation. -#[derive(Debug)] -pub struct DistinctArrayAgg { - /// Column name - name: String, - /// The DataType for the input expression - input_data_type: DataType, - /// The input expression - expr: Arc, -} - -impl DistinctArrayAgg { - /// Create a new DistinctArrayAgg aggregate function - pub fn new( - expr: Arc, - name: impl Into, - input_data_type: DataType, - ) -> Self { - let name = name.into(); - Self { - name, - input_data_type, - expr, - } - } -} - -impl AggregateExpr for DistinctArrayAgg { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - Ok(Field::new_list( - &self.name, - // This should be the same as return type of AggregateFunction::ArrayAgg - Field::new("item", self.input_data_type.clone(), true), - true, - )) - } - - fn create_accumulator(&self) -> Result> { - Ok(Box::new(DistinctArrayAggAccumulator::try_new( - &self.input_data_type, - )?)) - } - - fn state_fields(&self) -> Result> { - Ok(vec![Field::new_list( - format_state_name(&self.name, "distinct_array_agg"), - Field::new("item", self.input_data_type.clone(), true), - true, - )]) - } - - fn expressions(&self) -> Vec> { - vec![Arc::clone(&self.expr)] - } - - fn name(&self) -> &str { - &self.name - } -} - -impl PartialEq for DistinctArrayAgg { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.input_data_type == x.input_data_type - && self.expr.eq(&x.expr) - }) - .unwrap_or(false) - } -} - -#[derive(Debug)] -struct DistinctArrayAggAccumulator { - values: HashSet, - datatype: DataType, -} - -impl DistinctArrayAggAccumulator { - pub fn try_new(datatype: &DataType) -> Result { - Ok(Self { - values: HashSet::new(), - datatype: datatype.clone(), - }) - } -} - -impl Accumulator for DistinctArrayAggAccumulator { - fn state(&mut self) -> Result> { - Ok(vec![self.evaluate()?]) - } - - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - assert_eq!(values.len(), 1, "batch input should only include 1 column!"); - - let array = &values[0]; - - for i in 0..array.len() { - let scalar = ScalarValue::try_from_array(&array, i)?; - self.values.insert(scalar); - } - - Ok(()) - } - - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - if states.is_empty() { - return Ok(()); - } - - states[0] - .as_list::() - .iter() - .flatten() - .try_for_each(|val| self.update_batch(&[val])) - } - - fn evaluate(&mut self) -> Result { - let values: Vec = self.values.iter().cloned().collect(); - if values.is_empty() { - return Ok(ScalarValue::new_null_list(self.datatype.clone(), true, 1)); - } - let arr = ScalarValue::new_list(&values, &self.datatype, true); - Ok(ScalarValue::List(arr)) - } - - fn size(&self) -> usize { - std::mem::size_of_val(self) + ScalarValue::size_of_hashset(&self.values) - - std::mem::size_of_val(&self.values) - + self.datatype.size() - - std::mem::size_of_val(&self.datatype) - } -} - -#[cfg(test)] -mod tests { - - use super::*; - use crate::expressions::col; - use crate::expressions::tests::aggregate; - use arrow::array::Int32Array; - use arrow::datatypes::Schema; - use arrow::record_batch::RecordBatch; - use arrow_array::types::Int32Type; - use arrow_array::Array; - use arrow_array::ListArray; - use arrow_buffer::OffsetBuffer; - use datafusion_common::internal_err; - - // arrow::compute::sort can't sort nested ListArray directly, so we compare the scalar values pair-wise. - fn compare_list_contents( - expected: Vec, - actual: ScalarValue, - ) -> Result<()> { - let array = actual.to_array()?; - let list_array = array.as_list::(); - let inner_array = list_array.value(0); - let mut actual_scalars = vec![]; - for index in 0..inner_array.len() { - let sv = ScalarValue::try_from_array(&inner_array, index)?; - actual_scalars.push(sv); - } - - if actual_scalars.len() != expected.len() { - return internal_err!( - "Expected and actual list lengths differ: expected={}, actual={}", - expected.len(), - actual_scalars.len() - ); - } - - let mut seen = vec![false; expected.len()]; - for v in expected { - let mut found = false; - for (i, sv) in actual_scalars.iter().enumerate() { - if sv == &v { - seen[i] = true; - found = true; - break; - } - } - if !found { - return internal_err!( - "Expected value {:?} not found in actual values {:?}", - v, - actual_scalars - ); - } - } - - Ok(()) - } - - fn check_distinct_array_agg( - input: ArrayRef, - expected: Vec, - datatype: DataType, - ) -> Result<()> { - let schema = Schema::new(vec![Field::new("a", datatype.clone(), false)]); - let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![input])?; - - let agg = Arc::new(DistinctArrayAgg::new( - col("a", &schema)?, - "bla".to_string(), - datatype, - )); - let actual = aggregate(&batch, agg)?; - compare_list_contents(expected, actual) - } - - fn check_merge_distinct_array_agg( - input1: ArrayRef, - input2: ArrayRef, - expected: Vec, - datatype: DataType, - ) -> Result<()> { - let schema = Schema::new(vec![Field::new("a", datatype.clone(), false)]); - let agg = Arc::new(DistinctArrayAgg::new( - col("a", &schema)?, - "bla".to_string(), - datatype, - )); - - let mut accum1 = agg.create_accumulator()?; - let mut accum2 = agg.create_accumulator()?; - - accum1.update_batch(&[input1])?; - accum2.update_batch(&[input2])?; - - let array = accum2.state()?[0].raw_data()?; - accum1.merge_batch(&[array])?; - - let actual = accum1.evaluate()?; - compare_list_contents(expected, actual) - } - - #[test] - fn distinct_array_agg_i32() -> Result<()> { - let col: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 4, 5, 2])); - - let expected = vec![ - ScalarValue::Int32(Some(1)), - ScalarValue::Int32(Some(2)), - ScalarValue::Int32(Some(4)), - ScalarValue::Int32(Some(5)), - ScalarValue::Int32(Some(7)), - ]; - - check_distinct_array_agg(col, expected, DataType::Int32) - } - - #[test] - fn merge_distinct_array_agg_i32() -> Result<()> { - let col1: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 4, 5, 2])); - let col2: ArrayRef = Arc::new(Int32Array::from(vec![1, 3, 7, 8, 4])); - - let expected = vec![ - ScalarValue::Int32(Some(1)), - ScalarValue::Int32(Some(2)), - ScalarValue::Int32(Some(3)), - ScalarValue::Int32(Some(4)), - ScalarValue::Int32(Some(5)), - ScalarValue::Int32(Some(7)), - ScalarValue::Int32(Some(8)), - ]; - - check_merge_distinct_array_agg(col1, col2, expected, DataType::Int32) - } - - #[test] - fn distinct_array_agg_nested() -> Result<()> { - // [[1, 2, 3], [4, 5]] - let a1 = ListArray::from_iter_primitive::(vec![Some(vec![ - Some(1), - Some(2), - Some(3), - ])]); - let a2 = ListArray::from_iter_primitive::(vec![Some(vec![ - Some(4), - Some(5), - ])]); - let l1 = ListArray::new( - Arc::new(Field::new("item", a1.data_type().to_owned(), true)), - OffsetBuffer::from_lengths([2]), - arrow::compute::concat(&[&a1, &a2]).unwrap(), - None, - ); - - // [[6], [7, 8]] - let a1 = - ListArray::from_iter_primitive::(vec![Some(vec![Some(6)])]); - let a2 = ListArray::from_iter_primitive::(vec![Some(vec![ - Some(7), - Some(8), - ])]); - let l2 = ListArray::new( - Arc::new(Field::new("item", a1.data_type().to_owned(), true)), - OffsetBuffer::from_lengths([2]), - arrow::compute::concat(&[&a1, &a2]).unwrap(), - None, - ); - - // [[9]] - let a1 = - ListArray::from_iter_primitive::(vec![Some(vec![Some(9)])]); - let l3 = ListArray::new( - Arc::new(Field::new("item", a1.data_type().to_owned(), true)), - OffsetBuffer::from_lengths([1]), - Arc::new(a1), - None, - ); - - let l1 = ScalarValue::List(Arc::new(l1)); - let l2 = ScalarValue::List(Arc::new(l2)); - let l3 = ScalarValue::List(Arc::new(l3)); - - // Duplicate l1 and l3 in the input array and check that it is deduped in the output. - let array = ScalarValue::iter_to_array(vec![ - l1.clone(), - l2.clone(), - l3.clone(), - l3.clone(), - l1.clone(), - ]) - .unwrap(); - let expected = vec![l1, l2, l3]; - - check_distinct_array_agg( - array, - expected, - DataType::List(Arc::new(Field::new_list( - "item", - Field::new("item", DataType::Int32, true), - true, - ))), - ) - } - - #[test] - fn merge_distinct_array_agg_nested() -> Result<()> { - // [[1, 2], [3, 4]] - let a1 = ListArray::from_iter_primitive::(vec![Some(vec![ - Some(1), - Some(2), - ])]); - let a2 = ListArray::from_iter_primitive::(vec![Some(vec![ - Some(3), - Some(4), - ])]); - let l1 = ListArray::new( - Arc::new(Field::new("item", a1.data_type().to_owned(), true)), - OffsetBuffer::from_lengths([2]), - arrow::compute::concat(&[&a1, &a2]).unwrap(), - None, - ); - - let a1 = - ListArray::from_iter_primitive::(vec![Some(vec![Some(5)])]); - let l2 = ListArray::new( - Arc::new(Field::new("item", a1.data_type().to_owned(), true)), - OffsetBuffer::from_lengths([1]), - Arc::new(a1), - None, - ); - - // [[6, 7], [8]] - let a1 = ListArray::from_iter_primitive::(vec![Some(vec![ - Some(6), - Some(7), - ])]); - let a2 = - ListArray::from_iter_primitive::(vec![Some(vec![Some(8)])]); - let l3 = ListArray::new( - Arc::new(Field::new("item", a1.data_type().to_owned(), true)), - OffsetBuffer::from_lengths([2]), - arrow::compute::concat(&[&a1, &a2]).unwrap(), - None, - ); - - let l1 = ScalarValue::List(Arc::new(l1)); - let l2 = ScalarValue::List(Arc::new(l2)); - let l3 = ScalarValue::List(Arc::new(l3)); - - // Duplicate l1 in the input array and check that it is deduped in the output. - let input1 = ScalarValue::iter_to_array(vec![l1.clone(), l2.clone()]).unwrap(); - let input2 = ScalarValue::iter_to_array(vec![l1.clone(), l3.clone()]).unwrap(); - - let expected = vec![l1, l2, l3]; - - check_merge_distinct_array_agg(input1, input2, expected, DataType::Int32) - } -} diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index 58f948b9edc6..9c270561f37d 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -30,7 +30,7 @@ use std::sync::Arc; use arrow::datatypes::Schema; -use datafusion_common::{internal_err, not_impl_err, Result}; +use datafusion_common::{internal_err, Result}; use datafusion_expr::AggregateFunction; use crate::expressions::{self}; @@ -60,7 +60,7 @@ pub fn create_aggregate_expr( .collect::>>()?; let input_phy_exprs = input_phy_exprs.to_vec(); Ok(match (fun, distinct) { - (AggregateFunction::ArrayAgg, false) => { + (AggregateFunction::ArrayAgg, _) => { let expr = Arc::clone(&input_phy_exprs[0]); if ordering_req.is_empty() { @@ -77,15 +77,6 @@ pub fn create_aggregate_expr( )) } } - (AggregateFunction::ArrayAgg, true) => { - if !ordering_req.is_empty() { - return not_impl_err!( - "ARRAY_AGG(DISTINCT ORDER BY a ASC) order-sensitive aggregations are not available" - ); - } - let expr = Arc::clone(&input_phy_exprs[0]); - Arc::new(expressions::DistinctArrayAgg::new(expr, name, data_type)) - } (AggregateFunction::Min, _) => Arc::new(expressions::Min::new( Arc::clone(&input_phy_exprs[0]), name, @@ -106,51 +97,9 @@ mod tests { use datafusion_common::plan_err; use datafusion_expr::{type_coercion, Signature}; - use crate::expressions::{try_cast, DistinctArrayAgg, Max, Min}; + use crate::expressions::{try_cast, Max, Min}; use super::*; - #[test] - fn test_approx_expr() -> Result<()> { - let funcs = vec![AggregateFunction::ArrayAgg]; - let data_types = vec![ - DataType::UInt32, - DataType::Int32, - DataType::Float32, - DataType::Float64, - DataType::Decimal128(10, 2), - DataType::Utf8, - ]; - for fun in funcs { - for data_type in &data_types { - let input_schema = - Schema::new(vec![Field::new("c1", data_type.clone(), true)]); - let input_phy_exprs: Vec> = vec![Arc::new( - expressions::Column::new_with_schema("c1", &input_schema).unwrap(), - )]; - - let result_distinct = create_physical_agg_expr_for_test( - &fun, - true, - &input_phy_exprs[0..1], - &input_schema, - "c1", - )?; - if fun == AggregateFunction::ArrayAgg { - assert!(result_distinct.as_any().is::()); - assert_eq!("c1", result_distinct.name()); - assert_eq!( - Field::new_list( - "c1", - Field::new("item", data_type.clone(), true), - true, - ), - result_distinct.field().unwrap() - ); - } - } - } - Ok(()) - } #[test] fn test_min_max_expr() -> Result<()> { diff --git a/datafusion/physical-expr/src/aggregate/mod.rs b/datafusion/physical-expr/src/aggregate/mod.rs index c7baac120347..749cf2be7297 100644 --- a/datafusion/physical-expr/src/aggregate/mod.rs +++ b/datafusion/physical-expr/src/aggregate/mod.rs @@ -17,7 +17,6 @@ pub use datafusion_physical_expr_common::aggregate::AggregateExpr; -pub(crate) mod array_agg_distinct; pub(crate) mod array_agg_ordered; #[macro_use] pub(crate) mod min_max; diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index 5131639e523a..e358fb8decac 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -34,7 +34,6 @@ mod try_cast; pub mod helpers { pub use crate::aggregate::min_max::{max, min}; } -pub use crate::aggregate::array_agg_distinct::DistinctArrayAgg; pub use crate::aggregate::array_agg_ordered::OrderSensitiveArrayAgg; pub use crate::aggregate::build_in::create_aggregate_expr; pub use crate::aggregate::min_max::{Max, MaxAccumulator, Min, MinAccumulator}; diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 768419601fa5..01ae2d9177d4 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -23,8 +23,8 @@ use datafusion::datasource::file_format::parquet::ParquetSink; use datafusion::physical_expr::window::{NthValueKind, SlidingAggregateWindowExpr}; use datafusion::physical_expr::{PhysicalSortExpr, ScalarFunctionExpr}; use datafusion::physical_plan::expressions::{ - BinaryExpr, CaseExpr, CastExpr, Column, CumeDist, DistinctArrayAgg, InListExpr, - IsNotNullExpr, IsNullExpr, Literal, Max, Min, NegativeExpr, NotExpr, NthValue, Ntile, + BinaryExpr, CaseExpr, CastExpr, Column, CumeDist, InListExpr, IsNotNullExpr, + IsNullExpr, Literal, Max, Min, NegativeExpr, NotExpr, NthValue, Ntile, OrderSensitiveArrayAgg, Rank, RankType, RowNumber, TryCastExpr, WindowShift, }; use datafusion::physical_plan::udaf::AggregateFunctionExpr; @@ -262,10 +262,7 @@ fn aggr_expr_to_aggr_fn(expr: &dyn AggregateExpr) -> Result { let mut distinct = false; // TODO: remove - let inner = if aggr_expr.downcast_ref::().is_some() { - distinct = true; - protobuf::AggregateFunction::ArrayAgg - } else if aggr_expr.downcast_ref::().is_some() { + let inner = if aggr_expr.downcast_ref::().is_some() { protobuf::AggregateFunction::ArrayAgg } else if aggr_expr.downcast_ref::().is_some() { protobuf::AggregateFunction::Min diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 9d3a60eb6115..acfd4122ea80 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -703,6 +703,7 @@ async fn roundtrip_expr_api() -> Result<()> { bool_and(lit(true)), bool_or(lit(true)), array_agg(lit(1)), + array_agg(lit(1)), ]; // ensure expressions created with the expr api can be round tripped From f67d8f761a62cb80d9a6fab5c861d976bd79001d Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Wed, 17 Jul 2024 15:17:28 +0800 Subject: [PATCH 11/12] fix Signed-off-by: jayzhan211 --- .../physical-expr/src/expressions/mod.rs | 157 ------------------ .../proto/src/physical_plan/to_proto.rs | 6 +- .../tests/cases/roundtrip_logical_plan.rs | 1 - 3 files changed, 4 insertions(+), 160 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index e358fb8decac..fa80bc9873f0 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -61,160 +61,3 @@ pub use negative::{negative, NegativeExpr}; pub use no_op::NoOp; pub use not::{not, NotExpr}; pub use try_cast::{try_cast, TryCastExpr}; - -#[cfg(test)] -pub(crate) mod tests { - use std::sync::Arc; - - use crate::AggregateExpr; - - use arrow::record_batch::RecordBatch; - use datafusion_common::{Result, ScalarValue}; - - /// macro to perform an aggregation using [`datafusion_expr::Accumulator`] and verify the - /// result. - #[macro_export] - macro_rules! generic_test_op { - ($ARRAY:expr, $DATATYPE:expr, $OP:ident, $EXPECTED:expr) => { - generic_test_op!($ARRAY, $DATATYPE, $OP, $EXPECTED, $EXPECTED.data_type()) - }; - ($ARRAY:expr, $DATATYPE:expr, $OP:ident, $EXPECTED:expr, $EXPECTED_DATATYPE:expr) => {{ - let schema = Schema::new(vec![Field::new("a", $DATATYPE, true)]); - - let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![$ARRAY])?; - - let agg = Arc::new(<$OP>::new( - col("a", &schema)?, - "bla".to_string(), - $EXPECTED_DATATYPE, - )); - let actual = aggregate(&batch, agg)?; - let expected = ScalarValue::from($EXPECTED); - - assert_eq!(expected, actual); - - Ok(()) as Result<(), ::datafusion_common::DataFusionError> - }}; - } - - /// Same as [`generic_test_op`] but with support for providing a 4th argument, usually - /// a boolean to indicate if using the distinct version of the op. - #[macro_export] - macro_rules! generic_test_distinct_op { - ($ARRAY:expr, $DATATYPE:expr, $OP:ident, $DISTINCT:expr, $EXPECTED:expr) => { - generic_test_distinct_op!( - $ARRAY, - $DATATYPE, - $OP, - $DISTINCT, - $EXPECTED, - $EXPECTED.data_type() - ) - }; - ($ARRAY:expr, $DATATYPE:expr, $OP:ident, $DISTINCT:expr, $EXPECTED:expr, $EXPECTED_DATATYPE:expr) => {{ - let schema = Schema::new(vec![Field::new("a", $DATATYPE, true)]); - - let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![$ARRAY])?; - - let agg = Arc::new(<$OP>::new( - col("a", &schema)?, - "bla".to_string(), - $EXPECTED_DATATYPE, - $DISTINCT, - )); - let actual = aggregate(&batch, agg)?; - let expected = ScalarValue::from($EXPECTED); - - assert_eq!(expected, actual); - - Ok(()) as Result<(), ::datafusion_common::DataFusionError> - }}; - } - - /// macro to perform an aggregation using [`crate::GroupsAccumulator`] and verify the result. - /// - /// The difference between this and the above `generic_test_op` is that the former checks - /// the old slow-path [`datafusion_expr::Accumulator`] implementation, while this checks - /// the new [`crate::GroupsAccumulator`] implementation. - #[macro_export] - macro_rules! generic_test_op_new { - ($ARRAY:expr, $DATATYPE:expr, $OP:ident, $EXPECTED:expr) => { - generic_test_op_new!( - $ARRAY, - $DATATYPE, - $OP, - $EXPECTED, - $EXPECTED.data_type().clone() - ) - }; - ($ARRAY:expr, $DATATYPE:expr, $OP:ident, $EXPECTED:expr, $EXPECTED_DATATYPE:expr) => {{ - let schema = Schema::new(vec![Field::new("a", $DATATYPE, true)]); - - let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![$ARRAY])?; - - let agg = Arc::new(<$OP>::new( - col("a", &schema)?, - "bla".to_string(), - $EXPECTED_DATATYPE, - )); - let actual = aggregate_new(&batch, agg)?; - assert_eq!($EXPECTED, &actual); - - Ok(()) as Result<(), ::datafusion_common::DataFusionError> - }}; - } - - /// macro to perform an aggregation with two inputs and verify the result. - #[macro_export] - macro_rules! generic_test_op2 { - ($ARRAY1:expr, $ARRAY2:expr, $DATATYPE1:expr, $DATATYPE2:expr, $OP:ident, $EXPECTED:expr) => { - generic_test_op2!( - $ARRAY1, - $ARRAY2, - $DATATYPE1, - $DATATYPE2, - $OP, - $EXPECTED, - $EXPECTED.data_type() - ) - }; - ($ARRAY1:expr, $ARRAY2:expr, $DATATYPE1:expr, $DATATYPE2:expr, $OP:ident, $EXPECTED:expr, $EXPECTED_DATATYPE:expr) => {{ - let schema = Schema::new(vec![ - Field::new("a", $DATATYPE1, true), - Field::new("b", $DATATYPE2, true), - ]); - let batch = - RecordBatch::try_new(Arc::new(schema.clone()), vec![$ARRAY1, $ARRAY2])?; - - let agg = Arc::new(<$OP>::new( - col("a", &schema)?, - col("b", &schema)?, - "bla".to_string(), - $EXPECTED_DATATYPE, - )); - let actual = aggregate(&batch, agg)?; - let expected = ScalarValue::from($EXPECTED); - - assert_eq!(expected, actual); - - Ok(()) - }}; - } - - pub fn aggregate( - batch: &RecordBatch, - agg: Arc, - ) -> Result { - let mut accum = agg.create_accumulator()?; - let expr = agg.expressions(); - let values = expr - .iter() - .map(|e| { - e.evaluate(batch) - .and_then(|v| v.into_array(batch.num_rows())) - }) - .collect::>>()?; - accum.update_batch(&values)?; - accum.evaluate() - } -} diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 01ae2d9177d4..2921eb5c8f96 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -259,7 +259,6 @@ struct AggrFn { fn aggr_expr_to_aggr_fn(expr: &dyn AggregateExpr) -> Result { let aggr_expr = expr.as_any(); - let mut distinct = false; // TODO: remove let inner = if aggr_expr.downcast_ref::().is_some() { @@ -272,7 +271,10 @@ fn aggr_expr_to_aggr_fn(expr: &dyn AggregateExpr) -> Result { return not_impl_err!("Aggregate function not supported: {expr:?}"); }; - Ok(AggrFn { inner, distinct }) + Ok(AggrFn { + inner, + distinct: false, + }) } pub fn serialize_physical_sort_exprs( diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index acfd4122ea80..9d3a60eb6115 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -703,7 +703,6 @@ async fn roundtrip_expr_api() -> Result<()> { bool_and(lit(true)), bool_or(lit(true)), array_agg(lit(1)), - array_agg(lit(1)), ]; // ensure expressions created with the expr api can be round tripped From a9b925bc4b7ccabd436dd142e990fab9c99ed037 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Fri, 19 Jul 2024 07:51:48 +0800 Subject: [PATCH 12/12] address comment Signed-off-by: jayzhan211 --- datafusion/expr/src/expr_fn.rs | 13 ---------- .../functions-aggregate/src/array_agg.rs | 26 ++++++++++++++----- .../proto/src/physical_plan/to_proto.rs | 2 +- .../tests/cases/roundtrip_logical_plan.rs | 1 + .../sqllogictest/test_files/aggregate.slt | 2 +- 5 files changed, 22 insertions(+), 22 deletions(-) diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 0f7b33e12461..9187e8352205 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -171,19 +171,6 @@ pub fn max(expr: Expr) -> Expr { )) } -// TODO: remove -/// Create an expression to represent the array_agg() aggregate function -pub fn array_agg(expr: Expr) -> Expr { - Expr::AggregateFunction(AggregateFunction::new( - aggregate_function::AggregateFunction::ArrayAgg, - vec![expr], - false, - None, - None, - None, - )) -} - /// Return a new expression with bitwise AND pub fn bitwise_and(left: Expr, right: Expr) -> Expr { Expr::BinaryExpr(BinaryExpr::new( diff --git a/datafusion/functions-aggregate/src/array_agg.rs b/datafusion/functions-aggregate/src/array_agg.rs index 870beda3fbb7..9ad453d7a4b2 100644 --- a/datafusion/functions-aggregate/src/array_agg.rs +++ b/datafusion/functions-aggregate/src/array_agg.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! Defines physical expressions that can evaluated at runtime during query execution +//! `ARRAY_AGG` aggregate implementation: [`ArrayAgg`] use arrow::array::{Array, ArrayRef, AsArray}; use arrow::datatypes::DataType; @@ -23,8 +23,8 @@ use arrow_schema::Field; use datafusion_common::cast::as_list_array; use datafusion_common::utils::array_into_list_array_nullable; -use datafusion_common::Result; use datafusion_common::ScalarValue; +use datafusion_common::{internal_err, Result}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::format_state_name; use datafusion_expr::AggregateUDFImpl; @@ -126,12 +126,15 @@ impl ArrayAggAccumulator { } impl Accumulator for ArrayAggAccumulator { - // Append value like Int64Array(1,2,3) fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + // Append value like Int64Array(1,2,3) if values.is_empty() { return Ok(()); } - assert!(values.len() == 1, "array_agg can only take 1 param!"); + + if values.len() != 1 { + return internal_err!("expects single batch"); + } let val = Arc::clone(&values[0]); if val.len() > 0 { @@ -140,12 +143,15 @@ impl Accumulator for ArrayAggAccumulator { Ok(()) } - // Append value like ListArray(Int64Array(1,2,3), Int64Array(4,5,6)) fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + // Append value like ListArray(Int64Array(1,2,3), Int64Array(4,5,6)) if states.is_empty() { return Ok(()); } - assert!(states.len() == 1, "array_agg states must be singleton!"); + + if states.len() != 1 { + return internal_err!("expects single state"); + } let list_arr = as_list_array(&states[0])?; for arr in list_arr.iter().flatten() { @@ -207,7 +213,9 @@ impl Accumulator for DistinctArrayAggAccumulator { } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - assert_eq!(values.len(), 1, "batch input should only include 1 column!"); + if values.len() != 1 { + return internal_err!("expects single batch"); + } let array = &values[0]; @@ -224,6 +232,10 @@ impl Accumulator for DistinctArrayAggAccumulator { return Ok(()); } + if states.len() != 1 { + return internal_err!("expects single state"); + } + states[0] .as_list::() .iter() diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 2921eb5c8f96..e9a90fce2663 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -260,7 +260,7 @@ struct AggrFn { fn aggr_expr_to_aggr_fn(expr: &dyn AggregateExpr) -> Result { let aggr_expr = expr.as_any(); - // TODO: remove + // TODO: remove OrderSensitiveArrayAgg let inner = if aggr_expr.downcast_ref::().is_some() { protobuf::AggregateFunction::ArrayAgg } else if aggr_expr.downcast_ref::().is_some() { diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 9d3a60eb6115..11945f39589a 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -703,6 +703,7 @@ async fn roundtrip_expr_api() -> Result<()> { bool_and(lit(true)), bool_or(lit(true)), array_agg(lit(1)), + array_agg(lit(1)).distinct().build().unwrap(), ]; // ensure expressions created with the expr api can be round tripped diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index a0140b1c5292..1976951b8ce6 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -183,7 +183,7 @@ CREATE TABLE array_agg_distinct_list_table AS VALUES ; # Apply array_sort to have deterministic result, higher dimension nested array also works but not for array sort, -# so they are covered in `datafusion/physical-expr/src/aggregate/array_agg_distinct.rs` +# so they are covered in `datafusion/functions-aggregate/src/array_agg.rs` query ?? select array_sort(c1), array_sort(c2) from ( select array_agg(distinct column1) as c1, array_agg(distinct column2) as c2 from array_agg_distinct_list_table