Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 48 additions & 1 deletion datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -741,9 +741,54 @@ impl DefaultPhysicalPlanner {
})
.collect::<Result<Vec<_>>>()?;

let (aggregates, filters, _order_bys): (Vec<_>, Vec<_>, Vec<_>) =
let (mut aggregates, filters, _order_bys): (Vec<_>, Vec<_>, Vec<_>) =
multiunzip(agg_filter);

let mut async_exprs = Vec::new();
let num_input_columns = physical_input_schema.fields().len();

for agg_func in &mut aggregates {
match self.try_plan_async_exprs(
num_input_columns,
PlannedExprResult::Expr(agg_func.expressions()),
physical_input_schema.as_ref(),
)? {
PlanAsyncExpr::Async(
async_map,
PlannedExprResult::Expr(physical_exprs),
) => {
async_exprs.extend(async_map.async_exprs);

if let Some(new_agg_func) = agg_func.with_new_expressions(
physical_exprs,
agg_func
.order_bys()
.iter()
.cloned()
.map(|x| x.expr)
.collect(),
) {
*agg_func = Arc::new(new_agg_func);
} else {
return internal_err!("Failed to plan async expression");
}
}
PlanAsyncExpr::Sync(PlannedExprResult::Expr(_)) => {
// Do nothing
}
_ => {
return internal_err!(
"Unexpected result from try_plan_async_exprs"
)
}
}
}
let input_exec = if !async_exprs.is_empty() {
Arc::new(AsyncFuncExec::try_new(async_exprs, input_exec)?)
} else {
input_exec
};

let initial_aggr = Arc::new(AggregateExec::try_new(
AggregateMode::Partial,
groups.clone(),
Expand Down Expand Up @@ -2272,11 +2317,13 @@ impl DefaultPhysicalPlanner {
}
}

#[derive(Debug)]
enum PlannedExprResult {
ExprWithName(Vec<(Arc<dyn PhysicalExpr>, String)>),
Expr(Vec<Arc<dyn PhysicalExpr>>),
}

#[derive(Debug)]
enum PlanAsyncExpr {
Sync(PlannedExprResult),
Async(AsyncMapper, PlannedExprResult),
Expand Down
61 changes: 58 additions & 3 deletions datafusion/sqllogictest/src/test_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.

use std::any::Any;
use std::collections::HashMap;
use std::fs::File;
use std::io::Write;
Expand All @@ -31,8 +32,13 @@ use arrow::record_batch::RecordBatch;
use datafusion::catalog::{
CatalogProvider, MemoryCatalogProvider, MemorySchemaProvider, Session,
};
use datafusion::common::DataFusionError;
use datafusion::logical_expr::{create_udf, ColumnarValue, Expr, ScalarUDF, Volatility};
use datafusion::common::{not_impl_err, DataFusionError, Result};
use datafusion::functions::math::abs;
use datafusion::logical_expr::async_udf::{AsyncScalarUDF, AsyncScalarUDFImpl};
use datafusion::logical_expr::{
create_udf, ColumnarValue, Expr, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl,
Signature, Volatility,
};
use datafusion::physical_plan::ExecutionPlan;
use datafusion::prelude::*;
use datafusion::{
Expand Down Expand Up @@ -133,6 +139,10 @@ impl TestContext {
info!("Registering table with union column");
register_union_table(test_ctx.session_ctx())
}
"async_udf.slt" => {
info!("Registering dummy async udf");
register_async_abs_udf(test_ctx.session_ctx())
}
_ => {
info!("Using default SessionContext");
}
Expand Down Expand Up @@ -235,7 +245,7 @@ pub async fn register_temp_table(ctx: &SessionContext) {

#[async_trait]
impl TableProvider for TestTable {
fn as_any(&self) -> &dyn std::any::Any {
fn as_any(&self) -> &dyn Any {
self
}

Expand Down Expand Up @@ -458,3 +468,48 @@ fn register_union_table(ctx: &SessionContext) {

ctx.register_batch("union_table", batch).unwrap();
}

fn register_async_abs_udf(ctx: &SessionContext) {
#[derive(Debug, PartialEq, Eq, Hash)]
struct AsyncAbs {
inner_abs: Arc<ScalarUDF>,
}
impl AsyncAbs {
fn new() -> Self {
AsyncAbs { inner_abs: abs() }
}
}
impl ScalarUDFImpl for AsyncAbs {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
"async_abs"
}

fn signature(&self) -> &Signature {
self.inner_abs.signature()
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
self.inner_abs.return_type(arg_types)
}

fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
not_impl_err!("{} can only be called from async contexts", self.name())
}
}
#[async_trait]
impl AsyncScalarUDFImpl for AsyncAbs {
async fn invoke_async_with_args(
&self,
args: ScalarFunctionArgs,
) -> Result<ColumnarValue> {
return self.inner_abs.invoke_with_args(args);
}
}
let async_abs = AsyncAbs::new();
let udf = AsyncScalarUDF::new(Arc::new(async_abs));
ctx.register_udf(udf.into_scalar_udf());
}
107 changes: 107 additions & 0 deletions datafusion/sqllogictest/test_files/async_udf.slt
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

statement ok
create table data(x int) as values (-10), (2);

# Async udf can be used in aggregation
query I
select min(async_abs(x)) from data;
----
2

query TT
explain select min(async_abs(x)) from data;
----
logical_plan
01)Aggregate: groupBy=[[]], aggr=[[min(async_abs(data.x))]]
02)--TableScan: data projection=[x]
physical_plan
01)AggregateExec: mode=Final, gby=[], aggr=[min(async_abs(data.x))]
02)--CoalescePartitionsExec
03)----AggregateExec: mode=Partial, gby=[], aggr=[min(async_abs(data.x))]
04)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
05)--------AsyncFuncExec: async_expr=[async_expr(name=__async_fn_0, expr=async_abs(x@0))]
06)----------CoalesceBatchesExec: target_batch_size=8192
07)------------DataSourceExec: partitions=1, partition_sizes=[1]

