diff --git a/Cargo.lock b/Cargo.lock index 7147548b4b2e..2ce381711cd6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2215,6 +2215,7 @@ name = "datafusion-expr" version = "48.0.0" dependencies = [ "arrow", + "async-trait", "chrono", "ctor", "datafusion-common", diff --git a/datafusion-examples/examples/async_udf.rs b/datafusion-examples/examples/async_udf.rs new file mode 100644 index 000000000000..57925a0980a1 --- /dev/null +++ b/datafusion-examples/examples/async_udf.rs @@ -0,0 +1,254 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayIter, ArrayRef, AsArray, Int64Array, RecordBatch, StringArray}; +use arrow::compute::kernels::cmp::eq; +use arrow_schema::{DataType, Field, Schema}; +use async_trait::async_trait; +use datafusion::common::error::Result; +use datafusion::common::internal_err; +use datafusion::common::types::{logical_int64, logical_string}; +use datafusion::common::utils::take_function_args; +use datafusion::config::ConfigOptions; +use datafusion::logical_expr::async_udf::{ + AsyncScalarFunctionArgs, AsyncScalarUDF, AsyncScalarUDFImpl, +}; +use datafusion::logical_expr::{ + ColumnarValue, Signature, TypeSignature, TypeSignatureClass, Volatility, +}; +use datafusion::logical_expr_common::signature::Coercion; +use datafusion::physical_expr_common::datum::apply_cmp; +use datafusion::prelude::SessionContext; +use log::trace; +use std::any::Any; +use std::sync::Arc; + +#[tokio::main] +async fn main() -> Result<()> { + let ctx: SessionContext = SessionContext::new(); + + let async_upper = AsyncUpper::new(); + let udf = AsyncScalarUDF::new(Arc::new(async_upper)); + ctx.register_udf(udf.into_scalar_udf()); + let async_equal = AsyncEqual::new(); + let udf = AsyncScalarUDF::new(Arc::new(async_equal)); + ctx.register_udf(udf.into_scalar_udf()); + ctx.register_batch("animal", animal()?)?; + + // use Async UDF in the projection + // +---------------+----------------------------------------------------------------------------------------+ + // | plan_type | plan | + // +---------------+----------------------------------------------------------------------------------------+ + // | logical_plan | Projection: async_equal(a.id, Int64(1)) | + // | | SubqueryAlias: a | + // | | TableScan: animal projection=[id] | + // | physical_plan | ProjectionExec: expr=[__async_fn_0@1 as async_equal(a.id,Int64(1))] | + // | | AsyncFuncExec: async_expr=[async_expr(name=__async_fn_0, expr=async_equal(id@0, 1))] | + // | | CoalesceBatchesExec: target_batch_size=8192 | + // | | DataSourceExec: partitions=1, partition_sizes=[1] | + // | | | + // +---------------+----------------------------------------------------------------------------------------+ + ctx.sql("explain select async_equal(a.id, 1) from animal a") + .await? + .show() + .await?; + + // +----------------------------+ + // | async_equal(a.id,Int64(1)) | + // +----------------------------+ + // | true | + // | false | + // | false | + // | false | + // | false | + // +----------------------------+ + ctx.sql("select async_equal(a.id, 1) from animal a") + .await? + .show() + .await?; + + // use Async UDF in the filter + // +---------------+--------------------------------------------------------------------------------------------+ + // | plan_type | plan | + // +---------------+--------------------------------------------------------------------------------------------+ + // | logical_plan | SubqueryAlias: a | + // | | Filter: async_equal(animal.id, Int64(1)) | + // | | TableScan: animal projection=[id, name] | + // | physical_plan | CoalesceBatchesExec: target_batch_size=8192 | + // | | FilterExec: __async_fn_0@2, projection=[id@0, name@1] | + // | | RepartitionExec: partitioning=RoundRobinBatch(12), input_partitions=1 | + // | | AsyncFuncExec: async_expr=[async_expr(name=__async_fn_0, expr=async_equal(id@0, 1))] | + // | | CoalesceBatchesExec: target_batch_size=8192 | + // | | DataSourceExec: partitions=1, partition_sizes=[1] | + // | | | + // +---------------+--------------------------------------------------------------------------------------------+ + ctx.sql("explain select * from animal a where async_equal(a.id, 1)") + .await? + .show() + .await?; + + // +----+------+ + // | id | name | + // +----+------+ + // | 1 | cat | + // +----+------+ + ctx.sql("select * from animal a where async_equal(a.id, 1)") + .await? + .show() + .await?; + + Ok(()) +} + +fn animal() -> Result { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new("name", DataType::Utf8, false), + ])); + + let id_array = Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5])); + let name_array = Arc::new(StringArray::from(vec![ + "cat", "dog", "fish", "bird", "snake", + ])); + + Ok(RecordBatch::try_new(schema, vec![id_array, name_array])?) +} + +#[derive(Debug)] +pub struct AsyncUpper { + signature: Signature, +} + +impl Default for AsyncUpper { + fn default() -> Self { + Self::new() + } +} + +impl AsyncUpper { + pub fn new() -> Self { + Self { + signature: Signature::new( + TypeSignature::Coercible(vec![Coercion::Exact { + desired_type: TypeSignatureClass::Native(logical_string()), + }]), + Volatility::Volatile, + ), + } + } +} + +#[async_trait] +impl AsyncScalarUDFImpl for AsyncUpper { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "async_upper" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Utf8) + } + + fn ideal_batch_size(&self) -> Option { + Some(10) + } + + async fn invoke_async_with_args( + &self, + args: AsyncScalarFunctionArgs, + _option: &ConfigOptions, + ) -> Result { + trace!("Invoking async_upper with args: {:?}", args); + let value = &args.args[0]; + let result = match value { + ColumnarValue::Array(array) => { + let string_array = array.as_string::(); + let iter = ArrayIter::new(string_array); + let result = iter + .map(|string| string.map(|s| s.to_uppercase())) + .collect::(); + Arc::new(result) as ArrayRef + } + _ => return internal_err!("Expected a string argument, got {:?}", value), + }; + Ok(result) + } +} + +#[derive(Debug)] +struct AsyncEqual { + signature: Signature, +} + +impl Default for AsyncEqual { + fn default() -> Self { + Self::new() + } +} + +impl AsyncEqual { + pub fn new() -> Self { + Self { + signature: Signature::new( + TypeSignature::Coercible(vec![ + Coercion::Exact { + desired_type: TypeSignatureClass::Native(logical_int64()), + }, + Coercion::Exact { + desired_type: TypeSignatureClass::Native(logical_int64()), + }, + ]), + Volatility::Volatile, + ), + } + } +} + +#[async_trait] +impl AsyncScalarUDFImpl for AsyncEqual { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "async_equal" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Boolean) + } + + async fn invoke_async_with_args( + &self, + args: AsyncScalarFunctionArgs, + _option: &ConfigOptions, + ) -> Result { + let [arg1, arg2] = take_function_args(self.name(), &args.args)?; + apply_cmp(arg1, arg2, eq)?.to_array(args.number_rows) + } +} diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index c1121d59bb3f..8bf513a55a66 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -96,6 +96,7 @@ use datafusion_physical_plan::unnest::ListUnnest; use sqlparser::ast::NullTreatment; use async_trait::async_trait; +use datafusion_physical_plan::async_func::{AsyncFuncExec, AsyncMapper}; use futures::{StreamExt, TryStreamExt}; use itertools::{multiunzip, Itertools}; use log::{debug, trace}; @@ -777,12 +778,44 @@ impl DefaultPhysicalPlanner { let runtime_expr = self.create_physical_expr(predicate, input_dfschema, session_state)?; + + let filter = match self.try_plan_async_exprs( + input.schema().fields().len(), + PlannedExprResult::Expr(vec![runtime_expr]), + )? { + PlanAsyncExpr::Sync(PlannedExprResult::Expr(runtime_expr)) => { + FilterExec::try_new(Arc::clone(&runtime_expr[0]), physical_input)? + } + PlanAsyncExpr::Async( + async_map, + PlannedExprResult::Expr(runtime_expr), + ) => { + let async_exec = AsyncFuncExec::try_new( + async_map.async_exprs, + physical_input, + )?; + FilterExec::try_new( + Arc::clone(&runtime_expr[0]), + Arc::new(async_exec), + )? + // project the output columns excluding the async functions + // The async functions are always appended to the end of the schema. + .with_projection(Some( + (0..input.schema().fields().len()).collect(), + ))? + } + _ => { + return internal_err!( + "Unexpected result from try_plan_async_exprs" + ) + } + }; + let selectivity = session_state .config() .options() .optimizer .default_filter_selectivity; - let filter = FilterExec::try_new(runtime_expr, physical_input)?; Arc::new(filter.with_default_selectivity(selectivity)?) } LogicalPlan::Repartition(Repartition { @@ -2044,13 +2077,89 @@ impl DefaultPhysicalPlanner { }) .collect::>>()?; - Ok(Arc::new(ProjectionExec::try_new( - physical_exprs, - input_exec, - )?)) + let num_input_columns = input_exec.schema().fields().len(); + + match self.try_plan_async_exprs( + num_input_columns, + PlannedExprResult::ExprWithName(physical_exprs), + )? { + PlanAsyncExpr::Sync(PlannedExprResult::ExprWithName(physical_exprs)) => Ok( + Arc::new(ProjectionExec::try_new(physical_exprs, input_exec)?), + ), + PlanAsyncExpr::Async( + async_map, + PlannedExprResult::ExprWithName(physical_exprs), + ) => { + let async_exec = + AsyncFuncExec::try_new(async_map.async_exprs, input_exec)?; + let new_proj_exec = + ProjectionExec::try_new(physical_exprs, Arc::new(async_exec))?; + Ok(Arc::new(new_proj_exec)) + } + _ => internal_err!("Unexpected PlanAsyncExpressions variant"), + } + } + + fn try_plan_async_exprs( + &self, + num_input_columns: usize, + physical_expr: PlannedExprResult, + ) -> Result { + let mut async_map = AsyncMapper::new(num_input_columns); + match &physical_expr { + PlannedExprResult::ExprWithName(exprs) => { + exprs + .iter() + .try_for_each(|(expr, _)| async_map.find_references(expr))?; + } + PlannedExprResult::Expr(exprs) => { + exprs + .iter() + .try_for_each(|expr| async_map.find_references(expr))?; + } + } + + if async_map.is_empty() { + return Ok(PlanAsyncExpr::Sync(physical_expr)); + } + + let new_exprs = match physical_expr { + PlannedExprResult::ExprWithName(exprs) => PlannedExprResult::ExprWithName( + exprs + .iter() + .map(|(expr, column_name)| { + let new_expr = Arc::clone(expr) + .transform_up(|e| Ok(async_map.map_expr(e)))?; + Ok((new_expr.data, column_name.to_string())) + }) + .collect::>()?, + ), + PlannedExprResult::Expr(exprs) => PlannedExprResult::Expr( + exprs + .iter() + .map(|expr| { + let new_expr = Arc::clone(expr) + .transform_up(|e| Ok(async_map.map_expr(e)))?; + Ok(new_expr.data) + }) + .collect::>()?, + ), + }; + // rewrite the projection's expressions in terms of the columns with the result of async evaluation + Ok(PlanAsyncExpr::Async(async_map, new_exprs)) } } +enum PlannedExprResult { + ExprWithName(Vec<(Arc, String)>), + Expr(Vec>), +} + +enum PlanAsyncExpr { + Sync(PlannedExprResult), + Async(AsyncMapper, PlannedExprResult), +} + fn tuple_err(value: (Result, Result)) -> Result<(T, R)> { match value { (Ok(e), Ok(e1)) => Ok((e, e1)), diff --git a/datafusion/expr/Cargo.toml b/datafusion/expr/Cargo.toml index d77c59ff64e1..812544587bf9 100644 --- a/datafusion/expr/Cargo.toml +++ b/datafusion/expr/Cargo.toml @@ -42,6 +42,7 @@ recursive_protection = ["dep:recursive"] [dependencies] arrow = { workspace = true } +async-trait = { workspace = true } chrono = { workspace = true } datafusion-common = { workspace = true } datafusion-doc = { workspace = true } diff --git a/datafusion/expr/src/async_udf.rs b/datafusion/expr/src/async_udf.rs new file mode 100644 index 000000000000..4f2b593b421a --- /dev/null +++ b/datafusion/expr/src/async_udf.rs @@ -0,0 +1,154 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::{ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl}; +use arrow::array::ArrayRef; +use arrow::datatypes::{DataType, Field, FieldRef, SchemaRef}; +use async_trait::async_trait; +use datafusion_common::config::ConfigOptions; +use datafusion_common::error::Result; +use datafusion_common::internal_err; +use datafusion_expr_common::columnar_value::ColumnarValue; +use datafusion_expr_common::signature::Signature; +use std::any::Any; +use std::fmt::{Debug, Display}; +use std::sync::Arc; + +/// A scalar UDF that can invoke using async methods +/// +/// Note this is less efficient than the ScalarUDFImpl, but it can be used +/// to register remote functions in the context. +/// +/// The name is chosen to mirror ScalarUDFImpl +#[async_trait] +pub trait AsyncScalarUDFImpl: Debug + Send + Sync { + /// the function cast as any + fn as_any(&self) -> &dyn Any; + + /// The name of the function + fn name(&self) -> &str; + + /// The signature of the function + fn signature(&self) -> &Signature; + + /// The return type of the function + fn return_type(&self, _arg_types: &[DataType]) -> Result; + + /// What type will be returned by this function, given the arguments? + /// + /// By default, this function calls [`Self::return_type`] with the + /// types of each argument. + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let data_types = args + .arg_fields + .iter() + .map(|f| f.data_type()) + .cloned() + .collect::>(); + let return_type = self.return_type(&data_types)?; + Ok(Arc::new(Field::new(self.name(), return_type, true))) + } + + /// The ideal batch size for this function. + /// + /// This is used to determine what size of data to be evaluated at once. + /// If None, the whole batch will be evaluated at once. + fn ideal_batch_size(&self) -> Option { + None + } + + /// Invoke the function asynchronously with the async arguments + async fn invoke_async_with_args( + &self, + args: AsyncScalarFunctionArgs, + option: &ConfigOptions, + ) -> Result; +} + +/// A scalar UDF that must be invoked using async methods +/// +/// Note this is not meant to be used directly, but is meant to be an implementation detail +/// for AsyncUDFImpl. +#[derive(Debug)] +pub struct AsyncScalarUDF { + inner: Arc, +} + +impl AsyncScalarUDF { + pub fn new(inner: Arc) -> Self { + Self { inner } + } + + /// The ideal batch size for this function + pub fn ideal_batch_size(&self) -> Option { + self.inner.ideal_batch_size() + } + + /// Turn this AsyncUDF into a ScalarUDF, suitable for + /// registering in the context + pub fn into_scalar_udf(self) -> ScalarUDF { + ScalarUDF::new_from_impl(self) + } + + /// Invoke the function asynchronously with the async arguments + pub async fn invoke_async_with_args( + &self, + args: AsyncScalarFunctionArgs, + option: &ConfigOptions, + ) -> Result { + self.inner.invoke_async_with_args(args, option).await + } +} + +impl ScalarUDFImpl for AsyncScalarUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + self.inner.name() + } + + fn signature(&self) -> &Signature { + self.inner.signature() + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + self.inner.return_type(arg_types) + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + self.inner.return_field_from_args(args) + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + internal_err!("async functions should not be called directly") + } +} + +impl Display for AsyncScalarUDF { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "AsyncScalarUDF: {}", self.inner.name()) + } +} + +#[derive(Debug)] +pub struct AsyncScalarFunctionArgs { + pub args: Vec, + pub number_rows: usize, + pub schema: SchemaRef, +} diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 1f44f755b214..0c822bbb337b 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -63,6 +63,7 @@ pub mod simplify; pub mod sort_properties { pub use datafusion_expr_common::sort_properties::*; } +pub mod async_udf; pub mod statistics { pub use datafusion_expr_common::statistics::*; } diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 7139816c10e7..8b6ffba04ff6 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -17,6 +17,7 @@ //! [`ScalarUDF`]: Scalar User Defined Functions +use crate::async_udf::AsyncScalarUDF; use crate::expr::schema_name_from_exprs_comma_separated_without_space; use crate::simplify::{ExprSimplifyResult, SimplifyInfo}; use crate::sort_properties::{ExprProperties, SortProperties}; @@ -280,6 +281,11 @@ impl ScalarUDF { pub fn documentation(&self) -> Option<&Documentation> { self.inner.documentation() } + + /// Return true if this function is an async function + pub fn as_async(&self) -> Option<&AsyncScalarUDF> { + self.inner().as_any().downcast_ref::() + } } impl From for ScalarUDF diff --git a/datafusion/physical-expr/src/async_scalar_function.rs b/datafusion/physical-expr/src/async_scalar_function.rs new file mode 100644 index 000000000000..97ba642b06e3 --- /dev/null +++ b/datafusion/physical-expr/src/async_scalar_function.rs @@ -0,0 +1,232 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::ScalarFunctionExpr; +use arrow::array::{make_array, MutableArrayData, RecordBatch}; +use arrow::datatypes::{DataType, Field, Schema}; +use datafusion_common::config::ConfigOptions; +use datafusion_common::Result; +use datafusion_common::{internal_err, not_impl_err}; +use datafusion_expr::async_udf::{AsyncScalarFunctionArgs, AsyncScalarUDF}; +use datafusion_expr_common::columnar_value::ColumnarValue; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use std::any::Any; +use std::fmt::Display; +use std::hash::{Hash, Hasher}; +use std::sync::Arc; + +/// Wrapper around a scalar function that can be evaluated asynchronously +#[derive(Debug, Clone, Eq)] +pub struct AsyncFuncExpr { + /// The name of the output column this function will generate + pub name: String, + /// The actual function (always `ScalarFunctionExpr`) + pub func: Arc, +} + +impl Display for AsyncFuncExpr { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "async_expr(name={}, expr={})", self.name, self.func) + } +} + +impl PartialEq for AsyncFuncExpr { + fn eq(&self, other: &Self) -> bool { + self.name == other.name && self.func == Arc::clone(&other.func) + } +} + +impl Hash for AsyncFuncExpr { + fn hash(&self, state: &mut H) { + self.name.hash(state); + self.func.as_ref().hash(state); + } +} + +impl AsyncFuncExpr { + /// create a new AsyncFuncExpr + pub fn try_new(name: impl Into, func: Arc) -> Result { + let Some(_) = func.as_any().downcast_ref::() else { + return internal_err!( + "unexpected function type, expected ScalarFunctionExpr, got: {:?}", + func + ); + }; + + Ok(Self { + name: name.into(), + func, + }) + } + + /// return the name of the output column + pub fn name(&self) -> &str { + &self.name + } + + /// Return the output field generated by evaluating this function + pub fn field(&self, input_schema: &Schema) -> Result { + Ok(Field::new( + &self.name, + self.func.data_type(input_schema)?, + self.func.nullable(input_schema)?, + )) + } + + /// Return the ideal batch size for this function + pub fn ideal_batch_size(&self) -> Result> { + if let Some(expr) = self.func.as_any().downcast_ref::() { + if let Some(udf) = + expr.fun().inner().as_any().downcast_ref::() + { + return Ok(udf.ideal_batch_size()); + } + } + not_impl_err!("Can't get ideal_batch_size from {:?}", self.func) + } + + /// This (async) function is called for each record batch to evaluate the LLM expressions + /// + /// The output is the output of evaluating the async expression and the input record batch + pub async fn invoke_with_args( + &self, + batch: &RecordBatch, + option: &ConfigOptions, + ) -> Result { + let Some(scalar_function_expr) = + self.func.as_any().downcast_ref::() + else { + return internal_err!( + "unexpected function type, expected ScalarFunctionExpr, got: {:?}", + self.func + ); + }; + + let Some(async_udf) = scalar_function_expr + .fun() + .inner() + .as_any() + .downcast_ref::() + else { + return not_impl_err!( + "Don't know how to evaluate async function: {:?}", + scalar_function_expr + ); + }; + + let mut result_batches = vec![]; + if let Some(ideal_batch_size) = self.ideal_batch_size()? { + let mut remainder = batch.clone(); + while remainder.num_rows() > 0 { + let size = if ideal_batch_size > remainder.num_rows() { + remainder.num_rows() + } else { + ideal_batch_size + }; + + let current_batch = remainder.slice(0, size); // get next 10 rows + remainder = remainder.slice(size, remainder.num_rows() - size); + let args = scalar_function_expr + .args() + .iter() + .map(|e| e.evaluate(¤t_batch)) + .collect::>>()?; + result_batches.push( + async_udf + .invoke_async_with_args( + AsyncScalarFunctionArgs { + args: args.to_vec(), + number_rows: current_batch.num_rows(), + schema: current_batch.schema(), + }, + option, + ) + .await?, + ); + } + } else { + let args = scalar_function_expr + .args() + .iter() + .map(|e| e.evaluate(batch)) + .collect::>>()?; + + result_batches.push( + async_udf + .invoke_async_with_args( + AsyncScalarFunctionArgs { + args: args.to_vec(), + number_rows: batch.num_rows(), + schema: batch.schema(), + }, + option, + ) + .await?, + ); + } + + let datas = result_batches + .iter() + .map(|b| b.to_data()) + .collect::>(); + let total_len = datas.iter().map(|d| d.len()).sum(); + let mut mutable = MutableArrayData::new(datas.iter().collect(), false, total_len); + datas.iter().enumerate().for_each(|(i, data)| { + mutable.extend(i, 0, data.len()); + }); + let array_ref = make_array(mutable.freeze()); + Ok(ColumnarValue::Array(array_ref)) + } +} + +impl PhysicalExpr for AsyncFuncExpr { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, input_schema: &Schema) -> Result { + self.func.data_type(input_schema) + } + + fn nullable(&self, input_schema: &Schema) -> Result { + self.func.nullable(input_schema) + } + + fn evaluate(&self, _batch: &RecordBatch) -> Result { + // TODO: implement this for scalar value input + not_impl_err!("AsyncFuncExpr.evaluate") + } + + fn children(&self) -> Vec<&Arc> { + self.func.children() + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + let new_func = Arc::clone(&self.func).with_new_children(children)?; + Ok(Arc::new(AsyncFuncExpr { + name: self.name.clone(), + func: new_func, + })) + } + + fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.func) + } +} diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index 6741f94c9545..be60e26cc2d2 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -30,6 +30,7 @@ pub mod analysis; pub mod binary_map { pub use datafusion_physical_expr_common::binary_map::{ArrowBytesSet, OutputType}; } +pub mod async_scalar_function; pub mod equivalence; pub mod expressions; pub mod intervals; diff --git a/datafusion/physical-optimizer/src/coalesce_async_exec_input.rs b/datafusion/physical-optimizer/src/coalesce_async_exec_input.rs new file mode 100644 index 000000000000..0b46c68f2dae --- /dev/null +++ b/datafusion/physical-optimizer/src/coalesce_async_exec_input.rs @@ -0,0 +1,71 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::PhysicalOptimizerRule; +use datafusion_common::config::ConfigOptions; +use datafusion_common::internal_err; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_physical_plan::async_func::AsyncFuncExec; +use datafusion_physical_plan::coalesce_batches::CoalesceBatchesExec; +use datafusion_physical_plan::ExecutionPlan; +use std::sync::Arc; + +/// Optimizer rule that introduces CoalesceAsyncExec to reduce the number of async executions. +#[derive(Default, Debug)] +pub struct CoalesceAsyncExecInput {} + +impl CoalesceAsyncExecInput { + #[allow(missing_docs)] + pub fn new() -> Self { + Self::default() + } +} + +impl PhysicalOptimizerRule for CoalesceAsyncExecInput { + fn optimize( + &self, + plan: Arc, + config: &ConfigOptions, + ) -> datafusion_common::Result> { + let target_batch_size = config.execution.batch_size; + plan.transform(|plan| { + if let Some(async_exec) = plan.as_any().downcast_ref::() { + if async_exec.children().len() != 1 { + return internal_err!( + "Expected AsyncFuncExec to have exactly one child" + ); + } + let child = Arc::clone(async_exec.children()[0]); + let coalesce_exec = + Arc::new(CoalesceBatchesExec::new(child, target_batch_size)); + let coalesce_async_exec = plan.with_new_children(vec![coalesce_exec])?; + Ok(Transformed::yes(coalesce_async_exec)) + } else { + Ok(Transformed::no(plan)) + } + }) + .data() + } + + fn name(&self) -> &str { + "coalesce_async_exec_input" + } + + fn schema_check(&self) -> bool { + true + } +} diff --git a/datafusion/physical-optimizer/src/lib.rs b/datafusion/physical-optimizer/src/lib.rs index f8d7b3b74614..51037d18cb9d 100644 --- a/datafusion/physical-optimizer/src/lib.rs +++ b/datafusion/physical-optimizer/src/lib.rs @@ -25,6 +25,7 @@ #![deny(clippy::clone_on_ref_ptr)] pub mod aggregate_statistics; +pub mod coalesce_async_exec_input; pub mod coalesce_batches; pub mod combine_partial_final_agg; pub mod enforce_distribution; diff --git a/datafusion/physical-optimizer/src/optimizer.rs b/datafusion/physical-optimizer/src/optimizer.rs index 38ec92b7d116..f9ad521b4f7c 100644 --- a/datafusion/physical-optimizer/src/optimizer.rs +++ b/datafusion/physical-optimizer/src/optimizer.rs @@ -36,6 +36,7 @@ use crate::sanity_checker::SanityCheckPlan; use crate::topk_aggregation::TopKAggregation; use crate::update_aggr_exprs::OptimizeAggregateOrder; +use crate::coalesce_async_exec_input::CoalesceAsyncExecInput; use datafusion_common::config::ConfigOptions; use datafusion_common::Result; use datafusion_physical_plan::ExecutionPlan; @@ -121,6 +122,7 @@ impl PhysicalOptimizer { // The CoalesceBatches rule will not influence the distribution and ordering of the // whole plan tree. Therefore, to avoid influencing other rules, it should run last. Arc::new(CoalesceBatches::new()), + Arc::new(CoalesceAsyncExecInput::new()), // Remove the ancillary output requirement operator since we are done with the planning // phase. Arc::new(OutputRequirements::new_remove_mode()), diff --git a/datafusion/physical-plan/src/async_func.rs b/datafusion/physical-plan/src/async_func.rs new file mode 100644 index 000000000000..c808c5711755 --- /dev/null +++ b/datafusion/physical-plan/src/async_func.rs @@ -0,0 +1,297 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet}; +use crate::stream::RecordBatchStreamAdapter; +use crate::{ + DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties, +}; +use arrow::array::RecordBatch; +use arrow_schema::{Fields, Schema, SchemaRef}; +use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; +use datafusion_common::{internal_err, Result}; +use datafusion_execution::{SendableRecordBatchStream, TaskContext}; +use datafusion_physical_expr::async_scalar_function::AsyncFuncExpr; +use datafusion_physical_expr::equivalence::ProjectionMapping; +use datafusion_physical_expr::expressions::Column; +use datafusion_physical_expr::ScalarFunctionExpr; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use futures::stream::StreamExt; +use log::trace; +use std::any::Any; +use std::sync::Arc; + +/// This structure evaluates a set of async expressions on a record +/// batch producing a new record batch +/// +/// The schema of the output of the AsyncFuncExec is: +/// Input columns followed by one column for each async expression +#[derive(Debug)] +pub struct AsyncFuncExec { + /// The async expressions to evaluate + async_exprs: Vec>, + input: Arc, + cache: PlanProperties, + metrics: ExecutionPlanMetricsSet, +} + +impl AsyncFuncExec { + pub fn try_new( + async_exprs: Vec>, + input: Arc, + ) -> Result { + let async_fields = async_exprs + .iter() + .map(|async_expr| async_expr.field(input.schema().as_ref())) + .collect::>>()?; + + // compute the output schema: input schema then async expressions + let fields: Fields = input + .schema() + .fields() + .iter() + .cloned() + .chain(async_fields.into_iter().map(Arc::new)) + .collect(); + + let schema = Arc::new(Schema::new(fields)); + let tuples = async_exprs + .iter() + .map(|expr| (Arc::clone(&expr.func), expr.name().to_string())) + .collect::>(); + let async_expr_mapping = ProjectionMapping::try_new(tuples, &input.schema())?; + let cache = + AsyncFuncExec::compute_properties(&input, schema, &async_expr_mapping)?; + Ok(Self { + input, + async_exprs, + cache, + metrics: ExecutionPlanMetricsSet::new(), + }) + } + + /// This function creates the cache object that stores the plan properties + /// such as schema, equivalence properties, ordering, partitioning, etc. + fn compute_properties( + input: &Arc, + schema: SchemaRef, + async_expr_mapping: &ProjectionMapping, + ) -> Result { + Ok(PlanProperties::new( + input + .equivalence_properties() + .project(async_expr_mapping, schema), + input.output_partitioning().clone(), + input.pipeline_behavior(), + input.boundedness(), + )) + } +} + +impl DisplayAs for AsyncFuncExec { + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + let expr: Vec = self + .async_exprs + .iter() + .map(|async_expr| async_expr.to_string()) + .collect(); + let exprs = expr.join(", "); + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "AsyncFuncExec: async_expr=[{exprs}]") + } + DisplayFormatType::TreeRender => { + writeln!(f, "format=async_expr")?; + writeln!(f, "async_expr={exprs}")?; + Ok(()) + } + } + } +} + +impl ExecutionPlan for AsyncFuncExec { + fn name(&self) -> &str { + "async_func" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + &self.cache + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.input] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + if children.len() != 1 { + return internal_err!("AsyncFuncExec wrong number of children"); + } + Ok(Arc::new(AsyncFuncExec::try_new( + self.async_exprs.clone(), + Arc::clone(&children[0]), + )?)) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + trace!( + "Start AsyncFuncExpr::execute for partition {} of context session_id {} and task_id {:?}", + partition, + context.session_id(), + context.task_id() + ); + // TODO figure out how to record metrics + + // first execute the input stream + let input_stream = self.input.execute(partition, Arc::clone(&context))?; + + // now, for each record batch, evaluate the async expressions and add the columns to the result + let async_exprs_captured = Arc::new(self.async_exprs.clone()); + let schema_captured = self.schema(); + let config_option_ref = Arc::new(context.session_config().options().clone()); + + let stream_with_async_functions = input_stream.then(move |batch| { + // need to clone *again* to capture the async_exprs and schema in the + // stream and satisfy lifetime requirements. + let async_exprs_captured = Arc::clone(&async_exprs_captured); + let schema_captured = Arc::clone(&schema_captured); + let config_option = Arc::clone(&config_option_ref); + + async move { + let batch = batch?; + // append the result of evaluating the async expressions to the output + let mut output_arrays = batch.columns().to_vec(); + for async_expr in async_exprs_captured.iter() { + let output = + async_expr.invoke_with_args(&batch, &config_option).await?; + output_arrays.push(output.to_array(batch.num_rows())?); + } + let batch = RecordBatch::try_new(schema_captured, output_arrays)?; + Ok(batch) + } + }); + + // Adapt the stream with the output schema + let adapter = + RecordBatchStreamAdapter::new(self.schema(), stream_with_async_functions); + Ok(Box::pin(adapter)) + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } +} + +const ASYNC_FN_PREFIX: &str = "__async_fn_"; + +/// Maps async_expressions to new columns +/// +/// The output of the async functions are appended, in order, to the end of the input schema +#[derive(Debug)] +pub struct AsyncMapper { + /// the number of columns in the input plan + /// used to generate the output column names. + /// the first async expr is `__async_fn_0`, the second is `__async_fn_1`, etc + num_input_columns: usize, + /// the expressions to map + pub async_exprs: Vec>, +} + +impl AsyncMapper { + pub fn new(num_input_columns: usize) -> Self { + Self { + num_input_columns, + async_exprs: Vec::new(), + } + } + + pub fn is_empty(&self) -> bool { + self.async_exprs.is_empty() + } + + pub fn next_column_name(&self) -> String { + format!("{}{}", ASYNC_FN_PREFIX, self.async_exprs.len()) + } + + /// Finds any references to async functions in the expression and adds them to the map + pub fn find_references( + &mut self, + physical_expr: &Arc, + ) -> Result<()> { + // recursively look for references to async functions + physical_expr.apply(|expr| { + if let Some(scalar_func_expr) = + expr.as_any().downcast_ref::() + { + if scalar_func_expr.fun().as_async().is_some() { + let next_name = self.next_column_name(); + self.async_exprs.push(Arc::new(AsyncFuncExpr::try_new( + next_name, + Arc::clone(expr), + )?)); + } + } + Ok(TreeNodeRecursion::Continue) + })?; + Ok(()) + } + + /// If the expression matches any of the async functions, return the new column + pub fn map_expr( + &self, + expr: Arc, + ) -> Transformed> { + // find the first matching async function if any + let Some(idx) = + self.async_exprs + .iter() + .enumerate() + .find_map(|(idx, async_expr)| { + if async_expr.func == Arc::clone(&expr) { + Some(idx) + } else { + None + } + }) + else { + return Transformed::no(expr); + }; + // rewrite in terms of the output column + Transformed::yes(self.output_column(idx)) + } + + /// return the output column for the async function at index idx + pub fn output_column(&self, idx: usize) -> Arc { + let async_expr = &self.async_exprs[idx]; + let output_idx = self.num_input_columns + idx; + Arc::new(Column::new(async_expr.name(), output_idx)) + } +} diff --git a/datafusion/physical-plan/src/lib.rs b/datafusion/physical-plan/src/lib.rs index 22ae859e8c5b..4d3adebb91c6 100644 --- a/datafusion/physical-plan/src/lib.rs +++ b/datafusion/physical-plan/src/lib.rs @@ -59,6 +59,7 @@ mod visitor; pub mod aggregates; pub mod analyze; +pub mod async_func; pub mod coalesce; pub mod coalesce_batches; pub mod coalesce_partitions; diff --git a/datafusion/sqllogictest/test_files/explain.slt b/datafusion/sqllogictest/test_files/explain.slt index ae2ef67c041c..50575a3aba4d 100644 --- a/datafusion/sqllogictest/test_files/explain.slt +++ b/datafusion/sqllogictest/test_files/explain.slt @@ -238,6 +238,7 @@ physical_plan after EnforceSorting SAME TEXT AS ABOVE physical_plan after OptimizeAggregateOrder SAME TEXT AS ABOVE physical_plan after ProjectionPushdown SAME TEXT AS ABOVE physical_plan after coalesce_batches SAME TEXT AS ABOVE +physical_plan after coalesce_async_exec_input SAME TEXT AS ABOVE physical_plan after OutputRequirements DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], file_type=csv, has_header=true physical_plan after LimitAggregation SAME TEXT AS ABOVE physical_plan after LimitPushdown SAME TEXT AS ABOVE @@ -315,6 +316,7 @@ physical_plan after EnforceSorting SAME TEXT AS ABOVE physical_plan after OptimizeAggregateOrder SAME TEXT AS ABOVE physical_plan after ProjectionPushdown SAME TEXT AS ABOVE physical_plan after coalesce_batches SAME TEXT AS ABOVE +physical_plan after coalesce_async_exec_input SAME TEXT AS ABOVE physical_plan after OutputRequirements 01)GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Exact(8), Bytes=Exact(671), [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] 02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, file_type=parquet, statistics=[Rows=Exact(8), Bytes=Exact(671), [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] @@ -358,6 +360,7 @@ physical_plan after EnforceSorting SAME TEXT AS ABOVE physical_plan after OptimizeAggregateOrder SAME TEXT AS ABOVE physical_plan after ProjectionPushdown SAME TEXT AS ABOVE physical_plan after coalesce_batches SAME TEXT AS ABOVE +physical_plan after coalesce_async_exec_input SAME TEXT AS ABOVE physical_plan after OutputRequirements 01)GlobalLimitExec: skip=0, fetch=10 02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, file_type=parquet diff --git a/docs/source/library-user-guide/functions/adding-udfs.md b/docs/source/library-user-guide/functions/adding-udfs.md index cd40e664239a..66ffd69b4545 100644 --- a/docs/source/library-user-guide/functions/adding-udfs.md +++ b/docs/source/library-user-guide/functions/adding-udfs.md @@ -23,12 +23,13 @@ User Defined Functions (UDFs) are functions that can be used in the context of D This page covers how to add UDFs to DataFusion. In particular, it covers how to add Scalar, Window, and Aggregate UDFs. -| UDF Type | Description | Example | -| --------- | ---------------------------------------------------------------------------------------------------------- | ------------------- | -| Scalar | A function that takes a row of data and returns a single value. | [simple_udf.rs][1] | -| Window | A function that takes a row of data and returns a single value, but also has access to the rows around it. | [simple_udwf.rs][2] | -| Aggregate | A function that takes a group of rows and returns a single value. | [simple_udaf.rs][3] | -| Table | A function that takes parameters and returns a `TableProvider` to be used in an query plan. | [simple_udtf.rs][4] | +| UDF Type | Description | Example | +| ------------ | -------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------- | +| Scalar | A function that takes a row of data and returns a single value. | [simple_udf.rs][1] | +| Window | A function that takes a row of data and returns a single value, but also has access to the rows around it. | [simple_udwf.rs][2] | +| Aggregate | A function that takes a group of rows and returns a single value. | [simple_udaf.rs][3] | +| Table | A function that takes parameters and returns a `TableProvider` to be used in an query plan. | [simple_udtf.rs][4] | +| Async Scalar | A scalar function that natively supports asynchronous execution, allowing you to perform async operations (such as network or I/O calls) within the UDF. | [async_udf.rs][5] | First we'll talk about adding an Scalar UDF end-to-end, then we'll talk about the differences between the different types of UDFs. @@ -344,6 +345,210 @@ async fn main() { } ``` +## Adding a Scalar Async UDF + +A Scalar Async UDF allows you to implement user-defined functions that support asynchronous execution, such as performing network or I/O operations within the UDF. + +To add a Scalar Async UDF, you need to: + +1. Implement the `AsyncScalarUDFImpl` trait to define your async function logic, signature, and types. +2. Wrap your implementation with `AsyncScalarUDF::new` and register it with the `SessionContext`. + +### Adding by `impl AsyncScalarUDFImpl` + +```rust +use arrow::array::{ArrayIter, ArrayRef, AsArray, StringArray}; +use arrow_schema::DataType; +use async_trait::async_trait; +use datafusion::common::error::Result; +use datafusion::common::internal_err; +use datafusion::common::types::logical_string; +use datafusion::config::ConfigOptions; +use datafusion::logical_expr::async_udf::{ + AsyncScalarFunctionArgs, AsyncScalarUDFImpl, +}; +use datafusion::logical_expr::{ + ColumnarValue, Signature, TypeSignature, TypeSignatureClass, Volatility, +}; +use datafusion::logical_expr_common::signature::Coercion; +use log::trace; +use std::any::Any; +use std::sync::Arc; + +#[derive(Debug)] +pub struct AsyncUpper { + signature: Signature, +} + +impl Default for AsyncUpper { + fn default() -> Self { + Self::new() + } +} + +impl AsyncUpper { + pub fn new() -> Self { + Self { + signature: Signature::new( + TypeSignature::Coercible(vec![Coercion::Exact { + desired_type: TypeSignatureClass::Native(logical_string()), + }]), + Volatility::Volatile, + ), + } + } +} + +#[async_trait] +impl AsyncScalarUDFImpl for AsyncUpper { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "async_upper" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Utf8) + } + + fn ideal_batch_size(&self) -> Option { + Some(10) + } + + async fn invoke_async_with_args( + &self, + args: AsyncScalarFunctionArgs, + _option: &ConfigOptions, + ) -> Result { + trace!("Invoking async_upper with args: {:?}", args); + let value = &args.args[0]; + let result = match value { + ColumnarValue::Array(array) => { + let string_array = array.as_string::(); + let iter = ArrayIter::new(string_array); + let result = iter + .map(|string| string.map(|s| s.to_uppercase())) + .collect::(); + Arc::new(result) as ArrayRef + } + _ => return internal_err!("Expected a string argument, got {:?}", value), + }; + Ok(result) + } +} +``` + +We can now transfer the async UDF into the normal scalar using `into_scalar_udf` to register the function with DataFusion so that it can be used in the context of a query. + +```rust +# use arrow::array::{ArrayIter, ArrayRef, AsArray, StringArray}; +# use arrow_schema::DataType; +# use async_trait::async_trait; +# use datafusion::common::error::Result; +# use datafusion::common::internal_err; +# use datafusion::common::types::logical_string; +# use datafusion::config::ConfigOptions; +# use datafusion::logical_expr::async_udf::{ +# AsyncScalarFunctionArgs, AsyncScalarUDFImpl, +# }; +# use datafusion::logical_expr::{ +# ColumnarValue, Signature, TypeSignature, TypeSignatureClass, Volatility, +# }; +# use datafusion::logical_expr_common::signature::Coercion; +# use log::trace; +# use std::any::Any; +# use std::sync::Arc; +# +# #[derive(Debug)] +# pub struct AsyncUpper { +# signature: Signature, +# } +# +# impl Default for AsyncUpper { +# fn default() -> Self { +# Self::new() +# } +# } +# +# impl AsyncUpper { +# pub fn new() -> Self { +# Self { +# signature: Signature::new( +# TypeSignature::Coercible(vec![Coercion::Exact { +# desired_type: TypeSignatureClass::Native(logical_string()), +# }]), +# Volatility::Volatile, +# ), +# } +# } +# } +# +# #[async_trait] +# impl AsyncScalarUDFImpl for AsyncUpper { +# fn as_any(&self) -> &dyn Any { +# self +# } +# +# fn name(&self) -> &str { +# "async_upper" +# } +# +# fn signature(&self) -> &Signature { +# &self.signature +# } +# +# fn return_type(&self, _arg_types: &[DataType]) -> Result { +# Ok(DataType::Utf8) +# } +# +# fn ideal_batch_size(&self) -> Option { +# Some(10) +# } +# +# async fn invoke_async_with_args( +# &self, +# args: AsyncScalarFunctionArgs, +# _option: &ConfigOptions, +# ) -> Result { +# trace!("Invoking async_upper with args: {:?}", args); +# let value = &args.args[0]; +# let result = match value { +# ColumnarValue::Array(array) => { +# let string_array = array.as_string::(); +# let iter = ArrayIter::new(string_array); +# let result = iter +# .map(|string| string.map(|s| s.to_uppercase())) +# .collect::(); +# Arc::new(result) as ArrayRef +# } +# _ => return internal_err!("Expected a string argument, got {:?}", value), +# }; +# Ok(result) +# } +# } +use datafusion::execution::context::SessionContext; +use datafusion::logical_expr::async_udf::AsyncScalarUDF; + +let async_upper = AsyncUpper::new(); +let udf = AsyncScalarUDF::new(Arc::new(async_upper)); +let mut ctx = SessionContext::new(); +ctx.register_udf(udf.into_scalar_udf()); +``` + +After registration, you can use these async UDFs directly in SQL queries, for example: + +```sql +SELECT async_upper('datafusion'); +``` + +For async UDF implementation details, see [`async_udf.rs`](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/async_udf.rs). + [`scalarudf`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/struct.ScalarUDF.html [`create_udf`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/fn.create_udf.html [`process_scalar_func_inputs`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/functions/fn.process_scalar_func_inputs.html @@ -1244,8 +1449,3 @@ async fn main() -> Result<()> { Ok(()) } ``` - -[1]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/simple_udf.rs -[2]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/simple_udwf.rs -[3]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/simple_udaf.rs -[4]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/simple_udtf.rs