From 65c99afcd97b3e2cb625e83e29279755ef77d238 Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Sun, 23 Feb 2025 16:38:14 +0800 Subject: [PATCH 01/20] introduce async udf for projection --- Cargo.lock | 1 + datafusion/core/src/physical_planner.rs | 33 +- datafusion/expr/Cargo.toml | 1 + datafusion/expr/src/async_udf.rs | 136 ++++++++ datafusion/expr/src/lib.rs | 1 + datafusion/expr/src/udf.rs | 8 + .../src/async_scalar_function.rs | 226 ++++++++++++++ datafusion/physical-expr/src/lib.rs | 1 + datafusion/physical-plan/src/async_func.rs | 291 ++++++++++++++++++ datafusion/physical-plan/src/lib.rs | 1 + 10 files changed, 695 insertions(+), 4 deletions(-) create mode 100644 datafusion/expr/src/async_udf.rs create mode 100644 datafusion/physical-expr/src/async_scalar_function.rs create mode 100644 datafusion/physical-plan/src/async_func.rs diff --git a/Cargo.lock b/Cargo.lock index 7147548b4b2e..2ce381711cd6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2215,6 +2215,7 @@ name = "datafusion-expr" version = "48.0.0" dependencies = [ "arrow", + "async-trait", "chrono", "ctor", "datafusion-common", diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index c1121d59bb3f..74d550c194f1 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -96,6 +96,8 @@ use datafusion_physical_plan::unnest::ListUnnest; use sqlparser::ast::NullTreatment; use async_trait::async_trait; +use datafusion_datasource::file_groups::FileGroup; +use datafusion_physical_plan::async_func::{AsyncFuncExec, AsyncMapper}; use futures::{StreamExt, TryStreamExt}; use itertools::{multiunzip, Itertools}; use log::{debug, trace}; @@ -2044,10 +2046,33 @@ impl DefaultPhysicalPlanner { }) .collect::>>()?; - Ok(Arc::new(ProjectionExec::try_new( - physical_exprs, - input_exec, - )?)) + let num_input_columns = input_exec.schema().fields().len(); + let mut async_map = AsyncMapper::new(num_input_columns); + physical_exprs + .iter() + .try_for_each(|(expr, _column_name)| async_map.find_references(expr))?; + + // If there are no async expressions, we can create a ProjectionExec + if async_map.is_empty() { + return Ok(Arc::new(ProjectionExec::try_new( + physical_exprs, + input_exec, + )?)); + } + + // rewrite the projection's expressions in terms of the columns with the result of async evaluation + let new_exprs = physical_exprs + .iter() + .map(|(expr, column_name)| { + let new_expr = + Arc::clone(expr).transform_up(|e| Ok(async_map.map_expr(e)))?; + Ok((new_expr.data, column_name.to_string())) + }) + .collect::>()?; + + let async_exec = AsyncFuncExec::try_new(async_map.async_exprs, input_exec)?; + let new_proj_exec = ProjectionExec::try_new(new_exprs, Arc::new(async_exec))?; + Ok(Arc::new(new_proj_exec)) } } diff --git a/datafusion/expr/Cargo.toml b/datafusion/expr/Cargo.toml index d77c59ff64e1..766b412bbe96 100644 --- a/datafusion/expr/Cargo.toml +++ b/datafusion/expr/Cargo.toml @@ -54,6 +54,7 @@ paste = "^1.0" recursive = { workspace = true, optional = true } serde_json = { workspace = true } sqlparser = { workspace = true } +async-trait = "0.1.86" [dev-dependencies] ctor = { workspace = true } diff --git a/datafusion/expr/src/async_udf.rs b/datafusion/expr/src/async_udf.rs new file mode 100644 index 000000000000..13f95cd1d82f --- /dev/null +++ b/datafusion/expr/src/async_udf.rs @@ -0,0 +1,136 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::{ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl}; +use arrow::array::ArrayRef; +use arrow::datatypes::{DataType, SchemaRef}; +use async_trait::async_trait; +use datafusion_common::config::ConfigOptions; +use datafusion_common::error::Result; +use datafusion_common::internal_err; +use datafusion_expr_common::columnar_value::ColumnarValue; +use datafusion_expr_common::signature::Signature; +use std::any::Any; +use std::fmt::{Debug, Display}; +use std::sync::Arc; + +/// A scalar UDF that can invoke using async methods +/// +/// Note this is less efficient than the ScalarUDFImpl, but it can be used +/// to register remote functions in the context. +/// +/// The name is chosen to mirror ScalarUDFImpl +#[async_trait] +pub trait AsyncScalarUDFImpl: Debug + Send + Sync { + /// the function cast as any + fn as_any(&self) -> &dyn Any; + + /// The name of the function + fn name(&self) -> &str; + + /// The signature of the function + fn signature(&self) -> &Signature; + + /// The return type of the function + fn return_type(&self, _arg_types: &[DataType]) -> Result; + + /// The ideal batch size for this function. + fn ideal_batch_size(&self) -> Option { + None + } + + /// Invoke the function asynchronously with the async arguments + async fn invoke_async_with_args( + &self, + args: AsyncScalarFunctionArgs, + option: &ConfigOptions, + ) -> Result; +} + +/// A scalar UDF that must be invoked using async methods +/// +/// Note this is not meant to be used directly, but is meant to be an implementation detail +/// for AsyncUDFImpl. +/// +/// This is used to register remote functions in the context. The function +/// should not be invoked by DataFusion. It's only used to generate the logical +/// plan and unparsed them to SQL. +#[derive(Debug)] +pub struct AsyncScalarUDF { + inner: Arc, +} + +impl AsyncScalarUDF { + pub fn new(inner: Arc) -> Self { + Self { inner } + } + + /// The ideal batch size for this function + pub fn ideal_batch_size(&self) -> Option { + self.inner.ideal_batch_size() + } + + /// Turn this AsyncUDF into a ScalarUDF, suitable for + /// registering in the context + pub fn into_scalar_udf(self) -> Arc { + Arc::new(ScalarUDF::new_from_impl(self)) + } + + /// Invoke the function asynchronously with the async arguments + pub async fn invoke_async_with_args( + &self, + args: AsyncScalarFunctionArgs, + option: &ConfigOptions, + ) -> Result { + self.inner.invoke_async_with_args(args, option).await + } +} + +impl ScalarUDFImpl for AsyncScalarUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + self.inner.name() + } + + fn signature(&self) -> &Signature { + self.inner.signature() + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + self.inner.return_type(_arg_types) + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + internal_err!("async functions should not be called directly") + } +} + +impl Display for AsyncScalarUDF { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "AsyncScalarUDF: {}", self.inner.name()) + } +} + +#[derive(Debug)] +pub struct AsyncScalarFunctionArgs { + pub args: Vec, + pub number_rows: usize, + pub schema: SchemaRef, +} diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 1f44f755b214..0c822bbb337b 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -63,6 +63,7 @@ pub mod simplify; pub mod sort_properties { pub use datafusion_expr_common::sort_properties::*; } +pub mod async_udf; pub mod statistics { pub use datafusion_expr_common::statistics::*; } diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 7139816c10e7..84ae8eecd555 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -17,6 +17,7 @@ //! [`ScalarUDF`]: Scalar User Defined Functions +use crate::async_udf::AsyncScalarUDF; use crate::expr::schema_name_from_exprs_comma_separated_without_space; use crate::simplify::{ExprSimplifyResult, SimplifyInfo}; use crate::sort_properties::{ExprProperties, SortProperties}; @@ -280,6 +281,13 @@ impl ScalarUDF { pub fn documentation(&self) -> Option<&Documentation> { self.inner.documentation() } + + /// Return true if this function is an async function + pub fn as_async(&self) -> Option<&AsyncScalarUDF> { + self.inner() + .as_any() + .downcast_ref::() + } } impl From for ScalarUDF diff --git a/datafusion/physical-expr/src/async_scalar_function.rs b/datafusion/physical-expr/src/async_scalar_function.rs new file mode 100644 index 000000000000..fad40f359db4 --- /dev/null +++ b/datafusion/physical-expr/src/async_scalar_function.rs @@ -0,0 +1,226 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use crate::ScalarFunctionExpr; +use arrow::array::{make_array, MutableArrayData, RecordBatch}; +use arrow::datatypes::{DataType, Field, Schema}; +use datafusion_common::config::ConfigOptions; +use datafusion_common::Result; +use datafusion_common::{internal_err, not_impl_err}; +use datafusion_expr::async_udf::{AsyncScalarFunctionArgs, AsyncScalarUDF}; +use datafusion_expr_common::columnar_value::ColumnarValue; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use std::fmt::Display; +use std::hash::{Hash, Hasher}; +use std::sync::Arc; + +/// Wrapper for a Async function that can be used in a DataFusion query +#[derive(Debug, Clone, Eq)] +pub struct AsyncFuncExpr { + /// The name of the output column this function will generate + pub name: String, + /// The actual function (always `ScalarFunctionExpr`) + pub func: Arc, +} + +impl Display for AsyncFuncExpr { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "async_expr(name={}, expr={})", self.name, self.func) + } +} + +impl PartialEq for AsyncFuncExpr { + fn eq(&self, other: &Self) -> bool { + self.name == other.name && self.func == Arc::clone(&other.func) + } +} + +impl Hash for AsyncFuncExpr { + fn hash(&self, state: &mut H) { + self.name.hash(state); + self.func.as_ref().hash(state); + } +} + +impl AsyncFuncExpr { + /// create a new AsyncFuncExpr + pub fn try_new(name: impl Into, func: Arc) -> Result { + + let Some(_) = func.as_any().downcast_ref::() else { + return internal_err!( + "unexpected function type, expected ScalarFunctionExpr, got: {:?}", + func + ); + }; + + + Ok(Self { + name: name.into(), + func, + }) + } + + /// return the name of the output column + pub fn name(&self) -> &str { + &self.name + } + + /// Return the output field generated by evaluating this function + pub fn field(&self, input_schema: &Schema) -> Result { + Ok(Field::new( + &self.name, + self.func.data_type(input_schema)?, + self.func.nullable(input_schema)?, + )) + } + + /// Return the ideal batch size for this function + pub fn ideal_batch_size(&self) -> Result> { + if let Some(expr) = self.func.as_any().downcast_ref::() { + if let Some(udf) = + expr.fun().inner().as_any().downcast_ref::() + { + return Ok(udf.ideal_batch_size()); + } + } + not_impl_err!("Can't get ideal_batch_size from {:?}", self.func) + } + + /// This (async) function is called for each record batch to evaluate the LLM expressions + /// + /// The output is the output of evaluating the async expression and the input record batch + pub async fn invoke_with_args( + &self, + batch: &RecordBatch, + option: &ConfigOptions, + ) -> Result { + let Some(llm_function) = self.func.as_any().downcast_ref::() + else { + return internal_err!( + "unexpected function type, expected ScalarFunctionExpr, got: {:?}", + self.func + ); + }; + + let Some(async_udf) = llm_function + .fun() + .inner() + .as_any() + .downcast_ref::() + else { + return not_impl_err!( + "Don't know how to evaluate async function: {:?}", + llm_function + ); + }; + + let mut result_batches = vec![]; + if let Some(ideal_batch_size) = self.ideal_batch_size()? { + let mut remainder = batch.clone(); + while remainder.num_rows() > 0 { + let size = if ideal_batch_size > remainder.num_rows() { + remainder.num_rows() + } else { + ideal_batch_size + }; + + let current_batch = remainder.slice(0, size); // get next 10 rows + remainder = remainder.slice(size, remainder.num_rows() - size); + let args = llm_function + .args() + .iter() + .map(|e| e.evaluate(¤t_batch)) + .collect::>>()?; + result_batches.push( + async_udf + .invoke_async_with_args( + AsyncScalarFunctionArgs { + args: args.to_vec(), + number_rows: current_batch.num_rows(), + schema: current_batch.schema(), + }, + option, + ) + .await?, + ); + } + } else { + let args = llm_function + .args() + .iter() + .map(|e| e.evaluate(batch)) + .collect::>>()?; + + result_batches.push( + async_udf + .invoke_async_with_args( + AsyncScalarFunctionArgs { + args: args.to_vec(), + number_rows: batch.num_rows(), + schema: batch.schema(), + }, + option, + ) + .await?, + ); + } + + let datas = result_batches + .iter() + .map(|b| b.to_data()) + .collect::>(); + let total_len = datas.iter().map(|d| d.len()).sum(); + let mut mutable = MutableArrayData::new(datas.iter().collect(), false, total_len); + datas.iter().enumerate().for_each(|(i, data)| { + mutable.extend(i, 0, data.len()); + }); + let array_ref = make_array(mutable.freeze()); + Ok(ColumnarValue::Array(array_ref)) + } +} + +impl PhysicalExpr for AsyncFuncExpr { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, _input_schema: &Schema) -> Result { + self.func.data_type(_input_schema) + } + + fn nullable(&self, _input_schema: &Schema) -> Result { + self.func.nullable(_input_schema) + } + + fn evaluate(&self, _batch: &RecordBatch) -> Result { + // TODO: implement this + not_impl_err!("AsyncFuncExpr.evaluate") + } + + fn children(&self) -> Vec<&Arc> { + self.func.children() + } + + fn with_new_children(self: Arc, children: Vec>) -> Result> { + let new_func = Arc::clone(&self.func).with_new_children(children)?; + Ok(Arc::new(AsyncFuncExpr { + name: self.name.clone(), + func: new_func, + })) + } +} diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index 6741f94c9545..be60e26cc2d2 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -30,6 +30,7 @@ pub mod analysis; pub mod binary_map { pub use datafusion_physical_expr_common::binary_map::{ArrowBytesSet, OutputType}; } +pub mod async_scalar_function; pub mod equivalence; pub mod expressions; pub mod intervals; diff --git a/datafusion/physical-plan/src/async_func.rs b/datafusion/physical-plan/src/async_func.rs new file mode 100644 index 000000000000..6363d4b3ee16 --- /dev/null +++ b/datafusion/physical-plan/src/async_func.rs @@ -0,0 +1,291 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet}; +use crate::stream::RecordBatchStreamAdapter; +use crate::{ + DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties, +}; +use arrow::array::RecordBatch; +use arrow_schema::{Fields, Schema, SchemaRef}; +use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; +use datafusion_common::Result; +use datafusion_execution::{SendableRecordBatchStream, TaskContext}; +use datafusion_physical_expr::async_scalar_function::AsyncFuncExpr; +use datafusion_physical_expr::equivalence::ProjectionMapping; +use datafusion_physical_expr::expressions::Column; +use datafusion_physical_expr::ScalarFunctionExpr; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use futures::stream::StreamExt; +use log::trace; +use std::any::Any; +use std::sync::Arc; + +/// This structure evaluates a set of async expressions on a record +/// batch producing a new record batch +/// +/// This is similar to a ProjectionExec except that the functions can be async +/// +/// The schema of the output of the AsyncFuncExec is: +/// Input columns followed by one column for each async expression +#[derive(Debug)] +pub struct AsyncFuncExec { + /// The async expressions to evaluate + async_exprs: Vec>, + input: Arc, + /// Cache holding plan properties like equivalences, output partitioning etc. + cache: PlanProperties, + metrics: ExecutionPlanMetricsSet, +} + +impl AsyncFuncExec { + pub fn try_new( + async_exprs: Vec>, + input: Arc, + ) -> Result { + + let async_fields = async_exprs.iter().map(|async_expr| { + async_expr.field(input.schema().as_ref()) + }).collect::>>()?; + + // compute the output schema: input schema then async expressions + let fields: Fields = + input + .schema() + .fields() + .iter() + .cloned() + .chain(async_fields.into_iter().map(Arc::new)) + .collect(); + + let schema = Arc::new(Schema::new(fields)); + let tuples = async_exprs + .iter() + .map(|expr| (Arc::clone(&expr.func), expr.name().to_string())) + .collect::>(); + let async_expr_mapping = ProjectionMapping::try_new(&tuples, &input.schema())?; + let cache = + AsyncFuncExec::compute_properties(&input, schema, &async_expr_mapping)?; + Ok(Self { + input, + async_exprs, + cache, + metrics: ExecutionPlanMetricsSet::new(), + }) + } + + /// This function creates the cache object that stores the plan properties + /// such as schema, equivalence properties, ordering, partitioning, etc. + fn compute_properties( + input: &Arc, + schema: SchemaRef, + async_expr_mapping: &ProjectionMapping, + ) -> Result { + Ok(PlanProperties::new( + input + .equivalence_properties() + .project(async_expr_mapping, schema), + input.output_partitioning().clone(), + input.pipeline_behavior(), + input.boundedness(), + )) + } +} + +impl DisplayAs for AsyncFuncExec { + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + let expr: Vec = self + .async_exprs + .iter() + .map(|async_expr| async_expr.to_string()) + .collect(); + + write!(f, "AsyncFuncExec: async_expr=[{}]", expr.join(", ")) + } + } + } +} + +impl ExecutionPlan for AsyncFuncExec { + fn name(&self) -> &str { + "async_func" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + &self.cache + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.input] + } + + fn with_new_children( + self: Arc, + _children: Vec>, + ) -> Result> { + Ok(Arc::new(AsyncFuncExec::try_new( + self.async_exprs.clone(), + Arc::clone(&self.input), + )?)) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + trace!( + "Start AsyncFuncExpr::execute for partition {} of context session_id {} and task_id {:?}", + partition, + context.session_id(), + context.task_id() + ); + // TODO figure out how to record metrics + + // first execute the input stream + let input_stream = self.input.execute(partition, Arc::clone(&context))?; + + // now, for each record batch, evaluate the async expressions and add the columns to the result + let async_exprs_captured = Arc::new(self.async_exprs.clone()); + let schema_captured = self.schema(); + let config_option_ref = Arc::new(context.session_config().options().clone()); + + let stream_with_async_functions = input_stream.then(move |batch| { + // need to clone *again* to capture the async_exprs and schema in the + // stream and satisfy lifetime requirements. + let async_exprs_captured = Arc::clone(&async_exprs_captured); + let schema_captured = Arc::clone(&schema_captured); + let config_option = Arc::clone(&config_option_ref); + + async move { + let batch = batch?; + // append the result of evaluating the async expressions to the output + let mut output_arrays = batch.columns().to_vec(); + for async_expr in async_exprs_captured.iter() { + let output = + async_expr.invoke_with_args(&batch, &config_option).await?; + output_arrays.push(output.to_array(batch.num_rows())?); + } + let batch = RecordBatch::try_new(schema_captured, output_arrays)?; + Ok(batch) + } + }); + + // Adapt the stream with the output schema + let adapter = + RecordBatchStreamAdapter::new(self.schema(), stream_with_async_functions); + Ok(Box::pin(adapter)) + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } +} + +const ASYNC_FN_PREFIX: &str = "__async_fn_"; + +/// Maps async_expressions to new columns +/// +/// The output of the async functions are appended, in order, to the end of the input schema +#[derive(Debug)] +pub struct AsyncMapper { + /// the number of columns in the input plan + /// used to generate the output column names. + /// the first async expr is `__async_fn_0`, the second is `__async_fn_1`, etc + num_input_columns: usize, + /// the expressions to map + pub async_exprs: Vec>, +} + +impl AsyncMapper { + pub fn new(num_input_columns: usize) -> Self { + Self { + num_input_columns, + async_exprs: Vec::new(), + } + } + + pub fn is_empty(&self) -> bool { + self.async_exprs.is_empty() + } + + pub fn next_column_name(&self) -> String { + format!("{}{}", ASYNC_FN_PREFIX, self.async_exprs.len()) + } + + /// Finds any references to async functions in the expression and adds them to the map + pub fn find_references( + &mut self, + physical_expr: &Arc, + ) -> Result<()> { + // recursively look for references to async functions + physical_expr.apply(|expr| { + if let Some(scalar_func_expr) = expr.as_any().downcast_ref::() { + if scalar_func_expr.fun().as_async().is_some() { + let next_name = self.next_column_name(); + self.async_exprs + .push(Arc::new(AsyncFuncExpr::try_new(next_name, Arc::clone(expr))?)); + } + } + Ok(TreeNodeRecursion::Continue) + })?; + Ok(()) + } + + /// If the expression matches any of the async functions, return the new column + pub fn map_expr( + &self, + expr: Arc, + ) -> Transformed> { + // find the first matching async function if any + let Some(idx) = + self.async_exprs + .iter() + .enumerate() + .find_map( + |(idx, async_expr)| { + if async_expr.func == Arc::clone(&expr) { + Some(idx) + } else { + None + } + }, + ) + else { + return Transformed::no(expr); + }; + // rewrite in terms of the output column + Transformed::yes(self.output_column(idx)) + } + + /// return the output column for the async function at index idx + pub fn output_column(&self, idx: usize) -> Arc { + let async_expr = &self.async_exprs[idx]; + let output_idx = self.num_input_columns + idx; + Arc::new(Column::new(async_expr.name(), output_idx)) + } +} diff --git a/datafusion/physical-plan/src/lib.rs b/datafusion/physical-plan/src/lib.rs index 22ae859e8c5b..a9cca2cbc2ac 100644 --- a/datafusion/physical-plan/src/lib.rs +++ b/datafusion/physical-plan/src/lib.rs @@ -60,6 +60,7 @@ mod visitor; pub mod aggregates; pub mod analyze; pub mod coalesce; +pub mod async_func; pub mod coalesce_batches; pub mod coalesce_partitions; pub mod common; From 0902e8ecfc90efc2925474b97d1f1567eb29dccd Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Sun, 23 Feb 2025 17:56:14 +0800 Subject: [PATCH 02/20] refactor for filter --- datafusion/core/src/physical_planner.rs | 122 +++++++++++++++--- datafusion/expr/src/udf.rs | 4 +- .../src/async_scalar_function.rs | 11 +- datafusion/physical-plan/src/async_func.rs | 49 +++---- 4 files changed, 133 insertions(+), 53 deletions(-) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 74d550c194f1..37896b28c391 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -779,12 +779,39 @@ impl DefaultPhysicalPlanner { let runtime_expr = self.create_physical_expr(predicate, input_dfschema, session_state)?; + + let filter = match self.try_plan_async_exprs( + input.schema().fields().len(), + PlannedExprResult::Expr(vec![runtime_expr]), + )? { + PlanAsyncExpr::Sync(PlannedExprResult::Expr(runtime_expr)) => { + FilterExec::try_new(Arc::clone(&runtime_expr[0]), physical_input)? + } + PlanAsyncExpr::Async( + async_map, + PlannedExprResult::Expr(runtime_expr), + ) => { + let async_exec = AsyncFuncExec::try_new( + async_map.async_exprs, + physical_input, + )?; + FilterExec::try_new( + Arc::clone(&runtime_expr[0]), + Arc::new(async_exec), + )? + } + _ => { + return internal_err!( + "Unexpected result from try_plan_async_exprs" + ) + } + }; + let selectivity = session_state .config() .options() .optimizer .default_filter_selectivity; - let filter = FilterExec::try_new(runtime_expr, physical_input)?; Arc::new(filter.with_default_selectivity(selectivity)?) } LogicalPlan::Repartition(Repartition { @@ -2047,35 +2074,88 @@ impl DefaultPhysicalPlanner { .collect::>>()?; let num_input_columns = input_exec.schema().fields().len(); + + match self.try_plan_async_exprs( + num_input_columns, + PlannedExprResult::ExprWithName(physical_exprs), + )? { + PlanAsyncExpr::Sync(PlannedExprResult::ExprWithName(physical_exprs)) => Ok( + Arc::new(ProjectionExec::try_new(physical_exprs, input_exec)?), + ), + PlanAsyncExpr::Async( + async_map, + PlannedExprResult::ExprWithName(physical_exprs), + ) => { + let async_exec = + AsyncFuncExec::try_new(async_map.async_exprs, input_exec)?; + let new_proj_exec = + ProjectionExec::try_new(physical_exprs, Arc::new(async_exec))?; + Ok(Arc::new(new_proj_exec)) + } + _ => internal_err!("Unexpected PlanAsyncExpressions variant"), + } + } + + fn try_plan_async_exprs( + &self, + num_input_columns: usize, + physical_expr: PlannedExprResult, + ) -> Result { let mut async_map = AsyncMapper::new(num_input_columns); - physical_exprs - .iter() - .try_for_each(|(expr, _column_name)| async_map.find_references(expr))?; + match &physical_expr { + PlannedExprResult::ExprWithName(exprs) => { + exprs + .iter() + .try_for_each(|(expr, _)| async_map.find_references(expr))?; + } + PlannedExprResult::Expr(exprs) => { + exprs + .iter() + .try_for_each(|expr| async_map.find_references(expr))?; + } + } - // If there are no async expressions, we can create a ProjectionExec if async_map.is_empty() { - return Ok(Arc::new(ProjectionExec::try_new( - physical_exprs, - input_exec, - )?)); + return Ok(PlanAsyncExpr::Sync(physical_expr)); } + let new_exprs = match physical_expr { + PlannedExprResult::ExprWithName(exprs) => PlannedExprResult::ExprWithName( + exprs + .iter() + .map(|(expr, column_name)| { + let new_expr = Arc::clone(expr) + .transform_up(|e| Ok(async_map.map_expr(e)))?; + Ok((new_expr.data, column_name.to_string())) + }) + .collect::>()?, + ), + PlannedExprResult::Expr(exprs) => PlannedExprResult::Expr( + exprs + .iter() + .map(|expr| { + let new_expr = Arc::clone(expr) + .transform_up(|e| Ok(async_map.map_expr(e)))?; + Ok(new_expr.data) + }) + .collect::>()?, + ), + }; // rewrite the projection's expressions in terms of the columns with the result of async evaluation - let new_exprs = physical_exprs - .iter() - .map(|(expr, column_name)| { - let new_expr = - Arc::clone(expr).transform_up(|e| Ok(async_map.map_expr(e)))?; - Ok((new_expr.data, column_name.to_string())) - }) - .collect::>()?; - - let async_exec = AsyncFuncExec::try_new(async_map.async_exprs, input_exec)?; - let new_proj_exec = ProjectionExec::try_new(new_exprs, Arc::new(async_exec))?; - Ok(Arc::new(new_proj_exec)) + Ok(PlanAsyncExpr::Async(async_map, new_exprs)) } } +enum PlannedExprResult { + ExprWithName(Vec<(Arc, String)>), + Expr(Vec>), +} + +enum PlanAsyncExpr { + Sync(PlannedExprResult), + Async(AsyncMapper, PlannedExprResult), +} + fn tuple_err(value: (Result, Result)) -> Result<(T, R)> { match value { (Ok(e), Ok(e1)) => Ok((e, e1)), diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 84ae8eecd555..8b6ffba04ff6 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -284,9 +284,7 @@ impl ScalarUDF { /// Return true if this function is an async function pub fn as_async(&self) -> Option<&AsyncScalarUDF> { - self.inner() - .as_any() - .downcast_ref::() + self.inner().as_any().downcast_ref::() } } diff --git a/datafusion/physical-expr/src/async_scalar_function.rs b/datafusion/physical-expr/src/async_scalar_function.rs index fad40f359db4..609799c4f9bf 100644 --- a/datafusion/physical-expr/src/async_scalar_function.rs +++ b/datafusion/physical-expr/src/async_scalar_function.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; use crate::ScalarFunctionExpr; use arrow::array::{make_array, MutableArrayData, RecordBatch}; use arrow::datatypes::{DataType, Field, Schema}; @@ -25,6 +24,7 @@ use datafusion_common::{internal_err, not_impl_err}; use datafusion_expr::async_udf::{AsyncScalarFunctionArgs, AsyncScalarUDF}; use datafusion_expr_common::columnar_value::ColumnarValue; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use std::any::Any; use std::fmt::Display; use std::hash::{Hash, Hasher}; use std::sync::Arc; @@ -60,15 +60,13 @@ impl Hash for AsyncFuncExpr { impl AsyncFuncExpr { /// create a new AsyncFuncExpr pub fn try_new(name: impl Into, func: Arc) -> Result { - - let Some(_) = func.as_any().downcast_ref::() else { + let Some(_) = func.as_any().downcast_ref::() else { return internal_err!( "unexpected function type, expected ScalarFunctionExpr, got: {:?}", func ); }; - Ok(Self { name: name.into(), func, @@ -216,7 +214,10 @@ impl PhysicalExpr for AsyncFuncExpr { self.func.children() } - fn with_new_children(self: Arc, children: Vec>) -> Result> { + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { let new_func = Arc::clone(&self.func).with_new_children(children)?; Ok(Arc::new(AsyncFuncExpr { name: self.name.clone(), diff --git a/datafusion/physical-plan/src/async_func.rs b/datafusion/physical-plan/src/async_func.rs index 6363d4b3ee16..c70cc4b161c9 100644 --- a/datafusion/physical-plan/src/async_func.rs +++ b/datafusion/physical-plan/src/async_func.rs @@ -57,20 +57,19 @@ impl AsyncFuncExec { async_exprs: Vec>, input: Arc, ) -> Result { - - let async_fields = async_exprs.iter().map(|async_expr| { - async_expr.field(input.schema().as_ref()) - }).collect::>>()?; + let async_fields = async_exprs + .iter() + .map(|async_expr| async_expr.field(input.schema().as_ref())) + .collect::>>()?; // compute the output schema: input schema then async expressions - let fields: Fields = - input - .schema() - .fields() - .iter() - .cloned() - .chain(async_fields.into_iter().map(Arc::new)) - .collect(); + let fields: Fields = input + .schema() + .fields() + .iter() + .cloned() + .chain(async_fields.into_iter().map(Arc::new)) + .collect(); let schema = Arc::new(Schema::new(fields)); let tuples = async_exprs @@ -244,11 +243,15 @@ impl AsyncMapper { ) -> Result<()> { // recursively look for references to async functions physical_expr.apply(|expr| { - if let Some(scalar_func_expr) = expr.as_any().downcast_ref::() { + if let Some(scalar_func_expr) = + expr.as_any().downcast_ref::() + { if scalar_func_expr.fun().as_async().is_some() { let next_name = self.next_column_name(); - self.async_exprs - .push(Arc::new(AsyncFuncExpr::try_new(next_name, Arc::clone(expr))?)); + self.async_exprs.push(Arc::new(AsyncFuncExpr::try_new( + next_name, + Arc::clone(expr), + )?)); } } Ok(TreeNodeRecursion::Continue) @@ -266,15 +269,13 @@ impl AsyncMapper { self.async_exprs .iter() .enumerate() - .find_map( - |(idx, async_expr)| { - if async_expr.func == Arc::clone(&expr) { - Some(idx) - } else { - None - } - }, - ) + .find_map(|(idx, async_expr)| { + if async_expr.func == Arc::clone(&expr) { + Some(idx) + } else { + None + } + }) else { return Transformed::no(expr); }; From ca00e721aa0050b010932ff87ce6f67b523dcf90 Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Sun, 23 Feb 2025 20:21:53 +0800 Subject: [PATCH 03/20] coalesce_batches for AsyncFuncExec --- datafusion/physical-optimizer/src/coalesce_batches.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/datafusion/physical-optimizer/src/coalesce_batches.rs b/datafusion/physical-optimizer/src/coalesce_batches.rs index 5cf2c877c61a..5e7fe94e693d 100644 --- a/datafusion/physical-optimizer/src/coalesce_batches.rs +++ b/datafusion/physical-optimizer/src/coalesce_batches.rs @@ -31,6 +31,7 @@ use datafusion_physical_plan::{ }; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_physical_plan::async_func::AsyncFuncExec; /// Optimizer rule that introduces CoalesceBatchesExec to avoid overhead with small batches that /// are produced by highly selective filters @@ -62,6 +63,7 @@ impl PhysicalOptimizerRule for CoalesceBatches { // See https://github.com/apache/datafusion/issues/139 let wrap_in_coalesce = plan_any.downcast_ref::().is_some() || plan_any.downcast_ref::().is_some() + || plan_any.downcast_ref::().is_some() // Don't need to add CoalesceBatchesExec after a round robin RepartitionExec || plan_any .downcast_ref::() From e06e4bb2e25d06ee1db75dd03a80752d3fcf3c0b Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Sun, 23 Feb 2025 20:35:57 +0800 Subject: [PATCH 04/20] project filter to exclude the filter expression --- datafusion/core/src/physical_planner.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 37896b28c391..70ad782609e4 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -799,6 +799,11 @@ impl DefaultPhysicalPlanner { Arc::clone(&runtime_expr[0]), Arc::new(async_exec), )? + // project the output columns excluding the async functions + // The async functions are always appended to the end of the schema. + .with_projection(Some( + (0..input.schema().fields().len()).collect(), + ))? } _ => { return internal_err!( From f09337ff5311cc3e38e0eaff198d220c88d0c073 Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Sun, 23 Feb 2025 21:02:39 +0800 Subject: [PATCH 05/20] coalesce the input batch of AsyncFuncExec --- .../src/coalesce_async_exec_input.rs | 71 +++++++++++++++++++ .../src/coalesce_batches.rs | 2 - datafusion/physical-optimizer/src/lib.rs | 1 + .../physical-optimizer/src/optimizer.rs | 2 + datafusion/physical-plan/src/async_func.rs | 9 ++- 5 files changed, 80 insertions(+), 5 deletions(-) create mode 100644 datafusion/physical-optimizer/src/coalesce_async_exec_input.rs diff --git a/datafusion/physical-optimizer/src/coalesce_async_exec_input.rs b/datafusion/physical-optimizer/src/coalesce_async_exec_input.rs new file mode 100644 index 000000000000..0b46c68f2dae --- /dev/null +++ b/datafusion/physical-optimizer/src/coalesce_async_exec_input.rs @@ -0,0 +1,71 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::PhysicalOptimizerRule; +use datafusion_common::config::ConfigOptions; +use datafusion_common::internal_err; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_physical_plan::async_func::AsyncFuncExec; +use datafusion_physical_plan::coalesce_batches::CoalesceBatchesExec; +use datafusion_physical_plan::ExecutionPlan; +use std::sync::Arc; + +/// Optimizer rule that introduces CoalesceAsyncExec to reduce the number of async executions. +#[derive(Default, Debug)] +pub struct CoalesceAsyncExecInput {} + +impl CoalesceAsyncExecInput { + #[allow(missing_docs)] + pub fn new() -> Self { + Self::default() + } +} + +impl PhysicalOptimizerRule for CoalesceAsyncExecInput { + fn optimize( + &self, + plan: Arc, + config: &ConfigOptions, + ) -> datafusion_common::Result> { + let target_batch_size = config.execution.batch_size; + plan.transform(|plan| { + if let Some(async_exec) = plan.as_any().downcast_ref::() { + if async_exec.children().len() != 1 { + return internal_err!( + "Expected AsyncFuncExec to have exactly one child" + ); + } + let child = Arc::clone(async_exec.children()[0]); + let coalesce_exec = + Arc::new(CoalesceBatchesExec::new(child, target_batch_size)); + let coalesce_async_exec = plan.with_new_children(vec![coalesce_exec])?; + Ok(Transformed::yes(coalesce_async_exec)) + } else { + Ok(Transformed::no(plan)) + } + }) + .data() + } + + fn name(&self) -> &str { + "coalesce_async_exec_input" + } + + fn schema_check(&self) -> bool { + true + } +} diff --git a/datafusion/physical-optimizer/src/coalesce_batches.rs b/datafusion/physical-optimizer/src/coalesce_batches.rs index 5e7fe94e693d..5cf2c877c61a 100644 --- a/datafusion/physical-optimizer/src/coalesce_batches.rs +++ b/datafusion/physical-optimizer/src/coalesce_batches.rs @@ -31,7 +31,6 @@ use datafusion_physical_plan::{ }; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_physical_plan::async_func::AsyncFuncExec; /// Optimizer rule that introduces CoalesceBatchesExec to avoid overhead with small batches that /// are produced by highly selective filters @@ -63,7 +62,6 @@ impl PhysicalOptimizerRule for CoalesceBatches { // See https://github.com/apache/datafusion/issues/139 let wrap_in_coalesce = plan_any.downcast_ref::().is_some() || plan_any.downcast_ref::().is_some() - || plan_any.downcast_ref::().is_some() // Don't need to add CoalesceBatchesExec after a round robin RepartitionExec || plan_any .downcast_ref::() diff --git a/datafusion/physical-optimizer/src/lib.rs b/datafusion/physical-optimizer/src/lib.rs index f8d7b3b74614..51037d18cb9d 100644 --- a/datafusion/physical-optimizer/src/lib.rs +++ b/datafusion/physical-optimizer/src/lib.rs @@ -25,6 +25,7 @@ #![deny(clippy::clone_on_ref_ptr)] pub mod aggregate_statistics; +pub mod coalesce_async_exec_input; pub mod coalesce_batches; pub mod combine_partial_final_agg; pub mod enforce_distribution; diff --git a/datafusion/physical-optimizer/src/optimizer.rs b/datafusion/physical-optimizer/src/optimizer.rs index 38ec92b7d116..f9ad521b4f7c 100644 --- a/datafusion/physical-optimizer/src/optimizer.rs +++ b/datafusion/physical-optimizer/src/optimizer.rs @@ -36,6 +36,7 @@ use crate::sanity_checker::SanityCheckPlan; use crate::topk_aggregation::TopKAggregation; use crate::update_aggr_exprs::OptimizeAggregateOrder; +use crate::coalesce_async_exec_input::CoalesceAsyncExecInput; use datafusion_common::config::ConfigOptions; use datafusion_common::Result; use datafusion_physical_plan::ExecutionPlan; @@ -121,6 +122,7 @@ impl PhysicalOptimizer { // The CoalesceBatches rule will not influence the distribution and ordering of the // whole plan tree. Therefore, to avoid influencing other rules, it should run last. Arc::new(CoalesceBatches::new()), + Arc::new(CoalesceAsyncExecInput::new()), // Remove the ancillary output requirement operator since we are done with the planning // phase. Arc::new(OutputRequirements::new_remove_mode()), diff --git a/datafusion/physical-plan/src/async_func.rs b/datafusion/physical-plan/src/async_func.rs index c70cc4b161c9..e6e59f495dcd 100644 --- a/datafusion/physical-plan/src/async_func.rs +++ b/datafusion/physical-plan/src/async_func.rs @@ -23,7 +23,7 @@ use crate::{ use arrow::array::RecordBatch; use arrow_schema::{Fields, Schema, SchemaRef}; use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; -use datafusion_common::Result; +use datafusion_common::{internal_err, Result}; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_physical_expr::async_scalar_function::AsyncFuncExpr; use datafusion_physical_expr::equivalence::ProjectionMapping; @@ -144,11 +144,14 @@ impl ExecutionPlan for AsyncFuncExec { fn with_new_children( self: Arc, - _children: Vec>, + children: Vec>, ) -> Result> { + if children.len() != 1 { + return internal_err!("AsyncFuncExec wrong number of children"); + } Ok(Arc::new(AsyncFuncExec::try_new( self.async_exprs.clone(), - Arc::clone(&self.input), + Arc::clone(&children[0]), )?)) } From fe12d72c7cdf85177654e19b8ee359ea42ca4df2 Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Sun, 23 Feb 2025 21:06:10 +0800 Subject: [PATCH 06/20] simple example --- datafusion-examples/examples/async_udf.rs | 255 ++++++++++++++++++++++ 1 file changed, 255 insertions(+) create mode 100644 datafusion-examples/examples/async_udf.rs diff --git a/datafusion-examples/examples/async_udf.rs b/datafusion-examples/examples/async_udf.rs new file mode 100644 index 000000000000..3e5cda2f1961 --- /dev/null +++ b/datafusion-examples/examples/async_udf.rs @@ -0,0 +1,255 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayIter, ArrayRef, AsArray, Int64Array, RecordBatch, StringArray}; +use arrow::compute::kernels::cmp::eq; +use arrow_schema::{DataType, Field, Schema}; +use async_trait::async_trait; +use datafusion::common::error::Result; +use datafusion::common::internal_err; +use datafusion::common::types::{logical_int64, logical_string}; +use datafusion::common::utils::take_function_args; +use datafusion::config::ConfigOptions; +use datafusion::execution::{FunctionRegistry, SessionStateBuilder}; +use datafusion::logical_expr::async_udf::{ + AsyncScalarFunctionArgs, AsyncScalarUDF, AsyncScalarUDFImpl, +}; +use datafusion::logical_expr::{ + ColumnarValue, Signature, TypeSignature, TypeSignatureClass, Volatility, +}; +use datafusion::logical_expr_common::signature::Coercion; +use datafusion::physical_expr_common::datum::apply_cmp; +use datafusion::prelude::SessionContext; +use log::trace; +use std::any::Any; +use std::sync::Arc; + +#[tokio::main] +async fn main() -> Result<()> { + let mut state = SessionStateBuilder::new().build(); + + let async_upper = AsyncUpper::new(); + let udf = AsyncScalarUDF::new(Arc::new(async_upper)); + state.register_udf(udf.into_scalar_udf())?; + let async_equal = AsyncEqual::new(); + let udf = AsyncScalarUDF::new(Arc::new(async_equal)); + state.register_udf(udf.into_scalar_udf())?; + let ctx = SessionContext::new_with_state(state); + ctx.register_batch("animal", animal()?)?; + + // use Async UDF in the projection + // +---------------+----------------------------------------------------------------------------------------+ + // | plan_type | plan | + // +---------------+----------------------------------------------------------------------------------------+ + // | logical_plan | Projection: async_equal(a.id, Int64(1)) | + // | | SubqueryAlias: a | + // | | TableScan: animal projection=[id] | + // | physical_plan | ProjectionExec: expr=[__async_fn_0@1 as async_equal(a.id,Int64(1))] | + // | | AsyncFuncExec: async_expr=[async_expr(name=__async_fn_0, expr=async_equal(id@0, 1))] | + // | | CoalesceBatchesExec: target_batch_size=8192 | + // | | DataSourceExec: partitions=1, partition_sizes=[1] | + // | | | + // +---------------+----------------------------------------------------------------------------------------+ + ctx.sql("explain select async_equal(a.id, 1) from animal a") + .await? + .show() + .await?; + + // +----------------------------+ + // | async_equal(a.id,Int64(1)) | + // +----------------------------+ + // | true | + // | false | + // | false | + // | false | + // | false | + // +----------------------------+ + ctx.sql("select async_equal(a.id, 1) from animal a") + .await? + .show() + .await?; + + // +---------------+--------------------------------------------------------------------------------------------+ + // | plan_type | plan | + // +---------------+--------------------------------------------------------------------------------------------+ + // | logical_plan | SubqueryAlias: a | + // | | Filter: async_equal(animal.id, Int64(1)) | + // | | TableScan: animal projection=[id, name] | + // | physical_plan | CoalesceBatchesExec: target_batch_size=8192 | + // | | FilterExec: __async_fn_0@2, projection=[id@0, name@1] | + // | | RepartitionExec: partitioning=RoundRobinBatch(12), input_partitions=1 | + // | | AsyncFuncExec: async_expr=[async_expr(name=__async_fn_0, expr=async_equal(id@0, 1))] | + // | | CoalesceBatchesExec: target_batch_size=8192 | + // | | DataSourceExec: partitions=1, partition_sizes=[1] | + // | | | + // +---------------+--------------------------------------------------------------------------------------------+ + ctx.sql("explain select * from animal a where async_equal(a.id, 1)") + .await? + .show() + .await?; + + // +----+------+ + // | id | name | + // +----+------+ + // | 1 | cat | + // +----+------+ + ctx.sql("select * from animal a where async_equal(a.id, 1)") + .await? + .show() + .await?; + + Ok(()) +} + +fn animal() -> Result { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new("name", DataType::Utf8, false), + ])); + + let id_array = Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5])); + let name_array = Arc::new(StringArray::from(vec![ + "cat", "dog", "fish", "bird", "snake", + ])); + + Ok(RecordBatch::try_new(schema, vec![id_array, name_array])?) +} + +#[derive(Debug)] +pub struct AsyncUpper { + signature: Signature, +} + +impl Default for AsyncUpper { + fn default() -> Self { + Self::new() + } +} + +impl AsyncUpper { + pub fn new() -> Self { + Self { + signature: Signature::new( + TypeSignature::Coercible(vec![Coercion::Exact { + desired_type: TypeSignatureClass::Native(logical_string()), + }]), + Volatility::Volatile, + ), + } + } +} + +#[async_trait] +impl AsyncScalarUDFImpl for AsyncUpper { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "async_upper" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Utf8) + } + + fn ideal_batch_size(&self) -> Option { + Some(10) + } + + async fn invoke_async_with_args( + &self, + args: AsyncScalarFunctionArgs, + _option: &ConfigOptions, + ) -> Result { + trace!("Invoking async_upper with args: {:?}", args); + let value = &args.args[0]; + let result = match value { + ColumnarValue::Array(array) => { + let string_array = array.as_string::(); + let iter = ArrayIter::new(string_array); + let result = iter + .map(|string| string.map(|s| s.to_uppercase())) + .collect::(); + Arc::new(result) as ArrayRef + } + _ => return internal_err!("Expected a string argument, got {:?}", value), + }; + Ok(result) + } +} + +#[derive(Debug)] +struct AsyncEqual { + signature: Signature, +} + +impl Default for AsyncEqual { + fn default() -> Self { + Self::new() + } +} + +impl AsyncEqual { + pub fn new() -> Self { + Self { + signature: Signature::new( + TypeSignature::Coercible(vec![ + Coercion::Exact { + desired_type: TypeSignatureClass::Native(logical_int64()), + }, + Coercion::Exact { + desired_type: TypeSignatureClass::Native(logical_int64()), + }, + ]), + Volatility::Volatile, + ), + } + } +} + +#[async_trait] +impl AsyncScalarUDFImpl for AsyncEqual { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "async_equal" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Boolean) + } + + async fn invoke_async_with_args( + &self, + args: AsyncScalarFunctionArgs, + _option: &ConfigOptions, + ) -> Result { + let [arg1, arg2] = take_function_args(self.name(), &args.args)?; + apply_cmp(&arg1, &arg2, eq)?.to_array(args.number_rows) + } +} From 875d4e5083f6be6f2b9008f8770110676dfbe788 Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Sun, 23 Feb 2025 21:06:45 +0800 Subject: [PATCH 07/20] enhance comment --- datafusion-examples/examples/async_udf.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/datafusion-examples/examples/async_udf.rs b/datafusion-examples/examples/async_udf.rs index 3e5cda2f1961..a94906be56e4 100644 --- a/datafusion-examples/examples/async_udf.rs +++ b/datafusion-examples/examples/async_udf.rs @@ -83,6 +83,7 @@ async fn main() -> Result<()> { .show() .await?; + // use Async UDF in the filter // +---------------+--------------------------------------------------------------------------------------------+ // | plan_type | plan | // +---------------+--------------------------------------------------------------------------------------------+ From 017b11107eb313b5d0bee004cb65c7cec6ec2851 Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Sun, 23 Feb 2025 21:26:29 +0800 Subject: [PATCH 08/20] enhance doc and fix test --- datafusion/expr/src/async_udf.rs | 9 ++++----- datafusion/physical-expr/src/async_scalar_function.rs | 4 ++-- datafusion/physical-plan/src/async_func.rs | 5 +---- datafusion/sqllogictest/test_files/explain.slt | 3 +++ 4 files changed, 10 insertions(+), 11 deletions(-) diff --git a/datafusion/expr/src/async_udf.rs b/datafusion/expr/src/async_udf.rs index 13f95cd1d82f..2c72ddaa2bbf 100644 --- a/datafusion/expr/src/async_udf.rs +++ b/datafusion/expr/src/async_udf.rs @@ -33,7 +33,7 @@ use std::sync::Arc; /// Note this is less efficient than the ScalarUDFImpl, but it can be used /// to register remote functions in the context. /// -/// The name is chosen to mirror ScalarUDFImpl +/// The name is chosen to mirror ScalarUDFImpl #[async_trait] pub trait AsyncScalarUDFImpl: Debug + Send + Sync { /// the function cast as any @@ -49,6 +49,9 @@ pub trait AsyncScalarUDFImpl: Debug + Send + Sync { fn return_type(&self, _arg_types: &[DataType]) -> Result; /// The ideal batch size for this function. + /// + /// This is used to determine what size of data to be evaluated at once. + /// If None, the whole batch will be evaluated at once. fn ideal_batch_size(&self) -> Option { None } @@ -65,10 +68,6 @@ pub trait AsyncScalarUDFImpl: Debug + Send + Sync { /// /// Note this is not meant to be used directly, but is meant to be an implementation detail /// for AsyncUDFImpl. -/// -/// This is used to register remote functions in the context. The function -/// should not be invoked by DataFusion. It's only used to generate the logical -/// plan and unparsed them to SQL. #[derive(Debug)] pub struct AsyncScalarUDF { inner: Arc, diff --git a/datafusion/physical-expr/src/async_scalar_function.rs b/datafusion/physical-expr/src/async_scalar_function.rs index 609799c4f9bf..8573efaf3996 100644 --- a/datafusion/physical-expr/src/async_scalar_function.rs +++ b/datafusion/physical-expr/src/async_scalar_function.rs @@ -29,7 +29,7 @@ use std::fmt::Display; use std::hash::{Hash, Hasher}; use std::sync::Arc; -/// Wrapper for a Async function that can be used in a DataFusion query +/// Wrapper around a scalar function that can be evaluated asynchronously #[derive(Debug, Clone, Eq)] pub struct AsyncFuncExpr { /// The name of the output column this function will generate @@ -206,7 +206,7 @@ impl PhysicalExpr for AsyncFuncExpr { } fn evaluate(&self, _batch: &RecordBatch) -> Result { - // TODO: implement this + // TODO: implement this for scalar value input not_impl_err!("AsyncFuncExpr.evaluate") } diff --git a/datafusion/physical-plan/src/async_func.rs b/datafusion/physical-plan/src/async_func.rs index e6e59f495dcd..829f0b08d69b 100644 --- a/datafusion/physical-plan/src/async_func.rs +++ b/datafusion/physical-plan/src/async_func.rs @@ -35,11 +35,9 @@ use log::trace; use std::any::Any; use std::sync::Arc; -/// This structure evaluates a set of async expressions on a record +/// This structure evaluates a set of async expressions on a record /// batch producing a new record batch /// -/// This is similar to a ProjectionExec except that the functions can be async -/// /// The schema of the output of the AsyncFuncExec is: /// Input columns followed by one column for each async expression #[derive(Debug)] @@ -47,7 +45,6 @@ pub struct AsyncFuncExec { /// The async expressions to evaluate async_exprs: Vec>, input: Arc, - /// Cache holding plan properties like equivalences, output partitioning etc. cache: PlanProperties, metrics: ExecutionPlanMetricsSet, } diff --git a/datafusion/sqllogictest/test_files/explain.slt b/datafusion/sqllogictest/test_files/explain.slt index ae2ef67c041c..50575a3aba4d 100644 --- a/datafusion/sqllogictest/test_files/explain.slt +++ b/datafusion/sqllogictest/test_files/explain.slt @@ -238,6 +238,7 @@ physical_plan after EnforceSorting SAME TEXT AS ABOVE physical_plan after OptimizeAggregateOrder SAME TEXT AS ABOVE physical_plan after ProjectionPushdown SAME TEXT AS ABOVE physical_plan after coalesce_batches SAME TEXT AS ABOVE +physical_plan after coalesce_async_exec_input SAME TEXT AS ABOVE physical_plan after OutputRequirements DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], file_type=csv, has_header=true physical_plan after LimitAggregation SAME TEXT AS ABOVE physical_plan after LimitPushdown SAME TEXT AS ABOVE @@ -315,6 +316,7 @@ physical_plan after EnforceSorting SAME TEXT AS ABOVE physical_plan after OptimizeAggregateOrder SAME TEXT AS ABOVE physical_plan after ProjectionPushdown SAME TEXT AS ABOVE physical_plan after coalesce_batches SAME TEXT AS ABOVE +physical_plan after coalesce_async_exec_input SAME TEXT AS ABOVE physical_plan after OutputRequirements 01)GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Exact(8), Bytes=Exact(671), [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] 02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, file_type=parquet, statistics=[Rows=Exact(8), Bytes=Exact(671), [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] @@ -358,6 +360,7 @@ physical_plan after EnforceSorting SAME TEXT AS ABOVE physical_plan after OptimizeAggregateOrder SAME TEXT AS ABOVE physical_plan after ProjectionPushdown SAME TEXT AS ABOVE physical_plan after coalesce_batches SAME TEXT AS ABOVE +physical_plan after coalesce_async_exec_input SAME TEXT AS ABOVE physical_plan after OutputRequirements 01)GlobalLimitExec: skip=0, fetch=10 02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, file_type=parquet From da065ef221e77b3cd92e0632c3bf72db4e50499f Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Tue, 11 Mar 2025 21:29:48 +0800 Subject: [PATCH 09/20] fix clippy and fmt --- datafusion/expr/Cargo.toml | 1 - datafusion/physical-plan/src/async_func.rs | 19 ++++++++++++------- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/datafusion/expr/Cargo.toml b/datafusion/expr/Cargo.toml index 766b412bbe96..d77c59ff64e1 100644 --- a/datafusion/expr/Cargo.toml +++ b/datafusion/expr/Cargo.toml @@ -54,7 +54,6 @@ paste = "^1.0" recursive = { workspace = true, optional = true } serde_json = { workspace = true } sqlparser = { workspace = true } -async-trait = "0.1.86" [dev-dependencies] ctor = { workspace = true } diff --git a/datafusion/physical-plan/src/async_func.rs b/datafusion/physical-plan/src/async_func.rs index 829f0b08d69b..95c6b9f91d46 100644 --- a/datafusion/physical-plan/src/async_func.rs +++ b/datafusion/physical-plan/src/async_func.rs @@ -108,15 +108,20 @@ impl DisplayAs for AsyncFuncExec { t: DisplayFormatType, f: &mut std::fmt::Formatter, ) -> std::fmt::Result { + let expr: Vec = self + .async_exprs + .iter() + .map(|async_expr| async_expr.to_string()) + .collect(); + let exprs = expr.join(", "); match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { - let expr: Vec = self - .async_exprs - .iter() - .map(|async_expr| async_expr.to_string()) - .collect(); - - write!(f, "AsyncFuncExec: async_expr=[{}]", expr.join(", ")) + write!(f, "AsyncFuncExec: async_expr=[{}]", exprs) + } + DisplayFormatType::TreeRender => { + writeln!(f, "format=async_expr")?; + writeln!(f, "async_expr={}", exprs)?; + Ok(()) } } } From 3a2e7ffd4af228b0bae92eff56245457b97f44ac Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Tue, 11 Mar 2025 21:38:26 +0800 Subject: [PATCH 10/20] add missing dependency --- datafusion/expr/Cargo.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/datafusion/expr/Cargo.toml b/datafusion/expr/Cargo.toml index d77c59ff64e1..812544587bf9 100644 --- a/datafusion/expr/Cargo.toml +++ b/datafusion/expr/Cargo.toml @@ -42,6 +42,7 @@ recursive_protection = ["dep:recursive"] [dependencies] arrow = { workspace = true } +async-trait = { workspace = true } chrono = { workspace = true } datafusion-common = { workspace = true } datafusion-doc = { workspace = true } From cce1586442209621630f907bf590d9025f5bdd02 Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Tue, 11 Mar 2025 22:04:27 +0800 Subject: [PATCH 11/20] fix clippy --- datafusion-examples/examples/async_udf.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion-examples/examples/async_udf.rs b/datafusion-examples/examples/async_udf.rs index a94906be56e4..2644ac9697ec 100644 --- a/datafusion-examples/examples/async_udf.rs +++ b/datafusion-examples/examples/async_udf.rs @@ -251,6 +251,6 @@ impl AsyncScalarUDFImpl for AsyncEqual { _option: &ConfigOptions, ) -> Result { let [arg1, arg2] = take_function_args(self.name(), &args.args)?; - apply_cmp(&arg1, &arg2, eq)?.to_array(args.number_rows) + apply_cmp(arg1, arg2, eq)?.to_array(args.number_rows) } } From 5fbdd04ca5cd190ef2b2964599db24dc79930888 Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Sat, 22 Mar 2025 14:24:31 +0800 Subject: [PATCH 12/20] rename the symbol --- datafusion/physical-expr/src/async_scalar_function.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/datafusion/physical-expr/src/async_scalar_function.rs b/datafusion/physical-expr/src/async_scalar_function.rs index 8573efaf3996..c93ba1c5304b 100644 --- a/datafusion/physical-expr/src/async_scalar_function.rs +++ b/datafusion/physical-expr/src/async_scalar_function.rs @@ -107,7 +107,7 @@ impl AsyncFuncExpr { batch: &RecordBatch, option: &ConfigOptions, ) -> Result { - let Some(llm_function) = self.func.as_any().downcast_ref::() + let Some(scalar_function_expr) = self.func.as_any().downcast_ref::() else { return internal_err!( "unexpected function type, expected ScalarFunctionExpr, got: {:?}", @@ -115,7 +115,7 @@ impl AsyncFuncExpr { ); }; - let Some(async_udf) = llm_function + let Some(async_udf) = scalar_function_expr .fun() .inner() .as_any() @@ -123,7 +123,7 @@ impl AsyncFuncExpr { else { return not_impl_err!( "Don't know how to evaluate async function: {:?}", - llm_function + scalar_function_expr ); }; @@ -139,7 +139,7 @@ impl AsyncFuncExpr { let current_batch = remainder.slice(0, size); // get next 10 rows remainder = remainder.slice(size, remainder.num_rows() - size); - let args = llm_function + let args = scalar_function_expr .args() .iter() .map(|e| e.evaluate(¤t_batch)) @@ -158,7 +158,7 @@ impl AsyncFuncExpr { ); } } else { - let args = llm_function + let args = scalar_function_expr .args() .iter() .map(|e| e.evaluate(batch)) From 92e8144a45e3f77d2d77e0e1575f5323e42aec34 Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Sat, 22 Mar 2025 14:26:07 +0800 Subject: [PATCH 13/20] cargo fmt --- datafusion/physical-expr/src/async_scalar_function.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/datafusion/physical-expr/src/async_scalar_function.rs b/datafusion/physical-expr/src/async_scalar_function.rs index c93ba1c5304b..304fa4305136 100644 --- a/datafusion/physical-expr/src/async_scalar_function.rs +++ b/datafusion/physical-expr/src/async_scalar_function.rs @@ -107,7 +107,8 @@ impl AsyncFuncExpr { batch: &RecordBatch, option: &ConfigOptions, ) -> Result { - let Some(scalar_function_expr) = self.func.as_any().downcast_ref::() + let Some(scalar_function_expr) = + self.func.as_any().downcast_ref::() else { return internal_err!( "unexpected function type, expected ScalarFunctionExpr, got: {:?}", From ea0ce9874df6177c311127e12db1e60cc850587e Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Sat, 22 Mar 2025 14:33:55 +0800 Subject: [PATCH 14/20] fix fmt and rebase --- datafusion/expr/src/async_udf.rs | 4 ++-- .../physical-expr/src/async_scalar_function.rs | 12 ++++++++---- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/datafusion/expr/src/async_udf.rs b/datafusion/expr/src/async_udf.rs index 2c72ddaa2bbf..3716ac565a85 100644 --- a/datafusion/expr/src/async_udf.rs +++ b/datafusion/expr/src/async_udf.rs @@ -112,8 +112,8 @@ impl ScalarUDFImpl for AsyncScalarUDF { self.inner.signature() } - fn return_type(&self, _arg_types: &[DataType]) -> Result { - self.inner.return_type(_arg_types) + fn return_type(&self, arg_types: &[DataType]) -> Result { + self.inner.return_type(arg_types) } fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { diff --git a/datafusion/physical-expr/src/async_scalar_function.rs b/datafusion/physical-expr/src/async_scalar_function.rs index 304fa4305136..97ba642b06e3 100644 --- a/datafusion/physical-expr/src/async_scalar_function.rs +++ b/datafusion/physical-expr/src/async_scalar_function.rs @@ -198,12 +198,12 @@ impl PhysicalExpr for AsyncFuncExpr { self } - fn data_type(&self, _input_schema: &Schema) -> Result { - self.func.data_type(_input_schema) + fn data_type(&self, input_schema: &Schema) -> Result { + self.func.data_type(input_schema) } - fn nullable(&self, _input_schema: &Schema) -> Result { - self.func.nullable(_input_schema) + fn nullable(&self, input_schema: &Schema) -> Result { + self.func.nullable(input_schema) } fn evaluate(&self, _batch: &RecordBatch) -> Result { @@ -225,4 +225,8 @@ impl PhysicalExpr for AsyncFuncExpr { func: new_func, })) } + + fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.func) + } } From 4d4145b61e8302478b2e40b1cee98bdca5bb2b37 Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Sun, 18 May 2025 15:04:00 +0800 Subject: [PATCH 15/20] add return_field_from_args for async scalar udf --- datafusion/expr/src/async_udf.rs | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/datafusion/expr/src/async_udf.rs b/datafusion/expr/src/async_udf.rs index 3716ac565a85..c65039c21264 100644 --- a/datafusion/expr/src/async_udf.rs +++ b/datafusion/expr/src/async_udf.rs @@ -15,9 +15,9 @@ // specific language governing permissions and limitations // under the License. -use crate::{ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl}; +use crate::{ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl}; use arrow::array::ArrayRef; -use arrow::datatypes::{DataType, SchemaRef}; +use arrow::datatypes::{DataType, Field, SchemaRef}; use async_trait::async_trait; use datafusion_common::config::ConfigOptions; use datafusion_common::error::Result; @@ -48,6 +48,21 @@ pub trait AsyncScalarUDFImpl: Debug + Send + Sync { /// The return type of the function fn return_type(&self, _arg_types: &[DataType]) -> Result; + /// What type will be returned by this function, given the arguments? + /// + /// By default, this function calls [`Self::return_type`] with the + /// types of each argument. + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let data_types = args + .arg_fields + .iter() + .map(|f| f.data_type()) + .cloned() + .collect::>(); + let return_type = self.return_type(&data_types)?; + Ok(Field::new(self.name(), return_type, true)) + } + /// The ideal batch size for this function. /// /// This is used to determine what size of data to be evaluated at once. @@ -116,6 +131,10 @@ impl ScalarUDFImpl for AsyncScalarUDF { self.inner.return_type(arg_types) } + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + self.inner.return_field_from_args(args) + } + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { internal_err!("async functions should not be called directly") } From 68b22030c1819125ee47112ca349cf83a1d17764 Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Sun, 18 May 2025 16:09:42 +0800 Subject: [PATCH 16/20] modified into_scalar_udf method --- datafusion-examples/examples/async_udf.rs | 8 +++----- datafusion/expr/src/async_udf.rs | 4 ++-- datafusion/physical-plan/src/async_func.rs | 4 ++-- 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/datafusion-examples/examples/async_udf.rs b/datafusion-examples/examples/async_udf.rs index 2644ac9697ec..57925a0980a1 100644 --- a/datafusion-examples/examples/async_udf.rs +++ b/datafusion-examples/examples/async_udf.rs @@ -24,7 +24,6 @@ use datafusion::common::internal_err; use datafusion::common::types::{logical_int64, logical_string}; use datafusion::common::utils::take_function_args; use datafusion::config::ConfigOptions; -use datafusion::execution::{FunctionRegistry, SessionStateBuilder}; use datafusion::logical_expr::async_udf::{ AsyncScalarFunctionArgs, AsyncScalarUDF, AsyncScalarUDFImpl, }; @@ -40,15 +39,14 @@ use std::sync::Arc; #[tokio::main] async fn main() -> Result<()> { - let mut state = SessionStateBuilder::new().build(); + let ctx: SessionContext = SessionContext::new(); let async_upper = AsyncUpper::new(); let udf = AsyncScalarUDF::new(Arc::new(async_upper)); - state.register_udf(udf.into_scalar_udf())?; + ctx.register_udf(udf.into_scalar_udf()); let async_equal = AsyncEqual::new(); let udf = AsyncScalarUDF::new(Arc::new(async_equal)); - state.register_udf(udf.into_scalar_udf())?; - let ctx = SessionContext::new_with_state(state); + ctx.register_udf(udf.into_scalar_udf()); ctx.register_batch("animal", animal()?)?; // use Async UDF in the projection diff --git a/datafusion/expr/src/async_udf.rs b/datafusion/expr/src/async_udf.rs index c65039c21264..118273cbebd0 100644 --- a/datafusion/expr/src/async_udf.rs +++ b/datafusion/expr/src/async_udf.rs @@ -100,8 +100,8 @@ impl AsyncScalarUDF { /// Turn this AsyncUDF into a ScalarUDF, suitable for /// registering in the context - pub fn into_scalar_udf(self) -> Arc { - Arc::new(ScalarUDF::new_from_impl(self)) + pub fn into_scalar_udf(self) -> ScalarUDF { + ScalarUDF::new_from_impl(self) } /// Invoke the function asynchronously with the async arguments diff --git a/datafusion/physical-plan/src/async_func.rs b/datafusion/physical-plan/src/async_func.rs index 95c6b9f91d46..5a91f5856422 100644 --- a/datafusion/physical-plan/src/async_func.rs +++ b/datafusion/physical-plan/src/async_func.rs @@ -116,11 +116,11 @@ impl DisplayAs for AsyncFuncExec { let exprs = expr.join(", "); match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { - write!(f, "AsyncFuncExec: async_expr=[{}]", exprs) + write!(f, "AsyncFuncExec: async_expr=[{exprs}]") } DisplayFormatType::TreeRender => { writeln!(f, "format=async_expr")?; - writeln!(f, "async_expr={}", exprs)?; + writeln!(f, "async_expr={exprs}")?; Ok(()) } } From 6f05ec340e3bd7be2ce2bdc50ef7048df809b59f Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Sun, 18 May 2025 16:10:24 +0800 Subject: [PATCH 17/20] add the async scalar udf in udfs doc --- .../functions/adding-udfs.md | 134 ++++++++++++++++-- 1 file changed, 123 insertions(+), 11 deletions(-) diff --git a/docs/source/library-user-guide/functions/adding-udfs.md b/docs/source/library-user-guide/functions/adding-udfs.md index cd40e664239a..6eb3d96c83b7 100644 --- a/docs/source/library-user-guide/functions/adding-udfs.md +++ b/docs/source/library-user-guide/functions/adding-udfs.md @@ -23,12 +23,13 @@ User Defined Functions (UDFs) are functions that can be used in the context of D This page covers how to add UDFs to DataFusion. In particular, it covers how to add Scalar, Window, and Aggregate UDFs. -| UDF Type | Description | Example | -| --------- | ---------------------------------------------------------------------------------------------------------- | ------------------- | -| Scalar | A function that takes a row of data and returns a single value. | [simple_udf.rs][1] | -| Window | A function that takes a row of data and returns a single value, but also has access to the rows around it. | [simple_udwf.rs][2] | -| Aggregate | A function that takes a group of rows and returns a single value. | [simple_udaf.rs][3] | -| Table | A function that takes parameters and returns a `TableProvider` to be used in an query plan. | [simple_udtf.rs][4] | +| UDF Type | Description | Example | +| ------------ | ---------------------------------------------------------------------------------------------------------- | ------------------- | +| Scalar | A function that takes a row of data and returns a single value. | [simple_udf.rs][1] | +| Window | A function that takes a row of data and returns a single value, but also has access to the rows around it. | [simple_udwf.rs][2] | +| Aggregate | A function that takes a group of rows and returns a single value. | [simple_udaf.rs][3] | +| Table | A function that takes parameters and returns a `TableProvider` to be used in an query plan. | [simple_udtf.rs][4] | +| Async Scalar | A scalar function that natively supports asynchronous execution, allowing you to perform async operations (such as network or I/O calls) within the UDF. | [async_udf.rs][5] | First we'll talk about adding an Scalar UDF end-to-end, then we'll talk about the differences between the different types of UDFs. @@ -344,6 +345,122 @@ async fn main() { } ``` +## Adding a Scalar Async UDF + +A Scalar Async UDF allows you to implement user-defined functions that support asynchronous execution, such as performing network or I/O operations within the UDF. + +To add a Scalar Async UDF, you need to: + +1. Implement the `AsyncScalarUDFImpl` trait to define your async function logic, signature, and types. +2. Wrap your implementation with `AsyncScalarUDF::new` and register it with the `SessionContext`. + +### Adding by `impl AsyncScalarUDFImpl` + +```rust +use arrow::array::{ArrayIter, ArrayRef, AsArray, StringArray}; +use arrow_schema::DataType; +use async_trait::async_trait; +use datafusion::common::error::Result; +use datafusion::common::internal_err; +use datafusion::common::types::logical_string; +use datafusion::config::ConfigOptions; +use datafusion::logical_expr::async_udf::{ + AsyncScalarFunctionArgs, AsyncScalarUDFImpl, +}; +use datafusion::logical_expr::{ + ColumnarValue, Signature, TypeSignature, TypeSignatureClass, Volatility, +}; +use datafusion::logical_expr_common::signature::Coercion; +use log::trace; +use std::any::Any; +use std::sync::Arc; + +#[derive(Debug)] +pub struct AsyncUpper { + signature: Signature, +} + +impl Default for AsyncUpper { + fn default() -> Self { + Self::new() + } +} + +impl AsyncUpper { + pub fn new() -> Self { + Self { + signature: Signature::new( + TypeSignature::Coercible(vec![Coercion::Exact { + desired_type: TypeSignatureClass::Native(logical_string()), + }]), + Volatility::Volatile, + ), + } + } +} + +#[async_trait] +impl AsyncScalarUDFImpl for AsyncUpper { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "async_upper" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Utf8) + } + + fn ideal_batch_size(&self) -> Option { + Some(10) + } + + async fn invoke_async_with_args( + &self, + args: AsyncScalarFunctionArgs, + _option: &ConfigOptions, + ) -> Result { + trace!("Invoking async_upper with args: {:?}", args); + let value = &args.args[0]; + let result = match value { + ColumnarValue::Array(array) => { + let string_array = array.as_string::(); + let iter = ArrayIter::new(string_array); + let result = iter + .map(|string| string.map(|s| s.to_uppercase())) + .collect::(); + Arc::new(result) as ArrayRef + } + _ => return internal_err!("Expected a string argument, got {:?}", value), + }; + Ok(result) + } +} +``` + +We can now transfer the async UDF into the normal scalar using `into_scalar_udf` to register the function with DataFusion so that it can be used in the context of a query. + +```rust +let async_upper = AsyncUpper::new(); +let udf = AsyncScalarUDF::new(Arc::new(async_upper)); +ctx.register_udf(udf.into_scalar_udf()); +``` + +After registration, you can use these async UDFs directly in SQL queries, for example: + +```sql +SELECT async_upper('datafusion'); +``` + +For async UDF implementation details, see [`async_udf.rs`](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/async_udf.rs). + + [`scalarudf`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/struct.ScalarUDF.html [`create_udf`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/fn.create_udf.html [`process_scalar_func_inputs`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/functions/fn.process_scalar_func_inputs.html @@ -1244,8 +1361,3 @@ async fn main() -> Result<()> { Ok(()) } ``` - -[1]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/simple_udf.rs -[2]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/simple_udwf.rs -[3]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/simple_udaf.rs -[4]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/simple_udtf.rs From 68debd565b5af4ec4cee82f31fc15f91bd6d02ac Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Sun, 18 May 2025 16:59:20 +0800 Subject: [PATCH 18/20] pretty doc --- .../library-user-guide/functions/adding-udfs.md | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/docs/source/library-user-guide/functions/adding-udfs.md b/docs/source/library-user-guide/functions/adding-udfs.md index 6eb3d96c83b7..717314d5bad7 100644 --- a/docs/source/library-user-guide/functions/adding-udfs.md +++ b/docs/source/library-user-guide/functions/adding-udfs.md @@ -23,12 +23,12 @@ User Defined Functions (UDFs) are functions that can be used in the context of D This page covers how to add UDFs to DataFusion. In particular, it covers how to add Scalar, Window, and Aggregate UDFs. -| UDF Type | Description | Example | -| ------------ | ---------------------------------------------------------------------------------------------------------- | ------------------- | -| Scalar | A function that takes a row of data and returns a single value. | [simple_udf.rs][1] | -| Window | A function that takes a row of data and returns a single value, but also has access to the rows around it. | [simple_udwf.rs][2] | -| Aggregate | A function that takes a group of rows and returns a single value. | [simple_udaf.rs][3] | -| Table | A function that takes parameters and returns a `TableProvider` to be used in an query plan. | [simple_udtf.rs][4] | +| UDF Type | Description | Example | +| ------------ | -------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------- | +| Scalar | A function that takes a row of data and returns a single value. | [simple_udf.rs][1] | +| Window | A function that takes a row of data and returns a single value, but also has access to the rows around it. | [simple_udwf.rs][2] | +| Aggregate | A function that takes a group of rows and returns a single value. | [simple_udaf.rs][3] | +| Table | A function that takes parameters and returns a `TableProvider` to be used in an query plan. | [simple_udtf.rs][4] | | Async Scalar | A scalar function that natively supports asynchronous execution, allowing you to perform async operations (such as network or I/O calls) within the UDF. | [async_udf.rs][5] | First we'll talk about adding an Scalar UDF end-to-end, then we'll talk about the differences between the different @@ -460,7 +460,6 @@ SELECT async_upper('datafusion'); For async UDF implementation details, see [`async_udf.rs`](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/async_udf.rs). - [`scalarudf`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/struct.ScalarUDF.html [`create_udf`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/fn.create_udf.html [`process_scalar_func_inputs`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/functions/fn.process_scalar_func_inputs.html From 98cf8e2b765ebbb55773df2be8c4876428595d04 Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Sun, 18 May 2025 17:20:17 +0800 Subject: [PATCH 19/20] fix doc test --- .../functions/adding-udfs.md | 89 +++++++++++++++++++ 1 file changed, 89 insertions(+) diff --git a/docs/source/library-user-guide/functions/adding-udfs.md b/docs/source/library-user-guide/functions/adding-udfs.md index 717314d5bad7..66ffd69b4545 100644 --- a/docs/source/library-user-guide/functions/adding-udfs.md +++ b/docs/source/library-user-guide/functions/adding-udfs.md @@ -447,8 +447,97 @@ impl AsyncScalarUDFImpl for AsyncUpper { We can now transfer the async UDF into the normal scalar using `into_scalar_udf` to register the function with DataFusion so that it can be used in the context of a query. ```rust +# use arrow::array::{ArrayIter, ArrayRef, AsArray, StringArray}; +# use arrow_schema::DataType; +# use async_trait::async_trait; +# use datafusion::common::error::Result; +# use datafusion::common::internal_err; +# use datafusion::common::types::logical_string; +# use datafusion::config::ConfigOptions; +# use datafusion::logical_expr::async_udf::{ +# AsyncScalarFunctionArgs, AsyncScalarUDFImpl, +# }; +# use datafusion::logical_expr::{ +# ColumnarValue, Signature, TypeSignature, TypeSignatureClass, Volatility, +# }; +# use datafusion::logical_expr_common::signature::Coercion; +# use log::trace; +# use std::any::Any; +# use std::sync::Arc; +# +# #[derive(Debug)] +# pub struct AsyncUpper { +# signature: Signature, +# } +# +# impl Default for AsyncUpper { +# fn default() -> Self { +# Self::new() +# } +# } +# +# impl AsyncUpper { +# pub fn new() -> Self { +# Self { +# signature: Signature::new( +# TypeSignature::Coercible(vec![Coercion::Exact { +# desired_type: TypeSignatureClass::Native(logical_string()), +# }]), +# Volatility::Volatile, +# ), +# } +# } +# } +# +# #[async_trait] +# impl AsyncScalarUDFImpl for AsyncUpper { +# fn as_any(&self) -> &dyn Any { +# self +# } +# +# fn name(&self) -> &str { +# "async_upper" +# } +# +# fn signature(&self) -> &Signature { +# &self.signature +# } +# +# fn return_type(&self, _arg_types: &[DataType]) -> Result { +# Ok(DataType::Utf8) +# } +# +# fn ideal_batch_size(&self) -> Option { +# Some(10) +# } +# +# async fn invoke_async_with_args( +# &self, +# args: AsyncScalarFunctionArgs, +# _option: &ConfigOptions, +# ) -> Result { +# trace!("Invoking async_upper with args: {:?}", args); +# let value = &args.args[0]; +# let result = match value { +# ColumnarValue::Array(array) => { +# let string_array = array.as_string::(); +# let iter = ArrayIter::new(string_array); +# let result = iter +# .map(|string| string.map(|s| s.to_uppercase())) +# .collect::(); +# Arc::new(result) as ArrayRef +# } +# _ => return internal_err!("Expected a string argument, got {:?}", value), +# }; +# Ok(result) +# } +# } +use datafusion::execution::context::SessionContext; +use datafusion::logical_expr::async_udf::AsyncScalarUDF; + let async_upper = AsyncUpper::new(); let udf = AsyncScalarUDF::new(Arc::new(async_upper)); +let mut ctx = SessionContext::new(); ctx.register_udf(udf.into_scalar_udf()); ``` From 5f55674afdb4168163e6ca6fd45218e0ba2c3f10 Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Mon, 23 Jun 2025 20:57:59 +0800 Subject: [PATCH 20/20] fix merge conflict --- datafusion/core/src/physical_planner.rs | 1 - datafusion/expr/src/async_udf.rs | 8 ++++---- datafusion/physical-plan/src/async_func.rs | 2 +- datafusion/physical-plan/src/lib.rs | 2 +- 4 files changed, 6 insertions(+), 7 deletions(-) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 70ad782609e4..8bf513a55a66 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -96,7 +96,6 @@ use datafusion_physical_plan::unnest::ListUnnest; use sqlparser::ast::NullTreatment; use async_trait::async_trait; -use datafusion_datasource::file_groups::FileGroup; use datafusion_physical_plan::async_func::{AsyncFuncExec, AsyncMapper}; use futures::{StreamExt, TryStreamExt}; use itertools::{multiunzip, Itertools}; diff --git a/datafusion/expr/src/async_udf.rs b/datafusion/expr/src/async_udf.rs index 118273cbebd0..4f2b593b421a 100644 --- a/datafusion/expr/src/async_udf.rs +++ b/datafusion/expr/src/async_udf.rs @@ -17,7 +17,7 @@ use crate::{ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl}; use arrow::array::ArrayRef; -use arrow::datatypes::{DataType, Field, SchemaRef}; +use arrow::datatypes::{DataType, Field, FieldRef, SchemaRef}; use async_trait::async_trait; use datafusion_common::config::ConfigOptions; use datafusion_common::error::Result; @@ -52,7 +52,7 @@ pub trait AsyncScalarUDFImpl: Debug + Send + Sync { /// /// By default, this function calls [`Self::return_type`] with the /// types of each argument. - fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { let data_types = args .arg_fields .iter() @@ -60,7 +60,7 @@ pub trait AsyncScalarUDFImpl: Debug + Send + Sync { .cloned() .collect::>(); let return_type = self.return_type(&data_types)?; - Ok(Field::new(self.name(), return_type, true)) + Ok(Arc::new(Field::new(self.name(), return_type, true))) } /// The ideal batch size for this function. @@ -131,7 +131,7 @@ impl ScalarUDFImpl for AsyncScalarUDF { self.inner.return_type(arg_types) } - fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { self.inner.return_field_from_args(args) } diff --git a/datafusion/physical-plan/src/async_func.rs b/datafusion/physical-plan/src/async_func.rs index 5a91f5856422..c808c5711755 100644 --- a/datafusion/physical-plan/src/async_func.rs +++ b/datafusion/physical-plan/src/async_func.rs @@ -73,7 +73,7 @@ impl AsyncFuncExec { .iter() .map(|expr| (Arc::clone(&expr.func), expr.name().to_string())) .collect::>(); - let async_expr_mapping = ProjectionMapping::try_new(&tuples, &input.schema())?; + let async_expr_mapping = ProjectionMapping::try_new(tuples, &input.schema())?; let cache = AsyncFuncExec::compute_properties(&input, schema, &async_expr_mapping)?; Ok(Self { diff --git a/datafusion/physical-plan/src/lib.rs b/datafusion/physical-plan/src/lib.rs index a9cca2cbc2ac..4d3adebb91c6 100644 --- a/datafusion/physical-plan/src/lib.rs +++ b/datafusion/physical-plan/src/lib.rs @@ -59,8 +59,8 @@ mod visitor; pub mod aggregates; pub mod analyze; -pub mod coalesce; pub mod async_func; +pub mod coalesce; pub mod coalesce_batches; pub mod coalesce_partitions; pub mod common;