# Async udf can be used in aggregation with group by
query I rowsort
select min(async_abs(x)) from data group by async_abs(x);
----
10
2

query TT
explain select min(async_abs(x)) from data group by async_abs(x);
----
logical_plan
01)Projection: min(async_abs(data.x))
02)--Aggregate: groupBy=[[__common_expr_1 AS async_abs(data.x)]], aggr=[[min(__common_expr_1 AS async_abs(data.x))]]
03)----Projection: async_abs(data.x) AS __common_expr_1
04)------TableScan: data projection=[x]
physical_plan
01)ProjectionExec: expr=[min(async_abs(data.x))@1 as min(async_abs(data.x))]
02)--AggregateExec: mode=FinalPartitioned, gby=[async_abs(data.x)@0 as async_abs(data.x)], aggr=[min(async_abs(data.x))]
03)----CoalesceBatchesExec: target_batch_size=8192
04)------RepartitionExec: partitioning=Hash([async_abs(data.x)@0], 4), input_partitions=4
05)--------AggregateExec: mode=Partial, gby=[__common_expr_1@0 as async_abs(data.x)], aggr=[min(async_abs(data.x))]
06)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
07)------------ProjectionExec: expr=[__async_fn_0@1 as __common_expr_1]
08)--------------AsyncFuncExec: async_expr=[async_expr(name=__async_fn_0, expr=async_abs(x@0))]
09)----------------CoalesceBatchesExec: target_batch_size=8192
10)------------------DataSourceExec: partitions=1, partition_sizes=[1]

# Async udf can be used in filter
query I
select * from data where async_abs(x) < 5;
----
2

query TT
explain select * from data where async_abs(x) < 5;
----
logical_plan
01)Filter: async_abs(data.x) < Int32(5)
02)--TableScan: data projection=[x]
physical_plan
01)CoalesceBatchesExec: target_batch_size=8192
02)--FilterExec: __async_fn_0@1 < 5, projection=[x@0]
03)----RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
04)------AsyncFuncExec: async_expr=[async_expr(name=__async_fn_0, expr=async_abs(x@0))]
05)--------CoalesceBatchesExec: target_batch_size=8192
06)----------DataSourceExec: partitions=1, partition_sizes=[1]

# Async udf can be used in projection
query I rowsort
select async_abs(x) from data;
----
10
2

query TT
explain select async_abs(x) from data;
----
logical_plan
01)Projection: async_abs(data.x)
02)--TableScan: data projection=[x]
physical_plan
01)ProjectionExec: expr=[__async_fn_0@1 as async_abs(data.x)]
02)--AsyncFuncExec: async_expr=[async_expr(name=__async_fn_0, expr=async_abs(x@0))]
03)----CoalesceBatchesExec: target_batch_size=8192
04)------DataSourceExec: partitions=1, partition_sizes=[1]