diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 15d325288b07..aead7495830b 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -741,9 +741,54 @@ impl DefaultPhysicalPlanner { }) .collect::>>()?; - let (aggregates, filters, _order_bys): (Vec<_>, Vec<_>, Vec<_>) = + let (mut aggregates, filters, _order_bys): (Vec<_>, Vec<_>, Vec<_>) = multiunzip(agg_filter); + let mut async_exprs = Vec::new(); + let num_input_columns = physical_input_schema.fields().len(); + + for agg_func in &mut aggregates { + match self.try_plan_async_exprs( + num_input_columns, + PlannedExprResult::Expr(agg_func.expressions()), + physical_input_schema.as_ref(), + )? { + PlanAsyncExpr::Async( + async_map, + PlannedExprResult::Expr(physical_exprs), + ) => { + async_exprs.extend(async_map.async_exprs); + + if let Some(new_agg_func) = agg_func.with_new_expressions( + physical_exprs, + agg_func + .order_bys() + .iter() + .cloned() + .map(|x| x.expr) + .collect(), + ) { + *agg_func = Arc::new(new_agg_func); + } else { + return internal_err!("Failed to plan async expression"); + } + } + PlanAsyncExpr::Sync(PlannedExprResult::Expr(_)) => { + // Do nothing + } + _ => { + return internal_err!( + "Unexpected result from try_plan_async_exprs" + ) + } + } + } + let input_exec = if !async_exprs.is_empty() { + Arc::new(AsyncFuncExec::try_new(async_exprs, input_exec)?) + } else { + input_exec + }; + let initial_aggr = Arc::new(AggregateExec::try_new( AggregateMode::Partial, groups.clone(), @@ -2272,11 +2317,13 @@ impl DefaultPhysicalPlanner { } } +#[derive(Debug)] enum PlannedExprResult { ExprWithName(Vec<(Arc, String)>), Expr(Vec>), } +#[derive(Debug)] enum PlanAsyncExpr { Sync(PlannedExprResult), Async(AsyncMapper, PlannedExprResult), diff --git a/datafusion/sqllogictest/src/test_context.rs b/datafusion/sqllogictest/src/test_context.rs index 143e3ef1a89b..b499401e5589 100644 --- a/datafusion/sqllogictest/src/test_context.rs +++ b/datafusion/sqllogictest/src/test_context.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use std::any::Any; use std::collections::HashMap; use std::fs::File; use std::io::Write; @@ -31,8 +32,13 @@ use arrow::record_batch::RecordBatch; use datafusion::catalog::{ CatalogProvider, MemoryCatalogProvider, MemorySchemaProvider, Session, }; -use datafusion::common::DataFusionError; -use datafusion::logical_expr::{create_udf, ColumnarValue, Expr, ScalarUDF, Volatility}; +use datafusion::common::{not_impl_err, DataFusionError, Result}; +use datafusion::functions::math::abs; +use datafusion::logical_expr::async_udf::{AsyncScalarUDF, AsyncScalarUDFImpl}; +use datafusion::logical_expr::{ + create_udf, ColumnarValue, Expr, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, + Signature, Volatility, +}; use datafusion::physical_plan::ExecutionPlan; use datafusion::prelude::*; use datafusion::{ @@ -133,6 +139,10 @@ impl TestContext { info!("Registering table with union column"); register_union_table(test_ctx.session_ctx()) } + "async_udf.slt" => { + info!("Registering dummy async udf"); + register_async_abs_udf(test_ctx.session_ctx()) + } _ => { info!("Using default SessionContext"); } @@ -235,7 +245,7 @@ pub async fn register_temp_table(ctx: &SessionContext) { #[async_trait] impl TableProvider for TestTable { - fn as_any(&self) -> &dyn std::any::Any { + fn as_any(&self) -> &dyn Any { self } @@ -458,3 +468,48 @@ fn register_union_table(ctx: &SessionContext) { ctx.register_batch("union_table", batch).unwrap(); } + +fn register_async_abs_udf(ctx: &SessionContext) { + #[derive(Debug, PartialEq, Eq, Hash)] + struct AsyncAbs { + inner_abs: Arc, + } + impl AsyncAbs { + fn new() -> Self { + AsyncAbs { inner_abs: abs() } + } + } + impl ScalarUDFImpl for AsyncAbs { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "async_abs" + } + + fn signature(&self) -> &Signature { + self.inner_abs.signature() + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + self.inner_abs.return_type(arg_types) + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + not_impl_err!("{} can only be called from async contexts", self.name()) + } + } + #[async_trait] + impl AsyncScalarUDFImpl for AsyncAbs { + async fn invoke_async_with_args( + &self, + args: ScalarFunctionArgs, + ) -> Result { + return self.inner_abs.invoke_with_args(args); + } + } + let async_abs = AsyncAbs::new(); + let udf = AsyncScalarUDF::new(Arc::new(async_abs)); + ctx.register_udf(udf.into_scalar_udf()); +} diff --git a/datafusion/sqllogictest/test_files/async_udf.slt b/datafusion/sqllogictest/test_files/async_udf.slt new file mode 100644 index 000000000000..c61d02cfecfd --- /dev/null +++ b/datafusion/sqllogictest/test_files/async_udf.slt @@ -0,0 +1,107 @@ + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +statement ok +create table data(x int) as values (-10), (2); + +# Async udf can be used in aggregation +query I +select min(async_abs(x)) from data; +---- +2 + +query TT +explain select min(async_abs(x)) from data; +---- +logical_plan +01)Aggregate: groupBy=[[]], aggr=[[min(async_abs(data.x))]] +02)--TableScan: data projection=[x] +physical_plan +01)AggregateExec: mode=Final, gby=[], aggr=[min(async_abs(data.x))] +02)--CoalescePartitionsExec +03)----AggregateExec: mode=Partial, gby=[], aggr=[min(async_abs(data.x))] +04)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +05)--------AsyncFuncExec: async_expr=[async_expr(name=__async_fn_0, expr=async_abs(x@0))] +06)----------CoalesceBatchesExec: target_batch_size=8192 +07)------------DataSourceExec: partitions=1, partition_sizes=[1] + +# Async udf can be used in aggregation with group by +query I rowsort +select min(async_abs(x)) from data group by async_abs(x); +---- +10 +2 + +query TT +explain select min(async_abs(x)) from data group by async_abs(x); +---- +logical_plan +01)Projection: min(async_abs(data.x)) +02)--Aggregate: groupBy=[[__common_expr_1 AS async_abs(data.x)]], aggr=[[min(__common_expr_1 AS async_abs(data.x))]] +03)----Projection: async_abs(data.x) AS __common_expr_1 +04)------TableScan: data projection=[x] +physical_plan +01)ProjectionExec: expr=[min(async_abs(data.x))@1 as min(async_abs(data.x))] +02)--AggregateExec: mode=FinalPartitioned, gby=[async_abs(data.x)@0 as async_abs(data.x)], aggr=[min(async_abs(data.x))] +03)----CoalesceBatchesExec: target_batch_size=8192 +04)------RepartitionExec: partitioning=Hash([async_abs(data.x)@0], 4), input_partitions=4 +05)--------AggregateExec: mode=Partial, gby=[__common_expr_1@0 as async_abs(data.x)], aggr=[min(async_abs(data.x))] +06)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +07)------------ProjectionExec: expr=[__async_fn_0@1 as __common_expr_1] +08)--------------AsyncFuncExec: async_expr=[async_expr(name=__async_fn_0, expr=async_abs(x@0))] +09)----------------CoalesceBatchesExec: target_batch_size=8192 +10)------------------DataSourceExec: partitions=1, partition_sizes=[1] + +# Async udf can be used in filter +query I +select * from data where async_abs(x) < 5; +---- +2 + +query TT +explain select * from data where async_abs(x) < 5; +---- +logical_plan +01)Filter: async_abs(data.x) < Int32(5) +02)--TableScan: data projection=[x] +physical_plan +01)CoalesceBatchesExec: target_batch_size=8192 +02)--FilterExec: __async_fn_0@1 < 5, projection=[x@0] +03)----RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +04)------AsyncFuncExec: async_expr=[async_expr(name=__async_fn_0, expr=async_abs(x@0))] +05)--------CoalesceBatchesExec: target_batch_size=8192 +06)----------DataSourceExec: partitions=1, partition_sizes=[1] + +# Async udf can be used in projection +query I rowsort +select async_abs(x) from data; +---- +10 +2 + +query TT +explain select async_abs(x) from data; +---- +logical_plan +01)Projection: async_abs(data.x) +02)--TableScan: data projection=[x] +physical_plan +01)ProjectionExec: expr=[__async_fn_0@1 as async_abs(data.x)] +02)--AsyncFuncExec: async_expr=[async_expr(name=__async_fn_0, expr=async_abs(x@0))] +03)----CoalesceBatchesExec: target_batch_size=8192 +04)------DataSourceExec: partitions=1, partition_sizes=[1]